diff --git a/.github/workflows/deploy-examples.yml b/.github/workflows/deploy-examples.yml
index c90de6a..fc19378 100644
--- a/.github/workflows/deploy-examples.yml
+++ b/.github/workflows/deploy-examples.yml
@@ -59,7 +59,7 @@ jobs:
run: pip install hatch
- name: Build examples
run: |
- hatch run -- examples:pip install -e packages/aws-durable-execution-sdk-python packages/aws-durable-execution-sdk-python-otel
+ hatch run -- examples:pip install -e packages/aws-durable-execution-sdk-python packages/aws-durable-execution-sdk-python-otel packages/aws-durable-execution-sdk-python-testing
hatch run examples:build
- name: Deploy Lambda function - ${{ matrix.example.name }}
diff --git a/.github/workflows/ecr-release.yml b/.github/workflows/ecr-release.yml
new file mode 100644
index 0000000..55ffa02
--- /dev/null
+++ b/.github/workflows/ecr-release.yml
@@ -0,0 +1,130 @@
+name: Upload Testing SDK Emulator Image
+
+on:
+ release:
+ types: [published]
+
+permissions:
+ contents: read
+ id-token: write
+
+env:
+ package_path: packages/aws-durable-execution-sdk-python-testing
+ aws_region: us-east-1
+ ecr_repository_name: durable-functions/aws-durable-execution-emulator
+
+jobs:
+ build-and-upload-image-to-ecr:
+ runs-on: ubuntu-latest
+ outputs:
+ full_image_arm64: ${{ steps.build-publish.outputs.full_image_arm64 }}
+ full_image_x86_64: ${{ steps.build-publish.outputs.full_image_x86_64 }}
+ ecr_registry_repository: ${{ steps.build-publish.outputs.ecr_registry_repository }}
+ version: ${{ steps.version.outputs.VERSION }}
+ strategy:
+ matrix:
+ include:
+ - arch: x86_64
+ platform: linux/amd64
+ - arch: arm64
+ platform: linux/arm64
+
+ steps:
+ - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
+ with:
+ ref: ${{ github.event.release.tag_name }}
+
+ - name: Set up Python
+ uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
+ with:
+ python-version: "3.13"
+
+ - name: Install Hatch
+ run: python -m pip install --upgrade hatch==1.16.5
+
+ - name: Set up QEMU for multi-platform builds
+ if: matrix.arch == 'arm64'
+ uses: docker/setup-qemu-action@v3
+ with:
+ platforms: arm64
+
+ - name: Build distribution
+ working-directory: ${{ env.package_path }}
+ run: hatch build
+
+ - name: Get version from __about__.py
+ id: version
+ run: |
+ VERSION=$(grep "^__version__" "${{ env.package_path }}/src/aws_durable_execution_sdk_python_testing/__about__.py" | cut -d'"' -f2)
+ echo "VERSION=$VERSION"
+ echo "VERSION=${VERSION}" >> "$GITHUB_OUTPUT"
+
+ - name: Configure AWS Credentials
+ uses: aws-actions/configure-aws-credentials@v4
+ with:
+ role-to-assume: ${{ secrets.ECR_UPLOAD_IAM_ROLE_ARN }}
+ aws-region: ${{ env.aws_region }}
+
+ - name: Login to Amazon ECR
+ id: login-ecr-public
+ uses: aws-actions/amazon-ecr-login@v2
+ with:
+ registry-type: public
+
+ - name: Build, tag, and push image to Amazon ECR
+ id: build-publish
+ shell: bash
+ env:
+ ECR_REGISTRY: ${{ steps.login-ecr-public.outputs.registry }}
+ ECR_REPOSITORY: ${{ env.ecr_repository_name }}
+ PER_ARCH_IMAGE_TAG: v${{ steps.version.outputs.VERSION }}-${{ matrix.arch }}
+ run: |
+ docker build --platform "${{ matrix.platform }}" --provenance false "${{ env.package_path }}" -f "${{ env.package_path }}/Dockerfile" -t "$ECR_REGISTRY/$ECR_REPOSITORY:$PER_ARCH_IMAGE_TAG"
+ docker push "$ECR_REGISTRY/$ECR_REPOSITORY:$PER_ARCH_IMAGE_TAG"
+ echo "ecr_registry_repository=$ECR_REGISTRY/$ECR_REPOSITORY" >> "$GITHUB_OUTPUT"
+ echo "full_image_${{ matrix.arch }}=$ECR_REGISTRY/$ECR_REPOSITORY:$PER_ARCH_IMAGE_TAG" >> "$GITHUB_OUTPUT"
+
+ create-ecr-manifest-per-arch:
+ runs-on: ubuntu-latest
+ needs: [build-and-upload-image-to-ecr]
+ steps:
+ - name: Configure AWS Credentials
+ uses: aws-actions/configure-aws-credentials@v4
+ with:
+ role-to-assume: ${{ secrets.ECR_UPLOAD_IAM_ROLE_ARN }}
+ aws-region: ${{ env.aws_region }}
+
+ - name: Login to Amazon ECR
+ uses: aws-actions/amazon-ecr-login@v2
+ with:
+ registry-type: public
+
+ - name: Create and push explicit version manifest
+ run: |
+ docker manifest create "${{ needs.build-and-upload-image-to-ecr.outputs.ecr_registry_repository }}:v${{ needs.build-and-upload-image-to-ecr.outputs.version }}" \
+ "${{ needs.build-and-upload-image-to-ecr.outputs.full_image_x86_64 }}" \
+ "${{ needs.build-and-upload-image-to-ecr.outputs.full_image_arm64 }}"
+ docker manifest annotate "${{ needs.build-and-upload-image-to-ecr.outputs.ecr_registry_repository }}:v${{ needs.build-and-upload-image-to-ecr.outputs.version }}" \
+ "${{ needs.build-and-upload-image-to-ecr.outputs.full_image_arm64 }}" \
+ --arch arm64 \
+ --os linux
+ docker manifest annotate "${{ needs.build-and-upload-image-to-ecr.outputs.ecr_registry_repository }}:v${{ needs.build-and-upload-image-to-ecr.outputs.version }}" \
+ "${{ needs.build-and-upload-image-to-ecr.outputs.full_image_x86_64 }}" \
+ --arch amd64 \
+ --os linux
+ docker manifest push "${{ needs.build-and-upload-image-to-ecr.outputs.ecr_registry_repository }}:v${{ needs.build-and-upload-image-to-ecr.outputs.version }}"
+
+ - name: Create and push latest manifest
+ run: |
+ docker manifest create "${{ needs.build-and-upload-image-to-ecr.outputs.ecr_registry_repository }}" \
+ "${{ needs.build-and-upload-image-to-ecr.outputs.full_image_arm64 }}" \
+ "${{ needs.build-and-upload-image-to-ecr.outputs.full_image_x86_64 }}"
+ docker manifest annotate "${{ needs.build-and-upload-image-to-ecr.outputs.ecr_registry_repository }}" \
+ "${{ needs.build-and-upload-image-to-ecr.outputs.full_image_arm64 }}" \
+ --arch arm64 \
+ --os linux
+ docker manifest annotate "${{ needs.build-and-upload-image-to-ecr.outputs.ecr_registry_repository }}" \
+ "${{ needs.build-and-upload-image-to-ecr.outputs.full_image_x86_64 }}" \
+ --arch amd64 \
+ --os linux
+ docker manifest push "${{ needs.build-and-upload-image-to-ecr.outputs.ecr_registry_repository }}"
diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml
index 1b3176c..e038913 100644
--- a/.github/workflows/integration-tests.yml
+++ b/.github/workflows/integration-tests.yml
@@ -22,12 +22,6 @@ jobs:
with:
path: language-sdk
- - name: Checkout the latest Testing SDK
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- with:
- repository: aws/aws-durable-execution-sdk-python-testing
- path: language-sdk/packages/testing-sdk
-
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
with:
@@ -40,7 +34,6 @@ jobs:
working-directory: language-sdk
run: |
echo "Running SDK tests..."
- hatch run -- test:pip install -e packages/testing-sdk
hatch run types:check
hatch run test:cov
@@ -61,12 +54,6 @@ jobs:
with:
path: language-sdk
- - name: Checkout the latest Testing SDK
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- with:
- repository: aws/aws-durable-execution-sdk-python-testing
- path: language-sdk/packages/testing-sdk
-
- name: Set up Python 3.13
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
with:
@@ -105,7 +92,6 @@ jobs:
KMS_KEY_ARN: ${{ secrets.KMS_KEY_ARN }}
run: |
echo "Building examples..."
- hatch run -- examples:pip install -e packages/testing-sdk
hatch run examples:build
# Get first integration example for testing
diff --git a/.github/workflows/pypi-publish.yml b/.github/workflows/pypi-publish.yml
index 9e0fa22..42c66f8 100644
--- a/.github/workflows/pypi-publish.yml
+++ b/.github/workflows/pypi-publish.yml
@@ -26,6 +26,8 @@ jobs:
path: packages/aws-durable-execution-sdk-python
- name: aws-durable-execution-sdk-python-otel
path: packages/aws-durable-execution-sdk-python-otel
+ - name: aws-durable-execution-sdk-python-testing
+ path: packages/aws-durable-execution-sdk-python-testing
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@@ -58,6 +60,7 @@ jobs:
package:
- name: aws-durable-execution-sdk-python
- name: aws-durable-execution-sdk-python-otel
+ - name: aws-durable-execution-sdk-python-testing
permissions:
id-token: write
diff --git a/packages/aws-durable-execution-sdk-python-examples/cli.py b/packages/aws-durable-execution-sdk-python-examples/cli.py
index ff2e839..5a5fa71 100755
--- a/packages/aws-durable-execution-sdk-python-examples/cli.py
+++ b/packages/aws-durable-execution-sdk-python-examples/cli.py
@@ -46,24 +46,12 @@ def build_examples():
shutil.rmtree(build_dir)
build_dir.mkdir()
- # Copy testing library from current environment
- try:
- import aws_durable_execution_sdk_python_testing
-
- sdk_path = Path(aws_durable_execution_sdk_python_testing.__file__).parent
- logger.info("Copying SDK from %s", sdk_path)
- shutil.copytree(
- sdk_path, build_dir / "aws_durable_execution_sdk_python_testing"
- )
- except (ImportError, OSError):
- logger.exception("Failed to copy testing library")
- return False
-
# Install local packages so their runtime dependencies are included in
# the Lambda deployment package.
runtime_packages = [
packages_dir / "aws-durable-execution-sdk-python",
packages_dir / "aws-durable-execution-sdk-python-otel",
+ packages_dir / "aws-durable-execution-sdk-python-testing",
]
try:
subprocess.run(
diff --git a/packages/aws-durable-execution-sdk-python-testing/Dockerfile b/packages/aws-durable-execution-sdk-python-testing/Dockerfile
new file mode 100644
index 0000000..30d3a7f
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/Dockerfile
@@ -0,0 +1,20 @@
+FROM python:3.13-slim
+
+# Copy and install the wheel
+COPY dist/*.whl /tmp/
+RUN pip install --no-cache-dir /tmp/*.whl && rm -rf /tmp/*.whl
+
+# AWS credentials (required for boto3)
+ENV AWS_ACCESS_KEY_ID=foo \
+ AWS_SECRET_ACCESS_KEY=bar \
+ AWS_DEFAULT_REGION=us-east-1
+
+EXPOSE 9014
+
+CMD ["dex-local-runner", "start-server", \
+ "--host", "0.0.0.0", \
+ "--port", "9014", \
+ "--log-level", "DEBUG", \
+ "--lambda-endpoint", "http://host.docker.internal:3001", \
+ "--store-type", "sqlite", \
+ "--store-path", "/tmp/.durable-executions-local/durable-executions.db"]
diff --git a/packages/aws-durable-execution-sdk-python-testing/LICENSE b/packages/aws-durable-execution-sdk-python-testing/LICENSE
new file mode 100644
index 0000000..67db858
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/LICENSE
@@ -0,0 +1,175 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
diff --git a/packages/aws-durable-execution-sdk-python-testing/NOTICE b/packages/aws-durable-execution-sdk-python-testing/NOTICE
new file mode 100644
index 0000000..616fc58
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/NOTICE
@@ -0,0 +1 @@
+Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
diff --git a/packages/aws-durable-execution-sdk-python-testing/README.md b/packages/aws-durable-execution-sdk-python-testing/README.md
new file mode 100644
index 0000000..223f07c
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/README.md
@@ -0,0 +1,199 @@
+# AWS Durable Execution Testing SDK for Python
+
+[](https://pypi.org/project/aws-durable-execution-sdk-python-testing)
+[](https://pypi.org/project/aws-durable-execution-sdk-python-testing)
+
+
+[](https://scorecard.dev/viewer/?uri=github.com/aws/aws-durable-execution-sdk-python-testing)
+
+-----
+
+## Table of Contents
+
+- [Installation](#installation)
+- [Quick Start](#quick-start)
+- [Architecture](#architecture)
+- [Documentation](#documentation)
+- [Developer Guide](#developers)
+- [License](#license)
+
+## Installation
+
+```console
+pip install aws-durable-execution-sdk-python-testing
+```
+
+## Overview
+
+Use the AWS Durable Execution Testing SDK for Python to test your Python durable functions locally.
+
+The test framework contains a local runner, so you can run and test your durable function locally
+before you deploy it.
+
+## Quick Start
+
+### A durable function under test
+
+```python
+from aws_durable_execution_sdk_python.context import (
+ DurableContext,
+ durable_step,
+ durable_with_child_context,
+)
+from aws_durable_execution_sdk_python.execution import durable_execution
+from aws_durable_execution_sdk_python.config import Duration
+
+
+@durable_step
+def one(a: int, b: int) -> str:
+ return f"{a} {b}"
+
+@durable_step
+def two_1(a: int, b: int) -> str:
+ return f"{a} {b}"
+
+@durable_step
+def two_2(a: int, b: int) -> str:
+ return f"{b} {a}"
+
+@durable_with_child_context
+def two(ctx: DurableContext, a: int, b: int) -> str:
+ two_1_result: str = ctx.step(two_1(a, b))
+ two_2_result: str = ctx.step(two_2(a, b))
+ return f"{two_1_result} {two_2_result}"
+
+@durable_step
+def three(a: int, b: int) -> str:
+ return f"{a} {b}"
+
+@durable_execution
+def function_under_test(event: Any, context: DurableContext) -> list[str]:
+ results: list[str] = []
+
+ result_one: str = context.step(one(1, 2))
+ results.append(result_one)
+
+ context.wait(Duration.from_seconds(1))
+
+ result_two: str = context.run_in_child_context(two(3, 4))
+ results.append(result_two)
+
+ result_three: str = context.step(three(5, 6))
+ results.append(result_three)
+
+ return results
+```
+
+### Your test code
+
+```python
+from aws_durable_execution_sdk_python.execution import InvocationStatus
+from aws_durable_execution_sdk_python_testing.runner import (
+ ContextOperation,
+ DurableFunctionTestResult,
+ DurableFunctionTestRunner,
+ StepOperation,
+)
+
+def test_my_durable_functions():
+ with DurableFunctionTestRunner(handler=function_under_test) as runner:
+ result: DurableFunctionTestResult = runner.run(input="input str", timeout=10)
+
+ assert result.status is InvocationStatus.SUCCEEDED
+ assert result.result == '["1 2", "3 4 4 3", "5 6"]'
+
+ one_result: StepOperation = result.get_step("one")
+ assert one_result.result == '"1 2"'
+
+ two_result: ContextOperation = result.get_context("two")
+ assert two_result.result == '"3 4 4 3"'
+
+ three_result: StepOperation = result.get_step("three")
+ assert three_result.result == '"5 6"'
+```
+## Architecture
+
+
+## Event Flow
+
+
+1. **DurableTestRunner** starts execution via **Executor**
+2. **Executor** creates **Execution** and schedules initial invocation
+3. During execution, checkpoints are processed by **CheckpointProcessor**
+4. **Individual Processors** transform operation updates and may trigger events
+5. **ExecutionNotifier** broadcasts events to **Executor** (observer)
+6. **Executor** updates **Execution** state based on events
+7. **Execution** completion triggers final event notifications
+8. **DurableTestRunner** run() blocks until it receives completion event, and then returns `DurableFunctionTestResult`.
+
+## Major Components
+
+### Core Execution Flow
+- **DurableTestRunner** - Main entry point that orchestrates test execution
+- **Executor** - Manages execution lifecycle. Mutates Execution.
+- **Execution** - Represents the state and operations of a single durable execution
+
+### Service Client Integration
+- **InMemoryServiceClient** - Replaces AWS Lambda service client for local testing. Injected into SDK via `DurableExecutionInvocationInputWithClient`
+
+### Checkpoint Processing Pipeline
+- **CheckpointProcessor** - Orchestrates operation transformations and validation
+- **Individual Validators** - Validate operation updates and state transitions
+- **Individual Processors** - Transform operation updates into operations (step, wait, callback, context, execution)
+
+### Execution status changes (Observer Pattern)
+- **ExecutionNotifier** - Notifies observers of execution events
+- **ExecutionObserver** - Interface for receiving execution lifecycle events
+- **Executor** implements `ExecutionObserver` to handle completion events
+
+## Component Relationships
+
+### 1. DurableTestRunner → Executor → Execution
+- **DurableTestRunner** serves as the main API entry point and sets up all components
+- **Executor** manages the execution lifecycle, handling invocations and state transitions
+- **Execution** maintains the state of operations and completion status
+
+### 2. Service Client Injection
+- **DurableTestRunner** creates **InMemoryServiceClient** with **CheckpointProcessor**
+- **InProcessInvoker** injects the service client into SDK via `DurableExecutionInvocationInputWithClient`
+- When durable functions call checkpoint operations, they're intercepted by **InMemoryServiceClient**
+- **InMemoryServiceClient** delegates to **CheckpointProcessor** for local processing
+
+### 3. CheckpointProcessor → Individual Validators → Individual Processors
+- **CheckpointProcessor** orchestrates the checkpoint processing pipeline
+- **Individual Validators** (CheckpointValidator, TransitionsValidator, and operation-specific validators) ensure operation updates are valid
+- **Individual Processors** (StepProcessor, WaitProcessor, etc.) transform `OperationUpdate` into `Operation`
+
+### 4. Observer Pattern Flow
+The observer pattern enables loose coupling between checkpoint processing and execution management:
+
+1. **CheckpointProcessor** processes operation updates
+2. **Individual Processors** detect state changes (completion, failures, timer scheduling)
+3. **ExecutionNotifier** broadcasts events to registered observers
+4. **Executor** (as ExecutionObserver) receives notifications and updates **Execution** state
+5. **Execution** complete_* methods finalize the execution state
+
+
+## Documentation
+
+### Error Handling
+
+The testing framework implements AWS-compliant error responses that match the exact format expected by boto3 and AWS services. For detailed information about error response formats, exception types, and troubleshooting, see:
+
+- [Error Response Documentation](docs/error-responses.md)
+
+Key features:
+- **AWS-compliant JSON format**: Matches boto3 expectations exactly
+- **Smithy model compliance**: Field names follow AWS Smithy definitions
+- **HTTP status code mapping**: Standard AWS service status codes
+- **Boto3 compatibility**: Seamless integration with boto3 error handling
+
+## Developers
+Please see [CONTRIBUTING.md](CONTRIBUTING.md). It contains the testing guide, sample commands and instructions
+for how to contribute to this package.
+
+tldr; use `hatch` and it will manage virtual envs and dependencies for you, so you don't have to do it manually.
+
+## License
+
+This project is licensed under the [Apache-2.0 License](LICENSE).
diff --git a/packages/aws-durable-execution-sdk-python-testing/assets/dar-python-test-framework-architecture.svg b/packages/aws-durable-execution-sdk-python-testing/assets/dar-python-test-framework-architecture.svg
new file mode 100644
index 0000000..0d8fd6d
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/assets/dar-python-test-framework-architecture.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/packages/aws-durable-execution-sdk-python-testing/assets/dar-python-test-framework-event-flow.svg b/packages/aws-durable-execution-sdk-python-testing/assets/dar-python-test-framework-event-flow.svg
new file mode 100644
index 0000000..fbd55ab
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/assets/dar-python-test-framework-event-flow.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/packages/aws-durable-execution-sdk-python-testing/docs/error-responses.md b/packages/aws-durable-execution-sdk-python-testing/docs/error-responses.md
new file mode 100644
index 0000000..44a3567
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/docs/error-responses.md
@@ -0,0 +1,311 @@
+# AWS-Compliant Error Response Documentation
+
+This document describes the AWS-compliant error response format used by the Durable Executions Testing Library.
+
+## Overview
+
+The testing library implements AWS-compliant error responses that match the exact format expected by boto3 and AWS services. All error responses follow Smithy model definitions for structure and field naming.
+
+## Error Response Format
+
+### HTTP Response Structure
+
+All error responses use the following HTTP structure:
+
+```
+HTTP/1.1
+Content-Type: application/json
+
+
+```
+
+### JSON Body Format
+
+The JSON body format varies by exception type based on Smithy model definitions:
+
+#### Standard Format (Most Exceptions)
+
+```json
+{
+ "Type": "ExceptionName",
+ "message": "Detailed error message"
+}
+```
+
+**Used by:**
+- `InvalidParameterValueException`
+- `CallbackTimeoutException`
+
+#### Capital Message Format
+
+```json
+{
+ "Type": "ExceptionName",
+ "Message": "Detailed error message"
+}
+```
+
+**Used by:**
+- `ResourceNotFoundException`
+- `ServiceException`
+
+#### Special Format (ExecutionAlreadyStartedException)
+
+```json
+{
+ "message": "Detailed error message",
+ "DurableExecutionArn": "arn:aws:states:region:account:execution:name"
+}
+```
+
+**Note:** This exception has no "Type" field per AWS Smithy definition.
+
+## Exception Types and Examples
+
+### InvalidParameterValueException (HTTP 400)
+
+**When:** Invalid parameter values are provided to API operations.
+
+**Example Response:**
+```json
+{
+ "Type": "InvalidParameterValueException",
+ "message": "The parameter 'executionName' cannot be empty"
+}
+```
+
+**Common Causes:**
+- Empty or null required parameters
+- Invalid parameter formats
+- Parameter values outside allowed ranges
+
+### ResourceNotFoundException (HTTP 404)
+
+**When:** Requested resource does not exist.
+
+**Example Response:**
+```json
+{
+ "Type": "ResourceNotFoundException",
+ "Message": "Execution with ID 'exec-123' not found"
+}
+```
+
+**Common Causes:**
+- Non-existent execution IDs
+- Deleted or expired resources
+- Incorrect resource identifiers
+
+### ServiceException (HTTP 500)
+
+**When:** Internal service errors occur.
+
+**Example Response:**
+```json
+{
+ "Type": "ServiceException",
+ "Message": "An internal error occurred while processing the request"
+}
+```
+
+**Common Causes:**
+- Unexpected internal errors
+- System unavailability
+- Configuration issues
+
+### CallbackTimeoutException (HTTP 408)
+
+**When:** Callback operations timeout.
+
+**Example Response:**
+```json
+{
+ "Type": "CallbackTimeoutException",
+ "message": "Callback operation timed out after 30 seconds"
+}
+```
+
+**Common Causes:**
+- Callback not received within timeout period
+- Network connectivity issues
+- Client-side delays
+
+### ExecutionAlreadyStartedException (HTTP 409)
+
+**When:** Attempting to start an execution that is already running.
+
+**Example Response:**
+```json
+{
+ "message": "Execution is already started",
+ "DurableExecutionArn": "arn:aws:states:us-east-1:123456789012:execution:MyExecution:abc123"
+}
+```
+
+**Common Causes:**
+- Duplicate start execution requests
+- Race conditions in execution management
+- Client retry logic issues
+
+## HTTP Status Code Mapping
+
+| Exception | HTTP Status | Description |
+|-----------|-------------|-------------|
+| InvalidParameterValueException | 400 | Bad Request - Invalid input parameters |
+| ResourceNotFoundException | 404 | Not Found - Resource does not exist |
+| CallbackTimeoutException | 408 | Request Timeout - Operation timed out |
+| ExecutionAlreadyStartedException | 409 | Conflict - Resource already exists |
+| ServiceException | 500 | Internal Server Error - System error |
+
+## Field Name Conventions
+
+Field names strictly follow Smithy model definitions:
+
+- **lowercase "message"**: InvalidParameterValueException, CallbackTimeoutException, ExecutionAlreadyStartedException
+- **capital "Message"**: ResourceNotFoundException, ServiceException
+- **"Type"**: Present in all exceptions except ExecutionAlreadyStartedException
+- **"DurableExecutionArn"**: Only in ExecutionAlreadyStartedException
+
+## Boto3 Compatibility
+
+All error responses are designed for boto3 compatibility:
+
+### Client Error Handling
+
+```python
+import boto3
+from botocore.exceptions import ClientError
+
+try:
+ # API call that might fail
+ response = client.some_operation()
+except ClientError as e:
+ error_code = e.response['Error']['Code']
+ error_message = e.response['Error']['Message']
+
+ if error_code == 'InvalidParameterValueException':
+ # Handle invalid parameter
+ pass
+ elif error_code == 'ResourceNotFoundException':
+ # Handle not found
+ pass
+```
+
+### Error Response Structure
+
+The testing library's error responses match the structure boto3 expects:
+
+```python
+# What boto3 receives
+{
+ 'Error': {
+ 'Code': 'InvalidParameterValueException',
+ 'Message': 'The parameter cannot be empty'
+ },
+ 'ResponseMetadata': {
+ 'HTTPStatusCode': 400,
+ 'HTTPHeaders': {...}
+ }
+}
+```
+
+## Migration from Legacy Format
+
+### Old Format (Deprecated)
+```json
+{
+ "error": {
+ "type": "InvalidParameterError",
+ "message": "Error message",
+ "code": "INVALID_PARAMETER",
+ "requestId": "req-123"
+ }
+}
+```
+
+### New AWS-Compliant Format
+```json
+{
+ "Type": "InvalidParameterValueException",
+ "message": "Error message"
+}
+```
+
+### Key Changes
+1. **No wrapper object**: Direct JSON structure, no "error" wrapper
+2. **AWS exception names**: Use official AWS exception names
+3. **Smithy field names**: Follow exact Smithy model field naming
+4. **Simplified structure**: Only essential fields per AWS standards
+5. **Consistent HTTP codes**: Match AWS service status codes
+
+## Testing Error Responses
+
+### Unit Testing
+
+```python
+from aws_durable_execution_sdk_python_testing.exceptions import InvalidParameterValueException
+from aws_durable_execution_sdk_python_testing.web.models import HTTPResponse
+
+def test_error_response():
+ exception = InvalidParameterValueException("Test error")
+ response = HTTPResponse.create_error_from_exception(exception)
+
+ assert response.status_code == 400
+ assert response.body == {
+ "Type": "InvalidParameterValueException",
+ "message": "Test error"
+ }
+```
+
+### Integration Testing
+
+```python
+import requests
+
+def test_api_error_response():
+ response = requests.post('http://localhost:8080/invalid-endpoint')
+
+ assert response.status_code == 404
+ error_data = response.json()
+ assert error_data['Type'] == 'ResourceNotFoundException'
+ assert 'Message' in error_data
+```
+
+## Best Practices
+
+### Error Message Guidelines
+
+1. **Be specific**: Include relevant details about what went wrong
+2. **Be actionable**: Suggest how to fix the issue when possible
+3. **Be consistent**: Use consistent terminology across similar errors
+4. **Avoid sensitive data**: Don't include passwords, tokens, or PII
+
+### Exception Selection
+
+1. **InvalidParameterValueException**: For all input validation errors
+2. **ResourceNotFoundException**: When requested resources don't exist
+3. **ServiceException**: For unexpected internal errors only
+4. **CallbackTimeoutException**: Specifically for callback timeouts
+5. **ExecutionAlreadyStartedException**: Only for duplicate execution starts
+
+### HTTP Status Codes
+
+1. **Use standard codes**: Follow HTTP and AWS conventions
+2. **Be consistent**: Same error types should use same status codes
+3. **Client vs Server**: 4xx for client errors, 5xx for server errors
+
+## Troubleshooting
+
+### Common Issues
+
+1. **Wrong field names**: Ensure "message" vs "Message" matches exception type
+2. **Missing Type field**: All exceptions except ExecutionAlreadyStartedException need "Type"
+3. **Wrong status codes**: Verify HTTP status matches exception type
+4. **JSON serialization**: Ensure all fields are JSON-serializable
+
+### Debugging Tips
+
+1. **Check exception type**: Verify you're using the correct AWS exception
+2. **Validate JSON structure**: Use `to_dict()` to see exact output
+3. **Test with boto3**: Verify compatibility with actual boto3 client
+4. **Compare with AWS**: Match format with real AWS service responses
\ No newline at end of file
diff --git a/packages/aws-durable-execution-sdk-python-testing/pyproject.toml b/packages/aws-durable-execution-sdk-python-testing/pyproject.toml
new file mode 100644
index 0000000..5153fea
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/pyproject.toml
@@ -0,0 +1,132 @@
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
+[project]
+name = "aws-durable-execution-sdk-python-testing"
+dynamic = ["version"]
+description = 'AWS Durable Execution Testing SDK for Python'
+readme = "README.md"
+requires-python = ">=3.11"
+license = "Apache-2.0"
+keywords = []
+authors = [{ name = "AWS durable-execution-dev", email = "durable-execution-dev@amazon.com" }]
+classifiers = [
+ "Development Status :: 4 - Beta",
+ "Programming Language :: Python",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
+ "Programming Language :: Python :: 3.14",
+ "Programming Language :: Python :: Implementation :: CPython",
+ "Programming Language :: Python :: Implementation :: PyPy",
+]
+dependencies = [
+ "boto3>=1.42.1",
+ "aws-durable-execution-sdk-python>=1.0.0",
+]
+
+[project.urls]
+Documentation = "https://github.com/aws/aws-durable-execution-sdk-python-testing#readme"
+Issues = "https://github.com/aws/aws-durable-execution-sdk-python-testing/issues"
+Source = "https://github.com/aws/aws-durable-execution-sdk-python-testing"
+
+[project.scripts]
+dex-local-runner = "aws_durable_execution_sdk_python_testing.cli:main"
+
+[tool.hatch.build.targets.sdist]
+packages = ["src/aws_durable_execution_sdk_python_testing"]
+
+[tool.hatch.build.targets.wheel]
+packages = ["src/aws_durable_execution_sdk_python_testing"]
+
+[tool.hatch.metadata]
+allow-direct-references = true
+
+[tool.hatch.version]
+path = "src/aws_durable_execution_sdk_python_testing/__about__.py"
+
+# [tool.hatch.envs.default]
+# dependencies=["pytest"]
+
+# [tool.hatch.envs.default.scripts]
+# test="pytest"
+
+[tool.hatch.envs.test]
+dependencies = [
+ "coverage[toml]",
+ "pytest",
+ "pytest-cov",
+ "ruff",
+ "aws-durable-execution-sdk-python>=1.0.0",
+]
+
+[tool.hatch.envs.test.scripts]
+test = "pytest tests/ -v"
+cov = "pytest --cov-report=term-missing --cov-config=pyproject.toml --cov=src/aws_durable_execution_sdk_python_testing --cov-fail-under=95"
+
+[tool.hatch.envs.types]
+extra-dependencies = ["mypy>=1.0.0", "pytest"]
+[tool.hatch.envs.types.scripts]
+check = "mypy --install-types --non-interactive {args:src/aws_durable_execution_sdk_python_testing tests}"
+
+[tool.coverage.run]
+source_pkgs = ["aws_durable_execution_sdk_python_testing"]
+branch = true
+parallel = true
+omit = ["src/aws_durable_execution_sdk_python_testing/__about__.py", "tests/*"]
+
+[tool.coverage.paths]
+aws_durable_execution_sdk_python_testing = [
+ "src/aws_durable_execution_sdk_python_testing",
+ "*/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing",
+]
+tests = ["tests", "*/aws-durable-execution-sdk-python-testing/tests"]
+
+[tool.coverage.report]
+exclude_lines = [
+ "no cov",
+ "if __name__ == .__main__.:",
+ "if TYPE_CHECKING:",
+ "@abstractmethod",
+]
+
+[tool.ruff]
+line-length = 88
+target-version = "py311"
+
+[tool.ruff.lint]
+preview = false
+select = ["TID252"] # Enforce absolute imports (ban relative imports)
+
+[tool.ruff.lint.isort]
+known-first-party = ["aws_durable_execution_sdk_python_testing"]
+force-single-line = false
+lines-after-imports = 2
+
+[tool.ruff.lint.per-file-ignores]
+"tests/**" = [
+ "ARG001",
+ "ARG002",
+ "ARG005",
+ "S101",
+ "PLR2004",
+ "SIM117",
+ "TRY301",
+]
+"src/aws_durable_execution_sdk_python_testing/invoker.py" = [
+ "A002", # Argument `input` is shadowing a Python builtin
+]
+
+[tool.pytest.ini_options]
+# Declare custom markers to avoid warnings with --strict-markers
+markers = [
+ # Used for test selection with -m example
+ "example: marks tests as example tests (deselect with '-m \"not example\"')",
+ # Used for configuration - passes handler and lambda_function_name to durable_runner fixture
+ "durable_execution: marks tests that use the durable_runner fixture (not used for test selection)",
+]
+# Default test discovery paths
+testpaths = ["tests"]
+# Default options for all test runs
+addopts = "-v --strict-markers"
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/__about__.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/__about__.py
new file mode 100644
index 0000000..698e248
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/__about__.py
@@ -0,0 +1,4 @@
+# SPDX-FileCopyrightText: 2025-present Amazon.com, Inc. or its affiliates.
+#
+# SPDX-License-Identifier: Apache-2.0
+__version__ = "1.2.1"
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/__init__.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/__init__.py
new file mode 100644
index 0000000..c25db1c
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/__init__.py
@@ -0,0 +1,23 @@
+"""DurableExecutionsPythonTestingLibrary module."""
+
+from aws_durable_execution_sdk_python_testing.runner import (
+ DurableChildContextTestRunner,
+ DurableFunctionCloudTestRunner,
+ DurableFunctionTestResult,
+ DurableFunctionTestRunner,
+ WebRunner,
+ WebRunnerConfig,
+)
+
+from aws_durable_execution_sdk_python_testing.__about__ import __version__
+
+
+__all__ = [
+ "DurableChildContextTestRunner",
+ "DurableFunctionCloudTestRunner",
+ "DurableFunctionTestResult",
+ "DurableFunctionTestRunner",
+ "WebRunner",
+ "WebRunnerConfig",
+ "__version__",
+]
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/__init__.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/__init__.py
new file mode 100644
index 0000000..8128bfb
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/__init__.py
@@ -0,0 +1 @@
+"""Checkpoint processing module for handling OperationUpdate transformations."""
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processor.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processor.py
new file mode 100644
index 0000000..04b991c
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processor.py
@@ -0,0 +1,101 @@
+"""Main checkpoint processor that orchestrates operation transformations."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from aws_durable_execution_sdk_python.lambda_service import (
+ CheckpointOutput,
+ CheckpointUpdatedExecutionState,
+ OperationUpdate,
+ StateOutput,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.transformer import (
+ OperationTransformer,
+)
+from aws_durable_execution_sdk_python_testing.checkpoint.validators.checkpoint import (
+ CheckpointValidator,
+)
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier
+from aws_durable_execution_sdk_python_testing.token import CheckpointToken
+
+
+if TYPE_CHECKING:
+ from aws_durable_execution_sdk_python_testing.execution import Execution
+ from aws_durable_execution_sdk_python_testing.scheduler import Scheduler
+ from aws_durable_execution_sdk_python_testing.stores.base import ExecutionStore
+
+
+class CheckpointProcessor:
+ """Handle OperationUpdate transformations and execution state updates."""
+
+ def __init__(self, store: ExecutionStore, scheduler: Scheduler):
+ self._store = store
+ self._scheduler = scheduler
+ self._notifier = ExecutionNotifier()
+ self._transformer = OperationTransformer()
+
+ def add_execution_observer(self, observer) -> None:
+ """Add observer for execution events."""
+ self._notifier.add_observer(observer)
+
+ def process_checkpoint(
+ self,
+ checkpoint_token: str,
+ updates: list[OperationUpdate],
+ client_token: str | None, # noqa: ARG002
+ ) -> CheckpointOutput:
+ """Process checkpoint updates and return result with updated execution state."""
+ # 1. Get current execution state
+ token: CheckpointToken = CheckpointToken.from_str(checkpoint_token)
+ execution: Execution = self._store.load(token.execution_arn)
+
+ # 2. Validate checkpoint token
+ if execution.is_complete or token.token_sequence != execution.token_sequence:
+ msg: str = "Invalid checkpoint token"
+
+ raise InvalidParameterValueException(msg)
+
+ # 3. Validate all updates, state transitions are valid, sizes etc.
+ CheckpointValidator.validate_input(updates, execution)
+
+ # 4. Transform OperationUpdate -> Operation and schedule future replays
+ updated_operations, all_updates = self._transformer.process_updates(
+ updates=updates,
+ current_operations=execution.operations,
+ notifier=self._notifier,
+ execution_arn=token.execution_arn,
+ )
+
+ # 5. Generate a new checkpoint token and save updated operations
+ new_checkpoint_token = execution.get_new_checkpoint_token()
+ execution.operations = updated_operations
+ execution.updates.extend(all_updates)
+ self._store.update(execution)
+
+ # 6. Return checkpoint result
+ return CheckpointOutput(
+ checkpoint_token=new_checkpoint_token,
+ new_execution_state=CheckpointUpdatedExecutionState(
+ operations=execution.get_navigable_operations(), next_marker=None
+ ),
+ )
+
+ def get_execution_state(
+ self,
+ checkpoint_token: str,
+ next_marker: str, # noqa: ARG002
+ max_items: int = 1000, # noqa: ARG002
+ ) -> StateOutput:
+ """Get current execution state."""
+ token: CheckpointToken = CheckpointToken.from_str(checkpoint_token)
+ execution: Execution = self._store.load(token.execution_arn)
+
+ # TODO: paging when size or max
+ return StateOutput(
+ operations=execution.get_navigable_operations(), next_marker=None
+ )
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/__init__.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/__init__.py
new file mode 100644
index 0000000..0e52f40
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/__init__.py
@@ -0,0 +1 @@
+"""Checkpoint processors module."""
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/base.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/base.py
new file mode 100644
index 0000000..56933d5
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/base.py
@@ -0,0 +1,199 @@
+"""Base processor class for operation transformations."""
+
+from __future__ import annotations
+
+import datetime
+from datetime import timedelta
+from typing import TYPE_CHECKING
+
+from aws_durable_execution_sdk_python.lambda_service import (
+ CallbackDetails,
+ ChainedInvokeDetails,
+ ContextDetails,
+ ExecutionDetails,
+ Operation,
+ OperationAction,
+ OperationStatus,
+ OperationType,
+ OperationUpdate,
+ StepDetails,
+ WaitDetails,
+)
+
+
+if TYPE_CHECKING:
+ from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier
+
+
+class OperationProcessor:
+ """Base class for processing OperationUpdate to Operation transformations."""
+
+ def process(
+ self,
+ update: OperationUpdate,
+ current_op: Operation | None,
+ notifier: ExecutionNotifier,
+ execution_arn: str,
+ ) -> Operation | None:
+ """Process an operation update and return the transformed operation."""
+ raise NotImplementedError
+
+ def _get_start_time(
+ self, current_operation: Operation | None
+ ) -> datetime.datetime | None:
+ start_time: datetime.datetime | None = (
+ current_operation.start_timestamp
+ if current_operation
+ else datetime.datetime.now(tz=datetime.UTC)
+ )
+ return start_time
+
+ def _get_end_time(
+ self, current_operation: Operation | None, status: OperationStatus
+ ) -> datetime.datetime | None:
+ """Get end timestamp for operation based on current state and status."""
+ if current_operation and current_operation.end_timestamp:
+ return current_operation.end_timestamp
+ if status in {
+ OperationStatus.SUCCEEDED,
+ OperationStatus.FAILED,
+ OperationStatus.CANCELLED,
+ OperationStatus.TIMED_OUT,
+ OperationStatus.STOPPED,
+ }:
+ return datetime.datetime.now(tz=datetime.UTC)
+ return None
+
+ def _create_execution_details(
+ self, update: OperationUpdate
+ ) -> ExecutionDetails | None:
+ """Create ExecutionDetails from OperationUpdate."""
+ return (
+ ExecutionDetails(input_payload=update.payload)
+ if update.operation_type == OperationType.EXECUTION
+ else None
+ )
+
+ def _create_context_details(self, update: OperationUpdate) -> ContextDetails | None:
+ """Create ContextDetails from OperationUpdate."""
+ return (
+ ContextDetails(
+ result=update.payload,
+ error=update.error,
+ replay_children=update.context_options.replay_children
+ if update.context_options
+ else False,
+ )
+ if update.operation_type == OperationType.CONTEXT
+ else None
+ )
+
+ def _create_step_details(
+ self,
+ update: OperationUpdate,
+ current_operation: Operation | None = None,
+ ) -> StepDetails | None:
+ """Create StepDetails from OperationUpdate.
+
+ Automatically increments attempt count for RETRY, SUCCEED, and FAIL actions.
+ """
+
+ attempt: int = 0
+ next_attempt_timestamp: datetime.datetime | None = None
+
+ if update.operation_type is OperationType.STEP:
+ if current_operation and current_operation.step_details:
+ attempt = current_operation.step_details.attempt
+ next_attempt_timestamp = (
+ current_operation.step_details.next_attempt_timestamp
+ )
+ # Increment attempt for RETRY, SUCCEED, and FAIL actions
+ if update.action in {
+ OperationAction.RETRY,
+ OperationAction.SUCCEED,
+ OperationAction.FAIL,
+ }:
+ attempt += 1
+ return StepDetails(
+ attempt=attempt,
+ next_attempt_timestamp=next_attempt_timestamp,
+ result=update.payload,
+ error=update.error,
+ )
+
+ return None
+
+ def _create_callback_details(
+ self, update: OperationUpdate
+ ) -> CallbackDetails | None:
+ """Create CallbackDetails from OperationUpdate."""
+ return (
+ CallbackDetails(
+ callback_id="placeholder", result=update.payload, error=update.error
+ )
+ if update.operation_type == OperationType.CALLBACK
+ else None
+ )
+
+ def _create_invoke_details(
+ self, update: OperationUpdate
+ ) -> ChainedInvokeDetails | None:
+ """Create ChainedInvokeDetails from OperationUpdate."""
+ if (
+ update.operation_type == OperationType.CHAINED_INVOKE
+ and update.chained_invoke_options
+ ):
+ return ChainedInvokeDetails(result=update.payload, error=update.error)
+ return None
+
+ def _translate_update_to_operation(
+ self,
+ update: OperationUpdate,
+ current_operation: Operation | None,
+ status: OperationStatus,
+ ) -> Operation:
+ """Transform OperationUpdate to Operation, always creating new Operation."""
+ start_time: datetime.datetime | None = self._get_start_time(current_operation)
+ end_time: datetime.datetime | None = self._get_end_time(
+ current_operation, status
+ )
+
+ execution_details = self._create_execution_details(update)
+ context_details = self._create_context_details(update)
+ step_details = self._create_step_details(update, current_operation)
+ callback_details = self._create_callback_details(update)
+ invoke_details = self._create_invoke_details(update)
+ wait_details = self._create_wait_details(update, current_operation)
+
+ return Operation(
+ operation_id=update.operation_id,
+ parent_id=update.parent_id,
+ name=update.name,
+ start_timestamp=start_time,
+ end_timestamp=end_time,
+ operation_type=update.operation_type,
+ status=status,
+ sub_type=update.sub_type,
+ execution_details=execution_details,
+ context_details=context_details,
+ step_details=step_details,
+ callback_details=callback_details,
+ chained_invoke_details=invoke_details,
+ wait_details=wait_details,
+ )
+
+ def _create_wait_details(
+ self, update: OperationUpdate, current_operation: Operation | None
+ ) -> WaitDetails | None:
+ """Create WaitDetails from OperationUpdate."""
+ if update.operation_type == OperationType.WAIT and update.wait_options:
+ if current_operation and current_operation.wait_details:
+ scheduled_end_timestamp = (
+ current_operation.wait_details.scheduled_end_timestamp
+ )
+ else:
+ scheduled_end_timestamp = datetime.datetime.now(
+ tz=datetime.UTC
+ ) + timedelta(seconds=update.wait_options.wait_seconds)
+ return WaitDetails(scheduled_end_timestamp=scheduled_end_timestamp)
+ return None
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/callback.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/callback.py
new file mode 100644
index 0000000..48c2f01
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/callback.py
@@ -0,0 +1,89 @@
+"""Callback operation processor for handling CALLBACK operation updates."""
+
+from __future__ import annotations
+
+import datetime
+from typing import TYPE_CHECKING
+
+from aws_durable_execution_sdk_python.lambda_service import (
+ Operation,
+ OperationAction,
+ OperationStatus,
+ OperationUpdate,
+ CallbackDetails,
+ OperationType,
+ CallbackOptions,
+)
+from aws_durable_execution_sdk_python_testing.checkpoint.processors.base import (
+ OperationProcessor,
+)
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+from aws_durable_execution_sdk_python_testing.token import CallbackToken
+
+if TYPE_CHECKING:
+ from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier
+
+
+class CallbackProcessor(OperationProcessor):
+ """Processes CALLBACK operation updates with activity scheduling."""
+
+ def process(
+ self,
+ update: OperationUpdate,
+ current_op: Operation | None,
+ notifier: ExecutionNotifier, # noqa: ARG002
+ execution_arn: str, # noqa: ARG002
+ ) -> Operation:
+ """Process CALLBACK operation update with scheduler integration for activities."""
+ match update.action:
+ case OperationAction.START:
+ callback_token: CallbackToken = CallbackToken(
+ execution_arn=execution_arn,
+ operation_id=update.operation_id,
+ )
+
+ callback_id: str = callback_token.to_str()
+
+ callback_details: CallbackDetails | None = (
+ CallbackDetails(
+ callback_id=callback_id,
+ result=update.payload,
+ error=update.error,
+ )
+ if update.operation_type == OperationType.CALLBACK
+ else None
+ )
+
+ status: OperationStatus = OperationStatus.STARTED
+
+ start_time: datetime.datetime | None = self._get_start_time(current_op)
+
+ end_time: datetime.datetime | None = self._get_end_time(
+ current_op, status
+ )
+
+ operation: Operation = Operation(
+ operation_id=update.operation_id,
+ parent_id=update.parent_id,
+ name=update.name,
+ start_timestamp=start_time,
+ end_timestamp=end_time,
+ operation_type=update.operation_type,
+ status=status,
+ sub_type=update.sub_type,
+ callback_details=callback_details,
+ )
+ callback_options: CallbackOptions | None = update.callback_options
+
+ notifier.notify_callback_created(
+ execution_arn=execution_arn,
+ operation_id=update.operation_id,
+ callback_options=callback_options,
+ callback_token=callback_token,
+ )
+ return operation
+ case _:
+ msg: str = "Invalid action for CALLBACK operation."
+ raise InvalidParameterValueException(msg)
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/context.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/context.py
new file mode 100644
index 0000000..182bf91
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/context.py
@@ -0,0 +1,59 @@
+"""Context operation processor for handling CONTEXT operation updates."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from aws_durable_execution_sdk_python.lambda_service import (
+ Operation,
+ OperationAction,
+ OperationStatus,
+ OperationUpdate,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.processors.base import (
+ OperationProcessor,
+)
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+
+
+if TYPE_CHECKING:
+ from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier
+
+
+class ContextProcessor(OperationProcessor):
+ """Processes CONTEXT operation updates for execution context management."""
+
+ def process(
+ self,
+ update: OperationUpdate,
+ current_op: Operation | None,
+ notifier: ExecutionNotifier, # noqa: ARG002
+ execution_arn: str, # noqa: ARG002
+ ) -> Operation:
+ """Process CONTEXT operation update for context state transitions."""
+ match update.action:
+ case OperationAction.START:
+ # TODO: check for "Cannot start a CONTEXT operation that already exists."
+ return self._translate_update_to_operation(
+ update=update,
+ current_operation=current_op,
+ status=OperationStatus.STARTED,
+ )
+ case OperationAction.SUCCEED:
+ return self._translate_update_to_operation(
+ update=update,
+ current_operation=current_op,
+ status=OperationStatus.SUCCEEDED,
+ )
+ case OperationAction.FAIL:
+ return self._translate_update_to_operation(
+ update=update,
+ current_operation=current_op,
+ status=OperationStatus.FAILED,
+ )
+ case _:
+ msg: str = "Invalid action for CONTEXT operation."
+ raise InvalidParameterValueException(msg)
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/execution.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/execution.py
new file mode 100644
index 0000000..e8ad2ef
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/execution.py
@@ -0,0 +1,52 @@
+"""Execution operation processor for handling EXECUTION operation updates."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from aws_durable_execution_sdk_python.lambda_service import (
+ ErrorObject,
+ Operation,
+ OperationAction,
+ OperationUpdate,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.processors.base import (
+ OperationProcessor,
+)
+
+
+if TYPE_CHECKING:
+ from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier
+
+
+class ExecutionProcessor(OperationProcessor):
+ """Processes EXECUTION operation updates for workflow completion."""
+
+ def process(
+ self,
+ update: OperationUpdate,
+ current_op: Operation | None, # noqa: ARG002
+ notifier: ExecutionNotifier,
+ execution_arn: str,
+ ) -> Operation | None:
+ """Process EXECUTION operation update for workflow completion/failure."""
+ match update.action:
+ case OperationAction.SUCCEED:
+ notifier.notify_completed(
+ execution_arn=execution_arn, result=update.payload
+ )
+ case _:
+ # intentional. actual service will fail any EXECUTION update that is not SUCCEED.
+ error = (
+ update.error
+ if update.error
+ else ErrorObject.from_message(
+ "There is no error details but EXECUTION checkpoint action is not SUCCEED."
+ )
+ )
+ # All EXECUTION failures go through normal fail path
+ # Timeout/Stop status is set by executor based on the operation that caused it
+ notifier.notify_failed(execution_arn=execution_arn, error=error)
+ # TODO: Svc doesn't actually create checkpoint for EXECUTION. might have to for localrunner though.
+ return None
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/step.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/step.py
new file mode 100644
index 0000000..0db5a0b
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/step.py
@@ -0,0 +1,124 @@
+"""Step operation processor for handling STEP operation updates."""
+
+from __future__ import annotations
+
+from datetime import UTC, datetime, timedelta
+from typing import TYPE_CHECKING
+
+from aws_durable_execution_sdk_python.lambda_service import (
+ Operation,
+ OperationAction,
+ OperationStatus,
+ OperationUpdate,
+ StepDetails,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.processors.base import (
+ OperationProcessor,
+)
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+
+
+if TYPE_CHECKING:
+ from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier
+
+
+class StepProcessor(OperationProcessor):
+ """Processes STEP operation updates with retry scheduling."""
+
+ def process(
+ self,
+ update: OperationUpdate,
+ current_op: Operation | None,
+ notifier: ExecutionNotifier,
+ execution_arn: str,
+ ) -> Operation:
+ """Process STEP operation update with scheduler integration for retries."""
+ match update.action:
+ case OperationAction.START:
+ return self._translate_update_to_operation(
+ update=update,
+ current_operation=current_op,
+ status=OperationStatus.STARTED,
+ )
+ case OperationAction.RETRY:
+ # set Status=PENDING, next attempt time, attempt count + 1
+ delay = (
+ update.step_options.next_attempt_delay_seconds
+ if update.step_options
+ else 0
+ )
+ next_attempt_time = datetime.now(UTC) + timedelta(seconds=delay)
+
+ # Build new step_details with incremented attempt
+ current_attempt = (
+ current_op.step_details.attempt
+ if current_op and current_op.step_details
+ else 0
+ )
+ new_step_details = StepDetails(
+ attempt=current_attempt + 1,
+ next_attempt_timestamp=next_attempt_time,
+ result=(
+ current_op.step_details.result
+ if current_op and current_op.step_details
+ else None
+ ),
+ error=(
+ current_op.step_details.error
+ if current_op and current_op.step_details
+ else None
+ ),
+ )
+
+ # Create new operation with updated step_details
+ retry_operation = Operation(
+ operation_id=update.operation_id,
+ operation_type=update.operation_type,
+ status=OperationStatus.PENDING,
+ parent_id=update.parent_id,
+ name=update.name,
+ start_timestamp=(
+ current_op.start_timestamp if current_op else datetime.now(UTC)
+ ),
+ end_timestamp=None,
+ sub_type=update.sub_type,
+ execution_details=current_op.execution_details
+ if current_op
+ else None,
+ context_details=current_op.context_details if current_op else None,
+ step_details=new_step_details,
+ wait_details=current_op.wait_details if current_op else None,
+ callback_details=current_op.callback_details
+ if current_op
+ else None,
+ chained_invoke_details=current_op.chained_invoke_details
+ if current_op
+ else None,
+ )
+
+ # Schedule step retry timer to fire after delay
+ notifier.notify_step_retry_scheduled(
+ execution_arn=execution_arn,
+ operation_id=update.operation_id,
+ delay=delay,
+ )
+ return retry_operation
+ case OperationAction.SUCCEED:
+ return self._translate_update_to_operation(
+ update=update,
+ current_operation=current_op,
+ status=OperationStatus.SUCCEEDED,
+ )
+ case OperationAction.FAIL:
+ return self._translate_update_to_operation(
+ update=update,
+ current_operation=current_op,
+ status=OperationStatus.FAILED,
+ )
+ case _:
+ msg: str = "Invalid action for STEP operation."
+
+ raise InvalidParameterValueException(msg)
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/wait.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/wait.py
new file mode 100644
index 0000000..01cc69b
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/wait.py
@@ -0,0 +1,95 @@
+"""Wait operation processor for handling WAIT operation updates."""
+
+from __future__ import annotations
+
+import logging
+import os
+from datetime import UTC, datetime, timedelta
+from typing import TYPE_CHECKING
+
+from aws_durable_execution_sdk_python.lambda_service import (
+ Operation,
+ OperationAction,
+ OperationStatus,
+ OperationUpdate,
+ WaitDetails,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.processors.base import (
+ OperationProcessor,
+)
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+
+
+if TYPE_CHECKING:
+ from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier
+
+
+class WaitProcessor(OperationProcessor):
+ """Processes WAIT operation updates with timer scheduling."""
+
+ def process(
+ self,
+ update: OperationUpdate,
+ current_op: Operation | None,
+ notifier: ExecutionNotifier,
+ execution_arn: str,
+ ) -> Operation:
+ """Process WAIT operation update with scheduler integration for timers."""
+ match update.action:
+ case OperationAction.START:
+ wait_seconds = (
+ update.wait_options.wait_seconds if update.wait_options else 0
+ )
+ time_scale = float(os.getenv("DURABLE_EXECUTION_TIME_SCALE", "1.0"))
+ logging.info("Using DURABLE_EXECUTION_TIME_SCALE: %f", time_scale)
+ scaled_wait_seconds = wait_seconds * time_scale
+
+ scheduled_end_timestamp = datetime.now(UTC) + timedelta(
+ seconds=scaled_wait_seconds
+ )
+
+ # Create WaitDetails with scheduled timestamp
+ wait_details = WaitDetails(
+ scheduled_end_timestamp=scheduled_end_timestamp
+ )
+
+ # Create new operation with wait details
+ wait_operation = Operation(
+ operation_id=update.operation_id,
+ operation_type=update.operation_type,
+ status=OperationStatus.STARTED,
+ parent_id=update.parent_id,
+ name=update.name,
+ start_timestamp=datetime.now(UTC),
+ end_timestamp=None,
+ sub_type=update.sub_type,
+ execution_details=None,
+ context_details=None,
+ step_details=None,
+ wait_details=wait_details,
+ callback_details=None,
+ chained_invoke_details=None,
+ )
+
+ # Schedule wait timer to complete after delay
+ notifier.notify_wait_timer_scheduled(
+ execution_arn=execution_arn,
+ operation_id=update.operation_id,
+ delay=scaled_wait_seconds,
+ )
+ return wait_operation
+ case OperationAction.CANCEL:
+ # TODO: need to cancel the WAIT in the executor
+ # TODO: increase sequence id
+ return self._translate_update_to_operation(
+ update=update,
+ current_operation=current_op,
+ status=OperationStatus.CANCELLED,
+ )
+ case _:
+ msg: str = "Invalid action for WAIT operation."
+
+ raise InvalidParameterValueException(msg)
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/transformer.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/transformer.py
new file mode 100644
index 0000000..cd37b8a
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/transformer.py
@@ -0,0 +1,104 @@
+"""Operation transformer for converting OperationUpdates to Operations."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from aws_durable_execution_sdk_python.lambda_service import (
+ Operation,
+ OperationType,
+ OperationUpdate,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.processors.callback import (
+ CallbackProcessor,
+)
+from aws_durable_execution_sdk_python_testing.checkpoint.processors.context import (
+ ContextProcessor,
+)
+from aws_durable_execution_sdk_python_testing.checkpoint.processors.execution import (
+ ExecutionProcessor,
+)
+from aws_durable_execution_sdk_python_testing.checkpoint.processors.step import (
+ StepProcessor,
+)
+from aws_durable_execution_sdk_python_testing.checkpoint.processors.wait import (
+ WaitProcessor,
+)
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+
+
+if TYPE_CHECKING:
+ from collections.abc import MutableMapping
+
+ from aws_durable_execution_sdk_python_testing.checkpoint.processors.base import (
+ OperationProcessor,
+ )
+
+from typing import ClassVar
+
+
+class OperationTransformer:
+ """Transforms OperationUpdates to Operations while maintaining order and triggering scheduler actions."""
+
+ _DEFAULT_PROCESSORS: ClassVar[dict[OperationType, OperationProcessor]] = {
+ OperationType.STEP: StepProcessor(),
+ OperationType.WAIT: WaitProcessor(),
+ OperationType.CONTEXT: ContextProcessor(),
+ OperationType.CALLBACK: CallbackProcessor(),
+ OperationType.EXECUTION: ExecutionProcessor(),
+ }
+
+ def __init__(
+ self,
+ processors: MutableMapping[OperationType, OperationProcessor] | None = None,
+ ):
+ self.processors = processors if processors else self._DEFAULT_PROCESSORS
+
+ def process_updates(
+ self,
+ updates: list[OperationUpdate],
+ current_operations: list[Operation],
+ notifier,
+ execution_arn: str,
+ ) -> tuple[list[Operation], list[OperationUpdate]]:
+ """Transform updates maintaining operation order and return (operations, updates)."""
+ op_map = {op.operation_id: op for op in current_operations}
+
+ # Start with copy of current operations list
+ result_operations = current_operations.copy()
+
+ for update in updates:
+ processor = self.processors.get(update.operation_type)
+ if processor:
+ current_op = op_map.get(update.operation_id)
+ updated_op = processor.process(
+ update=update,
+ current_op=current_op,
+ notifier=notifier,
+ execution_arn=execution_arn,
+ )
+
+ if updated_op is not None:
+ if update.operation_id in op_map:
+ # Update existing operation in-place
+ for i, op in enumerate(result_operations): # pragma: no branch
+ # no branch coverage because result_operation empty not reachable here
+ if op.operation_id == update.operation_id:
+ result_operations[i] = updated_op
+ break
+ else:
+ # Append new operation to end
+ result_operations.append(updated_op)
+
+ # Update map for future lookups
+ op_map[update.operation_id] = updated_op
+ else:
+ msg: str = (
+ f"Checkpoint for {update.operation_type} is not implemented yet."
+ )
+ raise InvalidParameterValueException(msg)
+
+ return result_operations, updates
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/__init__.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/__init__.py
new file mode 100644
index 0000000..f97d027
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/__init__.py
@@ -0,0 +1 @@
+"""Checkpoint validation module."""
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/checkpoint.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/checkpoint.py
new file mode 100644
index 0000000..86d654d
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/checkpoint.py
@@ -0,0 +1,242 @@
+"""Main checkpoint input validator."""
+
+from __future__ import annotations
+
+import json
+from typing import TYPE_CHECKING
+
+from aws_durable_execution_sdk_python.lambda_service import (
+ OperationAction,
+ OperationType,
+ OperationUpdate,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.callback import (
+ CallbackOperationValidator,
+)
+from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.context import (
+ ContextOperationValidator,
+)
+from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.execution import (
+ ExecutionOperationValidator,
+)
+from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.invoke import (
+ ChainedInvokeOperationValidator,
+)
+from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.step import (
+ StepOperationValidator,
+)
+from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.wait import (
+ WaitOperationValidator,
+)
+from aws_durable_execution_sdk_python_testing.checkpoint.validators.transitions import (
+ ValidActionsByOperationTypeValidator,
+)
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+
+
+if TYPE_CHECKING:
+ from collections.abc import MutableMapping
+
+ from aws_durable_execution_sdk_python_testing.execution import Execution
+
+MAX_ERROR_PAYLOAD_SIZE_BYTES = 32768
+
+
+class CheckpointValidator:
+ """Validates checkpoint input based on current state."""
+
+ @staticmethod
+ def validate_input(updates: list[OperationUpdate], execution: Execution) -> None:
+ """Perform validation on the given input based on the current state."""
+ if not updates:
+ return
+
+ CheckpointValidator._validate_conflicting_execution_update(updates)
+ CheckpointValidator._validate_parent_id_and_duplicate_id(updates, execution)
+
+ for update in updates:
+ CheckpointValidator._validate_operation_update(update, execution)
+
+ @staticmethod
+ def _validate_conflicting_execution_update(updates: list[OperationUpdate]) -> None:
+ """Validate that there are no conflicting execution updates."""
+ execution_updates = [
+ update
+ for update in updates
+ if update.operation_type == OperationType.EXECUTION
+ ]
+
+ if len(execution_updates) > 1:
+ msg_multiple_exec: str = "Cannot checkpoint multiple EXECUTION updates."
+
+ raise InvalidParameterValueException(msg_multiple_exec)
+
+ if execution_updates and updates[-1].operation_type != OperationType.EXECUTION:
+ msg_exec_last: str = "EXECUTION checkpoint must be the last update."
+
+ raise InvalidParameterValueException(msg_exec_last)
+
+ @staticmethod
+ def _validate_operation_update(
+ update: OperationUpdate, execution: Execution
+ ) -> None:
+ """Validate a single operation update."""
+ CheckpointValidator._validate_inconsistent_operation_metadata(update, execution)
+ CheckpointValidator._validate_payload_sizes(update)
+ ValidActionsByOperationTypeValidator.validate(
+ update.operation_type, update.action
+ )
+ CheckpointValidator._validate_operation_status_transition(update, execution)
+
+ @staticmethod
+ def _validate_payload_sizes(update: OperationUpdate) -> None:
+ """Validate that operation payload sizes are not too large."""
+ if update.error is not None:
+ payload = json.dumps(update.error.to_dict())
+ if len(payload) > MAX_ERROR_PAYLOAD_SIZE_BYTES:
+ msg: str = f"Error object size must be less than {MAX_ERROR_PAYLOAD_SIZE_BYTES} bytes."
+ raise InvalidParameterValueException(msg)
+
+ @staticmethod
+ def _validate_operation_status_transition(
+ update: OperationUpdate, execution: Execution
+ ) -> None:
+ """Validate that the operation status transition is valid."""
+ current_state = None
+ for operation in execution.operations:
+ if operation.operation_id == update.operation_id:
+ current_state = operation
+ break
+
+ match update.operation_type:
+ case OperationType.STEP:
+ StepOperationValidator.validate(current_state, update)
+ case OperationType.CONTEXT:
+ ContextOperationValidator.validate(current_state, update)
+ case OperationType.WAIT:
+ WaitOperationValidator.validate(current_state, update)
+ case OperationType.CALLBACK:
+ CallbackOperationValidator.validate(current_state, update)
+ case OperationType.CHAINED_INVOKE:
+ ChainedInvokeOperationValidator.validate(current_state, update)
+ case OperationType.EXECUTION:
+ ExecutionOperationValidator.validate(update)
+ case _: # pragma: no cover
+ msg: str = "Invalid operation type."
+
+ raise InvalidParameterValueException(msg)
+
+ @staticmethod
+ def _validate_inconsistent_operation_metadata(
+ update: OperationUpdate, execution: Execution
+ ) -> None:
+ """Validate that operation metadata is consistent with existing operation."""
+ current_state = None
+ for operation in execution.operations:
+ if operation.operation_id == update.operation_id:
+ current_state = operation
+ break
+
+ if current_state is not None:
+ if (
+ update.operation_type is not None
+ and update.operation_type != current_state.operation_type
+ ):
+ msg: str = "Inconsistent operation type."
+ raise InvalidParameterValueException(msg)
+
+ if (
+ update.sub_type is not None
+ and update.sub_type != current_state.sub_type
+ ):
+ msg_subtype: str = "Inconsistent operation subtype."
+ raise InvalidParameterValueException(msg_subtype)
+
+ if update.name is not None and update.name != current_state.name:
+ msg_name: str = "Inconsistent operation name."
+ raise InvalidParameterValueException(msg_name)
+
+ if (
+ update.parent_id is not None
+ and update.parent_id != current_state.parent_id
+ ):
+ msg_parent: str = "Inconsistent parent operation id."
+ raise InvalidParameterValueException(msg_parent)
+
+ @staticmethod
+ def _validate_parent_id_and_duplicate_id(
+ updates: list[OperationUpdate], execution: Execution
+ ) -> None:
+ """Validate parent IDs and check for duplicate operation IDs.
+
+ Validate that any provided parentId is valid, and also validate no duplicate operation is being
+ updated at the same time (unless it is a STEP/CONTEXT starting + performing one more non-START action).
+ """
+ operations_started: MutableMapping[str, OperationUpdate] = {}
+ last_updates_seen: MutableMapping[str, OperationUpdate] = {}
+
+ for update in updates:
+ if CheckpointValidator._is_invalid_duplicate_update(
+ update, last_updates_seen
+ ):
+ msg_duplicate: str = (
+ "Cannot checkpoint multiple operations with the same ID."
+ )
+ raise InvalidParameterValueException(msg_duplicate)
+
+ if not CheckpointValidator._is_valid_parent_for_update(
+ execution, update, operations_started
+ ):
+ msg_parent: str = "Invalid parent operation id."
+ raise InvalidParameterValueException(msg_parent)
+
+ if update.action == OperationAction.START:
+ operations_started[update.operation_id] = update
+
+ last_updates_seen[update.operation_id] = update
+
+ @staticmethod
+ def _is_invalid_duplicate_update(
+ update: OperationUpdate, last_updates_seen: MutableMapping[str, OperationUpdate]
+ ) -> bool:
+ """Check if this is an invalid duplicate update."""
+ last_update = last_updates_seen.get(update.operation_id)
+ if last_update is None:
+ return False
+
+ if last_update.operation_type in (OperationType.STEP, OperationType.CONTEXT):
+ # Allow duplicate for STEP/CONTEXT if last was START and current is not START
+ allow_duplicate = (
+ last_update.action == OperationAction.START
+ and update.action != OperationAction.START
+ )
+ return not allow_duplicate
+
+ return True
+
+ @staticmethod
+ def _is_valid_parent_for_update(
+ execution: Execution,
+ update: OperationUpdate,
+ operations_started: MutableMapping[str, OperationUpdate],
+ ) -> bool:
+ """Check if the parent ID is valid for the update."""
+ parent_id = update.parent_id
+
+ if parent_id is None:
+ return True
+
+ # Check if parent is in operations started in this batch
+ if parent_id in operations_started:
+ parent_update = operations_started[parent_id]
+ return parent_update.operation_type == OperationType.CONTEXT
+
+ # Check if parent exists in current execution state
+ for operation in execution.operations:
+ if operation.operation_id == parent_id:
+ return operation.operation_type == OperationType.CONTEXT
+
+ return False
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/__init__.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/__init__.py
new file mode 100644
index 0000000..455b119
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/__init__.py
@@ -0,0 +1 @@
+"""Operation-specific validators."""
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/callback.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/callback.py
new file mode 100644
index 0000000..575db81
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/callback.py
@@ -0,0 +1,45 @@
+"""Callback operation validator."""
+
+from __future__ import annotations
+
+from aws_durable_execution_sdk_python.lambda_service import (
+ Operation,
+ OperationAction,
+ OperationStatus,
+ OperationUpdate,
+)
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+
+
+VALID_ACTIONS_FOR_CALLBACK = frozenset(
+ [
+ OperationAction.START,
+ ]
+)
+
+
+class CallbackOperationValidator:
+ """Validates CALLBACK operation transitions."""
+
+ _ALLOWED_STATUS_TO_CANCEL = frozenset(
+ [
+ OperationStatus.STARTED,
+ ]
+ )
+
+ @staticmethod
+ def validate(current_state: Operation | None, update: OperationUpdate) -> None:
+ """Validate CALLBACK operation update."""
+ match update.action:
+ case OperationAction.START:
+ if current_state is not None:
+ msg_callback_exists: str = (
+ "Cannot start a CALLBACK that already exist."
+ )
+ raise InvalidParameterValueException(msg_callback_exists)
+ case _:
+ msg: str = "Invalid action for CALLBACK operation."
+ raise InvalidParameterValueException(msg)
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/context.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/context.py
new file mode 100644
index 0000000..3104044
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/context.py
@@ -0,0 +1,73 @@
+"""Context operation validator."""
+
+from __future__ import annotations
+
+from aws_durable_execution_sdk_python.lambda_service import (
+ Operation,
+ OperationAction,
+ OperationStatus,
+ OperationUpdate,
+)
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+
+
+VALID_ACTIONS_FOR_CONTEXT = frozenset(
+ [
+ OperationAction.START,
+ OperationAction.FAIL,
+ OperationAction.SUCCEED,
+ ]
+)
+
+
+class ContextOperationValidator:
+ """Validates CONTEXT operation transitions."""
+
+ _ALLOWED_STATUS_TO_CLOSE = frozenset(
+ [
+ OperationStatus.STARTED,
+ ]
+ )
+
+ @staticmethod
+ def validate(current_state: Operation | None, update: OperationUpdate) -> None:
+ """Validate CONTEXT operation update."""
+ match update.action:
+ case OperationAction.START:
+ if current_state is not None:
+ msg_context_exists: str = (
+ "Cannot start a CONTEXT that already exist."
+ )
+
+ raise InvalidParameterValueException(msg_context_exists)
+ case OperationAction.FAIL | OperationAction.SUCCEED:
+ if (
+ current_state is not None
+ and current_state.status
+ not in ContextOperationValidator._ALLOWED_STATUS_TO_CLOSE
+ ):
+ msg_context_close: str = "Invalid current CONTEXT state to close."
+
+ raise InvalidParameterValueException(msg_context_close)
+ if update.action == OperationAction.FAIL and update.payload is not None:
+ msg_context_fail_payload: str = (
+ "Cannot provide a Payload for FAIL action."
+ )
+
+ raise InvalidParameterValueException(msg_context_fail_payload)
+ if (
+ update.action == OperationAction.SUCCEED
+ and update.error is not None
+ ):
+ msg_context_succeed_error: str = (
+ "Cannot provide an Error for SUCCEED action."
+ )
+
+ raise InvalidParameterValueException(msg_context_succeed_error)
+ case _:
+ msg_context_invalid: str = "Invalid CONTEXT action."
+
+ raise InvalidParameterValueException(msg_context_invalid)
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/execution.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/execution.py
new file mode 100644
index 0000000..5e66677
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/execution.py
@@ -0,0 +1,47 @@
+"""Execution operation validator."""
+
+from __future__ import annotations
+
+from aws_durable_execution_sdk_python.lambda_service import (
+ OperationAction,
+ OperationUpdate,
+)
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+
+
+VALID_ACTIONS_FOR_EXECUTION = frozenset(
+ [
+ OperationAction.SUCCEED,
+ OperationAction.FAIL,
+ ]
+)
+
+
+class ExecutionOperationValidator:
+ """Validates EXECUTION operation transitions."""
+
+ @staticmethod
+ def validate(update: OperationUpdate) -> None:
+ """Validate EXECUTION operation update."""
+ match update.action:
+ case OperationAction.SUCCEED:
+ if update.error is not None:
+ msg_exec_succeed_error: str = (
+ "Cannot provide an Error for SUCCEED action."
+ )
+
+ raise InvalidParameterValueException(msg_exec_succeed_error)
+ case OperationAction.FAIL:
+ if update.payload is not None:
+ msg_exec_fail_payload: str = (
+ "Cannot provide a Payload for FAIL action."
+ )
+
+ raise InvalidParameterValueException(msg_exec_fail_payload)
+ case _:
+ msg_exec_invalid: str = "Invalid EXECUTION action."
+
+ raise InvalidParameterValueException(msg_exec_invalid)
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/invoke.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/invoke.py
new file mode 100644
index 0000000..1c28712
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/invoke.py
@@ -0,0 +1,56 @@
+"""Invoke operation validator."""
+
+from __future__ import annotations
+
+from aws_durable_execution_sdk_python.lambda_service import (
+ Operation,
+ OperationAction,
+ OperationStatus,
+ OperationUpdate,
+)
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+
+
+VALID_ACTIONS_FOR_INVOKE = frozenset(
+ [
+ OperationAction.START,
+ OperationAction.CANCEL,
+ ]
+)
+
+
+class ChainedInvokeOperationValidator:
+ """Validates INVOKE operation transitions."""
+
+ _ALLOWED_STATUS_TO_CANCEL = frozenset(
+ [
+ OperationStatus.STARTED,
+ ]
+ )
+
+ @staticmethod
+ def validate(current_state: Operation | None, update: OperationUpdate) -> None:
+ """Validate INVOKE operation update."""
+ match update.action:
+ case OperationAction.START:
+ if current_state is not None:
+ msg_invoke_exists: str = (
+ "Cannot start an INVOKE that already exist."
+ )
+
+ raise InvalidParameterValueException(msg_invoke_exists)
+ case OperationAction.CANCEL:
+ if (
+ current_state is None
+ or current_state.status
+ not in ChainedInvokeOperationValidator._ALLOWED_STATUS_TO_CANCEL
+ ):
+ msg_invoke_cancel: str = "Cannot cancel an INVOKE that does not exist or has already completed."
+ raise InvalidParameterValueException(msg_invoke_cancel)
+ case _:
+ msg_invoke_invalid: str = "Invalid INVOKE action."
+
+ raise InvalidParameterValueException(msg_invoke_invalid)
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/step.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/step.py
new file mode 100644
index 0000000..52c5388
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/step.py
@@ -0,0 +1,106 @@
+"""Step operation validator."""
+
+from __future__ import annotations
+
+from aws_durable_execution_sdk_python.lambda_service import (
+ Operation,
+ OperationAction,
+ OperationStatus,
+ OperationUpdate,
+)
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+
+
+VALID_ACTIONS_FOR_STEP = frozenset(
+ [
+ OperationAction.START,
+ OperationAction.FAIL,
+ OperationAction.RETRY,
+ OperationAction.SUCCEED,
+ ]
+)
+
+
+class StepOperationValidator:
+ """Validates STEP operation transitions."""
+
+ _ALLOWED_STATUS_TO_CLOSE = frozenset(
+ [
+ OperationStatus.STARTED,
+ OperationStatus.READY,
+ ]
+ )
+
+ _ALLOWED_STATUS_TO_START = frozenset(
+ [
+ OperationStatus.READY,
+ ]
+ )
+
+ _ALLOWED_STATUS_TO_REATTEMPT = frozenset(
+ [
+ OperationStatus.STARTED,
+ OperationStatus.READY,
+ ]
+ )
+
+ @staticmethod
+ def validate(current_state: Operation | None, update: OperationUpdate) -> None:
+ """Validate STEP operation update."""
+ if current_state is None:
+ return
+
+ match update.action:
+ case OperationAction.START:
+ if (
+ current_state.status
+ not in StepOperationValidator._ALLOWED_STATUS_TO_START
+ ):
+ msg_step_start: str = "Invalid current STEP state to start."
+
+ raise InvalidParameterValueException(msg_step_start)
+ case OperationAction.FAIL | OperationAction.SUCCEED:
+ if (
+ current_state.status
+ not in StepOperationValidator._ALLOWED_STATUS_TO_CLOSE
+ ):
+ msg_step_close: str = "Invalid current STEP state to close."
+
+ raise InvalidParameterValueException(msg_step_close)
+ if update.action == OperationAction.FAIL and update.payload is not None:
+ msg_fail_payload: str = "Cannot provide a Payload for FAIL action."
+
+ raise InvalidParameterValueException(msg_fail_payload)
+ if (
+ update.action == OperationAction.SUCCEED
+ and update.error is not None
+ ):
+ msg_succeed_error: str = (
+ "Cannot provide an Error for SUCCEED action."
+ )
+
+ raise InvalidParameterValueException(msg_succeed_error)
+ case OperationAction.RETRY:
+ if (
+ current_state.status
+ not in StepOperationValidator._ALLOWED_STATUS_TO_REATTEMPT
+ ):
+ msg_step_retry: str = "Invalid current STEP state to re-attempt."
+
+ raise InvalidParameterValueException(msg_step_retry)
+ if update.step_options is None:
+ msg_step_options: str = "Invalid StepOptions for the given action."
+
+ raise InvalidParameterValueException(msg_step_options)
+ if update.error is not None and update.payload is not None:
+ msg_retry_both: str = (
+ "Cannot provide both error and payload to RETRY a STEP."
+ )
+ raise InvalidParameterValueException(msg_retry_both)
+ case _:
+ msg_step_invalid: str = "Invalid STEP action."
+
+ raise InvalidParameterValueException(msg_step_invalid)
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/wait.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/wait.py
new file mode 100644
index 0000000..1858b3b
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/wait.py
@@ -0,0 +1,54 @@
+"""Wait operation validator."""
+
+from __future__ import annotations
+
+from aws_durable_execution_sdk_python.lambda_service import (
+ Operation,
+ OperationAction,
+ OperationStatus,
+ OperationUpdate,
+)
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+
+
+VALID_ACTIONS_FOR_WAIT = frozenset(
+ [
+ OperationAction.START,
+ OperationAction.CANCEL,
+ ]
+)
+
+
+class WaitOperationValidator:
+ """Validates WAIT operation transitions."""
+
+ _ALLOWED_STATUS_TO_CANCEL = frozenset(
+ [
+ OperationStatus.STARTED,
+ ]
+ )
+
+ @staticmethod
+ def validate(current_state: Operation | None, update: OperationUpdate) -> None:
+ """Validate WAIT operation update."""
+ match update.action:
+ case OperationAction.START:
+ if current_state is not None:
+ msg_wait_exists: str = "Cannot start a WAIT that already exist."
+
+ raise InvalidParameterValueException(msg_wait_exists)
+ case OperationAction.CANCEL:
+ if (
+ current_state is None
+ or current_state.status
+ not in WaitOperationValidator._ALLOWED_STATUS_TO_CANCEL
+ ):
+ msg_wait_cancel: str = "Cannot cancel a WAIT that does not exist or has already completed."
+ raise InvalidParameterValueException(msg_wait_cancel)
+ case _:
+ msg_wait_invalid: str = "Invalid WAIT action."
+
+ raise InvalidParameterValueException(msg_wait_invalid)
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/transitions.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/transitions.py
new file mode 100644
index 0000000..fff45a1
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/transitions.py
@@ -0,0 +1,66 @@
+"""Validator for valid actions by operation type."""
+
+from __future__ import annotations
+
+from typing import ClassVar
+
+from aws_durable_execution_sdk_python.lambda_service import (
+ OperationAction,
+ OperationType,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.callback import (
+ VALID_ACTIONS_FOR_CALLBACK,
+)
+from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.context import (
+ VALID_ACTIONS_FOR_CONTEXT,
+)
+from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.execution import (
+ VALID_ACTIONS_FOR_EXECUTION,
+)
+from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.invoke import (
+ VALID_ACTIONS_FOR_INVOKE,
+)
+from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.step import (
+ VALID_ACTIONS_FOR_STEP,
+)
+from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.wait import (
+ VALID_ACTIONS_FOR_WAIT,
+)
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+
+
+class ValidActionsByOperationTypeValidator:
+ """Validates that the given action is valid for the given operation type."""
+
+ _VALID_ACTIONS_BY_OPERATION_TYPE: ClassVar[
+ dict[OperationType, frozenset[OperationAction]]
+ ] = {
+ OperationType.STEP: VALID_ACTIONS_FOR_STEP,
+ OperationType.CONTEXT: VALID_ACTIONS_FOR_CONTEXT,
+ OperationType.WAIT: VALID_ACTIONS_FOR_WAIT,
+ OperationType.CALLBACK: VALID_ACTIONS_FOR_CALLBACK,
+ OperationType.CHAINED_INVOKE: VALID_ACTIONS_FOR_INVOKE,
+ OperationType.EXECUTION: VALID_ACTIONS_FOR_EXECUTION,
+ }
+
+ @staticmethod
+ def validate(operation_type: OperationType, action: OperationAction) -> None:
+ """Validate that the action is valid for the operation type."""
+ valid_actions = (
+ ValidActionsByOperationTypeValidator._VALID_ACTIONS_BY_OPERATION_TYPE.get(
+ operation_type
+ )
+ )
+
+ if valid_actions is None:
+ msg_unknown_op: str = "Unknown operation type."
+
+ raise InvalidParameterValueException(msg_unknown_op)
+
+ if action not in valid_actions:
+ msg_invalid_action: str = "Invalid action for the given operation type."
+
+ raise InvalidParameterValueException(msg_invalid_action)
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/cli.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/cli.py
new file mode 100644
index 0000000..85dcb7d
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/cli.py
@@ -0,0 +1,498 @@
+"""Command-line interface for the AWS Durable Functions Local Runner.
+
+This module provides the dex-local-runner CLI with commands for:
+- start-server: Start the local web server
+- invoke: Invoke a durable execution
+- get-durable-execution: Get execution details
+- get-durable-execution-history: Get execution history
+"""
+
+from __future__ import annotations
+
+import argparse
+import json
+import logging
+import os
+import sys
+import uuid
+from dataclasses import dataclass
+from typing import Any
+from urllib.parse import urljoin
+
+import aws_durable_execution_sdk_python
+import boto3 # type: ignore
+from urllib.error import HTTPError, URLError
+from urllib.request import Request, urlopen
+
+from botocore.exceptions import ConnectionError # type: ignore
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ DurableFunctionsLocalRunnerError,
+ DurableFunctionsTestError,
+)
+from aws_durable_execution_sdk_python_testing.model import (
+ StartDurableExecutionInput,
+)
+from aws_durable_execution_sdk_python_testing.runner import WebRunner, WebRunnerConfig
+from aws_durable_execution_sdk_python_testing.stores.base import StoreType
+from aws_durable_execution_sdk_python_testing.web.server import WebServiceConfig
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass(frozen=True)
+class CliConfig:
+ """Configuration for the CLI application with environment variable support."""
+
+ # Server configuration
+ host: str = "0.0.0.0" # noqa:S104
+ port: int = 5000
+ log_level: int = logging.INFO
+ lambda_endpoint: str = "http://127.0.0.1:3001"
+ local_runner_endpoint: str = "http://0.0.0.0:5000"
+ local_runner_region: str = "us-west-2"
+ local_runner_mode: str = "local"
+
+ # Store configuration
+ store_type: StoreType = StoreType.MEMORY
+ store_path: str | None = None
+
+ @classmethod
+ def from_environment(cls) -> CliConfig:
+ """Create configuration from environment variables with defaults."""
+ # Convert log level string to integer if provided
+ log_level_str = os.getenv("AWS_DEX_LOG_LEVEL", "INFO")
+ log_level = logging.getLevelNamesMapping().get(log_level_str, logging.INFO)
+
+ return cls(
+ host=os.getenv("AWS_DEX_HOST", "0.0.0.0"), # noqa:S104
+ port=int(os.getenv("AWS_DEX_PORT", "5000")),
+ log_level=log_level,
+ lambda_endpoint=os.getenv(
+ "AWS_DEX_LAMBDA_ENDPOINT", "http://127.0.0.1:3001"
+ ),
+ local_runner_endpoint=os.getenv(
+ "AWS_DEX_LOCAL_RUNNER_ENDPOINT", "http://0.0.0.0:5000"
+ ),
+ local_runner_region=os.getenv("AWS_DEX_LOCAL_RUNNER_REGION", "us-west-2"),
+ local_runner_mode=os.getenv("AWS_DEX_LOCAL_RUNNER_MODE", "local"),
+ store_type=StoreType(os.getenv("AWS_DEX_STORE_TYPE", "memory")),
+ store_path=os.getenv("AWS_DEX_STORE_PATH"),
+ )
+
+
+class CliApp:
+ """Main CLI application for dex-local-runner."""
+
+ def __init__(self) -> None:
+ """Initialize the CLI application."""
+ self.config = CliConfig.from_environment()
+
+ def run(self, args: list[str] | None = None) -> int:
+ """Run the CLI application with the given arguments.
+
+ Args:
+ args: Command line arguments. If None, uses sys.argv[1:]
+
+ Returns:
+ Exit code (0 for success, non-zero for error)
+ """
+ try:
+ parser = self._create_parsers()
+ parsed_args = parser.parse_args(args)
+
+ # Configure logging based on log level
+ if hasattr(parsed_args, "log_level") and isinstance(
+ parsed_args.log_level, str
+ ):
+ level = logging.getLevelNamesMapping().get(
+ parsed_args.log_level, logging.INFO
+ )
+ else:
+ # config.log_level is always an integer
+ level = self.config.log_level
+
+ logging.basicConfig(
+ level=level,
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+ )
+ logging.getLogger("botocore").setLevel(logging.WARNING)
+
+ # Execute the appropriate command
+ return parsed_args.func(parsed_args)
+
+ except SystemExit as e:
+ # argparse calls sys.exit() for help, errors, etc.
+ return int(e.code) if e.code is not None else 1
+ except KeyboardInterrupt:
+ print("\nOperation cancelled by user", file=sys.stderr) # noqa: T201
+ return 130 # Standard exit code for SIGINT
+ except DurableFunctionsTestError:
+ logger.exception("Error")
+ return 1
+ except Exception:
+ logger.exception("Unexpected error.")
+ return 1
+
+ def _create_parsers(self) -> argparse.ArgumentParser:
+ """Create the argument parsers for all commands."""
+ parser = argparse.ArgumentParser(
+ prog="dex-local-runner",
+ description="AWS Durable Functions Local Runner CLI",
+ )
+
+ subparsers = parser.add_subparsers(
+ dest="command", help="Available commands", required=True
+ )
+
+ # Create individual parsers
+ self._create_start_server_parser(subparsers)
+ self._create_invoke_parser(subparsers)
+ self._create_get_durable_execution_parser(subparsers)
+ self._create_get_durable_execution_history_parser(subparsers)
+
+ return parser
+
+ # region parsers
+
+ def _create_start_server_parser(self, subparsers) -> None:
+ """Create the start-server command parser."""
+ start_server_parser = subparsers.add_parser(
+ "start-server", help="Start the local Durable Functions Server"
+ )
+ start_server_parser.add_argument(
+ "--host",
+ default=self.config.host,
+ help=f"Server bind address (default: {self.config.host}, env: AWS_DEX_HOST)",
+ )
+ start_server_parser.add_argument(
+ "--port",
+ type=int,
+ default=self.config.port,
+ help=f"Server port (default: {self.config.port}, env: AWS_DEX_PORT)",
+ )
+ start_server_parser.add_argument(
+ "--log-level",
+ type=str,
+ choices=list(logging.getLevelNamesMapping().keys()),
+ default=logging.getLevelName(self.config.log_level),
+ help=f"Logging level (default: {logging.getLevelName(self.config.log_level)}, env: AWS_DEX_LOG_LEVEL)",
+ )
+ start_server_parser.add_argument(
+ "--lambda-endpoint",
+ default=self.config.lambda_endpoint,
+ help=f"Lambda Service endpoint (default: {self.config.lambda_endpoint}, env: AWS_DEX_LAMBDA_ENDPOINT)",
+ )
+ start_server_parser.add_argument(
+ "--local-runner-endpoint",
+ default=self.config.local_runner_endpoint,
+ help=f"Local Runner endpoint (default: {self.config.local_runner_endpoint}, env: AWS_DEX_LOCAL_RUNNER_ENDPOINT)",
+ )
+ start_server_parser.add_argument(
+ "--local-runner-region",
+ default=self.config.local_runner_region,
+ help=f"Local Runner region (default: {self.config.local_runner_region}, env: AWS_DEX_LOCAL_RUNNER_REGION)",
+ )
+ start_server_parser.add_argument(
+ "--local-runner-mode",
+ default=self.config.local_runner_mode,
+ help=f"Local Runner mode (default: {self.config.local_runner_mode}, env: AWS_DEX_LOCAL_RUNNER_MODE)",
+ )
+ start_server_parser.add_argument(
+ "--store-type",
+ choices=[store_type.value for store_type in StoreType],
+ default=self.config.store_type.value,
+ help=f"Store type for execution persistence (default: {self.config.store_type.value}, env: AWS_DEX_STORE_TYPE)",
+ )
+ start_server_parser.add_argument(
+ "--store-path",
+ default=self.config.store_path,
+ help=f"Path for filesystem store (default: {self.config.store_path or '.durable_executions'}, env: AWS_DEX_STORE_PATH)",
+ )
+ start_server_parser.set_defaults(func=self.start_server_command)
+
+ def _create_invoke_parser(self, subparsers) -> None:
+ """Create the invoke command parser."""
+ invoke_parser = subparsers.add_parser(
+ "invoke", help="Invoke a Durable Execution"
+ )
+ invoke_parser.add_argument(
+ "--function-name", required=True, help="Function name (required)"
+ )
+ invoke_parser.add_argument(
+ "--input", default="{}", help="Input data (default: {})"
+ )
+ invoke_parser.add_argument(
+ "--durable-execution-name", help="Durable execution name (optional)"
+ )
+ invoke_parser.set_defaults(func=self.invoke_command)
+
+ def _create_get_durable_execution_parser(self, subparsers) -> None:
+ """Create the get-durable-execution command parser."""
+ get_execution_parser = subparsers.add_parser(
+ "get-durable-execution", help="Get execution details"
+ )
+ get_execution_parser.add_argument(
+ "--durable-execution-arn",
+ required=True,
+ help="Durable execution ARN (required)",
+ )
+ get_execution_parser.set_defaults(func=self.get_durable_execution_command)
+
+ def _create_get_durable_execution_history_parser(self, subparsers) -> None:
+ """Create the get-durable-execution-history command parser."""
+ get_history_parser = subparsers.add_parser(
+ "get-durable-execution-history", help="Get execution history"
+ )
+ get_history_parser.add_argument(
+ "--durable-execution-arn",
+ required=True,
+ help="Durable execution ARN (required)",
+ )
+ get_history_parser.set_defaults(func=self.get_durable_execution_history_command)
+
+ # endregion parsers
+
+ # region commands
+
+ def start_server_command(self, args: argparse.Namespace) -> int:
+ """Execute the start-server command.
+
+ Args:
+ args: Parsed command line arguments
+
+ Returns:
+ Exit code (0 for success, non-zero for error)
+ """
+ try:
+ # Create web service configuration from CLI arguments
+ web_config = WebServiceConfig(
+ host=args.host,
+ port=args.port,
+ log_level=args.log_level,
+ )
+
+ # Create web runner configuration with composition
+ runner_config = WebRunnerConfig(
+ web_service=web_config,
+ lambda_endpoint=args.lambda_endpoint,
+ local_runner_endpoint=args.local_runner_endpoint,
+ local_runner_region=args.local_runner_region,
+ local_runner_mode=args.local_runner_mode,
+ store_type=StoreType(args.store_type),
+ store_path=args.store_path,
+ )
+
+ logger.info(
+ "Starting Durable Functions Local Runner on %s:%s",
+ args.host,
+ args.port,
+ )
+ logger.info("Configuration:")
+ logger.info(" Host: %s", args.host)
+ logger.info(" Port: %s", args.port)
+ logger.info(" Log Level: %s", args.log_level)
+ logger.info(" Lambda Endpoint: %s", args.lambda_endpoint)
+ logger.info(" Local Runner Endpoint: %s", args.local_runner_endpoint)
+ logger.info(" Local Runner Region: %s", args.local_runner_region)
+ logger.info(" Local Runner Mode: %s", args.local_runner_mode)
+ logger.info(" Store Type: %s", args.store_type)
+ if StoreType(args.store_type) == StoreType.FILESYSTEM:
+ store_path = args.store_path or ".durable_executions"
+ logger.info(" Store Path: %s", store_path)
+
+ # Use runner as context manager for proper lifecycle
+ with WebRunner(runner_config) as runner:
+ logger.info("Server started successfully. Press Ctrl+C to stop.")
+ runner.serve_forever()
+
+ return 0 # noqa: TRY300
+
+ except KeyboardInterrupt:
+ logger.info("Received shutdown signal, stopping server...")
+ return 130 # Standard exit code for SIGINT
+ except Exception:
+ logger.exception("Failed to start server")
+ return 1
+
+ def invoke_command(self, args: argparse.Namespace) -> int:
+ """Execute the invoke command.
+
+ Args:
+ args: Parsed command line arguments
+
+ Returns:
+ Exit code (0 for success, non-zero for error)
+ """
+ # Validate input JSON
+ try:
+ json.loads(args.input) # Just validate, don't store
+ except json.JSONDecodeError:
+ logger.exception("JSON decode error")
+ return 1
+
+ try:
+ # Create StartDurableExecutionInput
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012", # Default account ID for local testing
+ function_name=args.function_name,
+ function_qualifier="$LATEST", # Default qualifier
+ execution_name=args.durable_execution_name
+ or f"{args.function_name}-execution",
+ execution_timeout_seconds=300, # 5 minutes default
+ execution_retention_period_days=7, # 1 week default
+ invocation_id=str(uuid.uuid4()), # Generate unique invocation ID
+ input=args.input,
+ )
+
+ # Make HTTP request to start-durable-execution endpoint
+ endpoint_url = self.config.local_runner_endpoint
+ url = urljoin(endpoint_url, "/start-durable-execution")
+
+ payload = start_input.to_dict()
+ data = json.dumps(payload).encode("utf-8")
+ req = Request(
+ url,
+ data=data,
+ headers={"Content-Type": "application/json"},
+ method="POST",
+ )
+
+ try:
+ with urlopen(req, timeout=10) as response: # noqa: S310
+ result = json.loads(response.read().decode("utf-8"))
+ print(json.dumps(result, indent=2)) # noqa: T201
+ return 0
+ except HTTPError as e:
+ try:
+ error_data = json.loads(e.read().decode("utf-8"))
+ logger.exception("HTTP error response")
+ print( # noqa: T201
+ f"Error: {error_data.get('ErrorMessage', 'Unknown error')}",
+ file=sys.stderr,
+ )
+ except json.JSONDecodeError:
+ logger.exception("Non-JSON error response")
+ return 1
+
+ except URLError:
+ logger.exception(
+ "Error: Could not connect to the local runner server. Is it running?"
+ )
+ return 1
+ except Exception:
+ logger.exception("Unexpected error in invoke command")
+ return 1
+
+ def get_durable_execution_command(self, args: argparse.Namespace) -> int:
+ """Execute the get-durable-execution command.
+
+ Args:
+ args: Parsed command line arguments
+
+ Returns:
+ Exit code (0 for success, non-zero for error)
+ """
+ try:
+ # Set up boto3 client with local endpoint
+ client = self._create_boto3_client()
+
+ # Call get_durable_execution
+ response = client.get_durable_execution(
+ DurableExecutionArn=args.durable_execution_arn
+ )
+
+ # Print formatted response
+ print(json.dumps(response, indent=2, default=str)) # noqa: T201
+ return 0 # noqa: TRY300
+
+ except client.exceptions.InvalidParameterValueException as e:
+ print(f"Error: Invalid parameter - {e}", file=sys.stderr) # noqa: T201
+ return 1
+ except client.exceptions.ResourceNotFoundException as e:
+ print(f"Error: Execution not found - {e}", file=sys.stderr) # noqa: T201
+ return 1
+ except client.exceptions.TooManyRequestsException as e:
+ print(f"Error: Too many requests - {e}", file=sys.stderr) # noqa: T201
+ return 1
+ except client.exceptions.ServiceException as e:
+ print(f"Error: Service error - {e}", file=sys.stderr) # noqa: T201
+ return 1
+ except ConnectionError:
+ logger.exception(
+ "Error: Could not connect to the local runner server. Is it running?"
+ )
+ return 1
+ except Exception:
+ logger.exception("Unexpected error in get-durable-execution command")
+ return 1
+
+ def get_durable_execution_history_command(self, args: argparse.Namespace) -> int:
+ """Execute the get-durable-execution-history command.
+
+ TODO: implement - this is incomplete
+
+ Args:
+ args: Parsed command line arguments
+
+ Returns:
+ Exit code (0 for success, non-zero for error)
+ """
+ try:
+ # Set up boto3 client with local endpoint
+ client = self._create_boto3_client()
+
+ # Call get_durable_execution_history
+ response = client.get_durable_execution_history(
+ DurableExecutionArn=args.durable_execution_arn
+ )
+
+ print(json.dumps(response, indent=2, default=str)) # noqa: T201
+ return 0 # noqa: TRY300
+
+ except Exception:
+ logger.exception("General error")
+ return 1
+
+ # endregion commands
+
+ def _create_boto3_client(
+ self, endpoint_url: str | None = None, region_name: str | None = None
+ ) -> Any:
+ """Create boto3 client for Lambda service.
+
+ Args:
+ endpoint_url: Optional endpoint URL override
+ region_name: Optional region name override
+
+ Returns:
+ Configured boto3 client for local runner
+
+ Raises:
+ Exception: If client creation fails
+ """
+ try:
+ # Use provided values or fall back to config
+ final_endpoint = endpoint_url or self.config.local_runner_endpoint
+ final_region = region_name or self.config.local_runner_region
+
+ # Create client with local endpoint - no AWS access keys required
+ return boto3.client(
+ "lambda",
+ endpoint_url=final_endpoint,
+ region_name=final_region,
+ )
+ except Exception as e:
+ msg = f"Failed to create boto3 client: {e}"
+ raise DurableFunctionsLocalRunnerError(msg) from e
+
+
+def main() -> int:
+ """Main entry point for the dex-local-runner CLI."""
+ app = CliApp()
+ return app.run()
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/client.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/client.py
new file mode 100644
index 0000000..a68f0cc
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/client.py
@@ -0,0 +1,50 @@
+"""An in-memory service client, that can replace the boto lambda service client."""
+
+import datetime
+
+from aws_durable_execution_sdk_python.lambda_service import (
+ CheckpointOutput,
+ DurableServiceClient,
+ OperationUpdate,
+ StateOutput,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.processor import (
+ CheckpointProcessor,
+)
+
+
+class InMemoryServiceClient(DurableServiceClient):
+ """An in-memory service client, that can replace the boto lambda service client."""
+
+ def __init__(self, checkpoint_processor: CheckpointProcessor):
+ self._checkpoint_processor: CheckpointProcessor = checkpoint_processor
+
+ def checkpoint(
+ self,
+ durable_execution_arn: str, # noqa: ARG002
+ checkpoint_token: str,
+ updates: list[OperationUpdate],
+ client_token: str | None,
+ ) -> CheckpointOutput:
+ # durable_execution_arn is not used in in-memory testing
+ return self._checkpoint_processor.process_checkpoint(
+ checkpoint_token, updates, client_token
+ )
+
+ def get_execution_state(
+ self,
+ durable_execution_arn: str, # noqa: ARG002
+ checkpoint_token: str,
+ next_marker: str,
+ max_items: int = 1000,
+ ) -> StateOutput:
+ # durable_execution_arn is not used in in-memory testing
+ return self._checkpoint_processor.get_execution_state(
+ checkpoint_token, next_marker, max_items
+ )
+
+ def stop(self, execution_arn: str, payload: bytes | None) -> datetime.datetime: # noqa: ARG002
+ # TODO: implement
+ # Return current time for in-memory testing
+ return datetime.datetime.now(tz=datetime.UTC)
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/exceptions.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/exceptions.py
new file mode 100644
index 0000000..8d51e2f
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/exceptions.py
@@ -0,0 +1,288 @@
+"""Exceptions for the Durable Executions Testing Library.
+
+This module provides AWS-compliant exceptions that serialize to the exact JSON format
+expected by boto3 and AWS services. All exceptions follow Smithy model definitions
+for field names and structure.
+
+## AWS-Compliant Error Format
+
+All AWS API exceptions inherit from `AwsApiException` and implement the `to_dict()` method
+to serialize to AWS-compliant JSON format. The format varies by exception type based on
+their Smithy model definitions:
+
+### Standard Format (most exceptions):
+```json
+{
+ "Type": "ExceptionName",
+ "message": "Error message" // or "Message" depending on Smithy definition
+}
+```
+
+### Special Cases:
+- `ExecutionAlreadyStartedException`: No "Type" field, includes "DurableExecutionArn"
+```json
+{
+ "message": "Error message",
+ "DurableExecutionArn": "arn:aws:states:..."
+}
+```
+
+## Field Name Conventions
+
+Field names follow the exact Smithy model definitions:
+- `InvalidParameterValueException`: uses lowercase "message"
+- `CallbackTimeoutException`: uses lowercase "message"
+- `ResourceNotFoundException`: uses capital "Message"
+- `ServiceException`: uses capital "Message"
+- `ExecutionAlreadyStartedException`: uses lowercase "message" + "DurableExecutionArn"
+
+## HTTP Status Codes
+
+Each exception maps to appropriate HTTP status codes:
+- 400: InvalidParameterValueException (Bad Request)
+- 404: ResourceNotFoundException (Not Found)
+- 408: CallbackTimeoutException (Request Timeout)
+- 409: ExecutionAlreadyStartedException (Conflict)
+- 500: ServiceException (Internal Server Error)
+
+## Usage Examples
+
+```python
+# Create and serialize an exception
+exception = InvalidParameterValueException("Invalid parameter value")
+json_dict = exception.to_dict()
+# Result: {"Type": "InvalidParameterValueException", "message": "Invalid parameter value"}
+
+# HTTP response creation
+from aws_durable_execution_sdk_python_testing.web.models import HTTPResponse
+
+response = HTTPResponse.create_error_from_exception(exception)
+# Creates HTTP 400 response with AWS-compliant JSON body
+```
+
+## Boto3 Compatibility
+
+All exceptions are designed to be compatible with boto3's error handling:
+- JSON structure matches boto3 expectations
+- Field names match Smithy model definitions
+- Type field values match exception class names
+- Can be deserialized by boto3's error factory
+
+Avoid any non-stdlib references in this module, it is at the bottom of the dependency chain.
+"""
+
+from __future__ import annotations
+
+from typing import Any
+
+
+# region Local Runner
+class DurableFunctionsLocalRunnerError(Exception):
+ """Base class for Durable Executions exceptions"""
+
+
+class UnknownRouteError(DurableFunctionsLocalRunnerError):
+ """No route matches the requested path pattern."""
+
+ def __init__(self, method: str, path: str) -> None:
+ """Initialize UnknownRouteError with method and path.
+
+ Args:
+ method: HTTP method (GET, POST, etc.)
+ path: Request path that couldn't be matched
+ """
+ self.method = method
+ self.path = path
+ message = f"Unknown path pattern: {method} {path}"
+ super().__init__(message)
+
+
+# endregion Local Runner
+
+
+class SerializationError(DurableFunctionsLocalRunnerError):
+ """Exception for serialization errors."""
+
+
+# region Testing
+class DurableFunctionsTestError(Exception):
+ """Base class for testing errors."""
+
+
+# endregion Testing
+
+
+# region AWS API Exceptions
+class AwsApiException(DurableFunctionsLocalRunnerError): # noqa: N818
+ """Base class for AWS API-style exceptions that can be serialized to AWS format."""
+
+ http_status_code: int = 500 # Default to server error
+
+ def to_dict(self) -> dict[str, Any]:
+ """Serialize to AWS-compliant JSON structure."""
+ raise NotImplementedError
+
+
+# Smithy-Mapped Exceptions (defined in Smithy models)
+class InvalidParameterValueException(AwsApiException):
+ """Exception for invalid parameter values."""
+
+ http_status_code = 400
+
+ def __init__(self, message: str) -> None:
+ """Initialize with message field (lowercase per Smithy definition)."""
+ self.message = message
+ super().__init__(message)
+
+ def to_dict(self) -> dict[str, Any]:
+ """Serialize to AWS-compliant JSON structure."""
+ return {"Type": "InvalidParameterValueException", "message": self.message}
+
+
+class ResourceNotFoundException(AwsApiException):
+ """Exception for resource not found errors."""
+
+ http_status_code = 404
+
+ def __init__(
+ self,
+ Message: str, # noqa: N803
+ ) -> None: # Capital M per Smithy definition
+ """Initialize with Message field (capital M per Smithy definition)."""
+ self.Message = Message
+ super().__init__(Message)
+
+ def to_dict(self) -> dict[str, Any]:
+ """Serialize to AWS-compliant JSON structure."""
+ return {"Type": "ResourceNotFoundException", "Message": self.Message}
+
+
+class ServiceException(AwsApiException):
+ """Exception for general service errors."""
+
+ http_status_code = 500
+
+ def __init__(
+ self,
+ Message: str, # noqa: N803
+ ) -> None: # Capital M per Smithy definition
+ """Initialize with Message field (capital M per Smithy definition)."""
+ self.Message = Message
+ super().__init__(Message)
+
+ def to_dict(self) -> dict[str, Any]:
+ """Serialize to AWS-compliant JSON structure."""
+ return {"Type": "ServiceException", "Message": self.Message}
+
+
+class ExecutionAlreadyStartedException(AwsApiException):
+ """Exception for execution already started errors."""
+
+ http_status_code = 409
+
+ def __init__(self, message: str, DurableExecutionArn: str) -> None: # noqa: N803
+ """Initialize with message and DurableExecutionArn fields."""
+ self.message = message
+ self.DurableExecutionArn = DurableExecutionArn
+ super().__init__(message)
+
+ def to_dict(self) -> dict[str, Any]:
+ """Serialize to AWS-compliant JSON structure (no Type field per Smithy definition)."""
+ return {
+ "message": self.message,
+ "DurableExecutionArn": self.DurableExecutionArn,
+ }
+
+
+class ExecutionConflictException(AwsApiException):
+ """Exception for execution conflict errors."""
+
+ http_status_code = 409
+
+ def __init__(self, message: str) -> None:
+ """Initialize with message field."""
+ self.message = message
+ super().__init__(message)
+
+ def to_dict(self) -> dict[str, Any]:
+ """Serialize to AWS-compliant JSON structure."""
+ return {"Type": "ExecutionConflictException", "message": self.message}
+
+
+class CallbackTimeoutException(AwsApiException):
+ """Exception for callback timeout errors."""
+
+ http_status_code = 408
+
+ def __init__(self, message: str) -> None:
+ """Initialize with message field (lowercase per Smithy definition)."""
+ self.message = message
+ super().__init__(message)
+
+ def to_dict(self) -> dict[str, Any]:
+ """Serialize to AWS-compliant JSON structure."""
+ return {"Type": "CallbackTimeoutException", "message": self.message}
+
+
+class TooManyRequestsException(AwsApiException):
+ """Exception for too many requests errors."""
+
+ http_status_code = 429
+
+ def __init__(self, message: str) -> None:
+ """Initialize with message field (lowercase per Smithy definition)."""
+ self.message = message
+ super().__init__(message)
+
+ def to_dict(self) -> dict[str, Any]:
+ """Serialize to AWS-compliant JSON structure."""
+ return {"Type": "TooManyRequestsException", "message": self.message}
+
+
+# Unmapped Exceptions (thrown by services but not in Smithy)
+class IllegalStateException(AwsApiException):
+ """IllegalStateException."""
+
+ http_status_code = 500
+
+ def __init__(self, message: str) -> None:
+ """Initialize with message field."""
+ self.message = message
+ super().__init__(message)
+
+ def to_dict(self) -> dict[str, Any]:
+ """Serialize to AWS-compliant JSON structure (maps to ServiceException)."""
+ return {"Type": "ServiceException", "Message": self.message}
+
+
+class RuntimeException(AwsApiException):
+ """RuntimeException."""
+
+ http_status_code = 500
+
+ def __init__(self, message: str) -> None:
+ """Initialize with message field."""
+ self.message = message
+ super().__init__(message)
+
+ def to_dict(self) -> dict[str, Any]:
+ """Serialize to AWS-compliant JSON structure (maps to ServiceException)."""
+ return {"Type": "ServiceException", "Message": self.message}
+
+
+class IllegalArgumentException(AwsApiException):
+ """IllegalArgumentException."""
+
+ http_status_code = 400
+
+ def __init__(self, message: str) -> None:
+ """Initialize with message field."""
+ self.message = message
+ super().__init__(message)
+
+ def to_dict(self) -> dict[str, Any]:
+ """Serialize to AWS-compliant JSON structure (maps to InvalidParameterValueException)."""
+ return {"Type": "InvalidParameterValueException", "message": self.message}
+
+
+# endregion AWS API Exceptions
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/execution.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/execution.py
new file mode 100644
index 0000000..a1096f1
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/execution.py
@@ -0,0 +1,444 @@
+from __future__ import annotations
+
+from dataclasses import replace
+from datetime import UTC, datetime
+from enum import Enum
+from threading import Lock
+from typing import Any
+from uuid import uuid4
+
+from aws_durable_execution_sdk_python.execution import (
+ DurableExecutionInvocationOutput,
+ InvocationStatus,
+)
+from aws_durable_execution_sdk_python.lambda_service import (
+ ErrorObject,
+ ExecutionDetails,
+ Operation,
+ OperationStatus,
+ OperationType,
+ OperationUpdate,
+)
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ IllegalStateException,
+ InvalidParameterValueException,
+)
+
+# Import AWS exceptions
+from aws_durable_execution_sdk_python_testing.model import (
+ InvocationCompletedDetails,
+ StartDurableExecutionInput,
+)
+from aws_durable_execution_sdk_python_testing.token import (
+ CheckpointToken,
+)
+
+
+class ExecutionStatus(Enum):
+ """Execution status for API responses."""
+
+ RUNNING = "RUNNING"
+ SUCCEEDED = "SUCCEEDED"
+ FAILED = "FAILED"
+ STOPPED = "STOPPED"
+ TIMED_OUT = "TIMED_OUT"
+
+
+class Execution:
+ """Execution state."""
+
+ def __init__(
+ self,
+ durable_execution_arn: str,
+ start_input: StartDurableExecutionInput,
+ operations: list[Operation],
+ ):
+ self.durable_execution_arn: str = durable_execution_arn
+ # operation is frozen, it won't mutate - no need to clone/deep-copy
+ self.start_input: StartDurableExecutionInput = start_input
+ self.operations: list[Operation] = operations
+ self.updates: list[OperationUpdate] = []
+ self.invocation_completions: list[InvocationCompletedDetails] = []
+ self.used_tokens: set[str] = set()
+ # TODO: this will need to persist/rehydrate depending on inmemory vs sqllite store
+ self._token_sequence: int = 0
+ self._state_lock: Lock = Lock()
+ self.is_complete: bool = False
+ self.result: DurableExecutionInvocationOutput | None = None
+ self.consecutive_failed_invocation_attempts: int = 0
+ self.close_status: ExecutionStatus | None = None
+
+ @property
+ def token_sequence(self) -> int:
+ """Get current token sequence value."""
+ return self._token_sequence
+
+ def current_status(self) -> ExecutionStatus:
+ """Get execution status."""
+ if not self.is_complete:
+ return ExecutionStatus.RUNNING
+
+ if not self.close_status:
+ msg: str = "close_status must be set when execution is complete"
+ raise IllegalStateException(msg)
+
+ return self.close_status
+
+ @staticmethod
+ def new(input: StartDurableExecutionInput) -> Execution: # noqa: A002
+ # make a nicer arn
+ # Pattern: arn:(aws[a-zA-Z-]*)?:lambda:[a-z]{2}(-gov)?-[a-z]+-\d{1}:\d{12}:durable-execution:[a-zA-Z0-9-_\.]+:[a-zA-Z0-9-_\.]+:[a-zA-Z0-9-_\.]+
+ # Example: arn:aws:lambda:us-east-1:123456789012:durable-execution:myDurableFunction:myDurableExecutionName:ce67da72-3701-4f83-9174-f4189d27b0a5
+ return Execution(
+ durable_execution_arn=str(uuid4())
+ + "/"
+ + (input.invocation_id or str(uuid4())),
+ start_input=input,
+ operations=[],
+ )
+
+ def to_json_dict(self) -> dict[str, Any]:
+ """Serialize execution to JSON-serializable dictionary"""
+ return {
+ "DurableExecutionArn": self.durable_execution_arn,
+ "StartInput": self.start_input.to_dict(),
+ "Operations": [op.to_json_dict() for op in self.operations],
+ "Updates": [update.to_dict() for update in self.updates],
+ "InvocationCompletions": [
+ completion.to_json_dict() for completion in self.invocation_completions
+ ],
+ "UsedTokens": list(self.used_tokens),
+ "TokenSequence": self._token_sequence,
+ "IsComplete": self.is_complete,
+ "Result": self.result.to_dict() if self.result else None,
+ "ConsecutiveFailedInvocationAttempts": self.consecutive_failed_invocation_attempts,
+ "CloseStatus": self.close_status.value if self.close_status else None,
+ }
+
+ @classmethod
+ def from_json_dict(cls, data: dict[str, Any]) -> Execution:
+ """Deserialize execution from dictionary."""
+ # Reconstruct start_input
+ start_input = StartDurableExecutionInput.from_dict(data["StartInput"])
+
+ # Reconstruct operations
+ operations = [
+ Operation.from_json_dict(op_data) for op_data in data["Operations"]
+ ]
+
+ # Create execution
+ execution = cls(
+ durable_execution_arn=data["DurableExecutionArn"],
+ start_input=start_input,
+ operations=operations,
+ )
+
+ # Set additional fields
+ execution.updates = [
+ OperationUpdate.from_dict(update_data) for update_data in data["Updates"]
+ ]
+ execution.invocation_completions = [
+ InvocationCompletedDetails.from_json_dict(item)
+ for item in data.get("InvocationCompletions", [])
+ ]
+ execution.used_tokens = set(data["UsedTokens"])
+ execution._token_sequence = data["TokenSequence"] # noqa: SLF001
+ execution.is_complete = data["IsComplete"]
+ execution.result = (
+ DurableExecutionInvocationOutput.from_dict(data["Result"])
+ if data["Result"]
+ else None
+ )
+ execution.consecutive_failed_invocation_attempts = data[
+ "ConsecutiveFailedInvocationAttempts"
+ ]
+ close_status_str = data.get("CloseStatus")
+ execution.close_status = (
+ ExecutionStatus(close_status_str) if close_status_str else None
+ )
+
+ return execution
+
+ def start(self) -> None:
+ if self.start_input.invocation_id is None:
+ msg: str = "invocation_id is required"
+ raise InvalidParameterValueException(msg)
+ with self._state_lock:
+ self.operations.append(
+ Operation(
+ operation_id=self.start_input.invocation_id,
+ parent_id=None,
+ name=self.start_input.execution_name,
+ start_timestamp=datetime.now(UTC),
+ operation_type=OperationType.EXECUTION,
+ status=OperationStatus.STARTED,
+ execution_details=ExecutionDetails(
+ input_payload=self.start_input.get_normalized_input()
+ ),
+ )
+ )
+
+ def get_operation_execution_started(self) -> Operation:
+ if not self.operations:
+ msg: str = "execution not started."
+
+ raise IllegalStateException(msg)
+
+ return self.operations[0]
+
+ def get_new_checkpoint_token(self) -> str:
+ """Generate a new checkpoint token with incremented sequence"""
+ with self._state_lock:
+ self._token_sequence += 1
+ new_token_sequence = self._token_sequence
+ token = CheckpointToken(
+ execution_arn=self.durable_execution_arn,
+ token_sequence=new_token_sequence,
+ )
+ token_str = token.to_str()
+ self.used_tokens.add(token_str)
+ return token_str
+
+ def get_navigable_operations(self) -> list[Operation]:
+ """Get list of operations, but exclude child operations where the parent has already completed."""
+ return self.operations
+
+ def get_assertable_operations(self) -> list[Operation]:
+ """Get list of operations, but exclude the EXECUTION operations"""
+ # TODO: this excludes EXECUTION at start, but can there be an EXECUTION at the end if there was a checkpoint with large payload?
+ return self.operations[1:]
+
+ def has_pending_operations(self, execution: Execution) -> bool:
+ """True if execution has pending operations."""
+
+ for operation in execution.operations:
+ if (
+ operation.operation_type == OperationType.STEP
+ and operation.status == OperationStatus.PENDING
+ ) or (
+ operation.operation_type
+ in [
+ OperationType.WAIT,
+ OperationType.CALLBACK,
+ OperationType.CHAINED_INVOKE,
+ ]
+ and operation.status == OperationStatus.STARTED
+ ):
+ return True
+ return False
+
+ def record_invocation_completion(
+ self, start_timestamp: datetime, end_timestamp: datetime, request_id: str
+ ) -> None:
+ """Record an invocation completion event."""
+ self.invocation_completions.append(
+ InvocationCompletedDetails(
+ start_timestamp=start_timestamp,
+ end_timestamp=end_timestamp,
+ request_id=request_id,
+ )
+ )
+
+ def complete_success(self, result: str | None) -> None:
+ """Complete execution successfully (DecisionType.COMPLETE_WORKFLOW_EXECUTION)."""
+ self.result = DurableExecutionInvocationOutput(
+ status=InvocationStatus.SUCCEEDED, result=result
+ )
+ self.is_complete = True
+ self.close_status = ExecutionStatus.SUCCEEDED
+ self._end_execution(OperationStatus.SUCCEEDED)
+
+ def complete_fail(self, error: ErrorObject) -> None:
+ """Complete execution with failure (DecisionType.FAIL_WORKFLOW_EXECUTION)."""
+ self.result = DurableExecutionInvocationOutput(
+ status=InvocationStatus.FAILED, error=error
+ )
+ self.is_complete = True
+ self.close_status = ExecutionStatus.FAILED
+ self._end_execution(OperationStatus.FAILED)
+
+ def complete_timeout(self, error: ErrorObject) -> None:
+ """Complete execution with timeout."""
+ self.result = DurableExecutionInvocationOutput(
+ status=InvocationStatus.FAILED, error=error
+ )
+ self.is_complete = True
+ self.close_status = ExecutionStatus.TIMED_OUT
+ self._end_execution(OperationStatus.TIMED_OUT)
+
+ def complete_stopped(self, error: ErrorObject) -> None:
+ """Complete execution as terminated (TerminateWorkflowExecutionV2Request)."""
+ self.result = DurableExecutionInvocationOutput(
+ status=InvocationStatus.FAILED, error=error
+ )
+ self.is_complete = True
+ self.close_status = ExecutionStatus.STOPPED
+ self._end_execution(OperationStatus.STOPPED)
+
+ def find_operation(self, operation_id: str) -> tuple[int, Operation]:
+ """Find operation by ID, return index and operation."""
+ for i, operation in enumerate(self.operations):
+ if operation.operation_id == operation_id:
+ return i, operation
+ msg: str = f"Attempting to update state of an Operation [{operation_id}] that doesn't exist"
+ raise IllegalStateException(msg)
+
+ def find_callback_operation(self, callback_id: str) -> tuple[int, Operation]:
+ """Find callback operation by callback_id, return index and operation."""
+ for i, operation in enumerate(self.operations):
+ if (
+ operation.operation_type == OperationType.CALLBACK
+ and operation.callback_details
+ and operation.callback_details.callback_id == callback_id
+ ):
+ return i, operation
+ msg: str = f"Callback operation with callback_id [{callback_id}] not found"
+ raise IllegalStateException(msg)
+
+ def complete_wait(self, operation_id: str) -> Operation:
+ """Complete WAIT operation when timer fires."""
+ index, operation = self.find_operation(operation_id)
+
+ # Validate
+ if operation.status != OperationStatus.STARTED:
+ msg_wait_not_started: str = f"Attempting to transition a Wait Operation[{operation_id}] to SUCCEEDED when it's not STARTED"
+ raise IllegalStateException(msg_wait_not_started)
+ if operation.operation_type != OperationType.WAIT:
+ msg_not_wait: str = (
+ f"Expected WAIT operation, got {operation.operation_type}"
+ )
+ raise IllegalStateException(msg_not_wait)
+
+ # Thread-safe increment sequence and operation update
+ with self._state_lock:
+ self._token_sequence += 1
+ # Build and assign updated operation
+ self.operations[index] = replace(
+ operation,
+ status=OperationStatus.SUCCEEDED,
+ end_timestamp=datetime.now(UTC),
+ )
+ return self.operations[index]
+
+ def complete_retry(self, operation_id: str) -> Operation:
+ """Complete STEP retry when timer fires."""
+ index, operation = self.find_operation(operation_id)
+
+ # Validate
+ if operation.status != OperationStatus.PENDING:
+ msg_step_not_pending: str = f"Attempting to transition a Step Operation[{operation_id}] to READY when it's not PENDING"
+ raise IllegalStateException(msg_step_not_pending)
+ if operation.operation_type != OperationType.STEP:
+ msg_not_step: str = (
+ f"Expected STEP operation, got {operation.operation_type}"
+ )
+ raise IllegalStateException(msg_not_step)
+
+ # Thread-safe increment sequence and operation update
+ with self._state_lock:
+ self._token_sequence += 1
+ # Build updated step_details with cleared next_attempt_timestamp
+ new_step_details = None
+ if operation.step_details:
+ new_step_details = replace(
+ operation.step_details, next_attempt_timestamp=None
+ )
+
+ # Build updated operation
+ updated_operation = replace(
+ operation, status=OperationStatus.READY, step_details=new_step_details
+ )
+
+ # Assign
+ self.operations[index] = updated_operation
+ return updated_operation
+
+ def complete_callback_success(
+ self, callback_id: str, result: bytes | None = None
+ ) -> Operation:
+ """Complete CALLBACK operation with success."""
+ index, operation = self.find_callback_operation(callback_id)
+ if operation.status != OperationStatus.STARTED:
+ msg: str = f"Callback operation [{callback_id}] is not in STARTED state"
+ raise IllegalStateException(msg)
+
+ with self._state_lock:
+ self._token_sequence += 1
+ updated_callback_details = None
+ if operation.callback_details:
+ updated_callback_details = replace(
+ operation.callback_details,
+ result=result.decode() if result else None,
+ )
+
+ self.operations[index] = replace(
+ operation,
+ status=OperationStatus.SUCCEEDED,
+ end_timestamp=datetime.now(UTC),
+ callback_details=updated_callback_details,
+ )
+ return self.operations[index]
+
+ def complete_callback_failure(
+ self, callback_id: str, error: ErrorObject
+ ) -> Operation:
+ """Complete CALLBACK operation with failure."""
+ index, operation = self.find_callback_operation(callback_id)
+
+ if operation.status != OperationStatus.STARTED:
+ msg: str = f"Callback operation [{callback_id}] is not in STARTED state"
+ raise IllegalStateException(msg)
+
+ with self._state_lock:
+ self._token_sequence += 1
+ updated_callback_details = None
+ if operation.callback_details:
+ updated_callback_details = replace(
+ operation.callback_details, error=error
+ )
+
+ self.operations[index] = replace(
+ operation,
+ status=OperationStatus.FAILED,
+ end_timestamp=datetime.now(UTC),
+ callback_details=updated_callback_details,
+ )
+ return self.operations[index]
+
+ def complete_callback_timeout(
+ self, callback_id: str, error: ErrorObject
+ ) -> Operation:
+ """Complete CALLBACK operation with timeout."""
+ index, operation = self.find_callback_operation(callback_id)
+
+ if operation.status != OperationStatus.STARTED:
+ msg: str = f"Callback operation [{callback_id}] is not in STARTED state"
+ raise IllegalStateException(msg)
+
+ with self._state_lock:
+ self._token_sequence += 1
+ updated_callback_details = None
+ if operation.callback_details:
+ updated_callback_details = replace(
+ operation.callback_details, error=error
+ )
+
+ self.operations[index] = replace(
+ operation,
+ status=OperationStatus.TIMED_OUT,
+ end_timestamp=datetime.now(UTC),
+ callback_details=updated_callback_details,
+ )
+ return self.operations[index]
+
+ def _end_execution(self, status: OperationStatus) -> None:
+ """Set the end_timestamp on the main EXECUTION operation when execution completes."""
+ execution_op: Operation = self.get_operation_execution_started()
+ if execution_op.operation_type == OperationType.EXECUTION:
+ with self._state_lock:
+ self.operations[0] = replace(
+ execution_op,
+ status=status,
+ end_timestamp=datetime.now(UTC),
+ )
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/executor.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/executor.py
new file mode 100644
index 0000000..02a1504
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/executor.py
@@ -0,0 +1,1234 @@
+"""Execution life-cycle logic."""
+
+from __future__ import annotations
+
+import logging
+import time
+import uuid
+from datetime import UTC, datetime
+from typing import TYPE_CHECKING
+
+from aws_durable_execution_sdk_python.execution import (
+ DurableExecutionInvocationInput,
+ DurableExecutionInvocationOutput,
+ InvocationStatus,
+)
+from aws_durable_execution_sdk_python.lambda_service import (
+ CallbackTimeoutType,
+ ErrorObject,
+ Operation,
+ OperationUpdate,
+ OperationStatus,
+ OperationType,
+ CallbackOptions,
+)
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ ExecutionAlreadyStartedException,
+ IllegalStateException,
+ InvalidParameterValueException,
+ ResourceNotFoundException,
+)
+from aws_durable_execution_sdk_python_testing.execution import Execution
+from aws_durable_execution_sdk_python_testing.model import (
+ CheckpointDurableExecutionResponse,
+ CheckpointUpdatedExecutionState,
+ EventCreationContext,
+ EventType,
+ GetDurableExecutionHistoryResponse,
+ GetDurableExecutionResponse,
+ GetDurableExecutionStateResponse,
+ ListDurableExecutionsByFunctionResponse,
+ ListDurableExecutionsResponse,
+ SendDurableExecutionCallbackFailureResponse,
+ SendDurableExecutionCallbackHeartbeatResponse,
+ SendDurableExecutionCallbackSuccessResponse,
+ StartDurableExecutionInput,
+ StartDurableExecutionOutput,
+ StopDurableExecutionResponse,
+ TERMINAL_STATUSES,
+)
+from aws_durable_execution_sdk_python_testing.model import (
+ Event as HistoryEvent,
+)
+from aws_durable_execution_sdk_python_testing.model import (
+ Execution as ExecutionSummary,
+)
+from aws_durable_execution_sdk_python_testing.observer import ExecutionObserver
+from aws_durable_execution_sdk_python_testing.token import CallbackToken
+
+
+if TYPE_CHECKING:
+ from collections.abc import Awaitable, Callable
+ from concurrent.futures import Future
+
+ from aws_durable_execution_sdk_python_testing.checkpoint.processor import (
+ CheckpointProcessor,
+ )
+ from aws_durable_execution_sdk_python_testing.invoker import Invoker
+ from aws_durable_execution_sdk_python_testing.scheduler import Event, Scheduler
+ from aws_durable_execution_sdk_python_testing.stores.base import ExecutionStore
+
+logger = logging.getLogger(__name__)
+
+
+class Executor(ExecutionObserver):
+ MAX_CONSECUTIVE_FAILED_ATTEMPTS: int = 5
+ RETRY_BACKOFF_SECONDS: int = 5
+
+ def __init__(
+ self,
+ store: ExecutionStore,
+ scheduler: Scheduler,
+ invoker: Invoker,
+ checkpoint_processor: CheckpointProcessor,
+ ):
+ self._store = store
+ self._scheduler = scheduler
+ self._invoker = invoker
+ self._checkpoint_processor = checkpoint_processor
+ self._completion_events: dict[str, Event] = {}
+ self._callback_timeouts: dict[str, Future] = {}
+ self._callback_heartbeats: dict[str, Future] = {}
+ self._execution_timeout: Future | None = None
+
+ def start_execution(
+ self,
+ input: StartDurableExecutionInput, # noqa: A002
+ ) -> StartDurableExecutionOutput:
+ # Generate invocation_id if not provided
+ if input.invocation_id is None:
+ input = StartDurableExecutionInput(
+ account_id=input.account_id,
+ function_name=input.function_name,
+ function_qualifier=input.function_qualifier,
+ execution_name=input.execution_name,
+ execution_timeout_seconds=input.execution_timeout_seconds,
+ execution_retention_period_days=input.execution_retention_period_days,
+ invocation_id=str(uuid.uuid4()),
+ trace_fields=input.trace_fields,
+ tenant_id=input.tenant_id,
+ input=input.input,
+ lambda_endpoint=input.lambda_endpoint,
+ )
+
+ execution = Execution.new(input=input)
+ execution.start()
+ self._store.save(execution)
+ logger.debug("Created execution with ARN: %s", execution.durable_execution_arn)
+
+ completion_event = self._scheduler.create_event()
+ self._completion_events[execution.durable_execution_arn] = completion_event
+
+ # Schedule execution timeout
+ if input.execution_timeout_seconds > 0:
+
+ def timeout_handler():
+ error = ErrorObject.from_message(
+ f"Execution timed out after {input.execution_timeout_seconds} seconds."
+ )
+ self.on_timed_out(execution.durable_execution_arn, error)
+
+ self._execution_timeout = self._scheduler.call_later(
+ timeout_handler,
+ delay=input.execution_timeout_seconds,
+ completion_event=completion_event,
+ )
+
+ # Schedule initial invocation to run immediately
+ self._invoke_execution(execution.durable_execution_arn)
+
+ return StartDurableExecutionOutput(
+ execution_arn=execution.durable_execution_arn
+ )
+
+ def get_execution(self, execution_arn: str) -> Execution:
+ """Get execution by ARN.
+
+ Args:
+ execution_arn: The execution ARN to retrieve
+
+ Returns:
+ Execution: The execution object
+
+ Raises:
+ ResourceNotFoundException: If execution does not exist
+ """
+ try:
+ return self._store.load(execution_arn)
+ except KeyError as e:
+ msg: str = f"Execution {execution_arn} not found"
+ raise ResourceNotFoundException(msg) from e
+
+ def get_execution_details(self, execution_arn: str) -> GetDurableExecutionResponse:
+ """Get detailed execution information for web API response.
+
+ Args:
+ execution_arn: The execution ARN to retrieve
+
+ Returns:
+ GetDurableExecutionResponse: Detailed execution information
+
+ Raises:
+ ResourceNotFoundException: If execution does not exist
+ """
+ execution = self.get_execution(execution_arn)
+
+ # Extract execution details from the first operation (EXECUTION type)
+ execution_op = execution.get_operation_execution_started()
+ status = execution.current_status().value
+
+ # Extract result and error from execution result
+ result = None
+ error = None
+ if execution.result:
+ if execution.result.status == InvocationStatus.SUCCEEDED:
+ result = execution.result.result
+ elif execution.result.status == InvocationStatus.FAILED:
+ error = execution.result.error
+
+ return GetDurableExecutionResponse(
+ durable_execution_arn=execution.durable_execution_arn,
+ durable_execution_name=execution.start_input.execution_name,
+ function_arn=f"arn:aws:lambda:us-east-1:123456789012:function:{execution.start_input.function_name}",
+ status=status,
+ start_timestamp=execution_op.start_timestamp
+ if execution_op.start_timestamp
+ else datetime.now(UTC),
+ input_payload=execution_op.execution_details.input_payload
+ if execution_op.execution_details
+ else None,
+ result=result,
+ error=error,
+ end_timestamp=execution_op.end_timestamp
+ if execution_op.end_timestamp
+ else None,
+ version="1.0",
+ )
+
+ def list_executions(
+ self,
+ function_name: str | None = None,
+ function_version: str | None = None, # noqa: ARG002
+ execution_name: str | None = None,
+ status_filter: str | None = None,
+ started_after: str | None = None,
+ started_before: str | None = None,
+ marker: str | None = None,
+ max_items: int | None = None,
+ reverse_order: bool = False, # noqa: FBT001, FBT002
+ ) -> ListDurableExecutionsResponse:
+ """List executions with filtering and pagination.
+
+ Args:
+ function_name: Filter by function name
+ function_version: Filter by function version
+ execution_name: Filter by execution name
+ status_filter: Filter by status (RUNNING, SUCCEEDED, FAILED)
+ started_after: Filter executions started after this time
+ started_before: Filter executions started before this time
+ marker: Pagination marker
+ max_items: Maximum items to return (default 50)
+ reverse_order: Return results in reverse chronological order
+
+ Returns:
+ ListDurableExecutionsResponse: List of executions with pagination
+ """
+ # Convert marker to offset
+ offset: int = 0
+ if marker:
+ try:
+ offset = int(marker)
+ except ValueError:
+ offset = 0
+
+ # Query store directly with parameters
+ executions, next_marker = self._store.query(
+ function_name=function_name,
+ execution_name=execution_name,
+ status_filter=status_filter,
+ started_after=started_after,
+ started_before=started_before,
+ limit=max_items or 50,
+ offset=offset,
+ reverse_order=reverse_order,
+ )
+
+ # Convert to ExecutionSummary objects
+ execution_summaries: list[ExecutionSummary] = [
+ ExecutionSummary.from_execution(execution, execution.current_status().value)
+ for execution in executions
+ ]
+
+ return ListDurableExecutionsResponse(
+ durable_executions=execution_summaries, next_marker=next_marker
+ )
+
+ def list_executions_by_function(
+ self,
+ function_name: str,
+ qualifier: str | None = None, # noqa: ARG002
+ execution_name: str | None = None,
+ status_filter: str | None = None,
+ started_after: str | None = None,
+ started_before: str | None = None,
+ marker: str | None = None,
+ max_items: int | None = None,
+ reverse_order: bool = False, # noqa: FBT001, FBT002
+ ) -> ListDurableExecutionsByFunctionResponse:
+ """List executions for a specific function.
+
+ Args:
+ function_name: The function name to filter by
+ qualifier: Function qualifier/version
+ execution_name: Filter by execution name
+ status_filter: Filter by status (RUNNING, SUCCEEDED, FAILED)
+ started_after: Filter executions started after this time
+ started_before: Filter executions started before this time
+ marker: Pagination marker
+ max_items: Maximum items to return (default 50)
+ reverse_order: Return results in reverse chronological order
+
+ Returns:
+ ListDurableExecutionsByFunctionResponse: List of executions for the function
+ """
+ # Use the general list_executions method with function_name filter
+ list_response = self.list_executions(
+ function_name=function_name,
+ execution_name=execution_name,
+ status_filter=status_filter,
+ started_after=started_after,
+ started_before=started_before,
+ marker=marker,
+ max_items=max_items,
+ reverse_order=reverse_order,
+ )
+
+ return ListDurableExecutionsByFunctionResponse(
+ durable_executions=list_response.durable_executions,
+ next_marker=list_response.next_marker,
+ )
+
+ def stop_execution(
+ self, execution_arn: str, error: ErrorObject | None = None
+ ) -> StopDurableExecutionResponse:
+ """Stop a running execution.
+
+ Args:
+ execution_arn: The execution ARN to stop
+ error: Optional error to use when stopping the execution
+
+ Returns:
+ StopDurableExecutionResponse: Response containing end timestamp
+
+ Raises:
+ ResourceNotFoundException: If execution does not exist
+ """
+ execution = self.get_execution(execution_arn)
+
+ if execution.is_complete:
+ # Idempotent: return the existing stop timestamp
+ execution_op = execution.get_operation_execution_started()
+ stop_timestamp = execution_op.end_timestamp or datetime.now(UTC)
+ return StopDurableExecutionResponse(stop_timestamp=stop_timestamp)
+
+ # Use provided error or create a default one
+ stop_error = error or ErrorObject.from_message(
+ "Execution stopped by user request"
+ )
+
+ # Stop sets TERMINATED close status (different from fail)
+ logger.exception("[%s] Stopping execution.", execution_arn)
+ execution.complete_stopped(error=stop_error) # Sets CloseStatus.TERMINATED
+ self._store.update(execution)
+ self._complete_events(execution_arn=execution_arn)
+
+ return StopDurableExecutionResponse(stop_timestamp=datetime.now(UTC))
+
+ def get_execution_state(
+ self,
+ execution_arn: str,
+ checkpoint_token: str | None = None,
+ marker: str | None = None,
+ max_items: int | None = None,
+ ) -> GetDurableExecutionStateResponse:
+ """Get execution state with operations.
+
+ Args:
+ execution_arn: The execution ARN
+ checkpoint_token: Checkpoint token for state consistency
+ marker: Pagination marker
+ max_items: Maximum items to return
+
+ Returns:
+ GetDurableExecutionStateResponse: Execution state with operations
+
+ Raises:
+ ResourceNotFoundException: If execution does not exist
+ InvalidParameterValueException: If checkpoint token is invalid
+ """
+ execution = self.get_execution(execution_arn)
+
+ # TODO: Validate checkpoint token if provided
+ if checkpoint_token and checkpoint_token not in execution.used_tokens:
+ msg: str = f"Invalid checkpoint token: {checkpoint_token}"
+ raise InvalidParameterValueException(msg)
+
+ # Get operations (excluding the initial EXECUTION operation for state)
+ operations = execution.get_assertable_operations()
+
+ # Apply pagination
+ if max_items is None:
+ max_items = 100
+
+ # Simple pagination - in real implementation would need proper marker handling
+ start_index = 0
+ if marker:
+ try:
+ start_index = int(marker)
+ except ValueError:
+ start_index = 0
+
+ end_index = start_index + max_items
+ paginated_operations = operations[start_index:end_index]
+
+ next_marker = None
+ if end_index < len(operations):
+ next_marker = str(end_index)
+
+ return GetDurableExecutionStateResponse(
+ operations=paginated_operations, next_marker=next_marker
+ )
+
+ def get_execution_history(
+ self,
+ execution_arn: str,
+ include_execution_data: bool = False, # noqa: FBT001, FBT002
+ reverse_order: bool = False, # noqa: FBT001, FBT002
+ marker: str | None = None,
+ max_items: int | None = None,
+ ) -> GetDurableExecutionHistoryResponse:
+ """Get execution history with events.
+
+ Args:
+ execution_arn: The execution ARN
+ include_execution_data: Whether to include execution data in events
+ reverse_order: Return events in reverse chronological order
+ marker: Pagination marker (event_id)
+ max_items: Maximum items to return
+
+ Returns:
+ GetDurableExecutionHistoryResponse: Execution history with events
+
+ Raises:
+ ResourceNotFoundException: If execution does not exist
+ """
+ execution: Execution = self.get_execution(execution_arn)
+
+ # Generate events
+ all_events: list[HistoryEvent] = []
+ ops: list[Operation] = execution.operations
+ updates: list[OperationUpdate] = execution.updates
+ updates_dict: dict[str, OperationUpdate] = {u.operation_id: u for u in updates}
+ durable_execution_arn: str = execution.durable_execution_arn
+
+ # Add InvocationCompleted events
+ for completion in execution.invocation_completions:
+ invocation_event = HistoryEvent.create_invocation_completed(
+ event_id=0, # Temporary, will be reassigned
+ event_timestamp=completion.end_timestamp,
+ start_timestamp=completion.start_timestamp,
+ end_timestamp=completion.end_timestamp,
+ request_id=completion.request_id,
+ )
+ all_events.append(invocation_event)
+
+ # Generate all events first (without final event IDs)
+ for op in ops:
+ operation_update: OperationUpdate | None = updates_dict.get(
+ op.operation_id, None
+ )
+
+ if op.status is OperationStatus.PENDING:
+ if (
+ op.operation_type is not OperationType.CHAINED_INVOKE
+ or op.start_timestamp is None
+ ):
+ continue
+ context: EventCreationContext = EventCreationContext(
+ op,
+ 0, # Temporary event_id, will be reassigned after sorting
+ durable_execution_arn,
+ execution.start_input,
+ execution.result,
+ operation_update,
+ include_execution_data,
+ )
+ pending = HistoryEvent.create_chained_invoke_event_pending(context)
+ all_events.append(pending)
+ if op.start_timestamp is not None:
+ context = EventCreationContext(
+ op,
+ 0, # Temporary event_id, will be reassigned after sorting
+ durable_execution_arn,
+ execution.start_input,
+ execution.result,
+ operation_update,
+ include_execution_data,
+ )
+ started = HistoryEvent.create_event_started(context)
+ all_events.append(started)
+ if op.end_timestamp is not None and op.status in TERMINAL_STATUSES:
+ context = EventCreationContext(
+ op,
+ 0, # Temporary event_id, will be reassigned after sorting
+ durable_execution_arn,
+ execution.start_input,
+ execution.result,
+ operation_update,
+ include_execution_data,
+ )
+ finished = HistoryEvent.create_event_terminated(context)
+ all_events.append(finished)
+
+ # Sort events by timestamp to get correct chronological order
+ all_events.sort(key=lambda event: event.event_timestamp)
+
+ # Reassign event IDs based on chronological order
+ all_events = [
+ HistoryEvent.from_event_with_id(event, i)
+ for i, event in enumerate(all_events, 1)
+ ]
+
+ # Apply cursor-based pagination
+ if max_items is None:
+ max_items = 100
+
+ # Handle pagination marker
+ if reverse_order:
+ all_events.reverse()
+ start_index: int = 0
+ if marker:
+ try:
+ marker_event_id: int = int(marker)
+ # Find the index of the first event with event_id >= marker
+ start_index = len(all_events)
+ for i, e in enumerate(all_events):
+ is_valid_page_start: bool = (
+ e.event_id < marker_event_id
+ if reverse_order
+ else e.event_id >= marker_event_id
+ )
+ if is_valid_page_start:
+ start_index = i
+ break
+ except ValueError:
+ start_index = 0
+
+ # Get paginated events
+ end_index: int = start_index + max_items
+ paginated_events: list[HistoryEvent] = all_events[start_index:end_index]
+
+ # Generate next marker
+ next_marker: str | None = None
+ if end_index < len(all_events):
+ if reverse_order:
+ # Next marker is the event_id of the last returned event
+ next_marker = (
+ str(paginated_events[-1].event_id) if paginated_events else None
+ )
+ else:
+ # Next marker is the event_id of the next event after the last returned
+ next_marker = (
+ str(all_events[end_index].event_id)
+ if end_index < len(all_events)
+ else None
+ )
+
+ return GetDurableExecutionHistoryResponse(
+ events=paginated_events, next_marker=next_marker
+ )
+
+ def checkpoint_execution(
+ self,
+ execution_arn: str,
+ checkpoint_token: str,
+ updates: list[OperationUpdate] | None = None,
+ client_token: str | None = None,
+ ) -> CheckpointDurableExecutionResponse:
+ """Process checkpoint for an execution.
+
+ Args:
+ execution_arn: The execution ARN
+ checkpoint_token: Current checkpoint token
+ updates: List of operation updates to process
+ client_token: Client token for idempotency
+
+ Returns:
+ CheckpointDurableExecutionResponse: Updated checkpoint token and state
+
+ Raises:
+ ResourceNotFoundException: If execution does not exist
+ InvalidParameterValueException: If checkpoint token is invalid
+ """
+ execution = self.get_execution(execution_arn)
+
+ # Validate checkpoint token
+ if checkpoint_token not in execution.used_tokens:
+ msg: str = f"Invalid checkpoint token: {checkpoint_token}"
+ raise InvalidParameterValueException(msg)
+
+ if updates:
+ checkpoint_output = self._checkpoint_processor.process_checkpoint(
+ checkpoint_token=checkpoint_token,
+ updates=updates,
+ client_token=client_token,
+ )
+
+ new_execution_state = None
+ if checkpoint_output.new_execution_state:
+ new_execution_state = CheckpointUpdatedExecutionState(
+ operations=checkpoint_output.new_execution_state.operations,
+ next_marker=checkpoint_output.new_execution_state.next_marker,
+ )
+
+ return CheckpointDurableExecutionResponse(
+ checkpoint_token=checkpoint_output.checkpoint_token,
+ new_execution_state=new_execution_state,
+ )
+
+ # Save execution state after generating new token
+ new_checkpoint_token = execution.get_new_checkpoint_token()
+ self._store.update(execution)
+
+ return CheckpointDurableExecutionResponse(
+ checkpoint_token=new_checkpoint_token,
+ new_execution_state=None,
+ )
+
+ def send_callback_success(
+ self,
+ callback_id: str,
+ result: bytes | None = None,
+ ) -> SendDurableExecutionCallbackSuccessResponse:
+ """Send callback success response.
+
+ Args:
+ callback_id: The callback ID to respond to
+ result: Optional result data for the callback
+
+ Returns:
+ SendDurableExecutionCallbackSuccessResponse: Empty response
+
+ Raises:
+ InvalidParameterValueException: If callback_id is invalid
+ ResourceNotFoundException: If callback does not exist
+ """
+ if not callback_id:
+ msg: str = "callback_id is required"
+ raise InvalidParameterValueException(msg)
+
+ try:
+ callback_token = CallbackToken.from_str(callback_id)
+ execution = self.get_execution(callback_token.execution_arn)
+ execution.complete_callback_success(callback_id, result)
+ self._store.update(execution)
+ self._cleanup_callback_timeouts(callback_id)
+ self._invoke_execution(callback_token.execution_arn)
+ logger.info("Callback success completed for callback_id: %s", callback_id)
+ except Exception as e:
+ msg = f"Failed to process callback success: {e}"
+ raise ResourceNotFoundException(msg) from e
+
+ return SendDurableExecutionCallbackSuccessResponse()
+
+ def send_callback_failure(
+ self,
+ callback_id: str,
+ error: ErrorObject | None = None,
+ ) -> SendDurableExecutionCallbackFailureResponse:
+ """Send callback failure response.
+
+ Args:
+ callback_id: The callback ID to respond to
+ error: Optional error object for the callback failure
+
+ Returns:
+ SendDurableExecutionCallbackFailureResponse: Empty response
+
+ Raises:
+ InvalidParameterValueException: If callback_id is invalid
+ ResourceNotFoundException: If callback does not exist
+ """
+ if not callback_id:
+ msg: str = "callback_id is required"
+ raise InvalidParameterValueException(msg)
+
+ callback_error: ErrorObject = error or ErrorObject.from_message("")
+
+ try:
+ callback_token: CallbackToken = CallbackToken.from_str(callback_id)
+ execution: Execution = self.get_execution(callback_token.execution_arn)
+ execution.complete_callback_failure(callback_id, callback_error)
+ self._store.update(execution)
+ self._cleanup_callback_timeouts(callback_id)
+ self._invoke_execution(callback_token.execution_arn)
+ logger.info("Callback failure completed for callback_id: %s", callback_id)
+ except Exception as e:
+ msg = f"Failed to process callback failure: {e}"
+ raise ResourceNotFoundException(msg) from e
+
+ return SendDurableExecutionCallbackFailureResponse()
+
+ def send_callback_heartbeat(
+ self, callback_id: str
+ ) -> SendDurableExecutionCallbackHeartbeatResponse:
+ """Send callback heartbeat to keep callback alive.
+
+ Args:
+ callback_id: The callback ID to send heartbeat for
+
+ Returns:
+ SendDurableExecutionCallbackHeartbeatResponse: Empty response
+
+ Raises:
+ InvalidParameterValueException: If callback_id is invalid
+ ResourceNotFoundException: If callback does not exist
+ """
+ if not callback_id:
+ msg: str = "callback_id is required"
+ raise InvalidParameterValueException(msg)
+
+ try:
+ callback_token: CallbackToken = CallbackToken.from_str(callback_id)
+ execution: Execution = self.get_execution(callback_token.execution_arn)
+
+ # Find callback operation to verify it exists and is active
+ _, operation = execution.find_callback_operation(callback_id)
+ if operation.status != OperationStatus.STARTED:
+ msg = f"Callback {callback_id} is not active"
+ raise ResourceNotFoundException(msg)
+
+ # Reset heartbeat timeout if configured
+ self._reset_callback_heartbeat_timeout(
+ callback_id, execution.durable_execution_arn
+ )
+ logger.info("Callback heartbeat processed for callback_id: %s", callback_id)
+ except Exception as e:
+ msg = f"Failed to process callback heartbeat: {e}"
+ raise ResourceNotFoundException(msg) from e
+
+ return SendDurableExecutionCallbackHeartbeatResponse()
+
+ def _validate_invocation_response_and_store(
+ self,
+ execution_arn: str,
+ response: DurableExecutionInvocationOutput,
+ execution: Execution,
+ ):
+ """Validate response status and save it to the store if fine.
+
+ Raises:
+ InvalidParameterValueException: If the response status is invalid.
+ IllegalStateException: If the response status is valid but the execution is already completed.
+ """
+ if execution.is_complete:
+ msg_already_complete: str = "Execution already completed, ignoring result"
+
+ raise IllegalStateException(msg_already_complete)
+
+ if response.status is None:
+ msg_status_required: str = "Response status is required"
+
+ raise InvalidParameterValueException(msg_status_required)
+
+ match response.status:
+ case InvocationStatus.FAILED:
+ if response.result is not None:
+ msg_failed_result: str = (
+ "Cannot provide a Result for FAILED status."
+ )
+ raise InvalidParameterValueException(msg_failed_result)
+ logger.info("[%s] Execution failed", execution_arn)
+ self._complete_workflow(
+ execution_arn, result=None, error=response.error
+ )
+
+ case InvocationStatus.SUCCEEDED:
+ if response.error is not None:
+ msg_success_error: str = (
+ "Cannot provide an Error for SUCCEEDED status."
+ )
+ raise InvalidParameterValueException(msg_success_error)
+ logger.info("[%s] Execution succeeded", execution_arn)
+ self._complete_workflow(
+ execution_arn, result=response.result, error=None
+ )
+
+ case InvocationStatus.PENDING:
+ if not execution.has_pending_operations(execution):
+ msg_pending_ops: str = (
+ "Cannot return PENDING status with no pending operations."
+ )
+ raise InvalidParameterValueException(msg_pending_ops)
+ logger.info("[%s] Execution pending async work", execution_arn)
+
+ case _:
+ msg_unexpected_status: str = (
+ f"Unexpected invocation status: {response.status}"
+ )
+ raise IllegalStateException(msg_unexpected_status)
+
+ def _invoke_handler(self, execution_arn: str) -> Callable[[], Awaitable[None]]:
+ """Create a parameterless callable that captures execution arn for the scheduler."""
+
+ async def invoke() -> None:
+ execution: Execution = self._store.load(execution_arn)
+
+ # Early exit if execution is already completed - like Java's COMPLETED check
+ if execution.is_complete:
+ logger.info(
+ "[%s] Execution already completed, ignoring result", execution_arn
+ )
+ return
+
+ try:
+ invocation_input: DurableExecutionInvocationInput = (
+ self._invoker.create_invocation_input(execution=execution)
+ )
+
+ self._store.save(execution)
+
+ invocation_start = datetime.now(UTC)
+ invoke_response = self._invoker.invoke(
+ execution.start_input.function_name,
+ invocation_input,
+ execution.start_input.lambda_endpoint,
+ )
+ invocation_end = datetime.now(UTC)
+
+ # Reload execution after invocation in case it was completed via checkpoint
+ execution = self._store.load(execution_arn)
+
+ # Record invocation completion and save immediately
+ execution.record_invocation_completion(
+ invocation_start, invocation_end, invoke_response.request_id
+ )
+ self._store.save(execution)
+
+ if execution.is_complete:
+ logger.info(
+ "[%s] Execution completed during invocation, ignoring result",
+ execution_arn,
+ )
+ return
+
+ # Process successful received response - validate status and handle accordingly
+ response = invoke_response.invocation_output
+ try:
+ self._validate_invocation_response_and_store(
+ execution_arn, response, execution
+ )
+ except (InvalidParameterValueException, IllegalStateException) as e:
+ logger.warning(
+ "[%s] Lambda output validation failure: %s", execution_arn, e
+ )
+ error_obj = ErrorObject.from_exception(e)
+ self._retry_invocation(execution, error_obj)
+
+ except ResourceNotFoundException:
+ logger.warning(
+ "[%s] Function No longer exists: %s",
+ execution_arn,
+ execution.start_input.function_name,
+ )
+ error_obj = ErrorObject.from_message(
+ message=f"Function not found: {execution.start_input.function_name}"
+ )
+ self._fail_workflow(execution_arn, error_obj)
+
+ except Exception as e: # noqa: BLE001
+ # Handle invocation errors (network, function not found, etc.)
+ logger.warning("[%s] Invocation failed: %s", execution_arn, e)
+ error_obj = ErrorObject.from_exception(e)
+ self._retry_invocation(execution, error_obj)
+
+ return invoke
+
+ def _invoke_execution(self, execution_arn: str, delay: float = 0) -> None:
+ """Invoke execution after delay in seconds."""
+ completion_event = self._completion_events.get(execution_arn)
+ self._scheduler.call_later(
+ self._invoke_handler(execution_arn),
+ delay=delay,
+ completion_event=completion_event,
+ )
+
+ def _complete_workflow(
+ self, execution_arn: str, result: str | None, error: ErrorObject | None
+ ):
+ """Complete workflow - handles both success and failure with terminal state validation."""
+ execution = self._store.load(execution_arn)
+
+ if execution.is_complete:
+ msg: str = "Cannot make multiple close workflow decisions."
+
+ raise IllegalStateException(msg)
+
+ if error is not None:
+ self.fail_execution(execution_arn, error)
+ else:
+ self.complete_execution(execution_arn, result)
+
+ def _fail_workflow(self, execution_arn: str, error: ErrorObject):
+ """Fail workflow with terminal state validation."""
+ execution = self._store.load(execution_arn)
+
+ if execution.is_complete:
+ msg: str = "Cannot make multiple close workflow decisions."
+
+ raise IllegalStateException(msg)
+
+ self.fail_execution(execution_arn, error)
+
+ def _retry_invocation(self, execution: Execution, error: ErrorObject):
+ """Handle retry logic or fail execution if retries exhausted."""
+ if (
+ execution.consecutive_failed_invocation_attempts
+ > self.MAX_CONSECUTIVE_FAILED_ATTEMPTS
+ ):
+ # Exhausted retries - fail the execution
+ self._fail_workflow(
+ execution_arn=execution.durable_execution_arn, error=error
+ )
+ else:
+ # Schedule retry with backoff
+ execution.consecutive_failed_invocation_attempts += 1
+ self._store.save(execution)
+ self._invoke_execution(
+ execution_arn=execution.durable_execution_arn,
+ delay=self.RETRY_BACKOFF_SECONDS,
+ )
+
+ def _complete_events(self, execution_arn: str):
+ # complete doesn't actually checkpoint explicitly
+ if event := self._completion_events.get(execution_arn):
+ event.set()
+ if self._execution_timeout:
+ self._execution_timeout.cancel()
+ self._execution_timeout = None
+
+ def wait_until_complete(
+ self, execution_arn: str, timeout: float | None = None
+ ) -> bool:
+ """Block until execution completion. Don't do this unless you actually want to block.
+
+ Args
+ timeout (int|float|None): Wait for event to set until this timeout.
+
+ Returns:
+ True when set. False if the event timed out without being set.
+ """
+ if event := self._completion_events.get(execution_arn):
+ return event.wait(timeout)
+
+ # this really shouldn't happen - implies execution timed out?
+ msg: str = "execution does not exist."
+
+ raise ResourceNotFoundException(msg)
+
+ def complete_execution(self, execution_arn: str, result: str | None = None) -> None:
+ """Complete execution successfully (COMPLETE_WORKFLOW_EXECUTION decision)."""
+ logger.debug("[%s] Completing execution with result: %s", execution_arn, result)
+ execution: Execution = self._store.load(execution_arn=execution_arn)
+ execution.complete_success(result=result) # Sets CloseStatus.COMPLETED
+ self._store.update(execution)
+ if execution.result is None:
+ msg: str = "Execution result is required"
+ raise IllegalStateException(msg)
+ self._complete_events(execution_arn=execution_arn)
+
+ def fail_execution(self, execution_arn: str, error: ErrorObject) -> None:
+ """Fail execution with error (FAIL_WORKFLOW_EXECUTION decision)."""
+ logger.error("[%s] Completing execution with error: %s", execution_arn, error)
+ execution: Execution = self._store.load(execution_arn=execution_arn)
+ execution.complete_fail(error=error) # Sets CloseStatus.FAILED
+ self._store.update(execution)
+ # set by complete_fail
+ if execution.result is None:
+ msg: str = "Execution result is required"
+ raise IllegalStateException(msg)
+ self._complete_events(execution_arn=execution_arn)
+
+ def _on_wait_succeeded(self, execution_arn: str, operation_id: str) -> None:
+ """Private method - called when a wait operation completes successfully."""
+ execution = self._store.load(execution_arn)
+
+ if execution.is_complete:
+ logger.info(
+ "[%s] Execution already completed, ignoring wait succeeded event",
+ execution_arn,
+ )
+ return
+
+ try:
+ execution.complete_wait(operation_id=operation_id)
+ self._store.update(execution)
+ logger.debug(
+ "[%s] Wait succeeded for operation %s", execution_arn, operation_id
+ )
+ except Exception:
+ logger.exception("[%s] Error processing wait succeeded.", execution_arn)
+
+ def _on_retry_ready(self, execution_arn: str, operation_id: str) -> None:
+ """Private method - called when a retry delay has elapsed and retry is ready."""
+ execution = self._store.load(execution_arn)
+
+ if execution.is_complete:
+ logger.info(
+ "[%s] Execution already completed, ignoring retry", execution_arn
+ )
+ return
+
+ try:
+ execution.complete_retry(operation_id=operation_id)
+ self._store.update(execution)
+ logger.debug(
+ "[%s] Retry ready for operation %s", execution_arn, operation_id
+ )
+ except Exception:
+ logger.exception("[%s] Error processing retry ready.", execution_arn)
+
+ # region ExecutionObserver
+ def on_completed(self, execution_arn: str, result: str | None = None) -> None:
+ """Complete execution successfully. Observer method triggered by notifier."""
+ self.complete_execution(execution_arn, result)
+
+ def on_failed(self, execution_arn: str, error: ErrorObject) -> None:
+ """Fail execution. Observer method triggered by notifier."""
+ self.fail_execution(execution_arn, error)
+
+ def on_timed_out(self, execution_arn: str, error: ErrorObject) -> None:
+ """Handle execution timeout (workflow timeout). Observer method triggered by notifier."""
+ logger.exception("[%s] Execution timed out.", execution_arn)
+ execution: Execution = self._store.load(execution_arn=execution_arn)
+ execution.complete_timeout(error=error) # Sets CloseStatus.TIMED_OUT
+ self._store.update(execution)
+ self._complete_events(execution_arn=execution_arn)
+
+ def on_stopped(self, execution_arn: str, error: ErrorObject) -> None:
+ """Handle execution stop. Observer method triggered by notifier."""
+ # This should not be called directly - stop_execution handles termination
+ self.fail_execution(execution_arn, error)
+
+ def on_wait_timer_scheduled(
+ self, execution_arn: str, operation_id: str, delay: float
+ ) -> None:
+ """Schedule a wait operation. Observer method triggered by notifier."""
+ logger.debug("[%s] scheduling wait with delay: %d", execution_arn, delay)
+
+ def wait_handler() -> None:
+ self._on_wait_succeeded(execution_arn, operation_id)
+ self._invoke_execution(execution_arn, delay=0)
+
+ completion_event = self._completion_events.get(execution_arn)
+ self._scheduler.call_later(
+ wait_handler, delay=delay, completion_event=completion_event
+ )
+
+ def on_step_retry_scheduled(
+ self, execution_arn: str, operation_id: str, delay: float
+ ) -> None:
+ """Schedule a retry a step. Observer method triggered by notifier."""
+ logger.debug(
+ "[%s] scheduling retry for %s with delay: %d",
+ execution_arn,
+ operation_id,
+ delay,
+ )
+
+ def retry_handler() -> None:
+ self._on_retry_ready(execution_arn, operation_id)
+ self._invoke_execution(execution_arn, delay=0)
+
+ completion_event = self._completion_events.get(execution_arn)
+ self._scheduler.call_later(
+ retry_handler, delay=delay, completion_event=completion_event
+ )
+
+ def on_callback_created(
+ self,
+ execution_arn: str,
+ operation_id: str,
+ callback_options: CallbackOptions | None,
+ callback_token: CallbackToken,
+ ) -> None:
+ """Handle callback creation. Observer method triggered by notifier."""
+ callback_id = callback_token.to_str()
+ logger.debug(
+ "[%s] Callback created for operation %s with callback_id: %s",
+ execution_arn,
+ operation_id,
+ callback_id,
+ )
+
+ # Schedule callback timeouts if configured
+ self._schedule_callback_timeouts(execution_arn, callback_options, callback_id)
+
+ # endregion ExecutionObserver
+
+ # region Callback Timeouts
+ def _schedule_callback_timeouts(
+ self,
+ execution_arn: str,
+ callback_options: CallbackOptions | None,
+ callback_id: str,
+ ) -> None:
+ """Schedule callback timeout and heartbeat timeout if configured."""
+ try:
+ if not callback_options:
+ return
+
+ completion_event = self._completion_events.get(execution_arn)
+
+ # Schedule main timeout if configured
+ if callback_options.timeout_seconds > 0:
+
+ def timeout_handler():
+ self._on_callback_timeout(execution_arn, callback_id)
+
+ timeout_future = self._scheduler.call_later(
+ timeout_handler,
+ delay=callback_options.timeout_seconds,
+ completion_event=completion_event,
+ )
+ self._callback_timeouts[callback_id] = timeout_future
+
+ # Schedule heartbeat timeout if configured
+ if callback_options.heartbeat_timeout_seconds > 0:
+
+ def heartbeat_timeout_handler():
+ self._on_callback_heartbeat_timeout(execution_arn, callback_id)
+
+ heartbeat_future = self._scheduler.call_later(
+ heartbeat_timeout_handler,
+ delay=callback_options.heartbeat_timeout_seconds,
+ completion_event=completion_event,
+ )
+ self._callback_heartbeats[callback_id] = heartbeat_future
+
+ except Exception:
+ logger.exception(
+ "[%s] Error scheduling callback timeouts for %s",
+ execution_arn,
+ callback_id,
+ )
+
+ def _reset_callback_heartbeat_timeout(
+ self, callback_id: str, execution_arn: str
+ ) -> None:
+ """Reset the heartbeat timeout for a callback."""
+ # Cancel existing heartbeat timeout
+ if heartbeat_future := self._callback_heartbeats.pop(callback_id, None):
+ heartbeat_future.cancel()
+
+ # Find callback options to reschedule heartbeat timeout
+ try:
+ callback_token = CallbackToken.from_str(callback_id)
+ execution = self.get_execution(callback_token.execution_arn)
+
+ callback_options = None
+ for update in execution.updates:
+ if (
+ update.operation_id == callback_token.operation_id
+ and update.callback_options
+ and update.action.value == "START"
+ ):
+ callback_options = update.callback_options
+ break
+
+ if callback_options and callback_options.heartbeat_timeout_seconds > 0:
+
+ def heartbeat_timeout_handler():
+ self._on_callback_heartbeat_timeout(execution_arn, callback_id)
+
+ completion_event = self._completion_events.get(execution_arn)
+
+ heartbeat_future = self._scheduler.call_later(
+ heartbeat_timeout_handler,
+ delay=callback_options.heartbeat_timeout_seconds,
+ completion_event=completion_event,
+ )
+ self._callback_heartbeats[callback_id] = heartbeat_future
+
+ except Exception:
+ logger.exception(
+ "[%s] Error resetting callback heartbeat timeout for %s",
+ execution_arn,
+ callback_id,
+ )
+
+ def _cleanup_callback_timeouts(self, callback_id: str) -> None:
+ """Clean up timeout events for a completed callback."""
+ # Clean up main timeout
+ if timeout_future := self._callback_timeouts.pop(callback_id, None):
+ timeout_future.cancel()
+
+ # Clean up heartbeat timeout
+ if heartbeat_future := self._callback_heartbeats.pop(callback_id, None):
+ heartbeat_future.cancel()
+
+ def _on_callback_timeout(self, execution_arn: str, callback_id: str) -> None:
+ """Handle callback timeout."""
+ try:
+ callback_token = CallbackToken.from_str(callback_id)
+ execution = self.get_execution(callback_token.execution_arn)
+
+ if execution.is_complete:
+ return
+
+ # Fail the callback with timeout error
+ timeout_error = ErrorObject.from_message(
+ f"Callback timed out: {CallbackTimeoutType.TIMEOUT.value}"
+ )
+ execution.complete_callback_timeout(callback_id, timeout_error)
+ self._store.update(execution)
+ logger.warning("[%s] Callback %s timed out", execution_arn, callback_id)
+ self._invoke_execution(callback_token.execution_arn)
+ except Exception:
+ logger.exception(
+ "[%s] Error processing callback timeout for %s",
+ execution_arn,
+ callback_id,
+ )
+
+ def _on_callback_heartbeat_timeout(
+ self, execution_arn: str, callback_id: str
+ ) -> None:
+ """Handle callback heartbeat timeout."""
+ try:
+ callback_token = CallbackToken.from_str(callback_id)
+ execution = self.get_execution(callback_token.execution_arn)
+
+ if execution.is_complete:
+ return
+
+ # Fail the callback with heartbeat timeout error
+
+ heartbeat_error = ErrorObject.from_message(
+ f"Callback heartbeat timed out: {CallbackTimeoutType.HEARTBEAT.value}"
+ )
+ execution.complete_callback_timeout(callback_id, heartbeat_error)
+ self._store.update(execution)
+ logger.warning(
+ "[%s] Callback %s heartbeat timed out", execution_arn, callback_id
+ )
+ self._invoke_execution(callback_token.execution_arn)
+ except Exception:
+ logger.exception(
+ "[%s] Error processing callback heartbeat timeout for %s",
+ execution_arn,
+ callback_id,
+ )
+
+ # endregion Callback Timeouts
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/invoker.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/invoker.py
new file mode 100644
index 0000000..26143c9
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/invoker.py
@@ -0,0 +1,340 @@
+from __future__ import annotations
+
+import json
+from dataclasses import dataclass
+from threading import Lock
+from typing import TYPE_CHECKING, Any, Protocol
+from uuid import uuid4
+
+import boto3 # type: ignore
+from botocore.config import Config # type: ignore
+
+from aws_durable_execution_sdk_python.execution import (
+ DurableExecutionInvocationInput,
+ DurableExecutionInvocationInputWithClient,
+ DurableExecutionInvocationOutput,
+ InitialExecutionState,
+)
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ DurableFunctionsTestError,
+ InvalidParameterValueException,
+ ResourceNotFoundException,
+)
+from aws_durable_execution_sdk_python_testing.model import LambdaContext
+
+
+if TYPE_CHECKING:
+ from collections.abc import Callable
+
+ from aws_durable_execution_sdk_python_testing.client import InMemoryServiceClient
+ from aws_durable_execution_sdk_python_testing.execution import Execution
+
+
+# Max Lambda function timeout is 15 minutes (900s); we give headroom for
+# network round-trip and RIE startup.
+_LAMBDA_READ_TIMEOUT_SECONDS = 960
+_LAMBDA_CLIENT_CONFIG = Config(
+ read_timeout=_LAMBDA_READ_TIMEOUT_SECONDS,
+ retries={"max_attempts": 0},
+)
+
+
+def create_lambda_client(endpoint_url: str | None, region_name: str) -> Any:
+ """Create a boto3 Lambda client configured for durable function invocations."""
+
+ return boto3.client(
+ "lambda",
+ endpoint_url=endpoint_url,
+ region_name=region_name,
+ config=_LAMBDA_CLIENT_CONFIG,
+ )
+
+
+@dataclass(frozen=True)
+class InvokeResponse:
+ """Response from invoking a durable function."""
+
+ invocation_output: DurableExecutionInvocationOutput
+ request_id: str
+
+
+def create_test_lambda_context() -> LambdaContext:
+ # Create client context as a dictionary, not as objects
+ # LambdaContext.__init__ expects dictionaries and will create the objects internally
+ client_context_dict = {
+ "custom": {"test_key": "test_value"},
+ "env": {"platform": "test", "make": "test", "model": "test"},
+ "client": {
+ "installation_id": "test-installation-123",
+ "app_title": "TestApp",
+ "app_version_name": "1.0.0",
+ "app_version_code": "100",
+ "app_package_name": "com.test.app",
+ },
+ }
+
+ cognito_identity_dict = {
+ "cognitoIdentityId": "test-cognito-identity-123",
+ "cognitoIdentityPoolId": "us-west-2:test-pool-456",
+ }
+
+ return LambdaContext(
+ aws_request_id="test-invoke-12345",
+ client_context=client_context_dict,
+ identity=cognito_identity_dict,
+ invoked_function_arn="arn:aws:lambda:us-west-2:123456789012:function:test-function",
+ tenant_id="test-tenant-789",
+ )
+
+
+class Invoker(Protocol):
+ def create_invocation_input(
+ self, execution: Execution
+ ) -> DurableExecutionInvocationInput: ... # pragma: no cover
+
+ def invoke(
+ self,
+ function_name: str,
+ input: DurableExecutionInvocationInput,
+ endpoint_url: str | None = None,
+ ) -> InvokeResponse: ... # pragma: no cover
+
+ def update_endpoint(
+ self, endpoint_url: str, region_name: str
+ ) -> None: ... # pragma: no cover
+
+
+class InProcessInvoker(Invoker):
+ def __init__(self, handler: Callable, service_client: InMemoryServiceClient):
+ self.handler = handler
+ self.service_client = service_client
+
+ def create_invocation_input(
+ self, execution: Execution
+ ) -> DurableExecutionInvocationInput:
+ return DurableExecutionInvocationInputWithClient(
+ durable_execution_arn=execution.durable_execution_arn,
+ # TODO: this needs better logic - use existing if not used yet, vs create new
+ checkpoint_token=execution.get_new_checkpoint_token(),
+ initial_execution_state=InitialExecutionState(
+ operations=execution.operations,
+ next_marker="",
+ ),
+ service_client=self.service_client,
+ )
+
+ def invoke(
+ self,
+ function_name: str, # noqa: ARG002
+ input: DurableExecutionInvocationInput,
+ endpoint_url: str | None = None, # noqa: ARG002
+ ) -> InvokeResponse:
+ # TODO: reasses if function_name will be used in future
+ input_with_client = DurableExecutionInvocationInputWithClient.from_durable_execution_invocation_input(
+ input, self.service_client
+ )
+ context = create_test_lambda_context()
+ response_dict = self.handler(input_with_client, context)
+ output = DurableExecutionInvocationOutput.from_dict(response_dict)
+ return InvokeResponse(
+ invocation_output=output, request_id=context.aws_request_id
+ )
+
+ def update_endpoint(self, endpoint_url: str, region_name: str) -> None:
+ """No-op for in-process invoker."""
+
+
+class LambdaInvoker(Invoker):
+ def __init__(self, lambda_client: Any) -> None:
+ self.lambda_client = lambda_client
+ # Maps execution_arn -> endpoint for that execution
+ # Maps endpoint -> client to reuse clients across executions
+ self._execution_endpoints: dict[str, str] = {}
+ self._endpoint_clients: dict[str, Any] = {}
+ self._current_endpoint: str = "" # Track current endpoint for new executions
+ self._lock = Lock()
+
+ @staticmethod
+ def create(endpoint_url: str, region_name: str) -> LambdaInvoker:
+ """Create with the boto lambda client."""
+ invoker = LambdaInvoker(create_lambda_client(endpoint_url, region_name))
+ invoker._current_endpoint = endpoint_url
+ invoker._endpoint_clients[endpoint_url] = invoker.lambda_client
+ return invoker
+
+ def update_endpoint(self, endpoint_url: str, region_name: str) -> None:
+ """Update the Lambda client endpoint."""
+ # Cache client by endpoint to reuse across executions
+ with self._lock:
+ if endpoint_url not in self._endpoint_clients:
+ self._endpoint_clients[endpoint_url] = create_lambda_client(
+ endpoint_url, region_name
+ )
+ self.lambda_client = self._endpoint_clients[endpoint_url]
+ self._current_endpoint = endpoint_url
+
+ def _get_client_for_execution(
+ self,
+ durable_execution_arn: str,
+ lambda_endpoint: str | None = None,
+ region_name: str | None = None,
+ ) -> Any:
+ """Get the appropriate client for this execution."""
+ # Use provided endpoint or fall back to cached endpoint for this execution
+ if lambda_endpoint:
+ if lambda_endpoint not in self._endpoint_clients:
+ self._endpoint_clients[lambda_endpoint] = create_lambda_client(
+ lambda_endpoint, region_name or "us-east-1"
+ )
+ return self._endpoint_clients[lambda_endpoint]
+
+ # Fallback to cached endpoint
+ if durable_execution_arn not in self._execution_endpoints:
+ with self._lock:
+ if durable_execution_arn not in self._execution_endpoints:
+ self._execution_endpoints[durable_execution_arn] = (
+ self._current_endpoint
+ )
+
+ endpoint = self._execution_endpoints[durable_execution_arn]
+
+ # If no endpoint configured, fall back to default client
+ if not endpoint:
+ return self.lambda_client
+
+ return self._endpoint_clients[endpoint]
+
+ def create_invocation_input(
+ self, execution: Execution
+ ) -> DurableExecutionInvocationInput:
+ return DurableExecutionInvocationInput(
+ durable_execution_arn=execution.durable_execution_arn,
+ checkpoint_token=execution.get_new_checkpoint_token(),
+ initial_execution_state=InitialExecutionState(
+ operations=execution.operations,
+ next_marker="",
+ ),
+ )
+
+ def invoke(
+ self,
+ function_name: str,
+ input: DurableExecutionInvocationInput,
+ endpoint_url: str | None = None,
+ ) -> InvokeResponse:
+ """Invoke AWS Lambda function and return durable execution result.
+
+ Args:
+ function_name: Name of the Lambda function to invoke
+ input: Durable execution invocation input
+ endpoint_url: Lambda endpoint url
+
+ Returns:
+ InvokeResponse: Response containing invocation output and request ID
+
+ Raises:
+ ResourceNotFoundException: If function does not exist
+ InvalidParameterValueException: If parameters are invalid
+ DurableFunctionsTestError: For other invocation failures
+ """
+
+ # Parameter validation
+ if not function_name or not function_name.strip():
+ msg = "Function name is required"
+ raise InvalidParameterValueException(msg)
+
+ # Get the client for this execution
+ client = self._get_client_for_execution(
+ input.durable_execution_arn, endpoint_url
+ )
+
+ try:
+ # Invoke AWS Lambda function using standard invoke method
+ response = client.invoke(
+ FunctionName=function_name,
+ InvocationType="RequestResponse", # Synchronous invocation
+ Payload=json.dumps(input.to_json_dict()),
+ )
+
+ # Check HTTP status code
+ status_code = response.get("StatusCode")
+ if status_code not in (200, 202, 204):
+ msg = f"Lambda invocation failed with status code: {status_code}"
+ raise DurableFunctionsTestError(msg)
+
+ # Check for function errors
+ if "FunctionError" in response:
+ error_payload = response["Payload"].read().decode("utf-8")
+ msg = f"Lambda invocation failed with status {status_code}: {error_payload}"
+ raise DurableFunctionsTestError(msg)
+
+ # Parse response payload
+ response_payload = response["Payload"].read().decode("utf-8")
+ response_dict = json.loads(response_payload)
+
+ # Extract request ID from response headers (x-amzn-RequestId or x-amzn-request-id)
+ headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
+ request_id = (
+ headers.get("x-amzn-RequestId")
+ or headers.get("x-amzn-request-id")
+ or f"local-{uuid4()}"
+ )
+
+ # Convert to DurableExecutionInvocationOutput
+ output = DurableExecutionInvocationOutput.from_dict(response_dict)
+ return InvokeResponse(invocation_output=output, request_id=request_id)
+
+ except client.exceptions.ResourceNotFoundException as e:
+ msg = f"Function not found: {function_name}"
+ raise ResourceNotFoundException(msg) from e
+ except client.exceptions.InvalidParameterValueException as e:
+ msg = f"Invalid parameter: {e}"
+ raise InvalidParameterValueException(msg) from e
+ except (
+ client.exceptions.TooManyRequestsException,
+ client.exceptions.ServiceException,
+ client.exceptions.ResourceConflictException,
+ client.exceptions.InvalidRequestContentException,
+ client.exceptions.RequestTooLargeException,
+ client.exceptions.UnsupportedMediaTypeException,
+ client.exceptions.InvalidRuntimeException,
+ client.exceptions.InvalidZipFileException,
+ client.exceptions.ResourceNotReadyException,
+ client.exceptions.SnapStartTimeoutException,
+ client.exceptions.SnapStartNotReadyException,
+ client.exceptions.SnapStartException,
+ client.exceptions.RecursiveInvocationException,
+ ) as e:
+ msg = f"Lambda invocation failed: {e}"
+ raise DurableFunctionsTestError(msg) from e
+ except (
+ client.exceptions.InvalidSecurityGroupIDException,
+ client.exceptions.EC2ThrottledException,
+ client.exceptions.EFSMountConnectivityException,
+ client.exceptions.SubnetIPAddressLimitReachedException,
+ client.exceptions.EC2UnexpectedException,
+ client.exceptions.InvalidSubnetIDException,
+ client.exceptions.EC2AccessDeniedException,
+ client.exceptions.EFSIOException,
+ client.exceptions.ENILimitReachedException,
+ client.exceptions.EFSMountTimeoutException,
+ client.exceptions.EFSMountFailureException,
+ ) as e:
+ msg = f"Lambda infrastructure error: {e}"
+ raise DurableFunctionsTestError(msg) from e
+ except (
+ client.exceptions.KMSAccessDeniedException,
+ client.exceptions.KMSDisabledException,
+ client.exceptions.KMSNotFoundException,
+ client.exceptions.KMSInvalidStateException,
+ ) as e:
+ msg = f"Lambda KMS error: {e}"
+ raise DurableFunctionsTestError(msg) from e
+ except Exception as e:
+ # Handle any remaining exceptions, including custom ones like DurableExecutionAlreadyStartedException
+ if "DurableExecutionAlreadyStartedException" in str(type(e)):
+ msg = f"Durable execution already started: {e}"
+ raise DurableFunctionsTestError(msg) from e
+ msg = f"Unexpected error during Lambda invocation: {e}"
+ raise DurableFunctionsTestError(msg) from e
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/model.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/model.py
new file mode 100644
index 0000000..12e69a8
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/model.py
@@ -0,0 +1,3296 @@
+"""Model classes for the web API."""
+
+from __future__ import annotations
+
+import datetime
+import json
+from dataclasses import dataclass, replace
+from enum import Enum
+from typing import Any
+
+from aws_durable_execution_sdk_python.execution import DurableExecutionInvocationOutput
+
+# Import existing types from the main SDK - REUSE EVERYTHING POSSIBLE
+from aws_durable_execution_sdk_python.lambda_service import (
+ CallbackDetails,
+ CallbackOptions,
+ ChainedInvokeDetails,
+ ChainedInvokeOptions,
+ ContextDetails,
+ ContextOptions,
+ ErrorObject,
+ ExecutionDetails,
+ Operation,
+ OperationAction,
+ OperationStatus,
+ OperationSubType,
+ OperationType,
+ OperationUpdate,
+ StepDetails,
+ StepOptions,
+ TimestampConverter,
+ WaitDetails,
+ WaitOptions,
+)
+from aws_durable_execution_sdk_python.types import (
+ LambdaContext as LambdaContextProtocol,
+)
+from dateutil.tz import UTC
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+
+
+class EventType(Enum):
+ """Event types for durable execution events."""
+
+ EXECUTION_STARTED = "ExecutionStarted"
+ EXECUTION_SUCCEEDED = "ExecutionSucceeded"
+ EXECUTION_FAILED = "ExecutionFailed"
+ EXECUTION_TIMED_OUT = "ExecutionTimedOut"
+ EXECUTION_STOPPED = "ExecutionStopped"
+ CONTEXT_STARTED = "ContextStarted"
+ CONTEXT_SUCCEEDED = "ContextSucceeded"
+ CONTEXT_FAILED = "ContextFailed"
+ WAIT_STARTED = "WaitStarted"
+ WAIT_SUCCEEDED = "WaitSucceeded"
+ WAIT_CANCELLED = "WaitCancelled"
+ STEP_STARTED = "StepStarted"
+ STEP_SUCCEEDED = "StepSucceeded"
+ STEP_FAILED = "StepFailed"
+ CHAINED_INVOKE_STARTED = "ChainedInvokeStarted"
+ CHAINED_INVOKE_SUCCEEDED = "ChainedInvokeSucceeded"
+ CHAINED_INVOKE_FAILED = "ChainedInvokeFailed"
+ CHAINED_INVOKE_TIMED_OUT = "ChainedInvokeTimedOut"
+ CHAINED_INVOKE_STOPPED = "ChainedInvokeStopped"
+ CALLBACK_STARTED = "CallbackStarted"
+ CALLBACK_SUCCEEDED = "CallbackSucceeded"
+ CALLBACK_FAILED = "CallbackFailed"
+ CALLBACK_TIMED_OUT = "CallbackTimedOut"
+ INVOCATION_COMPLETED = "InvocationCompleted"
+
+
+TERMINAL_STATUSES: set[OperationStatus] = {
+ OperationStatus.SUCCEEDED,
+ OperationStatus.FAILED,
+ OperationStatus.TIMED_OUT,
+ OperationStatus.STOPPED,
+ OperationStatus.CANCELLED,
+}
+
+
+@dataclass(frozen=True)
+class LambdaContext(LambdaContextProtocol):
+ """Lambda context for testing."""
+
+ aws_request_id: str
+ log_group_name: str | None = None
+ log_stream_name: str | None = None
+ function_name: str | None = None
+ memory_limit_in_mb: str | None = None
+ function_version: str | None = None
+ invoked_function_arn: str | None = None
+ tenant_id: str | None = None
+ client_context: dict | None = None
+ identity: dict | None = None
+
+ def get_remaining_time_in_millis(self) -> int:
+ return 900000 # 15 minutes default
+
+ def log(self, msg) -> None:
+ pass # No-op for testing
+
+
+# region web_api_models
+# Web API specific models (not in Smithy but needed for web interface)
+@dataclass(frozen=True)
+class StartDurableExecutionInput:
+ """Input for starting a durable execution via web API."""
+
+ account_id: str
+ function_name: str
+ function_qualifier: str
+ execution_name: str
+ execution_timeout_seconds: int
+ execution_retention_period_days: int
+ invocation_id: str | None = None
+ trace_fields: dict | None = None
+ tenant_id: str | None = None
+ input: str | None = None
+ lambda_endpoint: str | None = None # Endpoint for this specific execution
+
+ @classmethod
+ def from_dict(cls, data: dict) -> StartDurableExecutionInput:
+ # Validate required fields and raise AWS-compliant exceptions
+ required_fields = [
+ "AccountId",
+ "FunctionName",
+ "FunctionQualifier",
+ "ExecutionName",
+ "ExecutionTimeoutSeconds",
+ "ExecutionRetentionPeriodDays",
+ ]
+
+ for field in required_fields:
+ if field not in data:
+ msg: str = f"Missing required field: {field}"
+ raise InvalidParameterValueException(msg)
+
+ return cls(
+ account_id=data["AccountId"],
+ function_name=data["FunctionName"],
+ function_qualifier=data["FunctionQualifier"],
+ execution_name=data["ExecutionName"],
+ execution_timeout_seconds=data["ExecutionTimeoutSeconds"],
+ execution_retention_period_days=data["ExecutionRetentionPeriodDays"],
+ invocation_id=data.get("InvocationId"),
+ trace_fields=data.get("TraceFields"),
+ tenant_id=data.get("TenantId"),
+ input=data.get("Input"),
+ lambda_endpoint=data.get("LambdaEndpoint", None),
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result = {
+ "AccountId": self.account_id,
+ "FunctionName": self.function_name,
+ "FunctionQualifier": self.function_qualifier,
+ "ExecutionName": self.execution_name,
+ "ExecutionTimeoutSeconds": self.execution_timeout_seconds,
+ "ExecutionRetentionPeriodDays": self.execution_retention_period_days,
+ }
+ if self.invocation_id is not None:
+ result["InvocationId"] = self.invocation_id
+ if self.trace_fields is not None:
+ result["TraceFields"] = self.trace_fields
+ if self.tenant_id is not None:
+ result["TenantId"] = self.tenant_id
+ if self.input is not None:
+ result["Input"] = self.input
+ if self.lambda_endpoint is not None:
+ result["LambdaEndpoint"] = self.lambda_endpoint
+ return result
+
+ def get_normalized_input(self):
+ """
+ Normalize input string to be JSON deserializable.
+ Avoid double coding json input.
+ """
+ # Try to parse once
+ try:
+ _ = json.loads(self.input)
+ return self.input
+ except (json.JSONDecodeError, TypeError):
+ # Not valid JSON, treat as plain string and encode it
+ return json.dumps(self.input)
+
+
+@dataclass(frozen=True)
+class StartDurableExecutionOutput:
+ """Output from starting a durable execution via web API."""
+
+ execution_arn: str | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> StartDurableExecutionOutput:
+ return cls(execution_arn=data.get("ExecutionArn"))
+
+ def to_dict(self) -> dict[str, Any]:
+ result = {}
+ if self.execution_arn is not None:
+ result["ExecutionArn"] = self.execution_arn
+ return result
+
+
+# endregion web_api_models
+
+
+# region smithy_api_models
+# Smithy-based API models
+@dataclass(frozen=True)
+class GetDurableExecutionRequest:
+ """Request to get durable execution details."""
+
+ durable_execution_arn: str
+
+ @classmethod
+ def from_dict(cls, data: dict) -> GetDurableExecutionRequest:
+ return cls(durable_execution_arn=data["DurableExecutionArn"])
+
+ def to_dict(self) -> dict[str, Any]:
+ return {"DurableExecutionArn": self.durable_execution_arn}
+
+
+@dataclass(frozen=True)
+class GetDurableExecutionResponse:
+ """Response containing durable execution details."""
+
+ durable_execution_arn: str
+ durable_execution_name: str
+ function_arn: str
+ status: str
+ start_timestamp: datetime.datetime
+ input_payload: str | None = None
+ result: str | None = None
+ error: ErrorObject | None = None
+ end_timestamp: datetime.datetime | None = None
+ version: str | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> GetDurableExecutionResponse:
+ error = None
+ if error_data := data.get("Error"):
+ error = ErrorObject.from_dict(error_data)
+
+ return cls(
+ durable_execution_arn=data["DurableExecutionArn"],
+ durable_execution_name=data["DurableExecutionName"],
+ function_arn=data["FunctionArn"],
+ status=data["Status"],
+ start_timestamp=data["StartTimestamp"],
+ input_payload=data.get("InputPayload"),
+ result=data.get("Result"),
+ error=error,
+ end_timestamp=data.get("EndTimestamp"),
+ version=data.get("Version"),
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {
+ "DurableExecutionArn": self.durable_execution_arn,
+ "DurableExecutionName": self.durable_execution_name,
+ "FunctionArn": self.function_arn,
+ "Status": self.status,
+ "StartTimestamp": self.start_timestamp,
+ }
+ if self.input_payload is not None:
+ result["InputPayload"] = self.input_payload
+ if self.result is not None:
+ result["Result"] = self.result
+ if self.error is not None:
+ result["Error"] = self.error.to_dict()
+ if self.end_timestamp is not None:
+ result["EndTimestamp"] = self.end_timestamp
+ if self.end_timestamp is not None:
+ result["EndTimestamp"] = self.end_timestamp
+ if self.version is not None:
+ result["Version"] = self.version
+ return result
+
+
+@dataclass(frozen=True)
+class Execution:
+ """Execution summary structure from Smithy model."""
+
+ durable_execution_arn: str
+ durable_execution_name: str
+ function_arn: str
+ status: str
+ start_timestamp: datetime.datetime
+ end_timestamp: datetime.datetime | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> Execution:
+ return cls(
+ durable_execution_arn=data["DurableExecutionArn"],
+ durable_execution_name=data["DurableExecutionName"],
+ function_arn=data.get(
+ "FunctionArn", ""
+ ), # Make optional for backward compatibility
+ status=data["Status"],
+ start_timestamp=data["StartTimestamp"],
+ end_timestamp=data.get("EndTimestamp"),
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result = {
+ "DurableExecutionArn": self.durable_execution_arn,
+ "DurableExecutionName": self.durable_execution_name,
+ "Status": self.status,
+ "StartTimestamp": self.start_timestamp,
+ }
+ if self.function_arn: # Only include if not empty
+ result["FunctionArn"] = self.function_arn
+ if self.end_timestamp is not None:
+ result["EndTimestamp"] = self.end_timestamp
+ return result
+
+ @classmethod
+ def from_execution(cls, execution, status: str) -> Execution:
+ """Create ExecutionSummary from Execution object."""
+
+ execution_op = execution.get_operation_execution_started()
+ return cls(
+ durable_execution_arn=execution.durable_execution_arn,
+ durable_execution_name=execution.start_input.execution_name,
+ function_arn=f"arn:aws:lambda:us-east-1:123456789012:function:{execution.start_input.function_name}",
+ status=status,
+ start_timestamp=execution_op.start_timestamp
+ if execution_op.start_timestamp
+ else datetime.datetime.now(datetime.UTC),
+ end_timestamp=execution_op.end_timestamp
+ if execution_op.end_timestamp
+ else None,
+ )
+
+
+@dataclass(frozen=True)
+class ListDurableExecutionsRequest:
+ """Request to list durable executions."""
+
+ function_name: str | None = None
+ function_version: str | None = None
+ durable_execution_name: str | None = None
+ status_filter: list[str] | None = None
+ started_after: str | None = None
+ started_before: str | None = None
+ marker: str | None = None
+ max_items: int = 0
+ reverse_order: bool | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> ListDurableExecutionsRequest:
+ # Handle query parameters that may be lists
+ function_name = data.get("FunctionName")
+ if isinstance(function_name, list):
+ function_name = function_name[0] if function_name else None
+
+ function_version = data.get("FunctionVersion")
+ if isinstance(function_version, list):
+ function_version = function_version[0] if function_version else None
+
+ durable_execution_name = data.get("DurableExecutionName")
+ if isinstance(durable_execution_name, list):
+ durable_execution_name = (
+ durable_execution_name[0] if durable_execution_name else None
+ )
+
+ status_filter = data.get("StatusFilter")
+ if isinstance(status_filter, list):
+ status_filter = status_filter if status_filter else None
+ elif status_filter:
+ status_filter = [status_filter]
+
+ started_after = data.get("StartedAfter")
+ if isinstance(started_after, list):
+ started_after = started_after[0] if started_after else None
+
+ started_before = data.get("StartedBefore")
+ if isinstance(started_before, list):
+ started_before = started_before[0] if started_before else None
+
+ marker = data.get("Marker")
+ if isinstance(marker, list):
+ marker = marker[0] if marker else None
+
+ max_items = data.get("MaxItems", 0)
+ if isinstance(max_items, list):
+ max_items = int(max_items[0]) if max_items else 0
+
+ reverse_order = data.get("ReverseOrder")
+ if isinstance(reverse_order, list):
+ reverse_order = (
+ reverse_order[0].lower() in ("true", "1", "yes")
+ if reverse_order
+ else None
+ )
+ elif isinstance(reverse_order, str):
+ reverse_order = reverse_order.lower() in ("true", "1", "yes")
+
+ return cls(
+ function_name=function_name,
+ function_version=function_version,
+ durable_execution_name=durable_execution_name,
+ status_filter=status_filter,
+ started_after=started_after,
+ started_before=started_before,
+ marker=marker,
+ max_items=max_items,
+ reverse_order=reverse_order,
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ if self.function_name is not None:
+ result["FunctionName"] = self.function_name
+ if self.function_version is not None:
+ result["FunctionVersion"] = self.function_version
+ if self.durable_execution_name is not None:
+ result["DurableExecutionName"] = self.durable_execution_name
+ if self.status_filter is not None:
+ result["StatusFilter"] = self.status_filter
+ if self.started_after is not None:
+ result["StartedAfter"] = self.started_after
+ if self.started_before is not None:
+ result["StartedBefore"] = self.started_before
+ if self.marker is not None:
+ result["Marker"] = self.marker
+ if self.max_items is not None:
+ result["MaxItems"] = self.max_items
+ if self.reverse_order is not None:
+ result["ReverseOrder"] = self.reverse_order
+ return result
+
+
+@dataclass(frozen=True)
+class ListDurableExecutionsResponse:
+ """Response containing list of durable executions."""
+
+ durable_executions: list[Execution]
+ next_marker: str | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> ListDurableExecutionsResponse:
+ executions = [
+ Execution.from_dict(exec_data)
+ for exec_data in data.get("DurableExecutions", [])
+ ]
+ return cls(
+ durable_executions=executions,
+ next_marker=data.get("NextMarker"),
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {
+ "DurableExecutions": [exe.to_dict() for exe in self.durable_executions]
+ }
+ if self.next_marker is not None:
+ result["NextMarker"] = self.next_marker
+ return result
+
+
+@dataclass(frozen=True)
+class StopDurableExecutionRequest:
+ """Request to stop a durable execution."""
+
+ durable_execution_arn: str
+ error: ErrorObject | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> StopDurableExecutionRequest:
+ error = None
+ if error_data := data.get("Error"):
+ error = ErrorObject.from_dict(error_data)
+
+ return cls(
+ durable_execution_arn=data["DurableExecutionArn"],
+ error=error,
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {"DurableExecutionArn": self.durable_execution_arn}
+ if self.error is not None:
+ result["Error"] = self.error.to_dict()
+ return result
+
+
+@dataclass(frozen=True)
+class StopDurableExecutionResponse:
+ """Response from stopping a durable execution."""
+
+ stop_timestamp: datetime.datetime
+
+ @classmethod
+ def from_dict(cls, data: dict) -> StopDurableExecutionResponse:
+ return cls(stop_timestamp=data["StopTimestamp"])
+
+ def to_dict(self) -> dict[str, Any]:
+ return {"StopTimestamp": self.stop_timestamp}
+
+
+@dataclass(frozen=True)
+class GetDurableExecutionStateRequest:
+ """Request to get durable execution state."""
+
+ durable_execution_arn: str
+ checkpoint_token: str
+ marker: str | None = None
+ max_items: int = 0
+
+ @classmethod
+ def from_dict(cls, data: dict) -> GetDurableExecutionStateRequest:
+ return cls(
+ durable_execution_arn=data["DurableExecutionArn"],
+ checkpoint_token=data["CheckpointToken"],
+ marker=data.get("Marker"),
+ max_items=data.get("MaxItems", 0),
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {
+ "DurableExecutionArn": self.durable_execution_arn,
+ "CheckpointToken": self.checkpoint_token,
+ }
+ if self.marker is not None:
+ result["Marker"] = self.marker
+ if self.max_items is not None:
+ result["MaxItems"] = self.max_items
+ return result
+
+
+@dataclass(frozen=True)
+class GetDurableExecutionStateResponse:
+ """Response containing durable execution state operations."""
+
+ operations: list[Operation]
+ next_marker: str | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> GetDurableExecutionStateResponse:
+ operations = [
+ Operation.from_dict(op_data) for op_data in data.get("Operations", [])
+ ]
+ return cls(
+ operations=operations,
+ next_marker=data.get("NextMarker"),
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {
+ "Operations": [op.to_dict() for op in self.operations]
+ }
+ if self.next_marker is not None:
+ result["NextMarker"] = self.next_marker
+ return result
+
+
+# endregion smithy_api_models
+
+
+# region event_structures
+# Event-related structures from Smithy model
+@dataclass(frozen=True)
+class EventInput:
+ """Event input structure."""
+
+ payload: str | None = None
+ truncated: bool = False
+
+ @classmethod
+ def from_dict(cls, data: dict) -> EventInput:
+ return cls(
+ payload=data.get("Payload"),
+ truncated=data.get("Truncated", False),
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {"Truncated": self.truncated}
+ if self.payload is not None:
+ result["Payload"] = self.payload
+ return result
+
+ @classmethod
+ def from_details(
+ cls,
+ details: ExecutionDetails,
+ include: bool = False, # noqa: FBT001, FBT002
+ ) -> EventInput:
+ details_input: str | None = details.input_payload if details else None
+ payload: str | None = details_input if include else None
+ truncated: bool = not include
+ return cls(payload=payload, truncated=truncated)
+
+ @classmethod
+ def from_start_durable_execution_input(
+ cls,
+ start_durable_execution_input: StartDurableExecutionInput,
+ include: bool = False, # noqa: FBT001, FBT002
+ ) -> EventInput:
+ input: str | None = start_durable_execution_input.input
+ truncated: bool = not include
+ return cls(input, truncated)
+
+
+@dataclass(frozen=True)
+class EventResult:
+ """Event result structure."""
+
+ payload: str | None = None
+ truncated: bool = False
+
+ @classmethod
+ def from_dict(cls, data: dict) -> EventResult:
+ return cls(
+ payload=data.get("Payload"),
+ truncated=data.get("Truncated", False),
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {"Truncated": self.truncated}
+ if self.payload is not None:
+ result["Payload"] = self.payload
+ return result
+
+ @classmethod
+ def from_details(
+ cls,
+ details: CallbackDetails | StepDetails | ChainedInvokeDetails | ContextDetails,
+ include: bool = False, # noqa: FBT001, FBT002
+ ) -> EventResult:
+ details_result: str | None = details.result if details else None
+ payload: str | None = details_result if include else None
+ truncated: bool = not include
+ return cls(payload=payload, truncated=truncated)
+
+ @classmethod
+ def from_durable_execution_invocation_output(
+ cls,
+ durable_execution_invocation_output: DurableExecutionInvocationOutput,
+ include: bool = False, # noqa: FBT001, FBT002
+ ) -> EventResult:
+ truncated: bool = not include
+ return cls(durable_execution_invocation_output.result, truncated)
+
+
+@dataclass(frozen=True)
+class EventError:
+ """Event error structure."""
+
+ payload: ErrorObject | None = None
+ truncated: bool = False
+
+ @classmethod
+ def from_dict(cls, data: dict) -> EventError:
+ payload = None
+ if payload_data := data.get("Payload"):
+ payload = ErrorObject.from_dict(payload_data)
+
+ return cls(
+ payload=payload,
+ truncated=data.get("Truncated", False),
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {"Truncated": self.truncated}
+ if self.payload is not None:
+ result["Payload"] = self.payload.to_dict()
+ return result
+
+ @classmethod
+ def from_details(
+ cls,
+ details: CallbackDetails | StepDetails | ChainedInvokeDetails | ContextDetails,
+ include: bool = False, # noqa: FBT001, FBT002
+ ) -> EventError:
+ error_object: ErrorObject | None = details.error if details else None
+ truncated: bool = not include
+ return cls(error_object, truncated)
+
+ @classmethod
+ def from_durable_execution_invocation_output(
+ cls,
+ durable_execution_invocation_output: DurableExecutionInvocationOutput,
+ include: bool = False, # noqa: FBT001, FBT002
+ ) -> EventError:
+ truncated: bool = not include
+ return cls(durable_execution_invocation_output.error, truncated)
+
+
+@dataclass(frozen=True)
+class RetryDetails:
+ """Retry details structure."""
+
+ current_attempt: int = 0
+ next_attempt_delay_seconds: int | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> RetryDetails:
+ return cls(
+ current_attempt=data.get("CurrentAttempt", 0),
+ next_attempt_delay_seconds=data.get("NextAttemptDelaySeconds"),
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {"CurrentAttempt": self.current_attempt}
+ if self.next_attempt_delay_seconds is not None:
+ result["NextAttemptDelaySeconds"] = self.next_attempt_delay_seconds
+ return result
+
+
+# Event detail structures
+@dataclass(frozen=True)
+class ExecutionStartedDetails:
+ """Execution started event details."""
+
+ input: EventInput | None = None
+ execution_timeout: int | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> ExecutionStartedDetails:
+ input_data = None
+ if input_dict := data.get("Input"):
+ input_data = EventInput.from_dict(input_dict)
+
+ return cls(
+ input=input_data,
+ execution_timeout=data.get("ExecutionTimeout"),
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ if self.input is not None:
+ result["Input"] = self.input.to_dict()
+ if self.execution_timeout is not None:
+ result["ExecutionTimeout"] = self.execution_timeout
+ return result
+
+
+@dataclass(frozen=True)
+class ExecutionSucceededDetails:
+ """Execution succeeded event details."""
+
+ result: EventResult | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> ExecutionSucceededDetails:
+ result_data = None
+ if result_dict := data.get("Result"):
+ result_data = EventResult.from_dict(result_dict)
+
+ return cls(result=result_data)
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ if self.result is not None:
+ result["Result"] = self.result.to_dict()
+ return result
+
+
+@dataclass(frozen=True)
+class ExecutionFailedDetails:
+ """Execution failed event details."""
+
+ error: EventError | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> ExecutionFailedDetails:
+ error_data = None
+ if error_dict := data.get("Error"):
+ error_data = EventError.from_dict(error_dict)
+
+ return cls(error=error_data)
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ if self.error is not None:
+ result["Error"] = self.error.to_dict()
+ return result
+
+
+@dataclass(frozen=True)
+class ExecutionTimedOutDetails:
+ """Execution timed out event details."""
+
+ error: EventError | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> ExecutionTimedOutDetails:
+ error_data = None
+ if error_dict := data.get("Error"):
+ error_data = EventError.from_dict(error_dict)
+
+ return cls(error=error_data)
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ if self.error is not None:
+ result["Error"] = self.error.to_dict()
+ return result
+
+
+@dataclass(frozen=True)
+class ExecutionStoppedDetails:
+ """Execution stopped event details."""
+
+ error: EventError | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> ExecutionStoppedDetails:
+ error_data = None
+ if error_dict := data.get("Error"):
+ error_data = EventError.from_dict(error_dict)
+
+ return cls(error=error_data)
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ if self.error is not None:
+ result["Error"] = self.error.to_dict()
+ return result
+
+
+@dataclass(frozen=True)
+class ContextStartedDetails:
+ """Context started event details."""
+
+ @classmethod
+ def from_dict(cls, data: dict) -> ContextStartedDetails: # noqa: ARG003
+ return cls()
+
+ def to_dict(self) -> dict[str, Any]:
+ return {}
+
+
+@dataclass(frozen=True)
+class ContextSucceededDetails:
+ """Context succeeded event details."""
+
+ result: EventResult | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> ContextSucceededDetails:
+ result_data = None
+ if result_dict := data.get("Result"):
+ result_data = EventResult.from_dict(result_dict)
+
+ return cls(result=result_data)
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ if self.result is not None:
+ result["Result"] = self.result.to_dict()
+ return result
+
+
+@dataclass(frozen=True)
+class ContextFailedDetails:
+ """Context failed event details."""
+
+ error: EventError | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> ContextFailedDetails:
+ error_data = None
+ if error_dict := data.get("Error"):
+ error_data = EventError.from_dict(error_dict)
+
+ return cls(error=error_data)
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ if self.error is not None:
+ result["Error"] = self.error.to_dict()
+ return result
+
+
+@dataclass(frozen=True)
+class WaitStartedDetails:
+ """Wait started event details."""
+
+ duration: int | None = None
+ scheduled_end_timestamp: datetime.datetime | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> WaitStartedDetails:
+ return cls(
+ duration=data.get("Duration"),
+ scheduled_end_timestamp=data.get("ScheduledEndTimestamp"),
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ if self.duration is not None:
+ result["Duration"] = self.duration
+ if self.scheduled_end_timestamp is not None:
+ result["ScheduledEndTimestamp"] = self.scheduled_end_timestamp
+ return result
+
+
+@dataclass(frozen=True)
+class WaitSucceededDetails:
+ """Wait succeeded event details."""
+
+ duration: int | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> WaitSucceededDetails:
+ return cls(duration=data.get("Duration"))
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ if self.duration is not None:
+ result["Duration"] = self.duration
+ return result
+
+
+@dataclass(frozen=True)
+class WaitCancelledDetails:
+ """Wait cancelled event details."""
+
+ error: EventError | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> WaitCancelledDetails:
+ error_data = None
+ if error_dict := data.get("Error"):
+ error_data = EventError.from_dict(error_dict)
+
+ return cls(error=error_data)
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ if self.error is not None:
+ result["Error"] = self.error.to_dict()
+ return result
+
+
+@dataclass(frozen=True)
+class StepStartedDetails:
+ """Step started event details."""
+
+ @classmethod
+ def from_dict(cls, data: dict) -> StepStartedDetails: # noqa: ARG003
+ return cls()
+
+ def to_dict(self) -> dict[str, Any]:
+ return {}
+
+
+@dataclass(frozen=True)
+class StepSucceededDetails:
+ """Step succeeded event details."""
+
+ result: EventResult | None = None
+ retry_details: RetryDetails | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> StepSucceededDetails:
+ result_data = None
+ if result_dict := data.get("Result"):
+ result_data = EventResult.from_dict(result_dict)
+
+ retry_details_data = None
+ if retry_dict := data.get("RetryDetails"):
+ retry_details_data = RetryDetails.from_dict(retry_dict)
+
+ return cls(result=result_data, retry_details=retry_details_data)
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ if self.result is not None:
+ result["Result"] = self.result.to_dict()
+ if self.retry_details is not None:
+ result["RetryDetails"] = self.retry_details.to_dict()
+ return result
+
+
+@dataclass(frozen=True)
+class StepFailedDetails:
+ """Step failed event details."""
+
+ error: EventError | None = None
+ retry_details: RetryDetails | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> StepFailedDetails:
+ error_data = None
+ if error_dict := data.get("Error"):
+ error_data = EventError.from_dict(error_dict)
+
+ retry_details_data = None
+ if retry_dict := data.get("RetryDetails"):
+ retry_details_data = RetryDetails.from_dict(retry_dict)
+
+ return cls(error=error_data, retry_details=retry_details_data)
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ if self.error is not None:
+ result["Error"] = self.error.to_dict()
+ if self.retry_details is not None:
+ result["RetryDetails"] = self.retry_details.to_dict()
+ return result
+
+
+@dataclass(frozen=True)
+class ChainedInvokePendingDetails:
+ """Chained Invoke Pending event details."""
+
+ input: EventInput | None = None
+ function_name: str | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> ChainedInvokePendingDetails:
+ input_data = None
+ if input_dict := data.get("Input"):
+ input_data = EventInput.from_dict(input_dict)
+
+ return cls(
+ input=input_data,
+ function_name=data.get("FunctionName"),
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ if self.input is not None:
+ result["Input"] = self.input.to_dict()
+ if self.function_name is not None:
+ result["FunctionName"] = self.function_name
+ return result
+
+
+@dataclass(frozen=True)
+class ChainedInvokeStartedDetails:
+ """Chained invoke started event details."""
+
+ durable_execution_arn: str | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> ChainedInvokeStartedDetails:
+ return cls(
+ durable_execution_arn=data.get("DurableExecutionArn"),
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ if self.durable_execution_arn is not None:
+ result["DurableExecutionArn"] = self.durable_execution_arn
+ return result
+
+
+@dataclass(frozen=True)
+class ChainedInvokeSucceededDetails:
+ """Chained invoke succeeded event details."""
+
+ result: EventResult | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> ChainedInvokeSucceededDetails:
+ result_data = None
+ if result_dict := data.get("Result"):
+ result_data = EventResult.from_dict(result_dict)
+
+ return cls(result=result_data)
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ if self.result is not None:
+ result["Result"] = self.result.to_dict()
+ return result
+
+
+@dataclass(frozen=True)
+class ChainedInvokeFailedDetails:
+ """Chained invoke failed event details."""
+
+ error: EventError | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> ChainedInvokeFailedDetails:
+ error_data = None
+ if error_dict := data.get("Error"):
+ error_data = EventError.from_dict(error_dict)
+
+ return cls(error=error_data)
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ if self.error is not None:
+ result["Error"] = self.error.to_dict()
+ return result
+
+
+@dataclass(frozen=True)
+class ChainedInvokeTimedOutDetails:
+ """Chained invoke timed out event details."""
+
+ error: EventError | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> ChainedInvokeTimedOutDetails:
+ error_data = None
+ if error_dict := data.get("Error"):
+ error_data = EventError.from_dict(error_dict)
+
+ return cls(error=error_data)
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ if self.error is not None:
+ result["Error"] = self.error.to_dict()
+ return result
+
+
+@dataclass(frozen=True)
+class ChainedInvokeStoppedDetails:
+ """Chained invoke stopped event details."""
+
+ error: EventError | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> ChainedInvokeStoppedDetails:
+ error_data = None
+ if error_dict := data.get("Error"):
+ error_data = EventError.from_dict(error_dict)
+
+ return cls(error=error_data)
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ if self.error is not None:
+ result["Error"] = self.error.to_dict()
+ return result
+
+
+@dataclass(frozen=True)
+class CallbackStartedDetails:
+ """Callback started event details."""
+
+ callback_id: str | None = None
+ heartbeat_timeout: int | None = None
+ timeout: int | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> CallbackStartedDetails:
+ return cls(
+ callback_id=data.get("CallbackId"),
+ heartbeat_timeout=data.get("HeartbeatTimeout"),
+ timeout=data.get("Timeout"),
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ if self.callback_id is not None:
+ result["CallbackId"] = self.callback_id
+ if self.heartbeat_timeout is not None:
+ result["HeartbeatTimeout"] = self.heartbeat_timeout
+ if self.timeout is not None:
+ result["Timeout"] = self.timeout
+ return result
+
+
+@dataclass(frozen=True)
+class CallbackSucceededDetails:
+ """Callback succeeded event details."""
+
+ result: EventResult | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> CallbackSucceededDetails:
+ result_data = None
+ if result_dict := data.get("Result"):
+ result_data = EventResult.from_dict(result_dict)
+
+ return cls(result=result_data)
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ if self.result is not None:
+ result["Result"] = self.result.to_dict()
+ return result
+
+
+@dataclass(frozen=True)
+class CallbackFailedDetails:
+ """Callback failed event details."""
+
+ error: EventError | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> CallbackFailedDetails:
+ error_data = None
+ if error_dict := data.get("Error"):
+ error_data = EventError.from_dict(error_dict)
+
+ return cls(error=error_data)
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ if self.error is not None:
+ result["Error"] = self.error.to_dict()
+ return result
+
+
+@dataclass(frozen=True)
+class CallbackTimedOutDetails:
+ """Callback timed out event details."""
+
+ error: EventError | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> CallbackTimedOutDetails:
+ error_data = None
+ if error_dict := data.get("Error"):
+ error_data = EventError.from_dict(error_dict)
+
+ return cls(error=error_data)
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {}
+ if self.error is not None:
+ result["Error"] = self.error.to_dict()
+ return result
+
+
+@dataclass(frozen=True)
+class InvocationCompletedDetails:
+ """Invocation completed event details."""
+
+ start_timestamp: datetime.datetime
+ end_timestamp: datetime.datetime
+ request_id: str
+
+ @classmethod
+ def from_dict(cls, data: dict) -> InvocationCompletedDetails:
+ return cls(
+ start_timestamp=data["StartTimestamp"],
+ end_timestamp=data["EndTimestamp"],
+ request_id=data["RequestId"],
+ )
+
+ @classmethod
+ def from_json_dict(cls, data: dict) -> InvocationCompletedDetails:
+ """Deserialize from JSON dict with Unix millisecond timestamps."""
+ start_ts: datetime.datetime | None = TimestampConverter.from_unix_millis(
+ data["StartTimestamp"]
+ ) # type: ignore[arg-type]
+ end_ts: datetime.datetime | None = TimestampConverter.from_unix_millis(
+ data["EndTimestamp"]
+ ) # type: ignore[arg-type]
+
+ if start_ts is None or end_ts is None:
+ raise InvalidParameterValueException(
+ "StartTimestamp and EndTimestamp cannot be null"
+ )
+
+ return cls(
+ start_timestamp=start_ts,
+ end_timestamp=end_ts,
+ request_id=data["RequestId"],
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ return {
+ "StartTimestamp": self.start_timestamp,
+ "EndTimestamp": self.end_timestamp,
+ "RequestId": self.request_id,
+ }
+
+ def to_json_dict(self) -> dict[str, Any]:
+ """Convert to JSON-serializable dict with Unix millisecond timestamps."""
+ return {
+ "StartTimestamp": TimestampConverter.to_unix_millis(self.start_timestamp),
+ "EndTimestamp": TimestampConverter.to_unix_millis(self.end_timestamp),
+ "RequestId": self.request_id,
+ }
+
+
+# endregion event_structures
+
+
+@dataclass(frozen=True)
+class EventCreationContext:
+ operation: Operation
+ event_id: int
+ durable_execution_arn: str
+ start_durable_execution_input: StartDurableExecutionInput
+ durable_execution_invocation_output: DurableExecutionInvocationOutput | None = None
+ operation_update: OperationUpdate | None = None
+ include_execution_data: bool = False # noqa: FBT001, FBT002
+
+ @classmethod
+ def create(
+ cls,
+ operation: Operation,
+ event_id: int,
+ durable_execution_arn: str,
+ start_input: StartDurableExecutionInput,
+ result: DurableExecutionInvocationOutput | None = None,
+ operation_update: OperationUpdate | None = None,
+ include_execution_data: bool = False, # noqa: FBT001, FBT002
+ ) -> EventCreationContext:
+ return cls(
+ operation=operation,
+ event_id=event_id,
+ durable_execution_arn=durable_execution_arn,
+ start_durable_execution_input=start_input,
+ durable_execution_invocation_output=result,
+ operation_update=operation_update,
+ include_execution_data=include_execution_data,
+ )
+
+ @property
+ def sub_type(self) -> str | None:
+ return self.operation.sub_type.value if self.operation.sub_type else None
+
+ def get_retry_details(self) -> RetryDetails | None:
+ if not self.operation.step_details or not self.operation_update:
+ return None
+
+ delay = 0
+ if (
+ self.operation_update.operation_type == OperationType.STEP
+ and self.operation_update.step_options
+ ):
+ delay = self.operation_update.step_options.next_attempt_delay_seconds
+
+ return RetryDetails(
+ current_attempt=self.operation.step_details.attempt,
+ next_attempt_delay_seconds=delay,
+ )
+
+ @property
+ def start_timestamp(self) -> datetime.datetime:
+ return (
+ self.operation.start_timestamp
+ if self.operation.start_timestamp is not None
+ else datetime.datetime.now(UTC)
+ )
+
+ @property
+ def end_timestamp(self) -> datetime.datetime:
+ return (
+ self.operation.end_timestamp
+ if self.operation.end_timestamp is not None
+ else datetime.datetime.now(UTC)
+ )
+
+
+# region event_class
+@dataclass(frozen=True)
+class Event:
+ """Event structure from Smithy model."""
+
+ event_type: str
+ event_timestamp: datetime.datetime
+ sub_type: str | None = None
+ event_id: int = 1
+ operation_id: str | None = None
+ name: str | None = None
+ parent_id: str | None = None
+ execution_started_details: ExecutionStartedDetails | None = None
+ execution_succeeded_details: ExecutionSucceededDetails | None = None
+ execution_failed_details: ExecutionFailedDetails | None = None
+ execution_timed_out_details: ExecutionTimedOutDetails | None = None
+ execution_stopped_details: ExecutionStoppedDetails | None = None
+ context_started_details: ContextStartedDetails | None = None
+ context_succeeded_details: ContextSucceededDetails | None = None
+ context_failed_details: ContextFailedDetails | None = None
+ wait_started_details: WaitStartedDetails | None = None
+ wait_succeeded_details: WaitSucceededDetails | None = None
+ wait_cancelled_details: WaitCancelledDetails | None = None
+ step_started_details: StepStartedDetails | None = None
+ step_succeeded_details: StepSucceededDetails | None = None
+ step_failed_details: StepFailedDetails | None = None
+ chained_invoke_pending_details: ChainedInvokePendingDetails | None = None
+ chained_invoke_started_details: ChainedInvokeStartedDetails | None = None
+ chained_invoke_succeeded_details: ChainedInvokeSucceededDetails | None = None
+ chained_invoke_failed_details: ChainedInvokeFailedDetails | None = None
+ chained_invoke_timed_out_details: ChainedInvokeTimedOutDetails | None = None
+ chained_invoke_stopped_details: ChainedInvokeStoppedDetails | None = None
+ callback_started_details: CallbackStartedDetails | None = None
+ callback_succeeded_details: CallbackSucceededDetails | None = None
+ callback_failed_details: CallbackFailedDetails | None = None
+ callback_timed_out_details: CallbackTimedOutDetails | None = None
+ invocation_completed_details: InvocationCompletedDetails | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> Event:
+ # Parse all the detail structures
+ execution_started_details = None
+ if details_data := data.get("ExecutionStartedDetails"):
+ execution_started_details = ExecutionStartedDetails.from_dict(details_data)
+
+ execution_succeeded_details = None
+ if details_data := data.get("ExecutionSucceededDetails"):
+ execution_succeeded_details = ExecutionSucceededDetails.from_dict(
+ details_data
+ )
+
+ execution_failed_details = None
+ if details_data := data.get("ExecutionFailedDetails"):
+ execution_failed_details = ExecutionFailedDetails.from_dict(details_data)
+
+ execution_timed_out_details = None
+ if details_data := data.get("ExecutionTimedOutDetails"):
+ execution_timed_out_details = ExecutionTimedOutDetails.from_dict(
+ details_data
+ )
+
+ execution_stopped_details = None
+ if details_data := data.get("ExecutionStoppedDetails"):
+ execution_stopped_details = ExecutionStoppedDetails.from_dict(details_data)
+
+ context_started_details = None
+ if details_data := data.get("ContextStartedDetails"):
+ context_started_details = ContextStartedDetails.from_dict(details_data)
+
+ context_succeeded_details = None
+ if details_data := data.get("ContextSucceededDetails"):
+ context_succeeded_details = ContextSucceededDetails.from_dict(details_data)
+
+ context_failed_details = None
+ if details_data := data.get("ContextFailedDetails"):
+ context_failed_details = ContextFailedDetails.from_dict(details_data)
+
+ wait_started_details = None
+ if details_data := data.get("WaitStartedDetails"):
+ wait_started_details = WaitStartedDetails.from_dict(details_data)
+
+ wait_succeeded_details = None
+ if details_data := data.get("WaitSucceededDetails"):
+ wait_succeeded_details = WaitSucceededDetails.from_dict(details_data)
+
+ wait_cancelled_details = None
+ if details_data := data.get("WaitCancelledDetails"):
+ wait_cancelled_details = WaitCancelledDetails.from_dict(details_data)
+
+ step_started_details = None
+ if details_data := data.get("StepStartedDetails"):
+ step_started_details = StepStartedDetails.from_dict(details_data)
+
+ step_succeeded_details = None
+ if details_data := data.get("StepSucceededDetails"):
+ step_succeeded_details = StepSucceededDetails.from_dict(details_data)
+
+ step_failed_details = None
+ if details_data := data.get("StepFailedDetails"):
+ step_failed_details = StepFailedDetails.from_dict(details_data)
+
+ chained_invoke_pending_details = None
+ if details_data := data.get("ChainedInvokePendingDetails"):
+ chained_invoke_pending_details = ChainedInvokePendingDetails.from_dict(
+ details_data
+ )
+
+ chained_invoke_started_details = None
+ if details_data := data.get("ChainedInvokeStartedDetails"):
+ chained_invoke_started_details = ChainedInvokeStartedDetails.from_dict(
+ details_data
+ )
+
+ chained_invoke_succeeded_details = None
+ if details_data := data.get("ChainedInvokeSucceededDetails"):
+ chained_invoke_succeeded_details = ChainedInvokeSucceededDetails.from_dict(
+ details_data
+ )
+
+ chained_invoke_failed_details = None
+ if details_data := data.get("ChainedInvokeFailedDetails"):
+ chained_invoke_failed_details = ChainedInvokeFailedDetails.from_dict(
+ details_data
+ )
+
+ chained_invoke_timed_out_details = None
+ if details_data := data.get("ChainedInvokeTimedOutDetails"):
+ chained_invoke_timed_out_details = ChainedInvokeTimedOutDetails.from_dict(
+ details_data
+ )
+
+ chained_invoke_stopped_details = None
+ if details_data := data.get("ChainedInvokeStoppedDetails"):
+ chained_invoke_stopped_details = ChainedInvokeStoppedDetails.from_dict(
+ details_data
+ )
+
+ callback_started_details = None
+ if details_data := data.get("CallbackStartedDetails"):
+ callback_started_details = CallbackStartedDetails.from_dict(details_data)
+
+ callback_succeeded_details = None
+ if details_data := data.get("CallbackSucceededDetails"):
+ callback_succeeded_details = CallbackSucceededDetails.from_dict(
+ details_data
+ )
+
+ callback_failed_details = None
+ if details_data := data.get("CallbackFailedDetails"):
+ callback_failed_details = CallbackFailedDetails.from_dict(details_data)
+
+ callback_timed_out_details = None
+ if details_data := data.get("CallbackTimedOutDetails"):
+ callback_timed_out_details = CallbackTimedOutDetails.from_dict(details_data)
+
+ invocation_completed_details = None
+ if details_data := data.get("InvocationCompletedDetails"):
+ invocation_completed_details = InvocationCompletedDetails.from_dict(
+ details_data
+ )
+
+ return cls(
+ event_type=data["EventType"],
+ event_timestamp=data["EventTimestamp"],
+ sub_type=data.get("SubType"),
+ event_id=data.get("EventId", 1),
+ operation_id=data.get("Id"),
+ name=data.get("Name"),
+ parent_id=data.get("ParentId"),
+ execution_started_details=execution_started_details,
+ execution_succeeded_details=execution_succeeded_details,
+ execution_failed_details=execution_failed_details,
+ execution_timed_out_details=execution_timed_out_details,
+ execution_stopped_details=execution_stopped_details,
+ context_started_details=context_started_details,
+ context_succeeded_details=context_succeeded_details,
+ context_failed_details=context_failed_details,
+ wait_started_details=wait_started_details,
+ wait_succeeded_details=wait_succeeded_details,
+ wait_cancelled_details=wait_cancelled_details,
+ step_started_details=step_started_details,
+ step_succeeded_details=step_succeeded_details,
+ step_failed_details=step_failed_details,
+ chained_invoke_pending_details=chained_invoke_pending_details,
+ chained_invoke_started_details=chained_invoke_started_details,
+ chained_invoke_succeeded_details=chained_invoke_succeeded_details,
+ chained_invoke_failed_details=chained_invoke_failed_details,
+ chained_invoke_timed_out_details=chained_invoke_timed_out_details,
+ chained_invoke_stopped_details=chained_invoke_stopped_details,
+ callback_started_details=callback_started_details,
+ callback_succeeded_details=callback_succeeded_details,
+ callback_failed_details=callback_failed_details,
+ callback_timed_out_details=callback_timed_out_details,
+ invocation_completed_details=invocation_completed_details,
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {
+ "EventType": self.event_type,
+ "EventTimestamp": self.event_timestamp,
+ "EventId": self.event_id,
+ }
+ if self.sub_type is not None:
+ result["SubType"] = self.sub_type
+ if self.operation_id is not None:
+ result["Id"] = self.operation_id
+ if self.name is not None:
+ result["Name"] = self.name
+ if self.parent_id is not None:
+ result["ParentId"] = self.parent_id
+ if self.execution_started_details is not None:
+ result["ExecutionStartedDetails"] = self.execution_started_details.to_dict()
+ if self.execution_succeeded_details is not None:
+ result["ExecutionSucceededDetails"] = (
+ self.execution_succeeded_details.to_dict()
+ )
+ if self.execution_failed_details is not None:
+ result["ExecutionFailedDetails"] = self.execution_failed_details.to_dict()
+ if self.execution_timed_out_details is not None:
+ result["ExecutionTimedOutDetails"] = (
+ self.execution_timed_out_details.to_dict()
+ )
+ if self.execution_stopped_details is not None:
+ result["ExecutionStoppedDetails"] = self.execution_stopped_details.to_dict()
+ if self.context_started_details is not None:
+ result["ContextStartedDetails"] = self.context_started_details.to_dict()
+ if self.context_succeeded_details is not None:
+ result["ContextSucceededDetails"] = self.context_succeeded_details.to_dict()
+ if self.context_failed_details is not None:
+ result["ContextFailedDetails"] = self.context_failed_details.to_dict()
+ if self.wait_started_details is not None:
+ result["WaitStartedDetails"] = self.wait_started_details.to_dict()
+ if self.wait_succeeded_details is not None:
+ result["WaitSucceededDetails"] = self.wait_succeeded_details.to_dict()
+ if self.wait_cancelled_details is not None:
+ result["WaitCancelledDetails"] = self.wait_cancelled_details.to_dict()
+ if self.step_started_details is not None:
+ result["StepStartedDetails"] = self.step_started_details.to_dict()
+ if self.step_succeeded_details is not None:
+ result["StepSucceededDetails"] = self.step_succeeded_details.to_dict()
+ if self.step_failed_details is not None:
+ result["StepFailedDetails"] = self.step_failed_details.to_dict()
+ if self.chained_invoke_pending_details is not None:
+ result["ChainedInvokePendingDetails"] = (
+ self.chained_invoke_pending_details.to_dict()
+ )
+ if self.chained_invoke_started_details is not None:
+ result["ChainedInvokeStartedDetails"] = (
+ self.chained_invoke_started_details.to_dict()
+ )
+ if self.chained_invoke_succeeded_details is not None:
+ result["ChainedInvokeSucceededDetails"] = (
+ self.chained_invoke_succeeded_details.to_dict()
+ )
+ if self.chained_invoke_failed_details is not None:
+ result["ChainedInvokeFailedDetails"] = (
+ self.chained_invoke_failed_details.to_dict()
+ )
+ if self.chained_invoke_timed_out_details is not None:
+ result["ChainedInvokeTimedOutDetails"] = (
+ self.chained_invoke_timed_out_details.to_dict()
+ )
+ if self.chained_invoke_stopped_details is not None:
+ result["ChainedInvokeStoppedDetails"] = (
+ self.chained_invoke_stopped_details.to_dict()
+ )
+ if self.callback_started_details is not None:
+ result["CallbackStartedDetails"] = self.callback_started_details.to_dict()
+ if self.callback_succeeded_details is not None:
+ result["CallbackSucceededDetails"] = (
+ self.callback_succeeded_details.to_dict()
+ )
+ if self.callback_failed_details is not None:
+ result["CallbackFailedDetails"] = self.callback_failed_details.to_dict()
+ if self.callback_timed_out_details is not None:
+ result["CallbackTimedOutDetails"] = (
+ self.callback_timed_out_details.to_dict()
+ )
+ if self.invocation_completed_details is not None:
+ result["InvocationCompletedDetails"] = (
+ self.invocation_completed_details.to_dict()
+ )
+ return result
+
+ # region execution
+ @classmethod
+ def create_execution_event_started(cls, context: EventCreationContext) -> Event:
+ execution_details: ExecutionDetails | None = context.operation.execution_details
+ event_input: EventInput | None = (
+ EventInput.from_details(execution_details, context.include_execution_data)
+ if execution_details
+ else None
+ )
+ execution_timeout: int | None = (
+ context.start_durable_execution_input.execution_timeout_seconds
+ )
+
+ return cls(
+ event_type=EventType.EXECUTION_STARTED.value,
+ event_timestamp=context.start_timestamp,
+ sub_type=context.sub_type,
+ event_id=context.event_id,
+ operation_id=context.operation.operation_id,
+ name=context.operation.name,
+ parent_id=context.operation.parent_id,
+ execution_started_details=ExecutionStartedDetails(
+ input=event_input,
+ execution_timeout=execution_timeout,
+ ),
+ )
+
+ @classmethod
+ def create_execution_event_succeeded(cls, context: EventCreationContext) -> Event:
+ result: EventResult | None = (
+ EventResult.from_durable_execution_invocation_output(
+ context.durable_execution_invocation_output,
+ context.include_execution_data,
+ )
+ if context.durable_execution_invocation_output
+ else None
+ )
+ return cls(
+ event_type=EventType.EXECUTION_SUCCEEDED.value,
+ event_timestamp=context.end_timestamp,
+ sub_type=context.sub_type,
+ event_id=context.event_id,
+ operation_id=context.operation.operation_id,
+ name=context.operation.name,
+ parent_id=context.operation.parent_id,
+ execution_succeeded_details=ExecutionSucceededDetails(result=result),
+ )
+
+ @classmethod
+ def create_execution_event_failed(cls, context: EventCreationContext) -> Event:
+ error: EventError | None = (
+ EventError.from_durable_execution_invocation_output(
+ context.durable_execution_invocation_output,
+ include=context.include_execution_data,
+ )
+ if context.durable_execution_invocation_output
+ else None
+ )
+ return cls(
+ event_type=EventType.EXECUTION_FAILED.value,
+ event_timestamp=context.end_timestamp,
+ sub_type=context.sub_type,
+ event_id=context.event_id,
+ operation_id=context.operation.operation_id,
+ name=context.operation.name,
+ parent_id=context.operation.parent_id,
+ execution_failed_details=ExecutionFailedDetails(error=error),
+ )
+
+ @classmethod
+ def create_execution_event_timed_out(cls, context: EventCreationContext) -> Event:
+ error: EventError | None = (
+ EventError.from_durable_execution_invocation_output(
+ context.durable_execution_invocation_output,
+ include=context.include_execution_data,
+ )
+ if context.durable_execution_invocation_output
+ else None
+ )
+ return cls(
+ event_type=EventType.EXECUTION_TIMED_OUT.value,
+ event_timestamp=context.end_timestamp,
+ sub_type=context.sub_type,
+ event_id=context.event_id,
+ operation_id=context.operation.operation_id,
+ name=context.operation.name,
+ parent_id=context.operation.parent_id,
+ execution_timed_out_details=ExecutionTimedOutDetails(error=error),
+ )
+
+ @classmethod
+ def create_execution_event_stopped(cls, context: EventCreationContext) -> Event:
+ error: EventError | None = (
+ EventError.from_durable_execution_invocation_output(
+ context.durable_execution_invocation_output,
+ include=context.include_execution_data,
+ )
+ if context.durable_execution_invocation_output
+ else None
+ )
+ return cls(
+ event_type=EventType.EXECUTION_STOPPED.value,
+ event_timestamp=context.end_timestamp,
+ sub_type=context.sub_type,
+ event_id=context.event_id,
+ operation_id=context.operation.operation_id,
+ name=context.operation.name,
+ parent_id=context.operation.parent_id,
+ execution_stopped_details=ExecutionStoppedDetails(error=error),
+ )
+
+ @classmethod
+ def create_execution_event(cls, context: EventCreationContext) -> Event:
+ """Create execution event based on action."""
+ match context.operation.status:
+ case OperationStatus.STARTED:
+ return cls.create_execution_event_started(context)
+ case OperationStatus.SUCCEEDED:
+ return cls.create_execution_event_succeeded(context)
+ case OperationStatus.FAILED:
+ return cls.create_execution_event_failed(context)
+ case OperationStatus.TIMED_OUT:
+ return cls.create_execution_event_timed_out(context)
+ case OperationStatus.STOPPED:
+ return cls.create_execution_event_stopped(context)
+ case _:
+ msg = f"Operation status {context.operation.status} is not valid for execution operations. Valid statuses are: STARTED, SUCCEEDED, FAILED, TIMED_OUT, STOPPED"
+ raise InvalidParameterValueException(msg)
+
+ # endregion execution
+
+ # region context
+ @classmethod
+ def create_context_event_started(cls, context: EventCreationContext) -> Event:
+ return cls(
+ event_type=EventType.CONTEXT_STARTED.value,
+ event_timestamp=context.start_timestamp,
+ sub_type=context.sub_type,
+ event_id=context.event_id,
+ operation_id=context.operation.operation_id,
+ name=context.operation.name,
+ parent_id=context.operation.parent_id,
+ context_started_details=ContextStartedDetails(),
+ )
+
+ @classmethod
+ def create_context_event_succeeded(cls, context: EventCreationContext) -> Event:
+ context_details: ContextDetails | None = context.operation.context_details
+ event_result: EventResult | None = (
+ EventResult.from_details(context_details, context.include_execution_data)
+ if context_details
+ else None
+ )
+ return cls(
+ event_type=EventType.CONTEXT_SUCCEEDED.value,
+ event_timestamp=context.end_timestamp,
+ sub_type=context.sub_type,
+ event_id=context.event_id,
+ operation_id=context.operation.operation_id,
+ name=context.operation.name,
+ parent_id=context.operation.parent_id,
+ context_succeeded_details=ContextSucceededDetails(result=event_result),
+ )
+
+ @classmethod
+ def create_context_event_failed(cls, context: EventCreationContext) -> Event:
+ context_details: ContextDetails | None = context.operation.context_details
+ event_error: EventError | None = (
+ EventError.from_details(context_details) if context_details else None
+ )
+ return cls(
+ event_type=EventType.CONTEXT_FAILED.value,
+ event_timestamp=context.end_timestamp,
+ sub_type=context.sub_type,
+ event_id=context.event_id,
+ operation_id=context.operation.operation_id,
+ name=context.operation.name,
+ parent_id=context.operation.parent_id,
+ context_failed_details=ContextFailedDetails(error=event_error),
+ )
+
+ @classmethod
+ def create_context_event(cls, context: EventCreationContext) -> Event:
+ """Create context event based on action."""
+ match context.operation.status:
+ case OperationStatus.STARTED:
+ return cls.create_context_event_started(context)
+ case OperationStatus.SUCCEEDED:
+ return cls.create_context_event_succeeded(context)
+ case OperationStatus.FAILED:
+ return cls.create_context_event_failed(context)
+ case _:
+ msg = (
+ f"Operation status {context.operation.status} is not valid for context operations. "
+ f"Valid statuses are: STARTED, SUCCEEDED, FAILED"
+ )
+ raise InvalidParameterValueException(msg)
+
+ # endregion context
+
+ # region wait
+ @classmethod
+ def create_wait_event_started(cls, context: EventCreationContext) -> Event:
+ wait_details: WaitDetails | None = context.operation.wait_details
+ scheduled_end_timestamp: datetime.datetime | None = (
+ wait_details.scheduled_end_timestamp if wait_details else None
+ )
+ duration: int | None = None
+ if (
+ wait_details
+ and wait_details.scheduled_end_timestamp
+ and context.operation.start_timestamp
+ ):
+ duration = round(
+ (
+ wait_details.scheduled_end_timestamp
+ - context.operation.start_timestamp
+ ).total_seconds()
+ )
+ return cls(
+ event_type=EventType.WAIT_STARTED.value,
+ event_timestamp=context.start_timestamp,
+ sub_type=context.sub_type,
+ event_id=context.event_id,
+ operation_id=context.operation.operation_id,
+ name=context.operation.name,
+ parent_id=context.operation.parent_id,
+ wait_started_details=WaitStartedDetails(
+ duration=duration,
+ scheduled_end_timestamp=scheduled_end_timestamp,
+ ),
+ )
+
+ @classmethod
+ def create_wait_event_succeeded(cls, context: EventCreationContext) -> Event:
+ wait_details: WaitDetails | None = context.operation.wait_details
+ duration: int | None = None
+ if (
+ wait_details
+ and wait_details.scheduled_end_timestamp
+ and context.operation.start_timestamp
+ ):
+ duration = round(
+ (
+ wait_details.scheduled_end_timestamp - context.start_timestamp
+ ).total_seconds()
+ )
+ return cls(
+ event_type=EventType.WAIT_SUCCEEDED.value,
+ event_timestamp=context.end_timestamp,
+ sub_type=context.sub_type,
+ event_id=context.event_id,
+ operation_id=context.operation.operation_id,
+ name=context.operation.name,
+ parent_id=context.operation.parent_id,
+ wait_succeeded_details=WaitSucceededDetails(duration=duration),
+ )
+
+ @classmethod
+ def create_wait_event_cancelled(cls, context: EventCreationContext) -> Event:
+ error: EventError | None = None
+ if (
+ context.operation_update
+ and context.operation_update.operation_type == OperationType.WAIT
+ and context.operation_update.action == OperationAction.CANCEL
+ ):
+ error = EventError(
+ context.operation_update.error, not context.include_execution_data
+ )
+ return cls(
+ event_type=EventType.WAIT_CANCELLED.value,
+ event_timestamp=context.end_timestamp,
+ sub_type=context.sub_type,
+ event_id=context.event_id,
+ operation_id=context.operation.operation_id,
+ name=context.operation.name,
+ parent_id=context.operation.parent_id,
+ wait_cancelled_details=WaitCancelledDetails(error=error),
+ )
+
+ @classmethod
+ def create_wait_event(cls, context: EventCreationContext) -> Event:
+ """Create wait event based on action."""
+ match context.operation.status:
+ case OperationStatus.STARTED:
+ return cls.create_wait_event_started(context)
+ case OperationStatus.SUCCEEDED:
+ return cls.create_wait_event_succeeded(context)
+ case OperationStatus.CANCELLED:
+ return cls.create_wait_event_cancelled(context)
+ case _:
+ msg = (
+ f"Operation status {context.operation.status} is not valid for wait operations. "
+ f"Valid statuses are: STARTED, SUCCEEDED, CANCELLED"
+ )
+ raise InvalidParameterValueException(msg)
+
+ # endregion wait
+
+ # region step
+ @classmethod
+ def create_step_event_started(cls, context: EventCreationContext) -> Event:
+ return cls(
+ event_type=EventType.STEP_STARTED.value,
+ event_timestamp=context.start_timestamp,
+ sub_type=context.sub_type,
+ event_id=context.event_id,
+ operation_id=context.operation.operation_id,
+ name=context.operation.name,
+ parent_id=context.operation.parent_id,
+ step_started_details=StepStartedDetails(),
+ )
+
+ @classmethod
+ def create_step_event_succeeded(cls, context: EventCreationContext) -> Event:
+ step_details: StepDetails | None = context.operation.step_details
+ event_result: EventResult | None = (
+ EventResult.from_details(step_details, context.include_execution_data)
+ if step_details
+ else None
+ )
+ return cls(
+ event_type=EventType.STEP_SUCCEEDED.value,
+ event_timestamp=context.end_timestamp,
+ sub_type=context.sub_type,
+ event_id=context.event_id,
+ operation_id=context.operation.operation_id,
+ name=context.operation.name,
+ parent_id=context.operation.parent_id,
+ step_succeeded_details=StepSucceededDetails(
+ result=event_result,
+ retry_details=context.get_retry_details(),
+ ),
+ )
+
+ @classmethod
+ def create_step_event_failed(cls, context: EventCreationContext) -> Event:
+ step_details: StepDetails | None = context.operation.step_details
+ event_error: EventError | None = (
+ EventError.from_details(
+ step_details, include=context.include_execution_data
+ )
+ if step_details
+ else None
+ )
+ return cls(
+ event_type=EventType.STEP_FAILED.value,
+ event_timestamp=context.end_timestamp,
+ sub_type=context.sub_type,
+ event_id=context.event_id,
+ operation_id=context.operation.operation_id,
+ name=context.operation.name,
+ parent_id=context.operation.parent_id,
+ step_failed_details=StepFailedDetails(
+ error=event_error,
+ retry_details=context.get_retry_details(),
+ ),
+ )
+
+ @classmethod
+ def create_step_event(cls, context: EventCreationContext) -> Event:
+ """Create step event based on action."""
+ match context.operation.status:
+ case OperationStatus.STARTED:
+ return cls.create_step_event_started(context)
+ case OperationStatus.SUCCEEDED:
+ return cls.create_step_event_succeeded(context)
+ case OperationStatus.FAILED:
+ return cls.create_step_event_failed(context)
+ case _:
+ msg = (
+ f"Operation status {context.operation.status} is not valid for step operations. "
+ f"Valid statuses are: STARTED, SUCCEEDED, FAILED"
+ )
+ raise InvalidParameterValueException(msg)
+
+ # endregion step
+
+ # region chained_invoke
+ @classmethod
+ def create_chained_invoke_event_pending(
+ cls, context: EventCreationContext
+ ) -> Event:
+ input: EventInput = EventInput.from_start_durable_execution_input(
+ context.start_durable_execution_input, context.include_execution_data
+ )
+ return cls(
+ event_type=EventType.CHAINED_INVOKE_STARTED.value,
+ event_timestamp=context.start_timestamp,
+ sub_type=context.sub_type,
+ event_id=context.event_id,
+ operation_id=context.operation.operation_id,
+ name=context.operation.name,
+ parent_id=context.operation.parent_id,
+ chained_invoke_pending_details=ChainedInvokePendingDetails(
+ input=input,
+ function_name=context.start_durable_execution_input.function_name,
+ ),
+ )
+
+ @classmethod
+ def create_chained_invoke_event_started(
+ cls, context: EventCreationContext
+ ) -> Event:
+ return cls(
+ event_type=EventType.CHAINED_INVOKE_STARTED.value,
+ event_timestamp=context.start_timestamp,
+ sub_type=context.sub_type,
+ event_id=context.event_id,
+ operation_id=context.operation.operation_id,
+ name=context.operation.name,
+ parent_id=context.operation.parent_id,
+ chained_invoke_started_details=ChainedInvokeStartedDetails(
+ durable_execution_arn=context.durable_execution_arn
+ ),
+ )
+
+ @classmethod
+ def create_chained_invoke_event_succeeded(
+ cls, context: EventCreationContext
+ ) -> Event:
+ chained_invoke_details: ChainedInvokeDetails | None = (
+ context.operation.chained_invoke_details
+ )
+ event_result: EventResult | None = (
+ EventResult.from_details(
+ chained_invoke_details, context.include_execution_data
+ )
+ if chained_invoke_details
+ else None
+ )
+ return cls(
+ event_type=EventType.CHAINED_INVOKE_SUCCEEDED.value,
+ event_timestamp=context.end_timestamp,
+ sub_type=context.sub_type,
+ event_id=context.event_id,
+ operation_id=context.operation.operation_id,
+ name=context.operation.name,
+ parent_id=context.operation.parent_id,
+ chained_invoke_succeeded_details=ChainedInvokeSucceededDetails(
+ result=event_result
+ ),
+ )
+
+ @classmethod
+ def create_chained_invoke_event_failed(cls, context: EventCreationContext) -> Event:
+ chained_invoke_details: ChainedInvokeDetails | None = (
+ context.operation.chained_invoke_details
+ )
+ event_error: EventError | None = (
+ EventError.from_details(
+ chained_invoke_details, include=context.include_execution_data
+ )
+ if chained_invoke_details
+ else None
+ )
+ return cls(
+ event_type=EventType.CHAINED_INVOKE_FAILED.value,
+ event_timestamp=context.end_timestamp,
+ sub_type=context.sub_type,
+ event_id=context.event_id,
+ operation_id=context.operation.operation_id,
+ name=context.operation.name,
+ parent_id=context.operation.parent_id,
+ chained_invoke_failed_details=ChainedInvokeFailedDetails(error=event_error),
+ )
+
+ @classmethod
+ def create_chained_invoke_event_timed_out(
+ cls, context: EventCreationContext
+ ) -> Event:
+ chained_invoke_details: ChainedInvokeDetails | None = (
+ context.operation.chained_invoke_details
+ )
+ event_error: EventError | None = (
+ EventError.from_details(
+ chained_invoke_details, include=context.include_execution_data
+ )
+ if chained_invoke_details
+ else None
+ )
+ return cls(
+ event_type=EventType.CHAINED_INVOKE_TIMED_OUT.value,
+ event_timestamp=context.end_timestamp,
+ sub_type=context.sub_type,
+ event_id=context.event_id,
+ operation_id=context.operation.operation_id,
+ name=context.operation.name,
+ parent_id=context.operation.parent_id,
+ chained_invoke_timed_out_details=ChainedInvokeTimedOutDetails(
+ error=event_error
+ ),
+ )
+
+ @classmethod
+ def create_chained_invoke_event_stopped(
+ cls, context: EventCreationContext
+ ) -> Event:
+ chained_invoke_details: ChainedInvokeDetails | None = (
+ context.operation.chained_invoke_details
+ )
+ event_error: EventError | None = (
+ EventError.from_details(
+ chained_invoke_details, include=context.include_execution_data
+ )
+ if chained_invoke_details
+ else None
+ )
+ return cls(
+ event_type=EventType.CHAINED_INVOKE_STOPPED.value,
+ event_timestamp=context.end_timestamp,
+ sub_type=context.sub_type,
+ event_id=context.event_id,
+ operation_id=context.operation.operation_id,
+ name=context.operation.name,
+ parent_id=context.operation.parent_id,
+ chained_invoke_stopped_details=ChainedInvokeStoppedDetails(
+ error=event_error
+ ),
+ )
+
+ @classmethod
+ def create_chained_invoke_event(cls, context: EventCreationContext) -> Event:
+ """Create chained invoke event based on action."""
+ match context.operation.status:
+ case OperationStatus.PENDING:
+ return cls.create_chained_invoke_event_pending(context)
+ case OperationStatus.STARTED:
+ return cls.create_chained_invoke_event_started(context)
+ case OperationStatus.SUCCEEDED:
+ return cls.create_chained_invoke_event_succeeded(context)
+ case OperationStatus.FAILED:
+ return cls.create_chained_invoke_event_failed(context)
+ case OperationStatus.TIMED_OUT:
+ return cls.create_chained_invoke_event_timed_out(context)
+ case OperationStatus.STOPPED:
+ return cls.create_chained_invoke_event_stopped(context)
+ case _:
+ msg = (
+ f"Operation status {context.operation.status} is not valid for chained invoke operations. Valid statuses are: "
+ f"STARTED, SUCCEEDED, FAILED, TIMED_OUT, STOPPED"
+ )
+ raise InvalidParameterValueException(msg)
+
+ # endregion chained_invoke
+
+ # region callback
+ @classmethod
+ def create_callback_event_started(cls, context: EventCreationContext) -> Event:
+ callback_details: CallbackDetails | None = context.operation.callback_details
+ callback_id: str | None = (
+ callback_details.callback_id if callback_details else None
+ )
+ callback_options: CallbackOptions | None = (
+ context.operation_update.callback_options
+ if context.operation_update
+ else None
+ )
+ timeout: int | None = (
+ callback_options.timeout_seconds if callback_options else None
+ )
+ heartbeat_timeout: int | None = (
+ callback_options.heartbeat_timeout_seconds if callback_options else None
+ )
+ return cls(
+ event_type=EventType.CALLBACK_STARTED.value,
+ event_timestamp=context.start_timestamp,
+ sub_type=context.sub_type,
+ event_id=context.event_id,
+ operation_id=context.operation.operation_id,
+ name=context.operation.name,
+ parent_id=context.operation.parent_id,
+ callback_started_details=CallbackStartedDetails(
+ callback_id=callback_id,
+ timeout=timeout,
+ heartbeat_timeout=heartbeat_timeout,
+ ),
+ )
+
+ @classmethod
+ def create_callback_event_succeeded(cls, context: EventCreationContext) -> Event:
+ callback_details: CallbackDetails | None = context.operation.callback_details
+ event_result: EventResult | None = (
+ EventResult.from_details(callback_details, context.include_execution_data)
+ if callback_details
+ else None
+ )
+ return cls(
+ event_type=EventType.CALLBACK_SUCCEEDED.value,
+ event_timestamp=context.end_timestamp,
+ sub_type=context.sub_type,
+ event_id=context.event_id,
+ operation_id=context.operation.operation_id,
+ name=context.operation.name,
+ parent_id=context.operation.parent_id,
+ callback_succeeded_details=CallbackSucceededDetails(result=event_result),
+ )
+
+ @classmethod
+ def create_callback_event_failed(cls, context: EventCreationContext) -> Event:
+ callback_details: CallbackDetails | None = context.operation.callback_details
+ event_error: EventError | None = (
+ EventError.from_details(callback_details) if callback_details else None
+ )
+ return cls(
+ event_type=EventType.CALLBACK_FAILED.value,
+ event_timestamp=context.end_timestamp,
+ sub_type=context.sub_type,
+ event_id=context.event_id,
+ operation_id=context.operation.operation_id,
+ name=context.operation.name,
+ parent_id=context.operation.parent_id,
+ callback_failed_details=CallbackFailedDetails(error=event_error),
+ )
+
+ @classmethod
+ def create_callback_event_timed_out(cls, context: EventCreationContext) -> Event:
+ callback_details: CallbackDetails | None = context.operation.callback_details
+ event_error: EventError | None = (
+ EventError.from_details(callback_details) if callback_details else None
+ )
+ return cls(
+ event_type=EventType.CALLBACK_TIMED_OUT.value,
+ event_timestamp=context.end_timestamp,
+ sub_type=context.sub_type,
+ event_id=context.event_id,
+ operation_id=context.operation.operation_id,
+ name=context.operation.name,
+ parent_id=context.operation.parent_id,
+ callback_timed_out_details=CallbackTimedOutDetails(error=event_error),
+ )
+
+ @classmethod
+ def create_callback_event(cls, context: EventCreationContext) -> Event:
+ """Create callback event based on action."""
+ match context.operation.status:
+ case OperationStatus.STARTED:
+ return cls.create_callback_event_started(context)
+ case OperationStatus.SUCCEEDED:
+ return cls.create_callback_event_succeeded(context)
+ case OperationStatus.FAILED:
+ return cls.create_callback_event_failed(context)
+ case OperationStatus.TIMED_OUT:
+ return cls.create_callback_event_timed_out(context)
+ case _:
+ msg = (
+ f"Operation status {context.operation.status} is not valid for callback operations. "
+ f"Valid statuses are: STARTED, SUCCEEDED, FAILED, TIMED_OUT"
+ )
+ raise InvalidParameterValueException(msg)
+
+ # endregion callback
+
+ # region invocation_completed
+ @classmethod
+ def create_invocation_completed(
+ cls,
+ event_id: int,
+ event_timestamp: datetime.datetime,
+ start_timestamp: datetime.datetime,
+ end_timestamp: datetime.datetime,
+ request_id: str,
+ ) -> Event:
+ """Create invocation completed event."""
+ return cls(
+ event_type=EventType.INVOCATION_COMPLETED.value,
+ event_timestamp=event_timestamp,
+ event_id=event_id,
+ invocation_completed_details=InvocationCompletedDetails(
+ start_timestamp=start_timestamp,
+ end_timestamp=end_timestamp,
+ request_id=request_id,
+ ),
+ )
+
+ # endregion invocation_completed
+
+ @classmethod
+ def create_event_started(cls, context: EventCreationContext) -> Event:
+ """Convert operation to started event."""
+ if context.operation.start_timestamp is None:
+ msg: str = "Operation start timestamp cannot be None when converting to started event"
+ raise InvalidParameterValueException(msg)
+
+ match context.operation.operation_type:
+ case OperationType.EXECUTION:
+ return cls.create_execution_event_started(context)
+ case OperationType.CONTEXT:
+ return cls.create_context_event_started(context)
+ case OperationType.WAIT:
+ return cls.create_wait_event_started(context)
+ case OperationType.STEP:
+ return cls.create_step_event_started(context)
+ case OperationType.CHAINED_INVOKE:
+ return cls.create_chained_invoke_event_started(context)
+ case OperationType.CALLBACK:
+ return cls.create_callback_event_started(context)
+ case _:
+ msg = f"Unknown operation type: {context.operation.operation_type}"
+ raise InvalidParameterValueException(msg)
+
+ @classmethod
+ def from_event_with_id(cls, event: Event, event_id: int) -> Event:
+ """Create a new Event from an existing event with updated event_id."""
+ return cls(
+ event_type=event.event_type,
+ event_timestamp=event.event_timestamp,
+ sub_type=event.sub_type,
+ event_id=event_id,
+ operation_id=event.operation_id,
+ name=event.name,
+ parent_id=event.parent_id,
+ execution_started_details=event.execution_started_details,
+ execution_succeeded_details=event.execution_succeeded_details,
+ execution_failed_details=event.execution_failed_details,
+ execution_timed_out_details=event.execution_timed_out_details,
+ execution_stopped_details=event.execution_stopped_details,
+ context_started_details=event.context_started_details,
+ context_succeeded_details=event.context_succeeded_details,
+ context_failed_details=event.context_failed_details,
+ wait_started_details=event.wait_started_details,
+ wait_succeeded_details=event.wait_succeeded_details,
+ wait_cancelled_details=event.wait_cancelled_details,
+ step_started_details=event.step_started_details,
+ step_succeeded_details=event.step_succeeded_details,
+ step_failed_details=event.step_failed_details,
+ chained_invoke_pending_details=event.chained_invoke_pending_details,
+ chained_invoke_started_details=event.chained_invoke_started_details,
+ chained_invoke_succeeded_details=event.chained_invoke_succeeded_details,
+ chained_invoke_failed_details=event.chained_invoke_failed_details,
+ chained_invoke_timed_out_details=event.chained_invoke_timed_out_details,
+ chained_invoke_stopped_details=event.chained_invoke_stopped_details,
+ callback_started_details=event.callback_started_details,
+ callback_succeeded_details=event.callback_succeeded_details,
+ callback_failed_details=event.callback_failed_details,
+ callback_timed_out_details=event.callback_timed_out_details,
+ )
+
+ @classmethod
+ def create_event_terminated(cls, context: EventCreationContext) -> Event:
+ """Convert operation to finished event."""
+ operation: Operation = context.operation
+ if operation.end_timestamp is None:
+ msg: str = "Operation end timestamp cannot be None when converting to finished event"
+ raise InvalidParameterValueException(msg)
+
+ if operation.status not in TERMINAL_STATUSES:
+ msg = f"Operation status must be one of SUCCEEDED, FAILED, TIMED_OUT, STOPPED, or CANCELLED. Got: {operation.status}"
+ raise InvalidParameterValueException(msg)
+
+ match operation.operation_type:
+ case OperationType.EXECUTION:
+ return cls.create_execution_event(context)
+ case OperationType.CONTEXT:
+ return cls.create_context_event(context)
+ case OperationType.WAIT:
+ return cls.create_wait_event(context)
+ case OperationType.STEP:
+ return cls.create_step_event(context)
+ case OperationType.CHAINED_INVOKE:
+ return cls.create_chained_invoke_event(context)
+ case OperationType.CALLBACK:
+ return cls.create_callback_event(context)
+ case _:
+ msg = f"Unknown operation type: {operation.operation_type}"
+ raise InvalidParameterValueException(msg)
+
+
+# endregion event_class
+
+
+# region history_models
+@dataclass(frozen=True)
+class HistoryEventTypeConfig:
+ """Configuration for how to process a specific event type."""
+
+ operation_type: OperationType | None
+ operation_status: OperationStatus | None
+ is_start_event: bool
+ is_end_event: bool
+ has_result: bool # Whether this event type contains result/error data
+
+
+# Mapping of event types to their processing configuration
+# This matches the TypeScript historyEventTypes constant
+HISTORY_EVENT_TYPES: dict[str, HistoryEventTypeConfig] = {
+ "ExecutionStarted": HistoryEventTypeConfig(
+ operation_type=OperationType.EXECUTION,
+ operation_status=OperationStatus.STARTED,
+ is_start_event=True,
+ is_end_event=False,
+ has_result=False,
+ ),
+ "ExecutionFailed": HistoryEventTypeConfig(
+ operation_type=OperationType.EXECUTION,
+ operation_status=OperationStatus.FAILED,
+ is_start_event=False,
+ is_end_event=True,
+ has_result=False,
+ ),
+ "ExecutionStopped": HistoryEventTypeConfig(
+ operation_type=OperationType.EXECUTION,
+ operation_status=OperationStatus.STOPPED,
+ is_start_event=False,
+ is_end_event=True,
+ has_result=False,
+ ),
+ "ExecutionSucceeded": HistoryEventTypeConfig(
+ operation_type=OperationType.EXECUTION,
+ operation_status=OperationStatus.SUCCEEDED,
+ is_start_event=False,
+ is_end_event=True,
+ has_result=False,
+ ),
+ "ExecutionTimedOut": HistoryEventTypeConfig(
+ operation_type=OperationType.EXECUTION,
+ operation_status=OperationStatus.TIMED_OUT,
+ is_start_event=False,
+ is_end_event=True,
+ has_result=False,
+ ),
+ "CallbackStarted": HistoryEventTypeConfig(
+ operation_type=OperationType.CALLBACK,
+ operation_status=OperationStatus.STARTED,
+ is_start_event=True,
+ is_end_event=False,
+ has_result=False,
+ ),
+ "CallbackFailed": HistoryEventTypeConfig(
+ operation_type=OperationType.CALLBACK,
+ operation_status=OperationStatus.FAILED,
+ is_start_event=False,
+ is_end_event=True,
+ has_result=True,
+ ),
+ "CallbackSucceeded": HistoryEventTypeConfig(
+ operation_type=OperationType.CALLBACK,
+ operation_status=OperationStatus.SUCCEEDED,
+ is_start_event=False,
+ is_end_event=True,
+ has_result=True,
+ ),
+ "CallbackTimedOut": HistoryEventTypeConfig(
+ operation_type=OperationType.CALLBACK,
+ operation_status=OperationStatus.TIMED_OUT,
+ is_start_event=False,
+ is_end_event=True,
+ has_result=True,
+ ),
+ "ContextStarted": HistoryEventTypeConfig(
+ operation_type=OperationType.CONTEXT,
+ operation_status=OperationStatus.STARTED,
+ is_start_event=True,
+ is_end_event=False,
+ has_result=False,
+ ),
+ "ContextFailed": HistoryEventTypeConfig(
+ operation_type=OperationType.CONTEXT,
+ operation_status=OperationStatus.FAILED,
+ is_start_event=False,
+ is_end_event=True,
+ has_result=True,
+ ),
+ "ContextSucceeded": HistoryEventTypeConfig(
+ operation_type=OperationType.CONTEXT,
+ operation_status=OperationStatus.SUCCEEDED,
+ is_start_event=False,
+ is_end_event=True,
+ has_result=True,
+ ),
+ "ChainedInvokeStarted": HistoryEventTypeConfig(
+ operation_type=OperationType.CHAINED_INVOKE,
+ operation_status=OperationStatus.STARTED,
+ is_start_event=True,
+ is_end_event=False,
+ has_result=False,
+ ),
+ "ChainedInvokeFailed": HistoryEventTypeConfig(
+ operation_type=OperationType.CHAINED_INVOKE,
+ operation_status=OperationStatus.FAILED,
+ is_start_event=False,
+ is_end_event=True,
+ has_result=True,
+ ),
+ "ChainedInvokeSucceeded": HistoryEventTypeConfig(
+ operation_type=OperationType.CHAINED_INVOKE,
+ operation_status=OperationStatus.SUCCEEDED,
+ is_start_event=False,
+ is_end_event=True,
+ has_result=True,
+ ),
+ "ChainedInvokeTimedOut": HistoryEventTypeConfig(
+ operation_type=OperationType.CHAINED_INVOKE,
+ operation_status=OperationStatus.TIMED_OUT,
+ is_start_event=False,
+ is_end_event=True,
+ has_result=True,
+ ),
+ "ChainedInvokeCancelled": HistoryEventTypeConfig(
+ operation_type=OperationType.CHAINED_INVOKE,
+ operation_status=OperationStatus.CANCELLED,
+ is_start_event=False,
+ is_end_event=True,
+ has_result=True,
+ ),
+ "StepStarted": HistoryEventTypeConfig(
+ operation_type=OperationType.STEP,
+ operation_status=OperationStatus.STARTED,
+ is_start_event=True,
+ is_end_event=False,
+ has_result=False,
+ ),
+ "StepFailed": HistoryEventTypeConfig(
+ operation_type=OperationType.STEP,
+ operation_status=OperationStatus.FAILED,
+ is_start_event=False,
+ is_end_event=True,
+ has_result=True,
+ ),
+ "StepSucceeded": HistoryEventTypeConfig(
+ operation_type=OperationType.STEP,
+ operation_status=OperationStatus.SUCCEEDED,
+ is_start_event=False,
+ is_end_event=True,
+ has_result=True,
+ ),
+ "WaitStarted": HistoryEventTypeConfig(
+ operation_type=OperationType.WAIT,
+ operation_status=OperationStatus.STARTED,
+ is_start_event=True,
+ is_end_event=False,
+ has_result=True,
+ ),
+ "WaitSucceeded": HistoryEventTypeConfig(
+ operation_type=OperationType.WAIT,
+ operation_status=OperationStatus.SUCCEEDED,
+ is_start_event=False,
+ is_end_event=True,
+ has_result=True,
+ ),
+ "WaitCancelled": HistoryEventTypeConfig(
+ operation_type=OperationType.WAIT,
+ operation_status=OperationStatus.CANCELLED,
+ is_start_event=False,
+ is_end_event=True,
+ has_result=True,
+ ),
+ # TODO: add support for populating invocation information from InvocationCompleted event
+ "InvocationCompleted": HistoryEventTypeConfig(
+ operation_type=None,
+ operation_status=None,
+ is_start_event=False,
+ is_end_event=False,
+ has_result=True,
+ ),
+}
+
+
+def events_to_operations(events: list[Event]) -> list[Operation]:
+ """Convert a list of history events into operations.
+
+ This function processes raw history events and groups them by operation ID,
+ creating comprehensive operation objects following the TypeScript pattern from
+ aws-durable-execution-sdk-js-testing.
+
+ Multiple events for the same operation_id are merged together, with each event
+ contributing its specific fields (e.g., CallbackStarted provides callback_id,
+ CallbackSucceeded provides result).
+
+ Args:
+ events: List of history events to process
+
+ Returns:
+ List of operations, one per unique operation ID
+
+ Raises:
+ InvalidParameterValueException: When required fields are missing from an event
+
+ Note:
+ InvocationCompleted events are currently skipped as they don't represent
+ operations. Future enhancement: populate invocation information from these
+ events (TODO).
+ """
+ operations_map: dict[str, Operation] = {}
+
+ for event in events:
+ if not event.event_type:
+ msg = "Missing required 'event_type' field in event"
+ raise InvalidParameterValueException(msg)
+
+ # Get event type configuration
+ event_config: HistoryEventTypeConfig | None = HISTORY_EVENT_TYPES.get(
+ event.event_type
+ )
+ if not event_config:
+ msg = f"Unknown event type: {event.event_type}"
+ raise InvalidParameterValueException(msg)
+
+ # TODO: add support for populating invocation information from InvocationCompleted event
+ if event.event_type == "InvocationCompleted":
+ continue
+
+ if not event.operation_id:
+ msg = f"Missing required 'operation_id' field in event {event.event_id}"
+ raise InvalidParameterValueException(msg)
+
+ # Get previous operation if it exists
+ previous_operation: Operation | None = operations_map.get(event.operation_id)
+
+ # Get operation type and status from configuration
+ operation_type: OperationType = (
+ event_config.operation_type or OperationType.EXECUTION
+ )
+ status: OperationStatus = (
+ event_config.operation_status or OperationStatus.PENDING
+ )
+
+ # Parse sub_type
+ sub_type: OperationSubType | None = None
+ if event.sub_type:
+ try:
+ sub_type = OperationSubType(event.sub_type)
+ except ValueError as e:
+ raise InvalidParameterValueException(str(e)) from e
+
+ # Create base operation
+ operation = Operation(
+ operation_id=event.operation_id,
+ operation_type=operation_type,
+ status=status,
+ name=event.name,
+ parent_id=event.parent_id,
+ sub_type=sub_type,
+ start_timestamp=datetime.datetime.now(tz=datetime.timezone.utc),
+ )
+
+ # Merge with previous operation if it exists
+ # Most fields are immutable, so they get preserved from previous events
+ if previous_operation:
+ operation = replace(
+ operation,
+ name=operation.name or previous_operation.name,
+ parent_id=operation.parent_id or previous_operation.parent_id,
+ sub_type=operation.sub_type or previous_operation.sub_type,
+ start_timestamp=previous_operation.start_timestamp,
+ end_timestamp=previous_operation.end_timestamp,
+ execution_details=previous_operation.execution_details,
+ context_details=previous_operation.context_details,
+ step_details=previous_operation.step_details,
+ wait_details=previous_operation.wait_details,
+ callback_details=previous_operation.callback_details,
+ chained_invoke_details=previous_operation.chained_invoke_details,
+ )
+
+ # Set timestamps based on event configuration
+ if event_config.is_start_event:
+ operation = replace(operation, start_timestamp=event.event_timestamp)
+ if event_config.is_end_event:
+ operation = replace(operation, end_timestamp=event.event_timestamp)
+
+ # Add operation-specific details incrementally
+ # Each event type contributes only the fields it has
+
+ # EXECUTION details
+ if (
+ operation_type == OperationType.EXECUTION
+ and event.execution_started_details
+ and event.execution_started_details.input
+ ):
+ operation = replace(
+ operation,
+ execution_details=ExecutionDetails(
+ input_payload=event.execution_started_details.input.payload
+ ),
+ )
+
+ # CALLBACK details - merge callback_id, result, and error from different events
+ if operation_type == OperationType.CALLBACK:
+ existing_cb: CallbackDetails | None = operation.callback_details
+ callback_id: str = existing_cb.callback_id if existing_cb else ""
+ result: str | None = existing_cb.result if existing_cb else None
+ error: ErrorObject | None = existing_cb.error if existing_cb else None
+
+ # CallbackStarted provides callback_id
+ if event.callback_started_details:
+ callback_id = event.callback_started_details.callback_id or callback_id
+
+ # CallbackSucceeded provides result
+ if (
+ event.callback_succeeded_details
+ and event.callback_succeeded_details.result
+ ):
+ result = event.callback_succeeded_details.result.payload
+
+ # CallbackFailed provides error
+ if event.callback_failed_details and event.callback_failed_details.error:
+ error = event.callback_failed_details.error.payload
+
+ # CallbackTimedOut provides error
+ if (
+ event.callback_timed_out_details
+ and event.callback_timed_out_details.error
+ ):
+ error = event.callback_timed_out_details.error.payload
+
+ operation = replace(
+ operation,
+ callback_details=CallbackDetails(
+ callback_id=callback_id,
+ result=result,
+ error=error,
+ ),
+ )
+
+ # STEP details - only update if this event type has result data
+ if operation_type == OperationType.STEP and event_config.has_result:
+ existing_step: StepDetails | None = operation.step_details
+ result_val: str | None = existing_step.result if existing_step else None
+ error_val: ErrorObject | None = (
+ existing_step.error if existing_step else None
+ )
+ attempt: int = existing_step.attempt if existing_step else 0
+ next_attempt_ts: datetime.datetime | None = (
+ existing_step.next_attempt_timestamp if existing_step else None
+ )
+
+ # StepSucceeded provides result
+ if event.step_succeeded_details:
+ if event.step_succeeded_details.result:
+ result_val = event.step_succeeded_details.result.payload
+ if event.step_succeeded_details.retry_details:
+ attempt = event.step_succeeded_details.retry_details.current_attempt
+
+ # StepFailed provides error and retry details
+ if event.step_failed_details:
+ if event.step_failed_details.error:
+ error_val = event.step_failed_details.error.payload
+ if event.step_failed_details.retry_details:
+ attempt = event.step_failed_details.retry_details.current_attempt
+ if (
+ event.step_failed_details.retry_details.next_attempt_delay_seconds
+ is not None
+ ):
+ next_attempt_ts = event.event_timestamp + datetime.timedelta(
+ seconds=event.step_failed_details.retry_details.next_attempt_delay_seconds
+ )
+
+ operation = replace(
+ operation,
+ step_details=StepDetails(
+ result=result_val,
+ error=error_val,
+ attempt=attempt,
+ next_attempt_timestamp=next_attempt_ts,
+ ),
+ )
+
+ # WAIT details
+ if operation_type == OperationType.WAIT and event.wait_started_details:
+ operation = replace(
+ operation,
+ wait_details=WaitDetails(
+ scheduled_end_timestamp=event.wait_started_details.scheduled_end_timestamp
+ ),
+ )
+
+ # CONTEXT details - only update if this event type has result data (matching TypeScript hasResult)
+ if operation_type == OperationType.CONTEXT and event_config.has_result:
+ if (
+ event.context_succeeded_details
+ and event.context_succeeded_details.result
+ ):
+ operation = replace(
+ operation,
+ context_details=ContextDetails(
+ result=event.context_succeeded_details.result.payload,
+ error=None,
+ ),
+ )
+ elif event.context_failed_details and event.context_failed_details.error:
+ operation = replace(
+ operation,
+ context_details=ContextDetails(
+ result=None,
+ error=event.context_failed_details.error.payload,
+ ),
+ )
+
+ # CHAINED_INVOKE details - only update if this event type has result data (matching TypeScript hasResult)
+ if operation_type == OperationType.CHAINED_INVOKE and event_config.has_result:
+ if (
+ event.chained_invoke_succeeded_details
+ and event.chained_invoke_succeeded_details.result
+ ):
+ operation = replace(
+ operation,
+ chained_invoke_details=ChainedInvokeDetails(
+ result=event.chained_invoke_succeeded_details.result.payload,
+ error=None,
+ ),
+ )
+ elif (
+ event.chained_invoke_failed_details
+ and event.chained_invoke_failed_details.error
+ ):
+ operation = replace(
+ operation,
+ chained_invoke_details=ChainedInvokeDetails(
+ result=None,
+ error=event.chained_invoke_failed_details.error.payload,
+ ),
+ )
+
+ # Store in map
+ operations_map[event.operation_id] = operation
+
+ return list(operations_map.values())
+
+
+@dataclass(frozen=True)
+class GetDurableExecutionHistoryRequest:
+ """Request to get durable execution history."""
+
+ durable_execution_arn: str
+ include_execution_data: bool | None = None
+ reverse_order: bool | None = None
+ marker: str | None = None
+ max_items: int = 0
+
+ @classmethod
+ def from_dict(cls, data: dict) -> GetDurableExecutionHistoryRequest:
+ return cls(
+ durable_execution_arn=data["DurableExecutionArn"],
+ include_execution_data=data.get("IncludeExecutionData"),
+ reverse_order=data.get("ReverseOrder"),
+ marker=data.get("Marker"),
+ max_items=data.get("MaxItems", 0),
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {"DurableExecutionArn": self.durable_execution_arn}
+ if self.include_execution_data is not None:
+ result["IncludeExecutionData"] = self.include_execution_data
+ if self.reverse_order is not None:
+ result["ReverseOrder"] = self.reverse_order
+ if self.marker is not None:
+ result["Marker"] = self.marker
+ if self.max_items is not None:
+ result["MaxItems"] = self.max_items
+ return result
+
+
+@dataclass(frozen=True)
+class GetDurableExecutionHistoryResponse:
+ """Response containing durable execution history events."""
+
+ events: list[Event]
+ next_marker: str | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> GetDurableExecutionHistoryResponse:
+ events = [Event.from_dict(event_data) for event_data in data.get("Events", [])]
+ return cls(
+ events=events,
+ next_marker=data.get("NextMarker"),
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {"Events": [event.to_dict() for event in self.events]}
+ if self.next_marker is not None:
+ result["NextMarker"] = self.next_marker
+ return result
+
+
+@dataclass(frozen=True)
+class ListDurableExecutionsByFunctionRequest:
+ """Request to list durable executions by function."""
+
+ function_name: str
+ qualifier: str | None = None
+ durable_execution_name: str | None = None
+ status_filter: list[str] | None = None
+ started_after: str | None = None
+ started_before: str | None = None
+ marker: str | None = None
+ max_items: int = 0
+ reverse_order: bool | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> ListDurableExecutionsByFunctionRequest:
+ # Handle query parameters that may be lists
+ function_name = data.get("FunctionName")
+ if isinstance(function_name, list):
+ function_name = function_name[0] if function_name else ""
+ elif not function_name:
+ function_name = ""
+
+ qualifier = data.get("Qualifier") or data.get("functionVersion")
+ if isinstance(qualifier, list):
+ qualifier = qualifier[0] if qualifier else None
+
+ durable_execution_name = data.get("DurableExecutionName") or data.get(
+ "executionName"
+ )
+ if isinstance(durable_execution_name, list):
+ durable_execution_name = (
+ durable_execution_name[0] if durable_execution_name else None
+ )
+
+ status_filter = data.get("StatusFilter") or data.get("statusFilter")
+ if isinstance(status_filter, list):
+ status_filter = status_filter if status_filter else None
+ elif status_filter:
+ status_filter = [status_filter]
+
+ started_after = data.get("StartedAfter") or data.get("startedAfter")
+ if isinstance(started_after, list):
+ started_after = started_after[0] if started_after else None
+
+ started_before = data.get("StartedBefore") or data.get("startedBefore")
+ if isinstance(started_before, list):
+ started_before = started_before[0] if started_before else None
+
+ marker = data.get("Marker") or data.get("marker")
+ if isinstance(marker, list):
+ marker = marker[0] if marker else None
+
+ max_items = data.get("MaxItems") or data.get("maxItems", 0)
+ if isinstance(max_items, list):
+ max_items = int(max_items[0]) if max_items else 0
+
+ reverse_order = data.get("ReverseOrder") or data.get("reverseOrder")
+ if isinstance(reverse_order, list):
+ reverse_order = (
+ reverse_order[0].lower() in ("true", "1", "yes")
+ if reverse_order
+ else None
+ )
+ elif isinstance(reverse_order, str):
+ reverse_order = reverse_order.lower() in ("true", "1", "yes")
+
+ return cls(
+ function_name=function_name,
+ qualifier=qualifier,
+ durable_execution_name=durable_execution_name,
+ status_filter=status_filter,
+ started_after=started_after,
+ started_before=started_before,
+ marker=marker,
+ max_items=max_items,
+ reverse_order=reverse_order,
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {"FunctionName": self.function_name}
+ if self.qualifier is not None:
+ result["Qualifier"] = self.qualifier
+ if self.durable_execution_name is not None:
+ result["DurableExecutionName"] = self.durable_execution_name
+ if self.status_filter is not None:
+ result["StatusFilter"] = self.status_filter
+ if self.started_after is not None:
+ result["StartedAfter"] = self.started_after
+ if self.started_before is not None:
+ result["StartedBefore"] = self.started_before
+ if self.marker is not None:
+ result["Marker"] = self.marker
+ if self.max_items is not None:
+ result["MaxItems"] = self.max_items
+ if self.reverse_order is not None:
+ result["ReverseOrder"] = self.reverse_order
+ return result
+
+
+@dataclass(frozen=True)
+class ListDurableExecutionsByFunctionResponse:
+ """Response containing list of durable executions by function."""
+
+ durable_executions: list[Execution]
+ next_marker: str | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> ListDurableExecutionsByFunctionResponse:
+ executions = [
+ Execution.from_dict(exec_data)
+ for exec_data in data.get("DurableExecutions", [])
+ ]
+ return cls(
+ durable_executions=executions,
+ next_marker=data.get("NextMarker"),
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {
+ "DurableExecutions": [exe.to_dict() for exe in self.durable_executions]
+ }
+ if self.next_marker is not None:
+ result["NextMarker"] = self.next_marker
+ return result
+
+
+# endregion history_models
+
+
+# region callback_models
+# Callback-related models
+@dataclass(frozen=True)
+class SendDurableExecutionCallbackSuccessRequest:
+ """Request to send callback success."""
+
+ callback_id: str
+ result: bytes | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> SendDurableExecutionCallbackSuccessRequest:
+ return cls(
+ callback_id=data["CallbackId"],
+ result=data.get("Result"),
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {"CallbackId": self.callback_id}
+ if self.result is not None:
+ result["Result"] = self.result
+ return result
+
+
+@dataclass(frozen=True)
+class SendDurableExecutionCallbackSuccessResponse:
+ """Response from sending callback success."""
+
+
+@dataclass(frozen=True)
+class SendDurableExecutionCallbackFailureRequest:
+ """Request to send callback failure."""
+
+ callback_id: str
+ error: ErrorObject | None = None
+
+ @classmethod
+ def from_dict(
+ cls, data: dict, callback_id: str
+ ) -> SendDurableExecutionCallbackFailureRequest:
+ error = ErrorObject.from_dict(data) if data else None
+
+ return cls(
+ callback_id=callback_id,
+ error=error,
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {"CallbackId": self.callback_id}
+ if self.error is not None:
+ result["Error"] = self.error.to_dict()
+ return result
+
+
+@dataclass(frozen=True)
+class SendDurableExecutionCallbackFailureResponse:
+ """Response from sending callback failure."""
+
+
+@dataclass(frozen=True)
+class SendDurableExecutionCallbackHeartbeatRequest:
+ """Request to send callback heartbeat."""
+
+ callback_id: str
+
+ @classmethod
+ def from_dict(cls, data: dict) -> SendDurableExecutionCallbackHeartbeatRequest:
+ return cls(callback_id=data["CallbackId"])
+
+ def to_dict(self) -> dict[str, Any]:
+ return {"CallbackId": self.callback_id}
+
+
+@dataclass(frozen=True)
+class SendDurableExecutionCallbackHeartbeatResponse:
+ """Response from sending callback heartbeat."""
+
+
+# endregion callback_models
+
+
+# region checkpoint_models
+# Checkpoint-related models
+@dataclass(frozen=True)
+class CheckpointUpdatedExecutionState:
+ """Updated execution state from checkpoint."""
+
+ operations: list[Operation]
+ next_marker: str | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> CheckpointUpdatedExecutionState:
+ operations = [
+ Operation.from_dict(op_data) for op_data in data.get("Operations", [])
+ ]
+ return cls(
+ operations=operations,
+ next_marker=data.get("NextMarker"),
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {
+ "Operations": [op.to_dict() for op in self.operations]
+ }
+ if self.next_marker is not None:
+ result["NextMarker"] = self.next_marker
+ return result
+
+
+@dataclass(frozen=True)
+class CheckpointDurableExecutionRequest:
+ """Request to checkpoint a durable execution."""
+
+ durable_execution_arn: str
+ checkpoint_token: str
+ updates: list[OperationUpdate] | None = None
+ client_token: str | None = None
+
+ @classmethod
+ def from_dict(
+ cls, data: dict, durable_execution_arn: str
+ ) -> CheckpointDurableExecutionRequest:
+ updates = None
+ if updates_data := data.get("Updates"):
+ updates = []
+ for update_data in updates_data:
+ # Map dictionary fields to OperationUpdate constructor parameters
+ operation_update = OperationUpdate(
+ operation_id=update_data["Id"],
+ operation_type=OperationType(update_data["Type"]),
+ action=OperationAction(update_data["Action"]),
+ parent_id=update_data.get("ParentId"),
+ name=update_data.get("Name"),
+ sub_type=OperationSubType(update_data["SubType"])
+ if update_data.get("SubType")
+ else None,
+ payload=update_data.get("Payload"),
+ error=ErrorObject.from_dict(update_data["Error"])
+ if update_data.get("Error")
+ else None,
+ context_options=ContextOptions.from_dict(
+ update_data["ContextOptions"]
+ )
+ if update_data.get("ContextOptions")
+ else None,
+ step_options=StepOptions.from_dict(update_data["StepOptions"])
+ if update_data.get("StepOptions")
+ else None,
+ wait_options=WaitOptions.from_dict(update_data["WaitOptions"])
+ if update_data.get("WaitOptions")
+ else None,
+ callback_options=CallbackOptions.from_dict(
+ update_data["CallbackOptions"]
+ )
+ if update_data.get("CallbackOptions")
+ else None,
+ chained_invoke_options=ChainedInvokeOptions.from_dict(
+ update_data["ChainedInvokeOptions"]
+ )
+ if update_data.get("ChainedInvokeOptions")
+ else None,
+ )
+ updates.append(operation_update)
+
+ return cls(
+ durable_execution_arn=durable_execution_arn,
+ checkpoint_token=data["CheckpointToken"],
+ updates=updates,
+ client_token=data.get("ClientToken"),
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {
+ "DurableExecutionArn": self.durable_execution_arn,
+ "CheckpointToken": self.checkpoint_token,
+ }
+ if self.updates is not None:
+ result["Updates"] = [update.to_dict() for update in self.updates]
+ if self.client_token is not None:
+ result["ClientToken"] = self.client_token
+ return result
+
+
+@dataclass(frozen=True)
+class CheckpointDurableExecutionResponse:
+ """Response from checkpointing a durable execution."""
+
+ checkpoint_token: str
+ new_execution_state: CheckpointUpdatedExecutionState | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> CheckpointDurableExecutionResponse:
+ new_execution_state = None
+ if state_data := data.get("NewExecutionState"):
+ new_execution_state = CheckpointUpdatedExecutionState.from_dict(state_data)
+
+ return cls(
+ checkpoint_token=data["CheckpointToken"],
+ new_execution_state=new_execution_state,
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ result: dict[str, Any] = {"CheckpointToken": self.checkpoint_token}
+ if self.new_execution_state is not None:
+ result["NewExecutionState"] = self.new_execution_state.to_dict()
+ return result
+
+
+# endregion checkpoint_models
+
+
+# region error_models
+# Error response structure for consistent error handling
+@dataclass(frozen=True)
+class ErrorResponse:
+ """Structured error response for web service operations."""
+
+ error_type: str
+ error_message: str
+ error_code: str | None = None
+ request_id: str | None = None
+
+ @classmethod
+ def from_dict(cls, data: dict) -> ErrorResponse:
+ """Create ErrorResponse from dictionary.
+
+ Args:
+ data: Dictionary containing error data
+
+ Returns:
+ ErrorResponse: The error response object
+ """
+ error_data = data.get("error", data) # Support both nested and flat structures
+ return cls(
+ error_type=error_data["type"],
+ error_message=error_data["message"],
+ error_code=error_data.get("code"),
+ request_id=error_data.get("requestId"),
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ """Convert ErrorResponse to dictionary.
+
+ Returns:
+ dict: Dictionary representation of the error response
+ """
+ error_data: dict[str, Any] = {
+ "type": self.error_type,
+ "message": self.error_message,
+ }
+
+ if self.error_code is not None:
+ error_data["code"] = self.error_code
+ if self.request_id is not None:
+ error_data["requestId"] = self.request_id
+
+ return {"error": error_data}
+
+
+# endregion error_models
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/observer.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/observer.py
new file mode 100644
index 0000000..1b518ce
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/observer.py
@@ -0,0 +1,144 @@
+"""Checkpoint processors can notify the Execution of notable event state changes. Observer pattern."""
+
+from __future__ import annotations
+
+import threading
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING
+
+from aws_durable_execution_sdk_python_testing.token import CallbackToken
+
+if TYPE_CHECKING:
+ from collections.abc import Callable
+
+ from aws_durable_execution_sdk_python.lambda_service import (
+ ErrorObject,
+ CallbackOptions,
+ )
+
+
+class ExecutionObserver(ABC):
+ """Observer for execution lifecycle events."""
+
+ @abstractmethod
+ def on_completed(self, execution_arn: str, result: str | None = None) -> None:
+ """Called when execution completes successfully."""
+
+ @abstractmethod
+ def on_failed(self, execution_arn: str, error: ErrorObject) -> None:
+ """Called when execution fails."""
+
+ @abstractmethod
+ def on_timed_out(self, execution_arn: str, error: ErrorObject) -> None:
+ """Called when execution times out."""
+
+ @abstractmethod
+ def on_stopped(self, execution_arn: str, error: ErrorObject) -> None:
+ """Called when execution is stopped."""
+
+ @abstractmethod
+ def on_wait_timer_scheduled(
+ self, execution_arn: str, operation_id: str, delay: float
+ ) -> None:
+ """Called when wait timer scheduled."""
+
+ @abstractmethod
+ def on_step_retry_scheduled(
+ self, execution_arn: str, operation_id: str, delay: float
+ ) -> None:
+ """Called when step retry scheduled."""
+
+ @abstractmethod
+ def on_callback_created(
+ self,
+ execution_arn: str,
+ operation_id: str,
+ callback_options: CallbackOptions | None,
+ callback_token: CallbackToken,
+ ) -> None:
+ """Called when callback is created."""
+
+
+class ExecutionNotifier:
+ """Notifies observers about execution events. Thread-safe."""
+
+ def __init__(self) -> None:
+ self._observers: list[ExecutionObserver] = []
+ self._lock = threading.RLock()
+
+ def add_observer(self, observer: ExecutionObserver) -> None:
+ """Add an observer to be notified of execution events."""
+ with self._lock:
+ self._observers.append(observer)
+
+ def _notify_observers(self, method: Callable, *args, **kwargs) -> None:
+ """Notify all observers by calling the specified method."""
+ with self._lock:
+ observers = self._observers.copy()
+ for observer in observers:
+ getattr(observer, method.__name__)(*args, **kwargs)
+
+ # region event emitters
+ def notify_completed(self, execution_arn: str, result: str | None = None) -> None:
+ """Notify observers about execution completion."""
+ self._notify_observers(
+ ExecutionObserver.on_completed, execution_arn=execution_arn, result=result
+ )
+
+ def notify_failed(self, execution_arn: str, error: ErrorObject) -> None:
+ """Notify observers about execution failure."""
+ self._notify_observers(
+ ExecutionObserver.on_failed, execution_arn=execution_arn, error=error
+ )
+
+ def notify_timed_out(self, execution_arn: str, error: ErrorObject) -> None:
+ """Notify observers about execution timeout."""
+ self._notify_observers(
+ ExecutionObserver.on_timed_out, execution_arn=execution_arn, error=error
+ )
+
+ def notify_stopped(self, execution_arn: str, error: ErrorObject) -> None:
+ """Notify observers about execution being stopped."""
+ self._notify_observers(
+ ExecutionObserver.on_stopped, execution_arn=execution_arn, error=error
+ )
+
+ def notify_wait_timer_scheduled(
+ self, execution_arn: str, operation_id: str, delay: float
+ ) -> None:
+ """Notify observers about wait timer scheduling."""
+ self._notify_observers(
+ ExecutionObserver.on_wait_timer_scheduled,
+ execution_arn=execution_arn,
+ operation_id=operation_id,
+ delay=delay,
+ )
+
+ def notify_step_retry_scheduled(
+ self, execution_arn: str, operation_id: str, delay: float
+ ) -> None:
+ """Notify observers about step retry scheduling."""
+ self._notify_observers(
+ ExecutionObserver.on_step_retry_scheduled,
+ execution_arn=execution_arn,
+ operation_id=operation_id,
+ delay=delay,
+ )
+
+ def notify_callback_created(
+ self,
+ execution_arn: str,
+ operation_id: str,
+ callback_options: CallbackOptions | None,
+ callback_token: CallbackToken,
+ ) -> None:
+ """Notify observers about callback creation."""
+ self._notify_observers(
+ ExecutionObserver.on_callback_created,
+ execution_arn=execution_arn,
+ operation_id=operation_id,
+ callback_options=callback_options,
+ callback_token=callback_token,
+ )
+
+ # endregion event emitters
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/py.typed b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/py.typed
new file mode 100644
index 0000000..7ef2116
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/py.typed
@@ -0,0 +1 @@
+# Marker file that indicates this package supports typing
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/runner.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/runner.py
new file mode 100644
index 0000000..d60e774
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/runner.py
@@ -0,0 +1,1167 @@
+from __future__ import annotations
+
+import json
+import logging
+import os
+import time
+from dataclasses import dataclass, field
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Concatenate,
+ ParamSpec,
+ Protocol,
+ Self,
+ TypeVar,
+ cast,
+)
+
+import aws_durable_execution_sdk_python
+import boto3 # type: ignore
+from botocore.exceptions import ClientError # type: ignore
+from aws_durable_execution_sdk_python.execution import (
+ InvocationStatus,
+ durable_execution,
+)
+from aws_durable_execution_sdk_python.lambda_service import (
+ ErrorObject,
+ OperationPayload,
+ OperationStatus,
+ OperationSubType,
+ OperationType,
+)
+from aws_durable_execution_sdk_python.lambda_service import Operation as SvcOperation
+
+from aws_durable_execution_sdk_python_testing.checkpoint.processor import (
+ CheckpointProcessor,
+)
+from aws_durable_execution_sdk_python_testing.client import InMemoryServiceClient
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ DurableFunctionsLocalRunnerError,
+ DurableFunctionsTestError,
+ InvalidParameterValueException,
+ ResourceNotFoundException,
+)
+from aws_durable_execution_sdk_python_testing.executor import Executor
+from aws_durable_execution_sdk_python_testing.invoker import (
+ InProcessInvoker,
+ LambdaInvoker,
+ create_lambda_client,
+)
+from aws_durable_execution_sdk_python_testing.model import (
+ GetDurableExecutionHistoryResponse,
+ GetDurableExecutionResponse,
+ StartDurableExecutionInput,
+ StartDurableExecutionOutput,
+ events_to_operations,
+)
+from aws_durable_execution_sdk_python_testing.scheduler import Scheduler
+from aws_durable_execution_sdk_python_testing.stores.base import (
+ ExecutionStore,
+ StoreType,
+)
+from aws_durable_execution_sdk_python_testing.stores.filesystem import (
+ FileSystemExecutionStore,
+)
+from aws_durable_execution_sdk_python_testing.stores.memory import (
+ InMemoryExecutionStore,
+)
+from aws_durable_execution_sdk_python_testing.stores.sqlite import SQLiteExecutionStore
+from aws_durable_execution_sdk_python_testing.web.server import WebServer
+
+
+if TYPE_CHECKING:
+ import datetime
+ from collections.abc import Callable, MutableMapping
+
+ from aws_durable_execution_sdk_python.context import DurableContext
+ from aws_durable_execution_sdk_python.execution import InvocationStatus
+
+ from aws_durable_execution_sdk_python_testing.execution import Execution
+ from aws_durable_execution_sdk_python_testing.web.server import WebServiceConfig
+ from aws_durable_execution_sdk_python_testing.model import Event
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass(frozen=True)
+class WebRunnerConfig:
+ """Configuration for the WebRunner using composition pattern.
+
+ This configuration class encapsulates all settings needed to run the web server
+ for durable functions testing, including HTTP server configuration and Lambda
+ service configuration.
+ """
+
+ # HTTP server configuration (existing WebServiceConfig)
+ web_service: WebServiceConfig
+
+ # Lambda service configuration (web runner specific)
+ lambda_endpoint: str = "http://127.0.0.1:3001"
+ local_runner_endpoint: str = "http://0.0.0.0:5000"
+ local_runner_region: str = "us-west-2"
+ local_runner_mode: str = "local"
+
+ # Store configuration
+ store_type: StoreType = StoreType.MEMORY
+ store_path: str | None = None # Path for filesystem store
+
+
+@dataclass(frozen=True)
+class Operation:
+ operation_id: str
+ operation_type: OperationType
+ status: OperationStatus
+ parent_id: str | None = field(default=None, kw_only=True)
+ name: str | None = field(default=None, kw_only=True)
+ sub_type: OperationSubType | None = field(default=None, kw_only=True)
+ start_timestamp: datetime.datetime | None = field(default=None, kw_only=True)
+ end_timestamp: datetime.datetime | None = field(default=None, kw_only=True)
+
+
+T = TypeVar("T", bound=Operation)
+P = ParamSpec("P")
+
+
+class OperationFactory(Protocol):
+ @staticmethod
+ def from_svc_operation(
+ operation: SvcOperation, all_operations: list[SvcOperation] | None = None
+ ) -> Operation: ...
+
+
+@dataclass(frozen=True)
+class ExecutionOperation(Operation):
+ input_payload: str | None = None
+
+ @staticmethod
+ def from_svc_operation(
+ operation: SvcOperation,
+ all_operations: list[SvcOperation] | None = None, # noqa: ARG004
+ ) -> ExecutionOperation:
+ if operation.operation_type != OperationType.EXECUTION:
+ msg: str = f"Expected EXECUTION operation, got {operation.operation_type}"
+ raise InvalidParameterValueException(msg)
+ return ExecutionOperation(
+ operation_id=operation.operation_id,
+ operation_type=operation.operation_type,
+ status=operation.status,
+ parent_id=operation.parent_id,
+ name=operation.name,
+ sub_type=operation.sub_type,
+ start_timestamp=operation.start_timestamp,
+ end_timestamp=operation.end_timestamp,
+ input_payload=(
+ operation.execution_details.input_payload
+ if operation.execution_details
+ else None
+ ),
+ )
+
+
+@dataclass(frozen=True)
+class ContextOperation(Operation):
+ child_operations: list[Operation]
+ result: OperationPayload | None = None
+ error: ErrorObject | None = None
+
+ @staticmethod
+ def from_svc_operation(
+ operation: SvcOperation, all_operations: list[SvcOperation] | None = None
+ ) -> ContextOperation:
+ if operation.operation_type != OperationType.CONTEXT:
+ msg: str = f"Expected CONTEXT operation, got {operation.operation_type}"
+ raise InvalidParameterValueException(msg)
+
+ child_operations = []
+ if all_operations:
+ child_operations = [
+ create_operation(op, all_operations)
+ for op in all_operations
+ if op.parent_id == operation.operation_id
+ ]
+
+ return ContextOperation(
+ operation_id=operation.operation_id,
+ operation_type=operation.operation_type,
+ status=operation.status,
+ parent_id=operation.parent_id,
+ name=operation.name,
+ sub_type=operation.sub_type,
+ start_timestamp=operation.start_timestamp,
+ end_timestamp=operation.end_timestamp,
+ child_operations=child_operations,
+ result=operation.context_details.result
+ if operation.context_details
+ else None,
+ error=operation.context_details.error
+ if operation.context_details
+ else None,
+ )
+
+ def get_operation_by_name(self, name: str) -> Operation:
+ for operation in self.child_operations:
+ if operation.name == name:
+ return operation
+ msg: str = f"Child Operation with name '{name}' not found"
+ raise DurableFunctionsTestError(msg)
+
+ def get_step(self, name: str) -> StepOperation:
+ return cast(StepOperation, self.get_operation_by_name(name))
+
+ def get_wait(self, name: str) -> WaitOperation:
+ return cast(WaitOperation, self.get_operation_by_name(name))
+
+ def get_context(self, name: str) -> ContextOperation:
+ return cast(ContextOperation, self.get_operation_by_name(name))
+
+ def get_callback(self, name: str) -> CallbackOperation:
+ return cast(CallbackOperation, self.get_operation_by_name(name))
+
+ def get_invoke(self, name: str) -> InvokeOperation:
+ return cast(InvokeOperation, self.get_operation_by_name(name))
+
+ def get_execution(self, name: str) -> ExecutionOperation:
+ return cast(ExecutionOperation, self.get_operation_by_name(name))
+
+
+@dataclass(frozen=True)
+class StepOperation(ContextOperation):
+ attempt: int = 0
+ next_attempt_timestamp: datetime.datetime | None = None
+ result: OperationPayload | None = None
+ error: ErrorObject | None = None
+
+ @staticmethod
+ def from_svc_operation(
+ operation: SvcOperation, all_operations: list[SvcOperation] | None = None
+ ) -> StepOperation:
+ if operation.operation_type != OperationType.STEP:
+ msg: str = f"Expected STEP operation, got {operation.operation_type}"
+ raise InvalidParameterValueException(msg)
+
+ child_operations = []
+ if all_operations:
+ child_operations = [
+ create_operation(op, all_operations)
+ for op in all_operations
+ if op.parent_id == operation.operation_id
+ ]
+
+ return StepOperation(
+ operation_id=operation.operation_id,
+ operation_type=operation.operation_type,
+ status=operation.status,
+ parent_id=operation.parent_id,
+ name=operation.name,
+ sub_type=operation.sub_type,
+ start_timestamp=operation.start_timestamp,
+ end_timestamp=operation.end_timestamp,
+ child_operations=child_operations,
+ attempt=operation.step_details.attempt if operation.step_details else 0,
+ next_attempt_timestamp=(
+ operation.step_details.next_attempt_timestamp
+ if operation.step_details
+ else None
+ ),
+ result=operation.step_details.result if operation.step_details else None,
+ error=operation.step_details.error if operation.step_details else None,
+ )
+
+
+@dataclass(frozen=True)
+class WaitOperation(Operation):
+ scheduled_end_timestamp: datetime.datetime | None = None
+
+ @staticmethod
+ def from_svc_operation(
+ operation: SvcOperation,
+ all_operations: list[SvcOperation] | None = None, # noqa: ARG004
+ ) -> WaitOperation:
+ if operation.operation_type != OperationType.WAIT:
+ msg: str = f"Expected WAIT operation, got {operation.operation_type}"
+ raise InvalidParameterValueException(msg)
+ return WaitOperation(
+ operation_id=operation.operation_id,
+ operation_type=operation.operation_type,
+ status=operation.status,
+ parent_id=operation.parent_id,
+ name=operation.name,
+ sub_type=operation.sub_type,
+ start_timestamp=operation.start_timestamp,
+ end_timestamp=operation.end_timestamp,
+ scheduled_end_timestamp=(
+ operation.wait_details.scheduled_end_timestamp
+ if operation.wait_details
+ else None
+ ),
+ )
+
+
+@dataclass(frozen=True)
+class CallbackOperation(ContextOperation):
+ callback_id: str | None = None
+ result: OperationPayload | None = None
+ error: ErrorObject | None = None
+
+ @staticmethod
+ def from_svc_operation(
+ operation: SvcOperation, all_operations: list[SvcOperation] | None = None
+ ) -> CallbackOperation:
+ if operation.operation_type != OperationType.CALLBACK:
+ msg: str = f"Expected CALLBACK operation, got {operation.operation_type}"
+ raise InvalidParameterValueException(msg)
+
+ child_operations = []
+ if all_operations:
+ child_operations = [
+ create_operation(op, all_operations)
+ for op in all_operations
+ if op.parent_id == operation.operation_id
+ ]
+
+ return CallbackOperation(
+ operation_id=operation.operation_id,
+ operation_type=operation.operation_type,
+ status=operation.status,
+ parent_id=operation.parent_id,
+ name=operation.name,
+ sub_type=operation.sub_type,
+ start_timestamp=operation.start_timestamp,
+ end_timestamp=operation.end_timestamp,
+ child_operations=child_operations,
+ callback_id=(
+ operation.callback_details.callback_id
+ if operation.callback_details
+ else None
+ ),
+ result=operation.callback_details.result
+ if operation.callback_details
+ else None,
+ error=operation.callback_details.error
+ if operation.callback_details
+ else None,
+ )
+
+
+@dataclass(frozen=True)
+class InvokeOperation(Operation):
+ result: OperationPayload | None = None
+ error: ErrorObject | None = None
+
+ @staticmethod
+ def from_svc_operation(
+ operation: SvcOperation,
+ all_operations: list[SvcOperation] | None = None, # noqa: ARG004
+ ) -> InvokeOperation:
+ if operation.operation_type != OperationType.CHAINED_INVOKE:
+ msg: str = f"Expected INVOKE operation, got {operation.operation_type}"
+ raise InvalidParameterValueException(msg)
+ return InvokeOperation(
+ operation_id=operation.operation_id,
+ operation_type=operation.operation_type,
+ status=operation.status,
+ parent_id=operation.parent_id,
+ name=operation.name,
+ sub_type=operation.sub_type,
+ start_timestamp=operation.start_timestamp,
+ end_timestamp=operation.end_timestamp,
+ result=operation.chained_invoke_details.result
+ if operation.chained_invoke_details
+ else None,
+ error=operation.chained_invoke_details.error
+ if operation.chained_invoke_details
+ else None,
+ )
+
+
+OPERATION_FACTORIES: MutableMapping[OperationType, type[OperationFactory]] = {
+ OperationType.EXECUTION: ExecutionOperation,
+ OperationType.CONTEXT: ContextOperation,
+ OperationType.STEP: StepOperation,
+ OperationType.WAIT: WaitOperation,
+ OperationType.CHAINED_INVOKE: InvokeOperation,
+ OperationType.CALLBACK: CallbackOperation,
+}
+
+
+def create_operation(
+ svc_operation: SvcOperation, all_operations: list[SvcOperation] | None = None
+) -> Operation:
+ operation_class: type[OperationFactory] | None = OPERATION_FACTORIES.get(
+ svc_operation.operation_type
+ )
+ if not operation_class:
+ msg: str = f"Unknown operation type: {svc_operation.operation_type}"
+ raise DurableFunctionsTestError(msg)
+ return operation_class.from_svc_operation(svc_operation, all_operations)
+
+
+def _get_callback_id_from_events(
+ events: list[Event], name: str | None = None
+) -> str | None:
+ """
+ Get callback ID from execution history for callbacks that haven't completed.
+
+ Args:
+ execution_arn: The ARN of the execution to query.
+ name: Optional callback name to search for. If not provided, returns the latest callback.
+
+ Returns:
+ The callback ID string for a non-completed callback, or None if not found.
+
+ Raises:
+ DurableFunctionsTestError: If the named callback has already succeeded/failed/timed out.
+ """
+ callback_started_events = [
+ event for event in events if event.event_type == "CallbackStarted"
+ ]
+
+ if not callback_started_events:
+ return None
+
+ completed_callback_ids = {
+ event.event_id
+ for event in events
+ if event.event_type
+ in ["CallbackSucceeded", "CallbackFailed", "CallbackTimedOut"]
+ }
+
+ if name is not None:
+ for event in callback_started_events:
+ if event.name == name:
+ callback_id = event.event_id
+ if callback_id in completed_callback_ids:
+ raise DurableFunctionsTestError(
+ f"Callback {name} has already completed (succeeded/failed/timed out)"
+ )
+ return (
+ event.callback_started_details.callback_id
+ if event.callback_started_details
+ else None
+ )
+ return None
+
+ # If name is not provided, find the latest non-completed callback event
+ active_callbacks = [
+ event
+ for event in callback_started_events
+ if event.event_id not in completed_callback_ids
+ ]
+
+ if not active_callbacks:
+ return None
+
+ latest_event = active_callbacks[-1]
+ return (
+ latest_event.callback_started_details.callback_id
+ if latest_event.callback_started_details
+ else None
+ )
+
+
+@dataclass(frozen=True)
+class DurableFunctionTestResult:
+ status: InvocationStatus
+ operations: list[Operation]
+ result: OperationPayload | None = None
+ error: ErrorObject | None = None
+
+ @classmethod
+ def create(cls, execution: Execution) -> DurableFunctionTestResult:
+ operations = []
+ for operation in execution.operations:
+ if operation.operation_type is OperationType.EXECUTION:
+ # don't want the EXECUTION operations in the list test code asserts against
+ continue
+
+ if operation.parent_id is None:
+ operations.append(create_operation(operation, execution.operations))
+
+ if execution.result is None:
+ msg: str = "Execution result must exist to create test result."
+ raise DurableFunctionsTestError(msg)
+
+ return cls(
+ status=execution.result.status,
+ operations=operations,
+ result=execution.result.result,
+ error=execution.result.error,
+ )
+
+ @classmethod
+ def from_execution_history(
+ cls,
+ execution_response: GetDurableExecutionResponse,
+ history_response: GetDurableExecutionHistoryResponse,
+ ) -> DurableFunctionTestResult:
+ """Create test result from execution history responses.
+
+ Factory method for cloud runner that builds DurableFunctionTestResult
+ from GetDurableExecution and GetDurableExecutionHistory API responses.
+ """
+ # Map status string to InvocationStatus enum
+ try:
+ status = InvocationStatus[execution_response.status]
+ except KeyError:
+ logger.warning(
+ "Unknown status: %s, defaulting to FAILED", execution_response.status
+ )
+ status = InvocationStatus.FAILED
+
+ # Convert Events to Operations - group by operation_id and merge
+ try:
+ svc_operations = events_to_operations(history_response.events)
+ except Exception as e:
+ logger.warning("Failed to convert events to operations: %s", e)
+ svc_operations = []
+
+ # Build operation tree (exclude EXECUTION type from top level)
+ operations = []
+ for svc_op in svc_operations:
+ if svc_op.operation_type == OperationType.EXECUTION:
+ continue
+ if svc_op.parent_id is None:
+ operations.append(create_operation(svc_op, svc_operations))
+
+ return cls(
+ status=status,
+ operations=operations,
+ result=execution_response.result,
+ error=execution_response.error,
+ )
+
+ def get_operation_by_name(self, name: str) -> Operation:
+ for operation in self.operations:
+ if operation.name == name:
+ return operation
+ msg: str = f"Operation with name '{name}' not found"
+ raise DurableFunctionsTestError(msg)
+
+ def get_step(self, name: str) -> StepOperation:
+ return cast(StepOperation, self.get_operation_by_name(name))
+
+ def get_wait(self, name: str) -> WaitOperation:
+ return cast(WaitOperation, self.get_operation_by_name(name))
+
+ def get_context(self, name: str) -> ContextOperation:
+ return cast(ContextOperation, self.get_operation_by_name(name))
+
+ def get_callback(self, name: str) -> CallbackOperation:
+ return cast(CallbackOperation, self.get_operation_by_name(name))
+
+ def get_invoke(self, name: str) -> InvokeOperation:
+ return cast(InvokeOperation, self.get_operation_by_name(name))
+
+ def get_execution(self, name: str) -> ExecutionOperation:
+ return cast(ExecutionOperation, self.get_operation_by_name(name))
+
+ def get_all_operations(self) -> list[Operation]:
+ """Recursively get all operations including nested ones."""
+ all_ops = []
+ stack = list(self.operations)
+ while stack:
+ op = stack.pop()
+ all_ops.append(op)
+ # Add child operations to stack (if they exist)
+ if hasattr(op, "child_operations") and op.child_operations:
+ stack.extend(op.child_operations)
+ return all_ops
+
+
+class DurableFunctionTestRunner:
+ def __init__(self, handler: Callable, poll_interval: float = 1.0):
+ self._scheduler: Scheduler = Scheduler()
+ self._scheduler.start()
+ self._store = InMemoryExecutionStore()
+ self.poll_interval = poll_interval
+ self._checkpoint_processor = CheckpointProcessor(
+ store=self._store, scheduler=self._scheduler
+ )
+ self._service_client = InMemoryServiceClient(self._checkpoint_processor)
+ self._invoker = InProcessInvoker(handler, self._service_client)
+ self._executor = Executor(
+ store=self._store,
+ scheduler=self._scheduler,
+ invoker=self._invoker,
+ checkpoint_processor=self._checkpoint_processor,
+ )
+
+ # Wire up observer pattern - CheckpointProcessor uses this to notify executor of state changes
+ self._checkpoint_processor.add_execution_observer(self._executor)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
+
+ def close(self):
+ self._scheduler.stop()
+
+ def run(
+ self,
+ input: str | None = None, # noqa: A002
+ timeout: int = 900,
+ function_name: str = "test-function",
+ execution_name: str = "execution-name",
+ account_id: str = "123456789012",
+ ) -> DurableFunctionTestResult:
+ execution_arn = self.run_async(
+ input=input,
+ timeout=timeout,
+ function_name=function_name,
+ execution_name=execution_name,
+ account_id=account_id,
+ )
+
+ return self.wait_for_result(execution_arn=execution_arn, timeout=timeout)
+
+ def send_callback_success(
+ self, callback_id: str, result: bytes | None = None
+ ) -> None:
+ self._executor.send_callback_success(callback_id=callback_id, result=result)
+
+ def send_callback_failure(
+ self, callback_id: str, error: ErrorObject | None = None
+ ) -> None:
+ self._executor.send_callback_failure(callback_id=callback_id, error=error)
+
+ def send_callback_heartbeat(self, callback_id: str) -> None:
+ self._executor.send_callback_heartbeat(callback_id=callback_id)
+
+ def run_async(
+ self,
+ input: str | None = None, # noqa: A002
+ timeout: int = 900,
+ function_name: str = "test-function",
+ execution_name: str = "execution-name",
+ account_id: str = "123456789012",
+ ) -> str:
+ start_input = StartDurableExecutionInput(
+ account_id=account_id,
+ function_name=function_name,
+ function_qualifier="$LATEST",
+ execution_name=execution_name,
+ execution_timeout_seconds=timeout,
+ execution_retention_period_days=7,
+ invocation_id="inv-12345678-1234-1234-1234-123456789012",
+ trace_fields={"trace_id": "abc123", "span_id": "def456"},
+ tenant_id="tenant-001",
+ input=input,
+ )
+
+ output: StartDurableExecutionOutput = self._executor.start_execution(
+ start_input
+ )
+
+ if output.execution_arn is None:
+ msg_arn: str = "Execution ARN must exist to run test."
+ raise DurableFunctionsTestError(msg_arn)
+ return output.execution_arn
+
+ def wait_for_result(
+ self, execution_arn: str, timeout: int = 60
+ ) -> DurableFunctionTestResult:
+ # Block until completion
+ completed = self._executor.wait_until_complete(execution_arn, timeout)
+
+ if not completed:
+ msg_timeout: str = "Execution did not complete within timeout"
+
+ raise TimeoutError(msg_timeout)
+
+ execution: Execution = self._store.load(execution_arn)
+ return DurableFunctionTestResult.create(execution=execution)
+
+ def wait_for_callback(
+ self, execution_arn: str, name: str | None = None, timeout: int = 60
+ ) -> str:
+ start_time = time.time()
+
+ while time.time() - start_time < timeout:
+ try:
+ history_response = self._executor.get_execution_history(execution_arn)
+ callback_id = _get_callback_id_from_events(
+ events=history_response.events, name=name
+ )
+ if callback_id:
+ return callback_id
+ except ResourceNotFoundException as e:
+ pass
+ except Exception as e:
+ msg = f"Failed to fetch execution history: {e}"
+ raise DurableFunctionsTestError(msg) from e
+
+ # Wait before next poll
+ time.sleep(self.poll_interval)
+
+ # Timeout reached
+ elapsed = time.time() - start_time
+ msg = f"Callback did not available within {timeout}s (elapsed: {elapsed:.1f}s."
+ raise TimeoutError(msg)
+
+
+class DurableChildContextTestRunner(DurableFunctionTestRunner):
+ """Test a durable block, annotated with @durable_with_child_context, in isolation."""
+
+ def __init__(
+ self,
+ context_function: Callable[Concatenate[DurableContext, P], Any],
+ *args,
+ **kwargs,
+ ):
+ # wrap the durable context around a durable execution handler as a convenience to run directly
+ @durable_execution
+ def handler(event: Any, context: DurableContext): # noqa: ARG001
+ return context_function(*args, **kwargs)(context)
+
+ super().__init__(handler)
+
+
+class WebRunner:
+ """Web server runner for durable functions testing with HTTP API endpoints."""
+
+ def __init__(self, config: WebRunnerConfig) -> None:
+ """Initialize WebRunner with configuration.
+
+ Args:
+ config: WebRunnerConfig containing server and Lambda service settings
+ """
+ self._config = config
+ self._server: WebServer | None = None
+ self._scheduler: Scheduler | None = None
+ self._store: ExecutionStore | None = None
+ self._invoker: LambdaInvoker | None = None
+ self._executor: Executor | None = None
+
+ def __enter__(self) -> Self:
+ """Context manager entry point.
+
+ Returns:
+ WebRunner: Self for use in with statement
+ """
+ self.start()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
+ """Context manager exit point with cleanup.
+
+ Args:
+ exc_type: Exception type if an exception occurred
+ exc_val: Exception value if an exception occurred
+ exc_tb: Exception traceback if an exception occurred
+ """
+ self.stop()
+
+ def start(self) -> None:
+ """Start the server and initialize all dependencies.
+
+ Creates and configures all required components including scheduler,
+ store, invoker, executor, and web server. It does not however start
+ serving web requests, for that you need serve_forever.
+
+ Raises:
+ DurableFunctionsLocalRunnerError: If server is already started
+ """
+ if self._server is not None:
+ msg = "Server is already running"
+ raise DurableFunctionsLocalRunnerError(msg)
+
+ # Create dependencies and server
+ if self._config.store_type == StoreType.SQLITE:
+ store_path = self._config.store_path
+ self._store = SQLiteExecutionStore.create_and_initialize(store_path)
+ elif self._config.store_type == StoreType.FILESYSTEM:
+ store_path = self._config.store_path or ".durable_executions"
+ self._store = FileSystemExecutionStore.create(store_path)
+ else:
+ self._store = InMemoryExecutionStore()
+ self._scheduler = Scheduler()
+ self._invoker = LambdaInvoker(self._create_boto3_client())
+
+ # Create shared CheckpointProcessor
+ checkpoint_processor = CheckpointProcessor(self._store, self._scheduler)
+
+ # Create executor with all dependencies including checkpoint processor
+ self._executor = Executor(
+ store=self._store,
+ scheduler=self._scheduler,
+ invoker=self._invoker,
+ checkpoint_processor=checkpoint_processor,
+ )
+
+ # Add executor as observer to the checkpoint processor
+ checkpoint_processor.add_execution_observer(self._executor)
+
+ # Start the scheduler
+ self._scheduler.start()
+
+ # Create web server with configuration and executor
+ self._server = WebServer(
+ config=self._config.web_service, executor=self._executor
+ )
+
+ def serve_forever(self) -> None:
+ """Start serving HTTP requests indefinitely.
+
+ Delegates to the underlying WebServer.serve_forever() method.
+ This method blocks until the server is stopped.
+
+ Raises:
+ DurableFunctionsLocalRunnerError: If server has not been started
+ """
+ if self._server is None:
+ msg = "Server not started"
+ raise DurableFunctionsLocalRunnerError(msg)
+
+ # This blocks until KeyboardInterrupt - let caller handle the exception
+ self._server.serve_forever()
+
+ def stop(self) -> None:
+ """Stop the web server and cleanup resources.
+
+ Gracefully shuts down the server, scheduler, and cleans up
+ all allocated resources. Safe to call multiple times.
+ Handles cleanup exceptions gracefully to ensure all resources
+ are cleaned up even if some fail.
+ """
+ if self._server is not None:
+ try:
+ self._server.server_close()
+ except Exception:
+ # Log the exception but continue cleanup
+ logger.exception("error closing web server")
+
+ self._server = None
+
+ if self._scheduler is not None:
+ try:
+ self._scheduler.stop()
+ except Exception:
+ logger.exception("error stopping scheduler")
+ self._scheduler = None
+
+ self._store = None
+ self._invoker = None
+ self._executor = None
+
+ def _create_boto3_client(self) -> Any:
+ """Create boto3 client for Lambda service.
+
+ Creates a boto3 client with the local runner endpoint and region from configuration.
+
+ Returns:
+ Configured boto3 client for Lambda service
+
+ Raises:
+ Exception: If client creation fails - exceptions propagate naturally
+ for CLI to handle as general Exception
+ """
+ return create_lambda_client(
+ endpoint_url=self._config.lambda_endpoint,
+ region_name=self._config.local_runner_region,
+ )
+
+
+class DurableFunctionCloudTestRunner:
+ """Test runner that executes durable functions against actual AWS Lambda backend.
+
+ This runner invokes deployed Lambda functions and polls for execution completion,
+ providing the same interface as DurableFunctionTestRunner for seamless test
+ compatibility between local and cloud modes.
+
+ Example:
+ >>> runner = DurableFunctionCloudTestRunner(
+ ... function_name="HelloWorld-Python-PR-123", region="us-west-2"
+ ... )
+ >>> with runner:
+ ... result = runner.run(input={"name": "World"}, timeout=60)
+ >>> assert result.current_status == InvocationStatus.SUCCEEDED
+ """
+
+ def __init__(
+ self,
+ function_name: str,
+ region: str = "us-west-2",
+ lambda_endpoint: str | None = None,
+ poll_interval: float = 1.0,
+ ):
+ """Initialize cloud test runner."""
+ self.function_name = function_name
+ self.region = region
+ self.lambda_endpoint = lambda_endpoint
+ self.poll_interval = poll_interval
+
+ client_config = boto3.session.Config(parameter_validation=False)
+ self.lambda_client = boto3.client(
+ "lambda",
+ endpoint_url=lambda_endpoint,
+ region_name=region,
+ config=client_config,
+ )
+
+ def run(
+ self,
+ input: str | None = None, # noqa: A002
+ timeout: int = 60,
+ ) -> DurableFunctionTestResult:
+ """Execute function on AWS Lambda and wait for completion."""
+ logger.info(
+ "Invoking Lambda function: %s (timeout: %ds)", self.function_name, timeout
+ )
+
+ # JSON encode input
+ payload = json.dumps(input)
+
+ # Invoke Lambda function
+ try:
+ response = self.lambda_client.invoke(
+ FunctionName=self.function_name,
+ InvocationType="RequestResponse",
+ Payload=payload,
+ )
+ except Exception as e:
+ msg = f"Failed to invoke Lambda function {self.function_name}: {e}"
+ raise DurableFunctionsTestError(msg) from e
+
+ # Check HTTP status code, 200 for RequestResponse
+ status_code = response.get("StatusCode")
+ if status_code != 200:
+ error_payload = response["Payload"].read().decode("utf-8")
+ msg = f"Lambda invocation failed with status {status_code}: {error_payload}"
+ raise DurableFunctionsTestError(msg)
+
+ # Check for function errors, we want to return function error for testing purpose
+ if "FunctionError" in response:
+ error_payload = response["Payload"].read().decode("utf-8")
+ logger.warning("Lambda function failed: %s", error_payload)
+
+ result_payload = response["Payload"].read().decode("utf-8")
+ logger.info(
+ "Lambda invocation completed, response: %s",
+ result_payload,
+ )
+
+ # Extract durable execution ARN from response headers
+ # The InvocationResponse includes X-Amz-Durable-Execution-Arn header
+ execution_arn = response.get("DurableExecutionArn")
+ if not execution_arn:
+ msg = (
+ f"No DurableExecutionArn in response for function {self.function_name}"
+ )
+ raise DurableFunctionsTestError(msg)
+
+ return self.wait_for_result(execution_arn=execution_arn, timeout=timeout)
+
+ def run_async(
+ self,
+ input: str | None = None, # noqa: A002
+ timeout: int = 60,
+ ) -> str:
+ """Execute function on AWS Lambda asynchronously"""
+ logger.info(
+ "Invoking Lambda function: %s (timeout: %ds)", self.function_name, timeout
+ )
+ payload = json.dumps(input)
+ try:
+ response = self.lambda_client.invoke(
+ FunctionName=self.function_name,
+ InvocationType="Event",
+ Payload=payload,
+ )
+ except Exception as e:
+ msg = f"Failed to invoke Lambda function {self.function_name}: {e}"
+ raise DurableFunctionsTestError(msg) from e
+
+ # Check HTTP status code, 202 for Event
+ status_code = response.get("StatusCode")
+ if status_code != 202:
+ error_payload = response["Payload"].read().decode("utf-8")
+ msg = f"Lambda invocation failed with status {status_code}: {error_payload}"
+ raise DurableFunctionsTestError(msg)
+
+ return response.get("DurableExecutionArn")
+
+ def send_callback_success(
+ self, callback_id: str, result: bytes | None = None
+ ) -> None:
+ try:
+ self.lambda_client.send_durable_execution_callback_success(
+ CallbackId=callback_id, Result=result
+ )
+ except Exception as e:
+ msg = f"Failed to send callback success for {self.function_name}, callback_id {callback_id}: {e}"
+ raise DurableFunctionsTestError(msg) from e
+
+ def send_callback_failure(
+ self, callback_id: str, error: ErrorObject | None = None
+ ) -> None:
+ try:
+ self.lambda_client.send_durable_execution_callback_failure(
+ CallbackId=callback_id, Error=error.to_dict() if error else None
+ )
+ except Exception as e:
+ msg = f"Failed to send callback failure for {self.function_name}, callback_id {callback_id}: {e}"
+ raise DurableFunctionsTestError(msg) from e
+
+ def send_callback_heartbeat(self, callback_id: str) -> None:
+ try:
+ self.lambda_client.send_durable_execution_callback_heartbeat(
+ CallbackId=callback_id
+ )
+ except Exception as e:
+ msg = f"Failed to send callback heartbeat for {self.function_name}, callback_id {callback_id}: {e}"
+ raise DurableFunctionsTestError(msg) from e
+
+ def _wait_for_completion(
+ self, execution_arn: str, timeout: int
+ ) -> GetDurableExecutionResponse:
+ """Poll execution status until completion or timeout.
+
+ Args:
+ execution_arn: ARN of the durable execution
+ timeout: Maximum seconds to wait
+
+ Returns:
+ GetDurableExecutionResponse with typed execution details
+
+ Raises:
+ TimeoutError: If execution doesn't complete within timeout
+ DurableFunctionsTestError: If status check fails
+ """
+ start_time = time.time()
+ last_status = None
+
+ while time.time() - start_time < timeout:
+ try:
+ execution_dict = self.lambda_client.get_durable_execution(
+ DurableExecutionArn=execution_arn
+ )
+ execution = GetDurableExecutionResponse.from_dict(execution_dict)
+ except Exception as e:
+ msg = f"Failed to get execution status: {e}"
+ raise DurableFunctionsTestError(msg) from e
+
+ # Log status changes
+ if execution.status != last_status:
+ logger.info("Execution status: %s", execution.status)
+ last_status = execution.status
+
+ # Check if execution completed
+ if execution.status == "SUCCEEDED":
+ logger.info("Execution succeeded")
+ return execution
+ if execution.status == "FAILED":
+ logger.warning("Execution failed")
+ return execution
+ if execution.status in ["TIMED_OUT", "ABORTED"]:
+ logger.warning("Execution terminated: %s", execution.status)
+ return execution
+
+ # Wait before next poll
+ time.sleep(self.poll_interval)
+
+ # Timeout reached
+ elapsed = time.time() - start_time
+ msg = (
+ f"Execution did not complete within {timeout}s "
+ f"(elapsed: {elapsed:.1f}s, last status: {last_status})"
+ )
+ raise TimeoutError(msg)
+
+ def wait_for_result(
+ self, execution_arn: str, timeout: int = 60
+ ) -> DurableFunctionTestResult:
+ # Poll for completion
+ execution_response = self._wait_for_completion(execution_arn, timeout)
+
+ try:
+ history_response = self._fetch_execution_history(execution_arn)
+ except Exception as e:
+ msg = f"Failed to fetch execution history: {e}"
+ raise DurableFunctionsTestError(msg) from e
+
+ # Build test result from execution history
+ return DurableFunctionTestResult.from_execution_history(
+ execution_response, history_response
+ )
+
+ def wait_for_callback(
+ self, execution_arn: str, name: str | None = None, timeout: int = 60
+ ) -> str:
+ """
+ Wait for and retrieve a callback ID from a Step Functions execution.
+
+ Polls the execution history at regular intervals until a callback ID is found
+ or the timeout is reached.
+
+ Args:
+ execution_arn: Execution Arn
+ name: Specific callback name, default to None
+ timeout: Maximum time in seconds to wait for callback. Defaults to 60.
+
+ Returns:
+ str: The callback ID/token retrieved from the execution history
+
+ Raises:
+ TimeoutError: If callback is not found within the specified timeout period
+ DurableFunctionsTestError: If there's an error fetching execution history
+ (excluding retryable errors)
+ """
+ start_time = time.time()
+
+ while time.time() - start_time < timeout:
+ try:
+ history_response = self._fetch_execution_history(execution_arn)
+ callback_id = _get_callback_id_from_events(
+ events=history_response.events, name=name
+ )
+ if callback_id:
+ return callback_id
+ except ClientError as e:
+ error_code = e.response["Error"]["Code"]
+ # retryable error, the execution may not start yet in async invoke situation
+ if error_code in ["ResourceNotFoundException"]:
+ pass
+ else:
+ msg = f"Failed to fetch execution history: {e}"
+ raise DurableFunctionsTestError(msg) from e
+ except DurableFunctionsTestError as e:
+ raise e
+ except Exception as e:
+ msg = f"Failed to fetch execution history: {e}"
+ raise DurableFunctionsTestError(msg) from e
+
+ # Wait before next poll
+ time.sleep(self.poll_interval)
+
+ # Timeout reached
+ elapsed = time.time() - start_time
+ msg = f"Callback did not available within {timeout}s (elapsed: {elapsed:.1f}s."
+ raise TimeoutError(msg)
+
+ def _fetch_execution_history(
+ self, execution_arn: str
+ ) -> GetDurableExecutionHistoryResponse:
+ """Retrieve execution history from Lambda service.
+
+ Args:
+ execution_arn: ARN of the durable execution
+
+ Returns:
+ GetDurableExecutionHistoryResponse with typed Event objects
+
+ Raises:
+ ClientError: If lambda client encounter error
+ """
+ history_dict = self.lambda_client.get_durable_execution_history(
+ DurableExecutionArn=execution_arn,
+ IncludeExecutionData=True,
+ )
+ history_response = GetDurableExecutionHistoryResponse.from_dict(history_dict)
+
+ logger.info("Retrieved %d events from history", len(history_response.events))
+
+ return history_response
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/scheduler.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/scheduler.py
new file mode 100644
index 0000000..a45b942
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/scheduler.py
@@ -0,0 +1,246 @@
+"""A Scheduler that can run awaitables or standard sync callables on a schedule once or repeatedly."""
+
+from __future__ import annotations
+
+import asyncio
+import itertools
+import logging
+import threading
+from typing import TYPE_CHECKING, Any
+
+
+if TYPE_CHECKING:
+ from collections.abc import Callable
+ from concurrent.futures import Future
+
+logger = logging.getLogger(__name__)
+
+
+class Event:
+ """An event created by Scheduler that will block on wait until it's set."""
+
+ def __init__(self, scheduler: Scheduler, asyncio_event: asyncio.Event) -> None:
+ self._scheduler: Scheduler = scheduler
+ self._asyncio_event: asyncio.Event = asyncio_event
+ self._exception: Exception | None = None
+
+ def set(self):
+ """Set the event with this to unblock wait."""
+ self._scheduler.set_event(self._asyncio_event)
+
+ def set_exception(self, exception: Exception):
+ """Set exception and unblock waiters."""
+ self._exception = exception
+ self._scheduler.set_event(self._asyncio_event)
+
+ def wait(self, timeout: float | None = None, *, clear_on_set: bool = True) -> bool:
+ """Wait until the event is set.
+
+ Args:
+ timeout (int | float | None): Wait for event to set until this timeout.
+ clear_on_set (bool): Remove the event from the Scheduler on completion.
+ Use this if you won't re-use the event.
+
+ Returns:
+ True when set. False if the event timed out without being set.
+
+ Raises:
+ Exception: If an exception was stored via set_exception().
+ """
+ result = self._scheduler.wait_for_event(self._asyncio_event, timeout)
+ if clear_on_set:
+ self._scheduler.remove_event(self._asyncio_event)
+ if result and self._exception:
+ raise self._exception
+ return result
+
+ def remove(self):
+ """Remove the event from the Scheduler. Do this to avoid build-up of many events in the scheduler."""
+ self._scheduler.remove_event(self._asyncio_event)
+
+
+class Scheduler:
+ """A Scheduler to run callables later, repeatedly or raise events."""
+
+ def __init__(self) -> None:
+ self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
+ self._ready_event: threading.Event = threading.Event()
+ self._thread: threading.Thread = threading.Thread(
+ target=self._start_loop, daemon=True
+ )
+ self._running: bool = False
+ self._events: set[asyncio.Event] = set()
+
+ # region context manager
+ def __enter__(self):
+ self.start()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.stop()
+
+ # endregion context manager
+
+ # region event loop
+ def start(self):
+ """Start the scheduler. Not thread-safe."""
+ if self._running:
+ return
+
+ self._running = True
+
+ self._thread.start()
+ # Wait for inside of loop to notify it's ready (meaning _start_loop has completed)
+ self._ready_event.wait()
+
+ def stop(self):
+ """Stop the scheduler, releasing resources. Not thread-safe."""
+ if not self._running:
+ return
+
+ self._running = False
+ self._loop.call_soon_threadsafe(self._cleanup_and_stop)
+ self._thread.join()
+
+ def is_started(self) -> bool:
+ """Return True if the scheduler is started."""
+ return self._running
+
+ def event_count(self) -> int:
+ """Return the number of events in the scheduler."""
+ return len(self._events)
+
+ def task_count(self) -> int:
+ """Return the number of tasks in the scheduler."""
+ if not self._running:
+ return 0
+ return len(asyncio.all_tasks(self._loop))
+
+ def _cleanup_and_stop(self):
+ """Cancel all tasks and clear all events. Stop the event-loop."""
+ # Cancel all tasks
+ for task in asyncio.all_tasks(self._loop):
+ task.cancel()
+
+ # Clear events (don't set them)
+ self._events.clear()
+
+ self._loop.stop()
+
+ def _start_loop(self):
+ """Initialize the event-loop. The ready event notifies that the loop is started."""
+ asyncio.set_event_loop(self._loop)
+ # signal that loop is ready from within the loop
+ self._loop.call_soon(self._ready_event.set)
+ # block indefinitely - call_soon with the read_event will run soon as the loop starts
+ self._loop.run_forever()
+
+ # endregion event loop
+ # region Tasks
+ def call_later(
+ self,
+ func: Callable[[], Any],
+ delay: float = 0,
+ count: int | None = 1,
+ completion_event: Event | None = None,
+ ) -> Future[Any]:
+ """Call func after the delay.
+
+ If func is async it runs inside a thread-safe coroutine. If func is sync it runs in its own
+ threadpool, so it won't block the event loop.
+
+ Args:
+ func (Callable[[], Any]): The function to call later. This can be an async or a standard
+ sync function.
+ delay (float | int): Delay in seconds before calling func.
+ count (int | None): Number of times to call func. Default is 1 (call once).
+ Use None for infinite repeats.
+ completion_event (Event | None): Event to notify on exception.
+
+ Returns: Future that completes when the scheduled work is done.
+ """
+ # infinite counter if count = None, else it maxes out at count
+ loop_iter: itertools.count[int] | range = (
+ itertools.count() if count is None else range(count)
+ )
+
+ async def delayed_func() -> Any:
+ try:
+ for _ in loop_iter:
+ await asyncio.sleep(delay)
+
+ try:
+ if asyncio.iscoroutinefunction(func):
+ result = await func()
+ else:
+ result = await asyncio.to_thread(func)
+ return result # noqa: TRY300
+ except Exception as err:
+ if completion_event:
+ completion_event.set_exception(err)
+ else:
+ msg: str = "error in scheduled task"
+ logger.exception(msg)
+ raise
+ except asyncio.CancelledError: # noqa: TRY302
+ # might want to handle more things here
+ raise
+
+ future: Future[Any] = asyncio.run_coroutine_threadsafe(
+ delayed_func(), self._loop
+ )
+ return future
+
+ # endregion Tasks
+
+ # region Events
+
+ def create_event(self) -> Event:
+ """Create an event controlled by the Scheduler to signal between threads and coroutines."""
+ # create event inside the Scheduler event-loop
+ future: Future[asyncio.Event] = asyncio.run_coroutine_threadsafe(
+ self._create_event(), self._loop
+ )
+
+ # Add timeout to prevent surprising "hangs" if for whatever reason event fails to create.
+ # result with block. Do NOT call anything in _create_event that calls back into scheduler
+ # methods because it could create a circular depdendency which will deadlock.
+ event = future.result(timeout=5.0)
+ return Event(self, event)
+
+ def wait_for_event(
+ self, event: asyncio.Event, timeout: float | None = None
+ ) -> bool:
+ """Run event's wait inside the Scheduler event-loop."""
+ if event not in self._events:
+ return False
+
+ future: Future[bool] = asyncio.run_coroutine_threadsafe(
+ asyncio.wait_for(event.wait(), timeout), self._loop
+ )
+
+ try:
+ return future.result()
+ except TimeoutError:
+ return False
+
+ def set_event(self, event: asyncio.Event):
+ """Set event inside the Scheduler event-loop."""
+ if event in self._events:
+ self._loop.call_soon_threadsafe(event.set)
+
+ def remove_event(self, event: asyncio.Event):
+ """Remove event from Scheduler in the Scheduler event-loop."""
+
+ def _remove():
+ self._events.discard(event)
+
+ self._loop.call_soon_threadsafe(_remove)
+
+ async def _create_event(self) -> asyncio.Event:
+ """Create event and add it to the scheduler events list."""
+ event = asyncio.Event()
+ self._events.add(event)
+ return event
+
+ # endregion Events
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/__init__.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/__init__.py
new file mode 100644
index 0000000..8b13789
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/__init__.py
@@ -0,0 +1 @@
+
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/base.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/base.py
new file mode 100644
index 0000000..ca87e28
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/base.py
@@ -0,0 +1,147 @@
+"""Base classes and protocols for execution stores."""
+
+from __future__ import annotations
+
+from datetime import UTC
+from enum import Enum
+from typing import TYPE_CHECKING, Protocol
+
+
+if TYPE_CHECKING:
+ from aws_durable_execution_sdk_python.lambda_service import Operation
+
+ from aws_durable_execution_sdk_python_testing.execution import Execution
+
+
+class StoreType(Enum):
+ """Supported execution store types."""
+
+ MEMORY = "memory"
+ FILESYSTEM = "filesystem"
+ SQLITE = "sqlite"
+
+
+class ExecutionStore(Protocol):
+ """Protocol for execution storage implementations."""
+
+ # ignore cover because coverage doesn't understand elipses
+ def save(self, execution: Execution) -> None: ... # pragma: no cover
+ def load(self, execution_arn: str) -> Execution: ... # pragma: no cover
+ def update(self, execution: Execution) -> None: ... # pragma: no cover
+ def query(
+ self,
+ function_name: str | None = None,
+ execution_name: str | None = None,
+ status_filter: str | None = None,
+ started_after: str | None = None,
+ started_before: str | None = None,
+ limit: int | None = None,
+ offset: int = 0,
+ reverse_order: bool = False, # noqa: FBT001, FBT002
+ ) -> tuple[list[Execution], str | None]: ... # pragma: no cover
+ def list_all(
+ self,
+ ) -> list[Execution]: ... # pragma: no cover # Keep for backward compatibility
+
+
+class BaseExecutionStore(ExecutionStore):
+ """Base implementation for execution stores with shared query logic."""
+
+ @staticmethod
+ def process_query(
+ executions: list[Execution],
+ function_name: str | None = None,
+ execution_name: str | None = None,
+ status_filter: str | None = None,
+ started_after: str | None = None,
+ started_before: str | None = None,
+ limit: int | None = None,
+ offset: int = 0,
+ reverse_order: bool = False, # noqa: FBT001, FBT002
+ ) -> tuple[list[Execution], str | None]:
+ """Apply filtering, sorting, and pagination to executions."""
+ # Apply filters
+ filtered: list[Execution] = []
+ for execution in executions:
+ if function_name and execution.start_input.function_name != function_name:
+ continue
+ if (
+ execution_name
+ and execution.start_input.execution_name != execution_name
+ ):
+ continue
+
+ # Status filtering
+ if status_filter and execution.current_status().value != status_filter:
+ continue
+
+ # Time filtering
+ if started_after or started_before:
+ try:
+ operation: Operation = execution.get_operation_execution_started()
+ if operation.start_timestamp:
+ timestamp: float = (
+ operation.start_timestamp.timestamp()
+ if hasattr(operation.start_timestamp, "timestamp")
+ else operation.start_timestamp.replace(
+ tzinfo=UTC
+ ).timestamp()
+ )
+ if started_after and timestamp < float(started_after):
+ continue
+ if started_before and timestamp > float(started_before):
+ continue
+ except (ValueError, AttributeError):
+ continue
+
+ filtered.append(execution)
+
+ # Sort by start timestamp
+ def get_sort_key(exe: Execution):
+ try:
+ op: Operation = exe.get_operation_execution_started()
+ if op.start_timestamp:
+ return (
+ op.start_timestamp.timestamp()
+ if hasattr(op.start_timestamp, "timestamp")
+ else op.start_timestamp.replace(tzinfo=UTC).timestamp()
+ )
+ except Exception: # noqa: BLE001, S110
+ pass
+ return 0
+
+ filtered.sort(key=get_sort_key, reverse=reverse_order)
+
+ # Apply pagination
+ if limit is not None and limit > 0:
+ end_idx: int = offset + limit
+ paginated: list[Execution] = filtered[offset:end_idx]
+ has_more: bool = end_idx < len(filtered)
+ next_marker: str | None = str(end_idx) if has_more else None
+ return paginated, next_marker
+ return filtered[offset:], None
+
+ def query(
+ self,
+ function_name: str | None = None,
+ execution_name: str | None = None,
+ status_filter: str | None = None,
+ started_after: str | None = None,
+ started_before: str | None = None,
+ limit: int | None = None,
+ offset: int = 0,
+ reverse_order: bool = False, # noqa: FBT001, FBT002
+ ) -> tuple[list[Execution], str | None]:
+ """Apply filtering, sorting, and pagination to executions."""
+ executions: list[Execution] = self.list_all()
+ return self.process_query(
+ executions,
+ function_name=function_name,
+ execution_name=execution_name,
+ status_filter=status_filter,
+ started_after=started_after,
+ started_before=started_before,
+ limit=limit,
+ offset=offset,
+ reverse_order=reverse_order,
+ )
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/filesystem.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/filesystem.py
new file mode 100644
index 0000000..f0f3154
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/filesystem.py
@@ -0,0 +1,79 @@
+"""File system-based execution store implementation."""
+
+from __future__ import annotations
+
+import json
+import logging
+from pathlib import Path
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ ResourceNotFoundException,
+)
+from aws_durable_execution_sdk_python_testing.execution import Execution
+from aws_durable_execution_sdk_python_testing.stores.base import (
+ BaseExecutionStore,
+)
+
+
+class FileSystemExecutionStore(BaseExecutionStore):
+ """File system-based execution store for persistence."""
+
+ def __init__(self, storage_dir: Path) -> None:
+ self._storage_dir = storage_dir
+
+ @classmethod
+ def create(cls, storage_dir: str | Path | None = None) -> FileSystemExecutionStore:
+ """Create a FileSystemExecutionStore with directory creation.
+
+ Args:
+ storage_dir: Directory path for storage. Defaults to '.durable_executions'
+
+ Returns:
+ FileSystemExecutionStore instance with created directory
+ """
+ path = Path(storage_dir) if storage_dir else Path(".durable_executions")
+ path.mkdir(exist_ok=True)
+ return cls(storage_dir=path)
+
+ def _get_file_path(self, execution_arn: str) -> Path:
+ """Get file path for execution ARN."""
+ # Use ARN as filename with .json extension, replacing unsafe characters
+ safe_filename = execution_arn.replace(":", "_").replace("/", "_")
+ return self._storage_dir / f"{safe_filename}.json"
+
+ def save(self, execution: Execution) -> None:
+ """Save execution to file system."""
+ file_path = self._get_file_path(execution.durable_execution_arn)
+ data = execution.to_json_dict()
+
+ with open(file_path, "w", encoding="utf-8") as f:
+ json.dump(data, f, indent=2)
+
+ def load(self, execution_arn: str) -> Execution:
+ """Load execution from file system."""
+ file_path = self._get_file_path(execution_arn)
+ if not file_path.exists():
+ msg = f"Execution {execution_arn} not found"
+ raise ResourceNotFoundException(msg)
+
+ with open(file_path, encoding="utf-8") as f:
+ data = json.load(f)
+
+ return Execution.from_json_dict(data)
+
+ def update(self, execution: Execution) -> None:
+ """Update execution in file system (same as save)."""
+ self.save(execution)
+
+ def list_all(self) -> list[Execution]:
+ """List all executions from file system."""
+ executions = []
+ for file_path in self._storage_dir.glob("*.json"):
+ try:
+ with open(file_path, encoding="utf-8") as f:
+ data = json.load(f)
+ executions.append(Execution.from_json_dict(data))
+ except (json.JSONDecodeError, KeyError, OSError) as e:
+ logging.warning("Skipping corrupted file %s: %s", file_path, e)
+ continue
+ return executions
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/memory.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/memory.py
new file mode 100644
index 0000000..5e6e083
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/memory.py
@@ -0,0 +1,38 @@
+"""In-memory execution store implementation."""
+
+from __future__ import annotations
+
+from threading import Lock
+from typing import TYPE_CHECKING
+
+from aws_durable_execution_sdk_python_testing.stores.base import (
+ BaseExecutionStore,
+)
+
+
+if TYPE_CHECKING:
+ from aws_durable_execution_sdk_python_testing.execution import Execution
+
+
+class InMemoryExecutionStore(BaseExecutionStore):
+ """Dict-based storage for testing."""
+
+ def __init__(self) -> None:
+ self._store: dict[str, Execution] = {}
+ self._lock: Lock = Lock()
+
+ def save(self, execution: Execution) -> None:
+ with self._lock:
+ self._store[execution.durable_execution_arn] = execution
+
+ def load(self, execution_arn: str) -> Execution:
+ with self._lock:
+ return self._store[execution_arn]
+
+ def update(self, execution: Execution) -> None:
+ with self._lock:
+ self._store[execution.durable_execution_arn] = execution
+
+ def list_all(self) -> list[Execution]:
+ with self._lock:
+ return list(self._store.values())
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/sqlite.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/sqlite.py
new file mode 100644
index 0000000..fac1ca4
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/sqlite.py
@@ -0,0 +1,273 @@
+"""SQLite-based execution store implementation."""
+
+from __future__ import annotations
+
+import json
+import sqlite3
+from datetime import datetime
+from pathlib import Path
+from typing import Any, cast
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ ResourceNotFoundException,
+ InvalidParameterValueException,
+ RuntimeException,
+)
+from aws_durable_execution_sdk_python_testing.execution import Execution
+from aws_durable_execution_sdk_python_testing.stores.base import (
+ ExecutionStore,
+)
+
+
+class SQLiteExecutionStore(ExecutionStore):
+ """SQLite-based execution store for efficient querying."""
+
+ def __init__(self, db_path: Path) -> None:
+ self.db_path: Path = db_path
+
+ @classmethod
+ def create_and_initialize(
+ cls, db_path: Path | str | None = None
+ ) -> SQLiteExecutionStore:
+ """Create SQLite store with default path."""
+ path: Path = Path(db_path) if db_path else Path("durable-executions.db")
+ path.parent.mkdir(exist_ok=True)
+ store: SQLiteExecutionStore = cls(path)
+ store._init_db()
+ return store
+
+ def _get_connection(self) -> sqlite3.Connection:
+ """Get SQLite connection with optimizations."""
+ conn: sqlite3.Connection = sqlite3.connect(self.db_path, timeout=30.0)
+ conn.execute("PRAGMA journal_mode=WAL;")
+ conn.execute("PRAGMA synchronous=NORMAL;")
+ return conn
+
+ def _init_db(self) -> None:
+ """Initialize database schema."""
+ try:
+ with self._get_connection() as conn:
+ conn.execute("""
+ CREATE TABLE IF NOT EXISTS executions (
+ durable_execution_arn TEXT PRIMARY KEY,
+ function_name TEXT NOT NULL,
+ execution_name TEXT,
+ status TEXT NOT NULL,
+ start_timestamp REAL,
+ end_timestamp REAL,
+ data TEXT NOT NULL
+ )
+ """)
+ # Create indexes for better query performance
+ conn.execute(
+ "CREATE INDEX IF NOT EXISTS idx_function_name ON executions(function_name)"
+ )
+ conn.execute(
+ "CREATE INDEX IF NOT EXISTS idx_status ON executions(status)"
+ )
+ conn.execute(
+ "CREATE INDEX IF NOT EXISTS idx_start_timestamp ON executions(start_timestamp)"
+ )
+ conn.execute(
+ "CREATE INDEX IF NOT EXISTS idx_composite ON executions(function_name, status, start_timestamp)"
+ )
+ except sqlite3.Error as e:
+ raise RuntimeError(f"Failed to initialize database: {e}") from e
+
+ def save(self, execution: Execution) -> None:
+ """Save execution to SQLite."""
+ try:
+ execution_op = execution.get_operation_execution_started()
+ status: str = execution.current_status().value
+
+ with self._get_connection() as conn:
+ conn.execute(
+ """
+ INSERT OR REPLACE INTO executions
+ (durable_execution_arn, function_name, execution_name, status, start_timestamp, end_timestamp, data)
+ VALUES (?, ?, ?, ?, ?, ?, ?)
+ """,
+ (
+ execution.durable_execution_arn,
+ execution.start_input.function_name,
+ execution.start_input.execution_name,
+ status,
+ execution_op.start_timestamp.timestamp()
+ if execution_op.start_timestamp
+ else None,
+ execution_op.end_timestamp.timestamp()
+ if execution_op.end_timestamp
+ else None,
+ json.dumps(execution.to_json_dict()),
+ ),
+ )
+ except sqlite3.Error as e:
+ raise RuntimeError(
+ f"Failed to save execution {execution.durable_execution_arn}: {e}"
+ ) from e
+ except (AttributeError, TypeError) as e:
+ raise ValueError(f"Invalid execution data: {e}") from e
+
+ def load(self, execution_arn: str) -> Execution:
+ """Load execution from SQLite."""
+ try:
+ with self._get_connection() as conn:
+ cursor: sqlite3.Cursor = conn.execute(
+ "SELECT data FROM executions WHERE durable_execution_arn = ?",
+ (execution_arn,),
+ )
+ row: tuple[str] | None = cursor.fetchone()
+
+ if not row:
+ raise ResourceNotFoundException(f"Execution {execution_arn} not found")
+
+ return Execution.from_json_dict(json.loads(row[0]))
+ except sqlite3.Error as e:
+ raise RuntimeError(f"Failed to load execution {execution_arn}: {e}") from e
+ except json.JSONDecodeError as e:
+ raise ValueError(
+ f"Corrupted execution data for {execution_arn}: {e}"
+ ) from e
+
+ def update(self, execution: Execution) -> None:
+ """Update execution (same as save)."""
+ self.save(execution)
+
+ def query(
+ self,
+ function_name: str | None = None,
+ execution_name: str | None = None,
+ status_filter: str | None = None,
+ started_after: str | None = None,
+ started_before: str | None = None,
+ limit: int | None = None,
+ offset: int = 0,
+ reverse_order: bool = False,
+ ) -> tuple[list[Execution], str | None]:
+ """Query executions with efficient SQL filtering."""
+ try:
+ # Build query safely with parameterized conditions
+ conditions: list[str] = []
+ params: list[str | float | int] = []
+
+ if function_name:
+ conditions.append("function_name = ?")
+ params.append(function_name)
+
+ if execution_name:
+ conditions.append("execution_name = ?")
+ params.append(execution_name)
+
+ if status_filter:
+ conditions.append("status = ?")
+ params.append(status_filter)
+
+ if started_after:
+ started_after_float: float = datetime.fromisoformat(
+ started_after
+ ).timestamp()
+ conditions.append("start_timestamp >= ?")
+ params.append(started_after_float)
+
+ if started_before:
+ started_before_float: float = datetime.fromisoformat(
+ started_before
+ ).timestamp()
+ conditions.append("start_timestamp <= ?")
+ params.append(started_before_float)
+
+ # Build WHERE clause safely
+ where_clause: str = ""
+ if conditions:
+ where_clause = "WHERE " + " AND ".join(conditions)
+
+ # Build ORDER BY clause
+ order_direction: str = "DESC" if reverse_order else "ASC"
+ order_clause: str = f"ORDER BY start_timestamp {order_direction}"
+
+ # For better performance, only get metadata for counting and pagination
+ base_query: str = f"FROM executions {where_clause}"
+ count_query: str = f"SELECT COUNT(*) {base_query}"
+
+ limit_exists: bool = limit is not None and limit > 0
+
+ # Only fetch data we need
+ if limit_exists:
+ data_query: str = f"SELECT durable_execution_arn, data {base_query} {order_clause} LIMIT ? OFFSET ?"
+ params_with_limit: list[str | float | int] = params + [
+ cast(int, limit),
+ offset,
+ ]
+ else:
+ data_query = (
+ f"SELECT durable_execution_arn, data {base_query} {order_clause}"
+ )
+ params_with_limit = params
+
+ with self._get_connection() as conn:
+ # Get total count for pagination
+ total_count: int = int(conn.execute(count_query, params).fetchone()[0])
+
+ # Get actual data
+ cursor: sqlite3.Cursor = conn.execute(data_query, params_with_limit)
+ rows: list[tuple[str, str]] = cursor.fetchall()
+
+ # Only deserialize the executions we actually need
+ executions: list[Execution] = []
+ for durable_execution_arn, data in rows:
+ try:
+ executions.append(Execution.from_json_dict(json.loads(data)))
+ except (json.JSONDecodeError, ValueError) as e:
+ # Log corrupted data but continue with other records
+ print(
+ f"Warning: Skipping corrupted execution {durable_execution_arn}: {e}"
+ )
+ continue
+
+ # Calculate pagination
+ has_more: bool = limit_exists and (offset + len(executions) < total_count)
+ next_marker: str | None = (
+ str(offset + len(executions)) if has_more else None
+ )
+
+ return executions, next_marker
+
+ except sqlite3.Error as e:
+ raise RuntimeException(f"Query failed: {e}") from e
+ except ValueError as e:
+ raise InvalidParameterValueException(
+ f"Invalid query parameters: {e}"
+ ) from e
+
+ def list_all(self) -> list[Execution]:
+ """List all executions (for backward compatibility)."""
+ executions, _ = self.query()
+ return executions
+
+ def get_execution_metadata(self, execution_arn: str) -> dict[str, Any] | None:
+ """Get just the metadata without full deserialization for performance."""
+ try:
+ with self._get_connection() as conn:
+ cursor: sqlite3.Cursor = conn.execute(
+ "SELECT function_name, execution_name, status, start_timestamp, end_timestamp FROM executions WHERE durable_execution_arn = ?",
+ (execution_arn,),
+ )
+ row: tuple[str, str | None, str, float | None, float | None] | None = (
+ cursor.fetchone()
+ )
+
+ if not row:
+ return None
+
+ return {
+ "durable_execution_arn": execution_arn,
+ "function_name": row[0],
+ "execution_name": row[1],
+ "status": row[2],
+ "start_timestamp": row[3],
+ "end_timestamp": row[4],
+ }
+ except sqlite3.Error as e:
+ raise RuntimeError(
+ f"Failed to get metadata for {execution_arn}: {e}"
+ ) from e
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/token.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/token.py
new file mode 100644
index 0000000..23d81be
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/token.py
@@ -0,0 +1,49 @@
+"""Token models."""
+
+from __future__ import annotations
+
+import base64
+import json
+from dataclasses import dataclass
+
+
+@dataclass(frozen=True)
+class CheckpointToken:
+ """Model a checkpoint token. This isn't exactly the same format as the actual svc, but it will do for testing purposes."""
+
+ execution_arn: str
+ token_sequence: int
+
+ def to_str(self) -> str:
+ data = {"arn": self.execution_arn, "seq": self.token_sequence}
+ json_str = json.dumps(data, separators=(",", ":"))
+ # str -> bytes -> base64 bytes -> str
+ return base64.b64encode(json_str.encode()).decode()
+
+ @classmethod
+ def from_str(cls, token: str) -> CheckpointToken:
+ # str -> base64 bytes -> str
+ decoded = base64.b64decode(token).decode()
+ data = json.loads(decoded)
+ return cls(execution_arn=data["arn"], token_sequence=data["seq"])
+
+
+@dataclass(frozen=True)
+class CallbackToken:
+ """Model a callback token."""
+
+ execution_arn: str
+ operation_id: str
+
+ def to_str(self) -> str:
+ data = {"arn": self.execution_arn, "op": self.operation_id}
+ json_str = json.dumps(data, separators=(",", ":"))
+ # str -> bytes -> base64 bytes -> str
+ return base64.b64encode(json_str.encode()).decode()
+
+ @classmethod
+ def from_str(cls, token: str) -> CallbackToken:
+ # str -> base64 bytes -> str
+ decoded = base64.b64decode(token).decode()
+ data = json.loads(decoded)
+ return cls(execution_arn=data["arn"], operation_id=data["op"])
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/__init__.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/__init__.py
new file mode 100644
index 0000000..5f5f19e
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/__init__.py
@@ -0,0 +1 @@
+"""Web server module for the AWS Durable Functions SDK Python Testing Framework."""
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/errors.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/errors.py
new file mode 100644
index 0000000..75bc81c
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/errors.py
@@ -0,0 +1,8 @@
+"""Error handling utilities for AWS Lambda Durable Functions web service.
+
+This module is deprecated and will be removed. All error handling now uses
+AWS-compliant exception classes directly.
+"""
+
+# This file is kept temporarily for backward compatibility during migration.
+# All functionality has been moved to direct AWS exception usage.
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/handlers.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/handlers.py
new file mode 100644
index 0000000..e8cb841
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/handlers.py
@@ -0,0 +1,813 @@
+"""HTTP endpoint handlers for AWS Lambda Durable Functions operations."""
+
+from __future__ import annotations
+
+import base64
+import json
+import logging
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any, cast
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ AwsApiException,
+ ExecutionAlreadyStartedException,
+ ExecutionConflictException,
+ IllegalStateException,
+ InvalidParameterValueException,
+ ServiceException,
+)
+from aws_durable_execution_sdk_python_testing.model import (
+ CheckpointDurableExecutionRequest,
+ CheckpointDurableExecutionResponse,
+ GetDurableExecutionHistoryResponse,
+ GetDurableExecutionStateResponse,
+ ListDurableExecutionsByFunctionRequest,
+ ListDurableExecutionsRequest,
+ ListDurableExecutionsResponse,
+ SendDurableExecutionCallbackFailureRequest,
+ SendDurableExecutionCallbackFailureResponse,
+ SendDurableExecutionCallbackHeartbeatRequest,
+ SendDurableExecutionCallbackHeartbeatResponse,
+ SendDurableExecutionCallbackSuccessResponse,
+ StartDurableExecutionInput,
+ StartDurableExecutionOutput,
+ StopDurableExecutionRequest,
+ StopDurableExecutionResponse,
+)
+from aws_durable_execution_sdk_python_testing.web.models import (
+ HTTPRequest,
+ HTTPResponse,
+)
+from aws_durable_execution_sdk_python_testing.web.routes import (
+ CallbackFailureRoute,
+ CallbackHeartbeatRoute,
+ CallbackSuccessRoute,
+ CheckpointDurableExecutionRoute,
+ GetDurableExecutionHistoryRoute,
+ GetDurableExecutionRoute,
+ GetDurableExecutionStateRoute,
+ ListDurableExecutionsByFunctionRoute,
+ StopDurableExecutionRoute,
+)
+
+
+if TYPE_CHECKING:
+ from aws_durable_execution_sdk_python_testing.executor import Executor
+ from aws_durable_execution_sdk_python_testing.web.routes import Route
+
+logger = logging.getLogger(__name__)
+
+
+class EndpointHandler(ABC):
+ """Abstract base class for HTTP endpoint handlers."""
+
+ def __init__(self, executor: Executor) -> None:
+ """Initialize the handler with an executor.
+
+ Args:
+ executor: The executor instance for handling operations
+ """
+ self.executor = executor
+
+ @abstractmethod
+ def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse:
+ """Handle an HTTP request and return an HTTP response.
+
+ Args:
+ parsed_route: The strongly-typed route object
+ request: The HTTP request data
+
+ Returns:
+ HTTPResponse: The HTTP response to send to the client
+ """
+
+ def _parse_json_body(self, request: HTTPRequest) -> dict[str, Any]:
+ """Parse JSON body from HTTP request with validation.
+
+ Args:
+ request: The HTTP request containing the JSON body
+
+ Returns:
+ dict: The parsed JSON data
+
+ Raises:
+ InvalidParameterValueException: If the request body is empty
+ """
+ if not request.body:
+ msg = "Request body is required"
+ raise InvalidParameterValueException(msg)
+ return self._parse_json_body_optional(request)
+
+ def _parse_json_body_optional(self, request: HTTPRequest) -> dict[str, Any]:
+ """Parse JSON body from HTTP request with validation.
+
+ Args:
+ request: The HTTP request containing the JSON body
+
+ Returns:
+ dict: The parsed JSON data
+
+ Raises:
+ InvalidParameterValueException: If the request body is invalid JSON
+ """
+ if not request.body:
+ return {}
+
+ # Handle both dict and bytes body types
+ if isinstance(request.body, dict):
+ return request.body
+
+ try:
+ return json.loads(request.body.decode("utf-8"))
+ except (json.JSONDecodeError, UnicodeDecodeError) as e:
+ msg = f"Invalid JSON in request body: {e}"
+ raise InvalidParameterValueException(msg) from e
+
+ def _json_response(
+ self,
+ status_code: int,
+ data: dict[str, Any],
+ additional_headers: dict[str, str] | None = None,
+ ) -> HTTPResponse:
+ """Create a JSON HTTP response.
+
+ Args:
+ status_code: HTTP status code
+ data: Data to serialize as JSON
+ additional_headers: Optional additional headers to include
+
+ Returns:
+ HTTPResponse: The HTTP response with JSON body
+ """
+ return HTTPResponse.create_json(status_code, data, additional_headers)
+
+ def _success_response(
+ self, data: dict[str, Any], additional_headers: dict[str, str] | None = None
+ ) -> HTTPResponse:
+ """Create a successful JSON response (200 OK).
+
+ Args:
+ data: Data to serialize as JSON
+ additional_headers: Optional additional headers to include
+
+ Returns:
+ HTTPResponse: The HTTP response with JSON body
+ """
+ return self._json_response(200, data, additional_headers)
+
+ def _created_response(
+ self, data: dict[str, Any], additional_headers: dict[str, str] | None = None
+ ) -> HTTPResponse:
+ """Create a created JSON response (201 Created).
+
+ Args:
+ data: Data to serialize as JSON
+ additional_headers: Optional additional headers to include
+
+ Returns:
+ HTTPResponse: The HTTP response with JSON body
+ """
+ return self._json_response(201, data, additional_headers)
+
+ def _no_content_response(
+ self, additional_headers: dict[str, str] | None = None
+ ) -> HTTPResponse:
+ """Create a no content response (204 No Content).
+
+ Args:
+ additional_headers: Optional additional headers to include
+
+ Returns:
+ HTTPResponse: The HTTP response with empty body
+ """
+ return HTTPResponse.create_empty(204, additional_headers)
+
+ # Removed deprecated _error_response method - use AWS exceptions directly
+
+ def _parse_callback_result_payload(self, request: HTTPRequest) -> bytes:
+ """Parse callback result payload from request body.
+
+ Args:
+ request: The HTTP request containing the binary payload
+
+ Returns:
+ bytes: The result payload
+
+ Raises:
+ InvalidParameterValueException: If payload parsing fails
+ """
+ if not request.body or not isinstance(request.body, bytes):
+ return b""
+
+ return request.body
+
+ def _parse_query_param(self, request: HTTPRequest, param_name: str) -> str | None:
+ """Parse a single query parameter from the request.
+
+ Args:
+ request: The HTTP request
+ param_name: Name of the query parameter
+
+ Returns:
+ str | None: The parameter value or None if not present
+ """
+ param_values = request.query_params.get(param_name)
+ return param_values[0] if param_values else None
+
+ def _parse_query_param_list(
+ self, request: HTTPRequest, param_name: str
+ ) -> list[str]:
+ """Parse a query parameter that can have multiple values.
+
+ Args:
+ request: The HTTP request
+ param_name: Name of the query parameter
+
+ Returns:
+ list[str]: List of parameter values (empty if not present)
+ """
+ return request.query_params.get(param_name, [])
+
+ def _validate_required_fields(
+ self, data: dict[str, Any], required_fields: list[str]
+ ) -> None:
+ """Validate that required fields are present in the data.
+
+ Args:
+ data: The data dictionary to validate
+ required_fields: List of required field names
+
+ Raises:
+ ValueError: If any required field is missing
+ """
+ missing_fields = [field for field in required_fields if field not in data]
+ if missing_fields:
+ msg = f"Missing required fields: {', '.join(missing_fields)}"
+ raise InvalidParameterValueException(msg)
+
+ def _handle_aws_exception(self, exception: AwsApiException) -> HTTPResponse:
+ """Handle AWS API exceptions directly.
+
+ Args:
+ exception: The AWS API exception
+
+ Returns:
+ HTTPResponse: AWS-compliant error response
+ """
+ # Log server errors
+ if exception.http_status_code >= 500: # noqa: PLR2004
+ logger.exception("Server error: %s", exception)
+ return HTTPResponse.create_error_from_exception(exception)
+
+ def _handle_framework_exception(self, exception: Exception) -> HTTPResponse:
+ """Handle framework exceptions by mapping to AWS exceptions.
+
+ Args:
+ exception: The framework exception
+
+ Returns:
+ HTTPResponse: AWS-compliant error response
+ """
+ if isinstance(exception, (ValueError | KeyError)):
+ return HTTPResponse.create_error_from_exception(
+ InvalidParameterValueException(str(exception))
+ )
+ logger.exception("Unexpected error: %s", exception)
+ return HTTPResponse.create_error_from_exception(
+ ServiceException(str(exception))
+ )
+
+
+class StartExecutionHandler(EndpointHandler):
+ """Handler for POST /start-durable-execution."""
+
+ def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: # noqa: ARG002
+ """Handle start execution request.
+
+ Args:
+ parsed_route: The strongly-typed route object
+ request: The HTTP request data
+
+ Returns:
+ HTTPResponse: The HTTP response to send to the client
+ """
+ logger.debug("🌟 HANDLER: Received POST /start-durable-execution request")
+ try:
+ body_data: dict[str, Any] = self._parse_json_body(request)
+ logger.debug("🌟 HANDLER: Parsed request body successfully")
+
+ start_input: StartDurableExecutionInput = (
+ StartDurableExecutionInput.from_dict(body_data)
+ )
+ logger.debug(
+ "🌟 HANDLER: Created StartDurableExecutionInput, calling executor.start_execution()"
+ )
+
+ start_output: StartDurableExecutionOutput = self.executor.start_execution(
+ start_input
+ )
+ logger.debug("🌟 HANDLER: executor.start_execution() returned successfully")
+
+ response_data: dict[str, Any] = start_output.to_dict()
+
+ # Return HTTP 201 Created response
+ return self._created_response(response_data)
+
+ except IllegalStateException as e:
+ # For StartExecution operations, map to ExecutionAlreadyStartedException
+ aws_exception = ExecutionAlreadyStartedException(
+ str(e),
+ "arn:aws:lambda:us-east-1:123456789012:function:test",
+ )
+ return HTTPResponse.create_error_from_exception(aws_exception)
+ except AwsApiException as e:
+ return self._handle_aws_exception(e)
+ except Exception as e: # noqa: BLE001
+ return self._handle_framework_exception(e)
+
+
+class GetDurableExecutionHandler(EndpointHandler):
+ """Handler for GET /2025-12-01/durable-executions/{arn}."""
+
+ def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: # noqa: ARG002
+ """Handle get durable execution request.
+
+ Args:
+ parsed_route: The strongly-typed route object
+ request: The HTTP request data
+
+ Returns:
+ HTTPResponse: The HTTP response to send to the client
+ """
+ try:
+ route = cast(GetDurableExecutionRoute, parsed_route)
+
+ execution_response = self.executor.get_execution_details(route.arn)
+
+ response_data: dict[str, Any] = execution_response.to_dict()
+
+ # HTTP 200 OK response
+ return self._success_response(response_data)
+
+ except AwsApiException as e:
+ return self._handle_aws_exception(e)
+ except Exception as e: # noqa: BLE001
+ return self._handle_framework_exception(e)
+
+
+class CheckpointDurableExecutionHandler(EndpointHandler):
+ """Handler for POST /2025-12-01/durable-executions/{arn}/checkpoint."""
+
+ def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse:
+ """Handle checkpoint durable execution request.
+
+ Args:
+ parsed_route: The strongly-typed route object
+ request: The HTTP request data
+
+ Returns:
+ HTTPResponse: The HTTP response to send to the client
+ """
+ try:
+ body_data: dict[str, Any] = self._parse_json_body(request)
+
+ checkpoint_route = cast(CheckpointDurableExecutionRoute, parsed_route)
+ execution_arn: str = checkpoint_route.arn
+
+ checkpoint_request: CheckpointDurableExecutionRequest = (
+ CheckpointDurableExecutionRequest.from_dict(body_data, execution_arn)
+ )
+
+ checkpoint_response: CheckpointDurableExecutionResponse = (
+ self.executor.checkpoint_execution(
+ execution_arn,
+ checkpoint_request.checkpoint_token,
+ checkpoint_request.updates,
+ checkpoint_request.client_token,
+ )
+ )
+
+ response_data: dict[str, Any] = checkpoint_response.to_dict()
+
+ return self._success_response(response_data)
+
+ except AwsApiException as e:
+ return self._handle_aws_exception(e)
+ except Exception as e: # noqa: BLE001
+ return self._handle_framework_exception(e)
+
+
+class StopDurableExecutionHandler(EndpointHandler):
+ """Handler for POST /2025-12-01/durable-executions/{arn}/stop."""
+
+ def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse:
+ """Handle stop durable execution request.
+
+ Args:
+ parsed_route: The strongly-typed route object
+ request: The HTTP request data
+
+ Returns:
+ HTTPResponse: The HTTP response to send to the client
+ """
+ try:
+ body_data: dict[str, Any] = self._parse_json_body_optional(request)
+
+ stop_route = cast(StopDurableExecutionRoute, parsed_route)
+ execution_arn: str = stop_route.arn
+
+ body_data["DurableExecutionArn"] = execution_arn
+ stop_request: StopDurableExecutionRequest = (
+ StopDurableExecutionRequest.from_dict(body_data)
+ )
+
+ stop_response: StopDurableExecutionResponse = self.executor.stop_execution(
+ execution_arn, stop_request.error
+ )
+
+ response_data: dict[str, Any] = stop_response.to_dict()
+
+ return self._success_response(response_data)
+
+ except AwsApiException as e:
+ return self._handle_aws_exception(e)
+ except Exception as e: # noqa: BLE001
+ return self._handle_framework_exception(e)
+
+
+class GetDurableExecutionStateHandler(EndpointHandler):
+ """Handler for GET /2025-12-01/durable-executions/{arn}/state."""
+
+ def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: # noqa: ARG002
+ """Handle get durable execution state request.
+
+ Args:
+ parsed_route: The strongly-typed route object
+ request: The HTTP request data
+
+ Returns:
+ HTTPResponse: The HTTP response to send to the client
+ """
+ try:
+ state_route = cast(GetDurableExecutionStateRoute, parsed_route)
+ execution_arn: str = state_route.arn
+
+ state_response: GetDurableExecutionStateResponse = (
+ self.executor.get_execution_state(execution_arn)
+ )
+
+ response_data: dict[str, Any] = state_response.to_dict()
+
+ return self._success_response(response_data)
+
+ except AwsApiException as e:
+ return self._handle_aws_exception(e)
+ except Exception as e: # noqa: BLE001
+ return self._handle_framework_exception(e)
+
+
+class GetDurableExecutionHistoryHandler(EndpointHandler):
+ """Handler for GET /2025-12-01/durable-executions/{arn}/history."""
+
+ def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse:
+ """Handle get durable execution history request.
+
+ Args:
+ parsed_route: The strongly-typed route object
+ request: The HTTP request data
+
+ Returns:
+ HTTPResponse: The HTTP response to send to the client
+ """
+ try:
+ history_route = cast(GetDurableExecutionHistoryRoute, parsed_route)
+ execution_arn: str = history_route.arn
+
+ max_items: str | None = self._parse_query_param(request, "MaxItems")
+ marker: str | None = self._parse_query_param(request, "Marker")
+ include_execution_data_str: str | None = self._parse_query_param(
+ request, "IncludeExecutionData"
+ )
+ include_execution_data: bool = (
+ include_execution_data_str == "true"
+ if include_execution_data_str
+ else False
+ )
+
+ history_response: GetDurableExecutionHistoryResponse = (
+ self.executor.get_execution_history(
+ execution_arn,
+ include_execution_data=include_execution_data,
+ reverse_order=False,
+ marker=marker,
+ max_items=int(max_items) if max_items else None,
+ )
+ )
+
+ response_data: dict[str, Any] = history_response.to_dict()
+
+ return self._success_response(response_data)
+
+ except AwsApiException as e:
+ return self._handle_aws_exception(e)
+ except Exception as e: # noqa: BLE001
+ return self._handle_framework_exception(e)
+
+
+class ListDurableExecutionsHandler(EndpointHandler):
+ """Handler for GET /2025-12-01/durable-executions."""
+
+ def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: # noqa: ARG002
+ """Handle list durable executions request.
+
+ Args:
+ parsed_route: The strongly-typed route object
+ request: The HTTP request data
+
+ Returns:
+ HTTPResponse: The HTTP response to send to the client
+ """
+ try:
+ list_request: ListDurableExecutionsRequest = (
+ ListDurableExecutionsRequest.from_dict(request.query_params)
+ )
+
+ # Call executor method with correct attribute mapping
+ list_response: ListDurableExecutionsResponse = self.executor.list_executions(
+ function_name=list_request.function_name,
+ function_version=list_request.function_version,
+ execution_name=list_request.durable_execution_name, # Map to executor parameter
+ status_filter=list_request.status_filter[0]
+ if list_request.status_filter
+ else None, # Executor expects single string
+ started_after=list_request.started_after,
+ started_before=list_request.started_before,
+ marker=list_request.marker,
+ max_items=list_request.max_items
+ if list_request.max_items > 0
+ else None,
+ reverse_order=list_request.reverse_order or False,
+ )
+
+ # Serialize response
+ response_data: dict[str, Any] = list_response.to_dict()
+
+ return self._success_response(response_data)
+
+ except AwsApiException as e:
+ return self._handle_aws_exception(e)
+ except Exception as e: # noqa: BLE001
+ return self._handle_framework_exception(e)
+
+
+class ListDurableExecutionsByFunctionHandler(EndpointHandler):
+ """Handler for GET /2025-12-01/functions/{function_name}/durable-executions."""
+
+ @staticmethod
+ def _validate_function_name(function_name: str) -> None:
+ """Validate function name parameter."""
+ if not function_name or not function_name.strip():
+ msg = "Function name is required"
+ raise InvalidParameterValueException(msg)
+
+ def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse:
+ """Handle list durable executions by function request.
+
+ Args:
+ parsed_route: The strongly-typed route object
+ request: The HTTP request data
+
+ Returns:
+ HTTPResponse: The HTTP response to send to the client
+ """
+ function_route = cast(ListDurableExecutionsByFunctionRoute, parsed_route)
+ function_name: str = function_route.function_name
+
+ # Validate function name before processing
+ self._validate_function_name(function_name)
+
+ try:
+ # Add function name from route to query params
+ query_params = dict(request.query_params)
+ query_params["FunctionName"] = [function_name]
+ list_request = ListDurableExecutionsByFunctionRequest.from_dict(
+ query_params
+ )
+
+ list_response = self.executor.list_executions_by_function(
+ function_name=list_request.function_name,
+ qualifier=list_request.qualifier,
+ execution_name=list_request.durable_execution_name,
+ status_filter=list_request.status_filter[0]
+ if list_request.status_filter
+ else None,
+ started_after=list_request.started_after,
+ started_before=list_request.started_before,
+ marker=list_request.marker,
+ max_items=list_request.max_items
+ if list_request.max_items > 0
+ else None,
+ reverse_order=list_request.reverse_order or False,
+ )
+
+ return self._success_response(list_response.to_dict())
+
+ except AwsApiException as e:
+ return self._handle_aws_exception(e)
+ except Exception as e: # noqa: BLE001
+ return self._handle_framework_exception(e)
+
+
+class SendDurableExecutionCallbackSuccessHandler(EndpointHandler):
+ """Handler for POST /2025-12-01/durable-execution-callbacks/{callback_id}/succeed."""
+
+ def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse:
+ """Handle send durable execution callback success request.
+
+ Args:
+ parsed_route: The strongly-typed route object
+ request: The HTTP request data
+
+ Returns:
+ HTTPResponse: The HTTP response to send to the client
+ """
+ try:
+ callback_route = cast(CallbackSuccessRoute, parsed_route)
+ callback_id: str = callback_route.callback_id
+
+ result_bytes: bytes = self._parse_callback_result_payload(request)
+
+ callback_response: SendDurableExecutionCallbackSuccessResponse = ( # noqa: F841
+ self.executor.send_callback_success(
+ callback_id=callback_id, result=result_bytes
+ )
+ )
+
+ logger.debug(
+ "Callback %s succeeded with result: %s",
+ callback_id,
+ result_bytes.decode("utf-8", errors="replace"),
+ )
+
+ # Callback success response is empty
+ return self._success_response({})
+
+ except IllegalStateException as e:
+ # For callback operations, map to ExecutionConflictException
+ aws_exception = ExecutionConflictException(str(e))
+ return HTTPResponse.create_error_from_exception(aws_exception)
+ except AwsApiException as e:
+ return self._handle_aws_exception(e)
+ except Exception as e: # noqa: BLE001
+ return self._handle_framework_exception(e)
+
+
+class SendDurableExecutionCallbackFailureHandler(EndpointHandler):
+ """Handler for POST /2025-12-01/durable-execution-callbacks/{callback_id}/fail."""
+
+ def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse:
+ """Handle send durable execution callback failure request.
+
+ Args:
+ parsed_route: The strongly-typed route object
+ request: The HTTP request data
+
+ Returns:
+ HTTPResponse: The HTTP response to send to the client
+ """
+ try:
+ callback_route = cast(CallbackFailureRoute, parsed_route)
+ callback_id: str = callback_route.callback_id
+
+ body_data: dict[str, Any] = self._parse_json_body_optional(request)
+ callback_request: SendDurableExecutionCallbackFailureRequest = (
+ SendDurableExecutionCallbackFailureRequest.from_dict(
+ body_data, callback_id
+ )
+ )
+
+ callback_response: SendDurableExecutionCallbackFailureResponse = ( # noqa: F841
+ self.executor.send_callback_failure(
+ callback_id=callback_id, error=callback_request.error
+ )
+ )
+
+ logger.debug(
+ "Callback %s failed with error: %s", callback_id, callback_request.error
+ )
+
+ # Callback failure response is empty
+ return self._success_response({})
+
+ except IllegalStateException as e:
+ # For callback operations, map to ExecutionConflictException
+ aws_exception = ExecutionConflictException(str(e))
+ return HTTPResponse.create_error_from_exception(aws_exception)
+ except AwsApiException as e:
+ return self._handle_aws_exception(e)
+ except Exception as e: # noqa: BLE001
+ return self._handle_framework_exception(e)
+
+
+class SendDurableExecutionCallbackHeartbeatHandler(EndpointHandler):
+ """Handler for POST /2025-12-01/durable-execution-callbacks/{callback_id}/heartbeat."""
+
+ def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse:
+ """Handle send durable execution callback heartbeat request.
+
+ Args:
+ parsed_route: The strongly-typed route object
+ request: The HTTP request data
+
+ Returns:
+ HTTPResponse: The HTTP response to send to the client
+ """
+ try:
+ # Heartbeat requests don't have a body, only callback_id from URL
+ callback_route = cast(CallbackHeartbeatRoute, parsed_route)
+ callback_id: str = callback_route.callback_id
+
+ callback_response: SendDurableExecutionCallbackHeartbeatResponse = ( # noqa: F841
+ self.executor.send_callback_heartbeat(callback_id=callback_id)
+ )
+
+ # Callback heartbeat response is empty
+ return self._success_response({})
+
+ except IllegalStateException as e:
+ # For callback operations, map to ExecutionConflictException
+ aws_exception = ExecutionConflictException(str(e))
+ return HTTPResponse.create_error_from_exception(aws_exception)
+ except AwsApiException as e:
+ return self._handle_aws_exception(e)
+ except Exception as e: # noqa: BLE001
+ return self._handle_framework_exception(e)
+
+
+# TODO: should this be /ping instead?
+class HealthHandler(EndpointHandler):
+ """Handler for GET /health."""
+
+ def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: # noqa: ARG002
+ """Handle health check request.
+
+ Args:
+ parsed_route: The strongly-typed route object
+ request: The HTTP request data
+
+ Returns:
+ HTTPResponse: The HTTP response to send to the client
+ """
+ return self._success_response({"status": "healthy"})
+
+
+class MetricsHandler(EndpointHandler):
+ """Handler for GET /metrics."""
+
+ def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: # noqa: ARG002
+ """Handle metrics request.
+
+ Args:
+ parsed_route: The strongly-typed route object
+ request: The HTTP request data
+
+ Returns:
+ HTTPResponse: The HTTP response to send to the client
+ """
+ # TODO: Implement metrics collection logic
+ return self._success_response({"metrics": {}})
+
+
+class UpdateLambdaEndpointHandler(EndpointHandler):
+ """Handler for PUT /lambda-endpoint."""
+
+ def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: # noqa: ARG002
+ """Handle update Lambda endpoint request.
+
+ Args:
+ parsed_route: The strongly-typed route object
+ request: The HTTP request data
+
+ Returns:
+ HTTPResponse: The HTTP response to send to the client
+ """
+ try:
+ body = self._parse_json_body(request)
+ endpoint_url = body.get("EndpointUrl")
+ region_name = body.get("RegionName", "us-east-1")
+
+ if not endpoint_url:
+ return self._handle_aws_exception(
+ InvalidParameterValueException("EndpointUrl is required")
+ )
+
+ # Update the invoker's Lambda endpoint
+ invoker = self.executor._invoker # noqa: SLF001
+ logger.info("Updating lambda endpoint to %s", endpoint_url)
+ invoker.update_endpoint(endpoint_url, region_name)
+ return self._success_response(
+ {"message": "Lambda endpoint updated successfully"}
+ )
+
+ except Exception as e: # noqa: BLE001
+ return self._handle_framework_exception(e)
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/models.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/models.py
new file mode 100644
index 0000000..eebd0fe
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/models.py
@@ -0,0 +1,266 @@
+"""HTTP request/response data models and utilities for the web runner."""
+
+from __future__ import annotations
+
+import json
+import logging
+from dataclasses import dataclass
+from typing import Any, Protocol
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ AwsApiException,
+ InvalidParameterValueException,
+)
+
+# Removed deprecated imports from web.errors
+from aws_durable_execution_sdk_python_testing.web.routes import Route
+from aws_durable_execution_sdk_python_testing.web.serialization import (
+ AwsRestJsonDeserializer,
+ JSONSerializer,
+ Serializer,
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass(frozen=True)
+class HTTPRequest:
+ """HTTP request data model with dict or bytes body for handler logic."""
+
+ method: str
+ path: Route
+ headers: dict[str, str]
+ query_params: dict[str, list[str]]
+ body: dict[str, Any] | bytes
+
+ @classmethod
+ def from_raw_bytes(
+ cls,
+ body_bytes: bytes,
+ method: str = "POST",
+ path: Route | None = None,
+ headers: dict[str, str] | None = None,
+ query_params: dict[str, list[str]] | None = None,
+ ) -> HTTPRequest:
+ """Create HTTPRequest with raw bytes body (no parsing)."""
+ if headers is None:
+ headers = {}
+ if query_params is None:
+ query_params = {}
+ if path is None:
+ path = Route.from_string("")
+
+ return cls(
+ method=method,
+ path=path,
+ headers=headers,
+ query_params=query_params,
+ body=body_bytes,
+ )
+
+ @classmethod
+ def from_bytes(
+ cls,
+ body_bytes: bytes,
+ operation_name: str | None = None,
+ method: str = "POST",
+ path: Route | None = None,
+ headers: dict[str, str] | None = None,
+ query_params: dict[str, list[str]] | None = None,
+ ) -> HTTPRequest:
+ """Create HTTPRequest from raw bytes, deserializing to dict body.
+
+ Args:
+ body_bytes: Raw bytes to deserialize
+ operation_name: Optional AWS operation name for boto deserialization
+ method: HTTP method (default: POST)
+ path: Route object (required for actual usage)
+ headers: HTTP headers (default: empty dict)
+ query_params: Query parameters (default: empty dict)
+
+ Returns:
+ HTTPRequest: Request with deserialized dict body
+
+ Raises:
+ InvalidParameterValueException: If deserialization fails with both AWS and JSON methods
+ """
+ if headers is None:
+ headers = {}
+ if query_params is None:
+ query_params = {}
+
+ # Skip body parsing for GET requests
+ if method == "GET":
+ body_dict = {}
+ logger.debug("GET request, skipping body parsing")
+ # Try AWS deserialization first if operation_name provided
+ elif operation_name:
+ try:
+ deserializer = AwsRestJsonDeserializer.create(operation_name)
+ body_dict = deserializer.from_bytes(body_bytes)
+ logger.debug(
+ "Successfully deserialized request using AWS deserializer for %s",
+ operation_name,
+ )
+ except InvalidParameterValueException as e:
+ logger.warning(
+ "AWS deserialization failed for %s, falling back to JSON: %s",
+ operation_name,
+ e,
+ )
+ # Fall back to standard JSON
+ try:
+ body_dict = json.loads(body_bytes.decode("utf-8"))
+ logger.debug(
+ "Successfully deserialized request using JSON fallback"
+ )
+ except (json.JSONDecodeError, UnicodeDecodeError) as json_error:
+ msg = f"Both AWS and JSON deserialization failed: AWS error: {e}, JSON error: {json_error}"
+ raise InvalidParameterValueException(msg) from json_error
+ else:
+ # Use standard JSON deserialization
+ try:
+ if body_bytes == b"":
+ body_dict = {}
+ else:
+ body_dict = json.loads(body_bytes.decode("utf-8"))
+ logger.debug("Successfully deserialized request using standard JSON")
+ except (json.JSONDecodeError, UnicodeDecodeError) as e:
+ msg = f"JSON deserialization failed: {e}"
+ raise InvalidParameterValueException(msg) from e
+
+ # Handle case where path is None for testing
+ if path is None:
+ path = Route.from_string("")
+
+ return cls(
+ method=method,
+ path=path,
+ headers=headers,
+ query_params=query_params,
+ body=body_dict,
+ )
+
+
+@dataclass(frozen=True)
+class HTTPResponse:
+ """HTTP response data model with dict body and serialization capabilities."""
+
+ status_code: int
+ headers: dict[str, str]
+ body: dict[str, Any]
+ serializer: Serializer = JSONSerializer()
+
+ def body_to_bytes(self) -> bytes:
+ """Convert response dict body to bytes for HTTP transmission.
+
+ Returns:
+ bytes: Serialized response body
+
+ Raises:
+ InvalidParameterValueException: If serialization fails with both AWS and JSON methods
+ """
+ result = self.serializer.to_bytes(data=self.body)
+ logger.debug("Serialized result - before: %s, after: %s", self.body, result)
+ return result
+
+ @classmethod
+ def from_dict(
+ cls,
+ data: dict[str, Any],
+ status_code: int = 200,
+ headers: dict[str, str] | None = None,
+ ) -> HTTPResponse:
+ """Create HTTPResponse from dict data.
+
+ Args:
+ data: Response data as dictionary
+ status_code: HTTP status code (default: 200)
+ headers: HTTP headers (default: empty dict)
+
+ Returns:
+ HTTPResponse: Response with dict body
+ """
+ if headers is None:
+ headers = {}
+
+ return cls(status_code=status_code, headers=headers, body=data)
+
+ @staticmethod
+ def create_json(
+ status_code: int,
+ data: dict[str, Any],
+ additional_headers: dict[str, str] | None = None,
+ ) -> HTTPResponse:
+ """Create a JSON HTTP response.
+
+ Args:
+ status_code: HTTP status code
+ data: Data to serialize as JSON
+ additional_headers: Optional additional headers to include
+
+ Returns:
+ HTTPResponse: The HTTP response with dict body
+ """
+ headers = {"Content-Type": "application/json"}
+ if additional_headers:
+ headers.update(additional_headers)
+
+ return HTTPResponse(status_code=status_code, headers=headers, body=data)
+
+ # Removed deprecated create_error method - use create_error_from_exception instead
+
+ @staticmethod
+ def create_error_from_exception(aws_exception: AwsApiException) -> HTTPResponse:
+ """Create AWS-compliant error response from AwsApiException.
+
+ Args:
+ aws_exception: The AWS API exception to convert to HTTP response
+
+ Returns:
+ HTTPResponse: The HTTP error response with AWS-compliant format
+ """
+ if not isinstance(aws_exception, AwsApiException):
+ msg = f"Expected AwsApiException, got {type(aws_exception)}"
+ raise TypeError(msg)
+
+ # Use exception's http_status_code and to_dict() method
+ # This removes the wrapper "error" object to match AWS format
+ error_data = aws_exception.to_dict()
+ return HTTPResponse.create_json(aws_exception.http_status_code, error_data)
+
+ @staticmethod
+ def create_empty(
+ status_code: int, additional_headers: dict[str, str] | None = None
+ ) -> HTTPResponse:
+ """Create an empty HTTP response.
+
+ Args:
+ status_code: HTTP status code
+ additional_headers: Optional additional headers to include
+
+ Returns:
+ HTTPResponse: The HTTP response with empty dict body
+ """
+ headers = {}
+ if additional_headers:
+ headers.update(additional_headers)
+
+ return HTTPResponse(status_code=status_code, headers=headers, body={})
+
+
+class OperationHandler(Protocol):
+ """Protocol for handling HTTP operations with strongly-typed paths."""
+
+ def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse:
+ """Handle an HTTP request and return an HTTP response.
+
+ Args:
+ parsed_route: The strongly-typed route object
+ request: The HTTP request data
+
+ Returns:
+ HTTPResponse: The HTTP response to send to the client
+ """
+ ... # pragma: no cover
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/routes.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/routes.py
new file mode 100644
index 0000000..bc09906
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/routes.py
@@ -0,0 +1,692 @@
+"""Strongly-typed route parsing system for HTTP request routing."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from urllib.parse import unquote
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ UnknownRouteError,
+)
+
+
+@dataclass(frozen=True)
+class Route:
+ """Base route with segments and pattern matching capabilities."""
+
+ raw_path: str
+ segments: list[str]
+
+ @classmethod
+ def from_route(cls, _route: Route) -> Route:
+ """Create a typed route from a base Route.
+
+ Note: Call is_match(route, method) first to ensure the route is valid for this type.
+
+ Args:
+ route: Base route to convert
+
+ Returns:
+ Typed route instance
+
+ Raises:
+ NotImplementedError: This is an abstract method that must be implemented by subclasses
+ """
+ msg = "Subclasses must implement from_route()"
+ raise NotImplementedError(msg)
+
+ @classmethod
+ def from_string(cls, path: str) -> Route:
+ """Create a Route from a string.
+
+ Each segment is URL-decoded; ``raw_path`` is preserved as the
+ original wire path. Splitting on ``/`` happens before decoding so
+ that an encoded ``%2F`` inside a captured value (e.g. an ARN that
+ contains ``/``) stays inside its segment instead of being treated
+ as a path separator.
+
+ Args:
+ path: The raw path string
+
+ Returns:
+ Route instance with parsed, URL-decoded segments
+ """
+ # Remove leading/trailing slashes, split on '/', then URL-decode each
+ # segment. Order matters: split on the literal '/' first so '%2F'-
+ # encoded slashes inside values don't act as separators.
+ segments = [unquote(s) for s in path.strip("/").split("/") if s]
+ return cls(raw_path=path, segments=segments)
+
+ def matches_pattern(self, pattern: list[str]) -> bool:
+ """Check if route matches the given pattern.
+
+ Args:
+ pattern: List of pattern segments. Use '*' for wildcards.
+
+ Returns:
+ True if the route matches the pattern
+ """
+ if len(self.segments) != len(pattern):
+ return False
+
+ for segment, pattern_part in zip(self.segments, pattern, strict=False):
+ if pattern_part not in ("*", segment):
+ return False
+ return True
+
+ @classmethod
+ def is_match(cls, _route: Route, _method: str) -> bool:
+ """Check if the route and HTTP method match this route type.
+
+ Args:
+ _route: Route to check
+ _method: HTTP method to check
+
+ Returns:
+ True if the route and method match
+
+ Raises:
+ NotImplementedError: This is an abstract method that must be implemented by subclasses
+ """
+ msg = "Subclasses must implement is_match()"
+ raise NotImplementedError(msg)
+
+
+@dataclass(frozen=True)
+class StartExecutionRoute(Route):
+ """Route: POST /start-durable-execution"""
+
+ @classmethod
+ def is_match(cls, route: Route, method: str) -> bool:
+ """Check if the route and HTTP method match this route type.
+
+ Args:
+ route: Route to check
+ method: HTTP method to check
+
+ Returns:
+ True if the route and method match
+ """
+ return route.raw_path == "/start-durable-execution" and method == "POST"
+
+ @classmethod
+ def from_route(cls, route: Route) -> StartExecutionRoute:
+ """Create a StartExecutionRoute from a base Route.
+
+ Note: Call is_match(route, method) first to ensure the route is valid for this type.
+
+ Args:
+ route: Base route to convert
+
+ Returns:
+ StartExecutionRoute instance
+ """
+ return cls(raw_path=route.raw_path, segments=route.segments)
+
+
+@dataclass(frozen=True)
+class GetDurableExecutionRoute(Route):
+ """Route: GET /2025-12-01/durable-executions/{arn}"""
+
+ arn: str
+
+ @classmethod
+ def is_match(cls, route: Route, method: str) -> bool:
+ """Check if the route and HTTP method match this route type.
+
+ Args:
+ route: Route to check
+ method: HTTP method to check
+
+ Returns:
+ True if the route and method match
+ """
+ return (
+ route.matches_pattern(["2025-12-01", "durable-executions", "*"])
+ and method == "GET"
+ )
+
+ @classmethod
+ def from_route(cls, route: Route) -> GetDurableExecutionRoute:
+ """Create a GetDurableExecutionRoute from a base Route.
+
+ Note: Call is_match(route, method) first to ensure the route is valid for this type.
+
+ Args:
+ route: Base route to convert
+
+ Returns:
+ GetDurableExecutionRoute instance with extracted ARN
+ """
+ return cls(
+ raw_path=route.raw_path,
+ segments=route.segments,
+ arn=route.segments[2],
+ )
+
+
+@dataclass(frozen=True)
+class CheckpointDurableExecutionRoute(Route):
+ """Route: POST /2025-12-01/durable-executions/{arn}/checkpoint"""
+
+ arn: str
+
+ @classmethod
+ def is_match(cls, route: Route, method: str) -> bool:
+ """Check if the route and HTTP method match this route type.
+
+ Args:
+ route: Route to check
+ method: HTTP method to check
+
+ Returns:
+ True if the route and method match
+ """
+ return (
+ route.matches_pattern(
+ ["2025-12-01", "durable-executions", "*", "checkpoint"]
+ )
+ and method == "POST"
+ )
+
+ @classmethod
+ def from_route(cls, route: Route) -> CheckpointDurableExecutionRoute:
+ """Create a CheckpointDurableExecutionRoute from a base Route.
+
+ Note: Call is_match(route, method) first to ensure the route is valid for this type.
+
+ Args:
+ route: Base route to convert
+
+ Returns:
+ CheckpointDurableExecutionRoute instance with extracted ARN
+ """
+ return cls(
+ raw_path=route.raw_path,
+ segments=route.segments,
+ arn=route.segments[2],
+ )
+
+
+@dataclass(frozen=True)
+class StopDurableExecutionRoute(Route):
+ """Route: POST /2025-12-01/durable-executions/{arn}/stop"""
+
+ arn: str
+
+ @classmethod
+ def is_match(cls, route: Route, method: str) -> bool:
+ """Check if the route and HTTP method match this route type.
+
+ Args:
+ route: Route to check
+ method: HTTP method to check
+
+ Returns:
+ True if the route and method match
+ """
+ return (
+ route.matches_pattern(["2025-12-01", "durable-executions", "*", "stop"])
+ and method == "POST"
+ )
+
+ @classmethod
+ def from_route(cls, route: Route) -> StopDurableExecutionRoute:
+ """Create a StopDurableExecutionRoute from a base Route.
+
+ Note: Call is_match(route, method) first to ensure the route is valid for this type.
+
+ Args:
+ route: Base route to convert
+
+ Returns:
+ StopDurableExecutionRoute instance with extracted ARN
+ """
+ return cls(
+ raw_path=route.raw_path,
+ segments=route.segments,
+ arn=route.segments[2],
+ )
+
+
+@dataclass(frozen=True)
+class GetDurableExecutionStateRoute(Route):
+ """Route: GET /2025-12-01/durable-executions/{arn}/state"""
+
+ arn: str
+
+ @classmethod
+ def is_match(cls, route: Route, method: str) -> bool:
+ """Check if the route and HTTP method match this route type.
+
+ Args:
+ route: Route to check
+ method: HTTP method to check
+
+ Returns:
+ True if the route and method match
+ """
+ return (
+ route.matches_pattern(["2025-12-01", "durable-executions", "*", "state"])
+ and method == "GET"
+ )
+
+ @classmethod
+ def from_route(cls, route: Route) -> GetDurableExecutionStateRoute:
+ """Create a GetDurableExecutionStateRoute from a base Route.
+
+ Note: Call is_match(route, method) first to ensure the route is valid for this type.
+
+ Args:
+ route: Base route to convert
+
+ Returns:
+ GetDurableExecutionStateRoute instance with extracted ARN
+ """
+ return cls(
+ raw_path=route.raw_path,
+ segments=route.segments,
+ arn=route.segments[2],
+ )
+
+
+@dataclass(frozen=True)
+class GetDurableExecutionHistoryRoute(Route):
+ """Route: GET /2025-12-01/durable-executions/{arn}/history"""
+
+ arn: str
+
+ @classmethod
+ def is_match(cls, route: Route, method: str) -> bool:
+ """Check if the route and HTTP method match this route type.
+
+ Args:
+ route: Route to check
+ method: HTTP method to check
+
+ Returns:
+ True if the route and method match
+ """
+ return (
+ route.matches_pattern(["2025-12-01", "durable-executions", "*", "history"])
+ and method == "GET"
+ )
+
+ @classmethod
+ def from_route(cls, route: Route) -> GetDurableExecutionHistoryRoute:
+ """Create a GetDurableExecutionHistoryRoute from a base Route.
+
+ Note: Call is_match(route, method) first to ensure the route is valid for this type.
+
+ Args:
+ route: Base route to convert
+
+ Returns:
+ GetDurableExecutionHistoryRoute instance with extracted ARN
+ """
+ return cls(
+ raw_path=route.raw_path,
+ segments=route.segments,
+ arn=route.segments[2],
+ )
+
+
+@dataclass(frozen=True)
+class ListDurableExecutionsRoute(Route):
+ """Route: GET /2025-12-01/durable-executions"""
+
+ @classmethod
+ def is_match(cls, route: Route, method: str) -> bool:
+ """Check if the route and HTTP method match this route type.
+
+ Args:
+ route: Route to check
+ method: HTTP method to check
+
+ Returns:
+ True if the route and method match
+ """
+ return (
+ route.matches_pattern(["2025-12-01", "durable-executions"])
+ and method == "GET"
+ )
+
+ @classmethod
+ def from_route(cls, route: Route) -> ListDurableExecutionsRoute:
+ """Create a ListDurableExecutionsRoute from a base Route.
+
+ Note: Call is_match(route, method) first to ensure the route is valid for this type.
+
+ Args:
+ route: Base route to convert
+
+ Returns:
+ ListDurableExecutionsRoute instance
+ """
+ return cls(raw_path=route.raw_path, segments=route.segments)
+
+
+@dataclass(frozen=True)
+class ListDurableExecutionsByFunctionRoute(Route):
+ """Route: GET /2025-12-01/functions/{function_name}/durable-executions"""
+
+ function_name: str
+
+ @classmethod
+ def is_match(cls, route: Route, method: str) -> bool:
+ """Check if the route and HTTP method match this route type.
+
+ Args:
+ route: Route to check
+ method: HTTP method to check
+
+ Returns:
+ True if the route and method match
+ """
+ return (
+ route.matches_pattern(
+ ["2025-12-01", "functions", "*", "durable-executions"]
+ )
+ and method == "GET"
+ )
+
+ @classmethod
+ def from_route(cls, route: Route) -> ListDurableExecutionsByFunctionRoute:
+ """Create a ListDurableExecutionsByFunctionRoute from a base Route.
+
+ Note: Call is_match(route, method) first to ensure the route is valid for this type.
+
+ Args:
+ route: Base route to convert
+
+ Returns:
+ ListDurableExecutionsByFunctionRoute instance with extracted function name
+ """
+ return cls(
+ raw_path=route.raw_path,
+ segments=route.segments,
+ function_name=route.segments[2],
+ )
+
+
+@dataclass(frozen=True)
+class BytesPayloadRoute(Route):
+ """Base class for routes that handle raw bytes payloads instead of JSON."""
+
+
+@dataclass(frozen=True)
+class CallbackSuccessRoute(BytesPayloadRoute):
+ """Route: POST /2025-12-01/durable-execution-callbacks/{callback_id}/succeed"""
+
+ callback_id: str
+
+ @classmethod
+ def is_match(cls, route: Route, method: str) -> bool:
+ """Check if the route and HTTP method match this route type.
+
+ Args:
+ route: Route to check
+ method: HTTP method to check
+
+ Returns:
+ True if the route and method match
+ """
+ return (
+ route.matches_pattern(
+ ["2025-12-01", "durable-execution-callbacks", "*", "succeed"]
+ )
+ and method == "POST"
+ )
+
+ @classmethod
+ def from_route(cls, route: Route) -> CallbackSuccessRoute:
+ """Create a CallbackSuccessRoute from a base Route.
+
+ Note: Call is_match(route, method) first to ensure the route is valid for this type.
+
+ Args:
+ route: Base route to convert
+
+ Returns:
+ CallbackSuccessRoute instance with extracted callback ID
+ """
+ return cls(
+ raw_path=route.raw_path,
+ segments=route.segments,
+ callback_id=route.segments[2],
+ )
+
+
+@dataclass(frozen=True)
+class CallbackFailureRoute(BytesPayloadRoute):
+ """Route: POST /2025-12-01/durable-execution-callbacks/{callback_id}/fail"""
+
+ callback_id: str
+
+ @classmethod
+ def is_match(cls, route: Route, method: str) -> bool:
+ """Check if the route and HTTP method match this route type.
+
+ Args:
+ route: Route to check
+ method: HTTP method to check
+
+ Returns:
+ True if the route and method match
+ """
+ return (
+ route.matches_pattern(
+ ["2025-12-01", "durable-execution-callbacks", "*", "fail"]
+ )
+ and method == "POST"
+ )
+
+ @classmethod
+ def from_route(cls, route: Route) -> CallbackFailureRoute:
+ """Create a CallbackFailureRoute from a base Route.
+
+ Note: Call is_match(route, method) first to ensure the route is valid for this type.
+
+ Args:
+ route: Base route to convert
+
+ Returns:
+ CallbackFailureRoute instance with extracted callback ID
+ """
+ return cls(
+ raw_path=route.raw_path,
+ segments=route.segments,
+ callback_id=route.segments[2],
+ )
+
+
+@dataclass(frozen=True)
+class CallbackHeartbeatRoute(Route):
+ """Route: POST /2025-12-01/durable-execution-callbacks/{callback_id}/heartbeat"""
+
+ callback_id: str
+
+ @classmethod
+ def is_match(cls, route: Route, method: str) -> bool:
+ """Check if the route and HTTP method match this route type.
+
+ Args:
+ route: Route to check
+ method: HTTP method to check
+
+ Returns:
+ True if the route and method match
+ """
+ return (
+ route.matches_pattern(
+ ["2025-12-01", "durable-execution-callbacks", "*", "heartbeat"]
+ )
+ and method == "POST"
+ )
+
+ @classmethod
+ def from_route(cls, route: Route) -> CallbackHeartbeatRoute:
+ """Create a CallbackHeartbeatRoute from a base Route.
+
+ Note: Call is_match(route, method) first to ensure the route is valid for this type.
+
+ Args:
+ route: Base route to convert
+
+ Returns:
+ CallbackHeartbeatRoute instance with extracted callback ID
+ """
+ return cls(
+ raw_path=route.raw_path,
+ segments=route.segments,
+ callback_id=route.segments[2],
+ )
+
+
+@dataclass(frozen=True)
+class HealthRoute(Route):
+ """Route: GET /health"""
+
+ @classmethod
+ def is_match(cls, route: Route, method: str) -> bool:
+ """Check if the route and HTTP method match this route type.
+
+ Args:
+ route: Route to check
+ method: HTTP method to check
+
+ Returns:
+ True if the route and method match
+ """
+ return route.raw_path == "/health" and method == "GET"
+
+ @classmethod
+ def from_route(cls, route: Route) -> HealthRoute:
+ """Create a HealthRoute from a base Route.
+
+ Note: Call is_match(route, method) first to ensure the route is valid for this type.
+
+ Args:
+ route: Base route to convert
+
+ Returns:
+ HealthRoute instance
+ """
+ return cls(raw_path=route.raw_path, segments=route.segments)
+
+
+@dataclass(frozen=True)
+class UpdateLambdaEndpointRoute(Route):
+ """Route: PUT /lambda-endpoint"""
+
+ @classmethod
+ def is_match(cls, route: Route, method: str) -> bool:
+ """Check if the route and HTTP method match this route type.
+
+ Args:
+ route: Route to check
+ method: HTTP method to check
+
+ Returns:
+ True if the route and method match
+ """
+ return route.raw_path == "/lambda-endpoint" and method == "PUT"
+
+ @classmethod
+ def from_route(cls, route: Route) -> UpdateLambdaEndpointRoute:
+ """Create UpdateLambdaEndpointRoute from base route.
+
+ Args:
+ route: Base route to convert
+
+ Returns:
+ UpdateLambdaEndpointRoute instance
+ """
+ return cls(raw_path=route.raw_path, segments=route.segments)
+
+
+@dataclass(frozen=True)
+class MetricsRoute(Route):
+ """Route: GET /metrics"""
+
+ @classmethod
+ def is_match(cls, route: Route, method: str) -> bool:
+ """Check if the route and HTTP method match this route type.
+
+ Args:
+ route: Route to check
+ method: HTTP method to check
+
+ Returns:
+ True if the route and method match
+ """
+ return route.raw_path == "/metrics" and method == "GET"
+
+ @classmethod
+ def from_route(cls, route: Route) -> MetricsRoute:
+ """Create a MetricsRoute from a base Route.
+
+ Note: Call is_match(route, method) first to ensure the route is valid for this type.
+
+ Args:
+ route: Base route to convert
+
+ Returns:
+ MetricsRoute instance
+ """
+ return cls(raw_path=route.raw_path, segments=route.segments)
+
+
+# Default registry of all route types for matching
+DEFAULT_ROUTE_TYPES: list[type[Route]] = [
+ StartExecutionRoute,
+ GetDurableExecutionRoute,
+ CheckpointDurableExecutionRoute,
+ StopDurableExecutionRoute,
+ GetDurableExecutionStateRoute,
+ GetDurableExecutionHistoryRoute,
+ ListDurableExecutionsRoute,
+ ListDurableExecutionsByFunctionRoute,
+ CallbackSuccessRoute,
+ CallbackFailureRoute,
+ CallbackHeartbeatRoute,
+ HealthRoute,
+ UpdateLambdaEndpointRoute,
+ MetricsRoute,
+]
+
+
+class Router:
+ """HTTP request router that matches routes to strongly-typed route objects."""
+
+ def __init__(self, route_types: list[type[Route]] | None = None) -> None:
+ """Initialize the router with route types.
+
+ Args:
+ route_types: List of route type classes to use for matching.
+ If None, uses the default route types.
+ """
+ self._route_types = (
+ route_types if route_types is not None else DEFAULT_ROUTE_TYPES
+ )
+
+ def find_route(self, path: str, method: str) -> Route:
+ """Find a matching route for the given path and HTTP method.
+
+ Args:
+ path: The raw path string to parse
+ method: The HTTP method (GET, POST, etc.)
+
+ Returns:
+ Strongly-typed Route instance
+
+ Raises:
+ UnknownRouteError: If the path and method don't match any known pattern
+ """
+ base_route = Route.from_string(path)
+
+ for route_type in self._route_types:
+ if route_type.is_match(base_route, method):
+ return route_type.from_route(base_route)
+
+ raise UnknownRouteError(method, path)
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/serialization.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/serialization.py
new file mode 100644
index 0000000..6532a66
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/serialization.py
@@ -0,0 +1,235 @@
+"""Serialization interfaces and AWS boto integration for HTTP request/response models.
+
+This module provides Protocol interfaces for serialization and deserialization,
+along with AWS-compatible implementations using boto's rest-json serializers.
+"""
+
+from __future__ import annotations
+
+import json
+import os
+from typing import Any, Protocol
+from datetime import datetime
+
+import aws_durable_execution_sdk_python
+import botocore.loaders # type: ignore
+from botocore.model import ServiceModel # type: ignore
+from botocore.parsers import create_parser # type: ignore
+from botocore.serialize import create_serializer # type: ignore
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+
+
+class Serializer(Protocol):
+ """Interface for serializing data to bytes."""
+
+ def to_bytes(self, data: Any) -> bytes:
+ """Serialize data to bytes.
+
+ Args:
+ data: The data to serialize
+
+ Returns:
+ bytes: The serialized data
+
+ Raises:
+ InvalidParameterValueException: If serialization fails
+ """
+ ... # pragma: no cover
+
+
+class Deserializer(Protocol):
+ """Interface for deserializing bytes to data."""
+
+ def from_bytes(self, data: bytes) -> dict[str, Any]:
+ """Deserialize bytes to dictionary.
+
+ Args:
+ data: The bytes to deserialize
+
+ Returns:
+ dict: The deserialized data
+
+ Raises:
+ InvalidParameterValueException: If deserialization fails
+ """
+ ... # pragma: no cover
+
+
+class JSONSerializer:
+ """JSON serializer with datetime support."""
+
+ def to_bytes(self, data: Any) -> bytes:
+ """Serialize data to JSON bytes."""
+ try:
+ json_string = json.dumps(
+ data, separators=(",", ":"), default=self._default_handler
+ )
+ return json_string.encode("utf-8")
+ except (TypeError, ValueError) as e:
+ raise InvalidParameterValueException(
+ f"Failed to serialize data to JSON: {str(e)}"
+ )
+
+ def _default_handler(self, obj: Any) -> float:
+ """Handle non-permitive objects."""
+ if isinstance(obj, datetime):
+ return obj.timestamp()
+ # Raise TypeError for unsupported types
+ raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
+
+
+class AwsRestJsonSerializer:
+ """AWS rest-json serializer using boto."""
+
+ def __init__(self, operation_name: str, serializer: Any, operation_model: Any):
+ """Initialize the AWS rest-json serializer.
+
+ Args:
+ operation_name: Name of the AWS operation
+ serializer: Boto serializer instance
+ operation_model: Boto operation model
+ """
+ self._operation_name = operation_name
+ self._serializer = serializer
+ self._operation_model = operation_model
+
+ @classmethod
+ def create(cls, operation_name: str) -> AwsRestJsonSerializer:
+ """Create serializer with boto components.
+
+ Args:
+ operation_name: Name of the AWS operation
+
+ Returns:
+ AwsRestJsonSerializer: Configured serializer instance
+
+ Raises:
+ InvalidParameterValueException: If serializer creation fails
+ """
+ try:
+ # Load service model
+ loader = botocore.loaders.Loader()
+
+ raw_model = loader.load_service_model("lambda", "service-2")
+ service_model = ServiceModel(raw_model)
+
+ # Create serializer (rest-json protocol)
+ serializer = create_serializer("rest-json", include_validation=True)
+ operation_model = service_model.operation_model(operation_name)
+
+ return cls(operation_name, serializer, operation_model)
+ except Exception as e:
+ msg = f"Failed to create serializer for {operation_name}: {e}"
+ raise InvalidParameterValueException(msg) from e
+
+ def to_bytes(self, data: dict[str, Any]) -> bytes:
+ """Serialize data using boto rest-json serializer.
+
+ Args:
+ data: Dictionary data to serialize
+
+ Returns:
+ bytes: Serialized data
+
+ Raises:
+ InvalidParameterValueException: If serialization fails
+ """
+ if not self._serializer or not self._operation_model:
+ msg = f"Serializer not initialized for {self._operation_name}"
+ raise InvalidParameterValueException(msg)
+
+ try:
+ serialized = self._serializer.serialize_to_request(
+ data, self._operation_model
+ )
+ body = serialized.get("body", b"")
+
+ if isinstance(body, str):
+ return body.encode("utf-8")
+
+ return body # noqa: TRY300
+ except Exception as e:
+ msg = f"Failed to serialize data for {self._operation_name}: {e}"
+ raise InvalidParameterValueException(msg) from e
+
+
+class AwsRestJsonDeserializer:
+ """AWS rest-json deserializer using boto."""
+
+ def __init__(self, operation_name: str, parser: Any, operation_model: Any):
+ """Initialize the AWS rest-json deserializer.
+
+ Args:
+ operation_name: Name of the AWS operation
+ parser: Boto parser instance
+ operation_model: Boto operation model
+ """
+ self._operation_name = operation_name
+ self._parser = parser
+ self._operation_model = operation_model
+
+ @classmethod
+ def create(cls, operation_name: str) -> AwsRestJsonDeserializer:
+ """Create deserializer with boto components.
+
+ Args:
+ operation_name: Name of the AWS operation
+
+ Returns:
+ AwsRestJsonDeserializer: Configured deserializer instance
+
+ Raises:
+ InvalidParameterValueException: If deserializer creation fails
+ """
+ try:
+ # Load service model
+ loader = botocore.loaders.Loader()
+
+ raw_model = loader.load_service_model("lambda", "service-2")
+ service_model = ServiceModel(raw_model)
+
+ # Create parser (rest-json protocol)
+ parser = create_parser("rest-json")
+ operation_model = service_model.operation_model(operation_name)
+
+ return cls(operation_name, parser, operation_model)
+ except Exception as e:
+ msg = f"Failed to create deserializer for {operation_name}: {e}"
+ raise InvalidParameterValueException(msg) from e
+
+ def from_bytes(self, data: bytes) -> dict[str, Any]:
+ """Deserialize bytes using boto rest-json parser.
+
+ Args:
+ data: Bytes to deserialize
+
+ Returns:
+ dict: Deserialized data
+
+ Raises:
+ InvalidParameterValueException: If deserialization fails
+ """
+ if not self._parser or not self._operation_model:
+ msg = f"Parser not initialized for {self._operation_name}"
+ raise InvalidParameterValueException(msg)
+
+ try:
+ if self._operation_model.output_shape:
+ # Create response dict for boto parser
+ response_dict = {
+ "body": data,
+ "headers": {"content-type": "application/json"},
+ "status_code": 200,
+ }
+ return self._parser.parse(
+ response_dict, self._operation_model.output_shape
+ )
+
+ # If no output shape, just parse as JSON
+ return json.loads(data.decode("utf-8"))
+ except Exception as e:
+ msg = f"Failed to deserialize data for {self._operation_name}: {e}"
+ raise InvalidParameterValueException(msg) from e
diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/server.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/server.py
new file mode 100644
index 0000000..89cb219
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/server.py
@@ -0,0 +1,243 @@
+"""Local testing web server for AWS Lambda Durable Functions that mimics the actual Lambda backend services."""
+
+from __future__ import annotations
+
+import logging
+from dataclasses import dataclass
+from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
+from typing import TYPE_CHECKING, Self
+from urllib.parse import parse_qs, urlparse
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ AwsApiException,
+ ServiceException,
+ UnknownRouteError,
+)
+
+
+if TYPE_CHECKING:
+ from aws_durable_execution_sdk_python_testing.executor import Executor
+
+
+# Removed deprecated imports from web.errors
+from aws_durable_execution_sdk_python_testing.web.handlers import (
+ CheckpointDurableExecutionHandler,
+ EndpointHandler,
+ GetDurableExecutionHandler,
+ GetDurableExecutionHistoryHandler,
+ GetDurableExecutionStateHandler,
+ HealthHandler,
+ ListDurableExecutionsByFunctionHandler,
+ ListDurableExecutionsHandler,
+ MetricsHandler,
+ SendDurableExecutionCallbackFailureHandler,
+ SendDurableExecutionCallbackHeartbeatHandler,
+ SendDurableExecutionCallbackSuccessHandler,
+ StartExecutionHandler,
+ StopDurableExecutionHandler,
+ UpdateLambdaEndpointHandler,
+)
+from aws_durable_execution_sdk_python_testing.web.models import (
+ HTTPRequest,
+ HTTPResponse,
+)
+from aws_durable_execution_sdk_python_testing.web.routes import (
+ BytesPayloadRoute,
+ CallbackFailureRoute,
+ CallbackHeartbeatRoute,
+ CallbackSuccessRoute,
+ CheckpointDurableExecutionRoute,
+ GetDurableExecutionHistoryRoute,
+ GetDurableExecutionRoute,
+ GetDurableExecutionStateRoute,
+ HealthRoute,
+ ListDurableExecutionsByFunctionRoute,
+ ListDurableExecutionsRoute,
+ MetricsRoute,
+ Route,
+ Router,
+ StartExecutionRoute,
+ StopDurableExecutionRoute,
+ UpdateLambdaEndpointRoute,
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass(frozen=True)
+class WebServiceConfig:
+ """Configuration for the web service."""
+
+ host: str = "localhost"
+ port: int = 5000
+ log_level: int = logging.INFO
+ max_request_size: int = 10 * 1024 * 1024 # 10MB
+
+
+class RequestHandler(BaseHTTPRequestHandler):
+ """HTTP request handler for durable execution operations."""
+
+ def __init__(self, request, client_address, server) -> None:
+ self.executor: Executor = server.executor
+ self.router: Router = server.router # Access shared router
+ self.endpoint_handlers: dict[type[Route], EndpointHandler] = (
+ server.endpoint_handlers
+ ) # Access shared handlers
+ super().__init__(request, client_address, server)
+
+ def do_GET(self) -> None: # noqa: N802
+ """Handle GET requests."""
+ self._handle_request("GET")
+
+ def do_POST(self) -> None: # noqa: N802
+ """Handle POST requests."""
+ self._handle_request("POST")
+
+ def do_PUT(self) -> None: # noqa: N802
+ """Handle PUT requests."""
+ self._handle_request("PUT")
+
+ def _handle_request(self, method: str) -> None:
+ """Handle HTTP request with strongly-typed routing."""
+ try:
+ # Parse URL path and method into strongly-typed Route object using shared router
+ url_path: str = self.path.split("?")[0]
+ parsed_route: Route = self.router.find_route(url_path, method)
+
+ # Find handler for this route type
+ handler: EndpointHandler | None = self.endpoint_handlers.get(
+ type(parsed_route)
+ )
+
+ if not handler:
+ raise UnknownRouteError(method, url_path) # noqa: TRY301
+
+ # Parse query parameters and request body
+ parsed_url = urlparse(self.path)
+ query_params: dict[str, list[str]] = parse_qs(parsed_url.query)
+ content_length: int = int(self.headers.get("Content-Length", 0))
+ body_bytes: bytes = (
+ self.rfile.read(content_length) if content_length > 0 else b""
+ )
+
+ # For callback operations, use raw bytes directly
+ if isinstance(parsed_route, BytesPayloadRoute):
+ request = HTTPRequest.from_raw_bytes(
+ body_bytes=body_bytes,
+ method=method,
+ path=parsed_route,
+ headers=dict(self.headers),
+ query_params=query_params,
+ )
+ else:
+ # Create strongly-typed HTTP request object with pre-parsed body
+ request = HTTPRequest.from_bytes(
+ body_bytes=body_bytes,
+ operation_name=None,
+ method=method,
+ path=parsed_route,
+ headers=dict(self.headers),
+ query_params=query_params,
+ )
+
+ # Handle request with appropriate handler
+ response: HTTPResponse = handler.handle(parsed_route, request)
+
+ # Send HTTP response
+ self._send_response(response)
+
+ except Exception as e:
+ logger.exception("Request handling failed")
+
+ aws_exception: AwsApiException = (
+ e if isinstance(e, AwsApiException) else ServiceException(str(e))
+ )
+
+ http_response = HTTPResponse.create_error_from_exception(aws_exception)
+ self._send_response(http_response)
+
+ def _send_response(self, response: HTTPResponse) -> None:
+ """Send HTTP response to client."""
+ self.send_response(response.status_code)
+ for header_name, header_value in response.headers.items():
+ self.send_header(header_name, header_value)
+ self.end_headers()
+
+ # Convert response body to bytes for transmission
+ if response.body:
+ self.wfile.write(response.body_to_bytes())
+
+ def log_message(self, format_string: str, *args) -> None:
+ """Override to use Python logging instead of stderr."""
+ logger.info("%s - %s", self.address_string(), format_string % args)
+
+
+class WebServer(ThreadingHTTPServer):
+ """Multi-threaded HTTP server for durable execution operations."""
+
+ def __init__(self, config: WebServiceConfig, executor: Executor) -> None:
+ """Initialize the web server.
+
+ Args:
+ config: Server configuration
+ executor: Executor instance for handling operations
+ """
+ self.config = config
+ self.executor = executor
+
+ # Configure logging
+ logging.basicConfig(level=config.log_level)
+ logging.getLogger("botocore").setLevel(logging.WARNING)
+
+ # Create shared router and endpoint handlers
+ self.router = Router() # Shared across all request handlers
+ self.endpoint_handlers = (
+ self._create_endpoint_handlers()
+ ) # Shared handler registry
+
+ # Initialize the HTTP server
+ super().__init__((config.host, config.port), RequestHandler)
+
+ logger.info("Web server initialized on %s:%s", config.host, config.port)
+
+ def _create_endpoint_handlers(self) -> dict[type[Route], EndpointHandler]:
+ """Create endpoint handlers registry - called once during server initialization."""
+ return {
+ StartExecutionRoute: StartExecutionHandler(self.executor),
+ GetDurableExecutionRoute: GetDurableExecutionHandler(self.executor),
+ CheckpointDurableExecutionRoute: CheckpointDurableExecutionHandler(
+ self.executor
+ ),
+ StopDurableExecutionRoute: StopDurableExecutionHandler(self.executor),
+ GetDurableExecutionStateRoute: GetDurableExecutionStateHandler(
+ self.executor
+ ),
+ GetDurableExecutionHistoryRoute: GetDurableExecutionHistoryHandler(
+ self.executor
+ ),
+ ListDurableExecutionsRoute: ListDurableExecutionsHandler(self.executor),
+ ListDurableExecutionsByFunctionRoute: ListDurableExecutionsByFunctionHandler(
+ self.executor
+ ),
+ CallbackSuccessRoute: SendDurableExecutionCallbackSuccessHandler(
+ self.executor
+ ),
+ CallbackFailureRoute: SendDurableExecutionCallbackFailureHandler(
+ self.executor
+ ),
+ CallbackHeartbeatRoute: SendDurableExecutionCallbackHeartbeatHandler(
+ self.executor
+ ),
+ HealthRoute: HealthHandler(self.executor),
+ UpdateLambdaEndpointRoute: UpdateLambdaEndpointHandler(self.executor),
+ MetricsRoute: MetricsHandler(self.executor),
+ }
+
+ def __enter__(self) -> Self:
+ """Context manager entry."""
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
+ """Context manager exit - cleanup server resources."""
+ self.server_close()
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/__init__.py b/packages/aws-durable-execution-sdk-python-testing/tests/__init__.py
new file mode 100644
index 0000000..66173ae
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/__init__.py
@@ -0,0 +1 @@
+# Test package
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/__init__.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/__init__.py
new file mode 100644
index 0000000..78d8de9
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/__init__.py
@@ -0,0 +1 @@
+"""Test package"""
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processor_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processor_test.py
new file mode 100644
index 0000000..1df24cc
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processor_test.py
@@ -0,0 +1,295 @@
+"""Unit tests for CheckpointProcessor."""
+
+from unittest.mock import Mock, patch
+
+import pytest
+from aws_durable_execution_sdk_python.lambda_service import (
+ CheckpointOutput,
+ CheckpointUpdatedExecutionState,
+ OperationAction,
+ OperationType,
+ OperationUpdate,
+ StateOutput,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.processor import (
+ CheckpointProcessor,
+)
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+from aws_durable_execution_sdk_python_testing.execution import Execution
+from aws_durable_execution_sdk_python_testing.scheduler import Scheduler
+from aws_durable_execution_sdk_python_testing.stores.base import ExecutionStore
+from aws_durable_execution_sdk_python_testing.token import CheckpointToken
+
+
+def test_init():
+ """Test CheckpointProcessor initialization."""
+ store = Mock(spec=ExecutionStore)
+ scheduler = Mock(spec=Scheduler)
+
+ processor = CheckpointProcessor(store, scheduler)
+
+ # Test that processor was created successfully by calling a public method
+ # This indirectly verifies that internal components were initialized
+ assert processor is not None
+
+ # Test that we can add observers (verifies notifier is initialized)
+ observer = Mock()
+ processor.add_execution_observer(observer) # Should not raise an exception
+
+
+@patch(
+ "aws_durable_execution_sdk_python_testing.checkpoint.processor.ExecutionNotifier"
+)
+def test_add_execution_observer(mock_notifier_class):
+ """Test adding execution observer."""
+ store = Mock(spec=ExecutionStore)
+ scheduler = Mock(spec=Scheduler)
+ mock_notifier_instance = Mock()
+ mock_notifier_class.return_value = mock_notifier_instance
+
+ processor = CheckpointProcessor(store, scheduler)
+ observer = Mock()
+
+ processor.add_execution_observer(observer)
+
+ # Verify observer was added through the notifier's public method
+ mock_notifier_instance.add_observer.assert_called_once_with(observer)
+
+
+@patch(
+ "aws_durable_execution_sdk_python_testing.checkpoint.processor.CheckpointValidator"
+)
+@patch(
+ "aws_durable_execution_sdk_python_testing.checkpoint.processor.OperationTransformer"
+)
+def test_process_checkpoint_success(mock_transformer_class, mock_validator):
+ """Test successful checkpoint processing."""
+ # Setup mocks
+ store = Mock(spec=ExecutionStore)
+ scheduler = Mock(spec=Scheduler)
+ mock_transformer_instance = Mock()
+ mock_transformer_class.return_value = mock_transformer_instance
+
+ processor = CheckpointProcessor(store, scheduler)
+
+ # Mock execution
+ execution = Mock(spec=Execution)
+ execution.is_complete = False
+ execution.token_sequence = 1
+ execution.operations = []
+ execution.updates = []
+ execution.get_new_checkpoint_token.return_value = "new-token"
+ execution.get_navigable_operations.return_value = []
+
+ store.load.return_value = execution
+
+ # Mock transformer
+ mock_transformer_instance.process_updates.return_value = ([], [])
+
+ # Test data
+ checkpoint_token = "test-token" # noqa: S105
+ updates = [
+ OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ )
+ ]
+
+ # Mock token parsing
+ with patch.object(CheckpointToken, "from_str") as mock_from_str:
+ mock_token = Mock()
+ mock_token.execution_arn = "arn:test"
+ mock_token.token_sequence = 1
+ mock_from_str.return_value = mock_token
+
+ result = processor.process_checkpoint(checkpoint_token, updates, "client-token")
+
+ # Verify calls
+ store.load.assert_called_once_with("arn:test")
+ mock_validator.validate_input.assert_called_once_with(updates, execution)
+ mock_transformer_instance.process_updates.assert_called_once()
+ store.update.assert_called_once_with(execution)
+
+ # Verify result
+ assert isinstance(result, CheckpointOutput)
+ assert result.checkpoint_token == "new-token" # noqa: S105
+ assert isinstance(result.new_execution_state, CheckpointUpdatedExecutionState)
+
+
+@patch(
+ "aws_durable_execution_sdk_python_testing.checkpoint.processor.CheckpointValidator"
+)
+def test_process_checkpoint_invalid_token_complete_execution(mock_validator):
+ """Test checkpoint processing with complete execution."""
+ store = Mock(spec=ExecutionStore)
+ scheduler = Mock(spec=Scheduler)
+ processor = CheckpointProcessor(store, scheduler)
+
+ # Mock execution as complete
+ execution = Mock(spec=Execution)
+ execution.is_complete = True
+ execution.token_sequence = 1
+
+ store.load.return_value = execution
+
+ checkpoint_token = "test-token" # noqa: S105
+ updates = []
+
+ with patch.object(CheckpointToken, "from_str") as mock_from_str:
+ mock_token = Mock()
+ mock_token.execution_arn = "arn:test"
+ mock_token.token_sequence = 1
+ mock_from_str.return_value = mock_token
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid checkpoint token"
+ ):
+ processor.process_checkpoint(checkpoint_token, updates, "client-token")
+
+
+@patch(
+ "aws_durable_execution_sdk_python_testing.checkpoint.processor.CheckpointValidator"
+)
+def test_process_checkpoint_invalid_token_sequence(mock_validator):
+ """Test checkpoint processing with invalid token sequence."""
+ store = Mock(spec=ExecutionStore)
+ scheduler = Mock(spec=Scheduler)
+ processor = CheckpointProcessor(store, scheduler)
+
+ # Mock execution with different token sequence
+ execution = Mock(spec=Execution)
+ execution.is_complete = False
+ execution.token_sequence = 2
+
+ store.load.return_value = execution
+
+ checkpoint_token = "test-token" # noqa: S105
+ updates = []
+
+ with patch.object(CheckpointToken, "from_str") as mock_from_str:
+ mock_token = Mock()
+ mock_token.execution_arn = "arn:test"
+ mock_token.token_sequence = 1 # Different from execution
+ mock_from_str.return_value = mock_token
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid checkpoint token"
+ ):
+ processor.process_checkpoint(checkpoint_token, updates, "client-token")
+
+
+@patch(
+ "aws_durable_execution_sdk_python_testing.checkpoint.processor.CheckpointValidator"
+)
+@patch(
+ "aws_durable_execution_sdk_python_testing.checkpoint.processor.OperationTransformer"
+)
+def test_process_checkpoint_updates_execution_state(
+ mock_transformer_class, mock_validator
+):
+ """Test that checkpoint processing updates execution state correctly."""
+ store = Mock(spec=ExecutionStore)
+ scheduler = Mock(spec=Scheduler)
+ mock_transformer_instance = Mock()
+ mock_transformer_class.return_value = mock_transformer_instance
+
+ processor = CheckpointProcessor(store, scheduler)
+
+ # Mock execution
+ execution = Mock(spec=Execution)
+ execution.is_complete = False
+ execution.token_sequence = 1
+ execution.operations = []
+ execution.updates = []
+ execution.get_new_checkpoint_token.return_value = "new-token"
+ execution.get_navigable_operations.return_value = []
+
+ store.load.return_value = execution
+
+ # Mock transformer to return updated operations and updates
+ updated_operations = [Mock()]
+ all_updates = [Mock()]
+ mock_transformer_instance.process_updates.return_value = (
+ updated_operations,
+ all_updates,
+ )
+
+ checkpoint_token = "test-token" # noqa: S105
+ updates = [
+ OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ )
+ ]
+
+ with patch.object(CheckpointToken, "from_str") as mock_from_str:
+ mock_token = Mock()
+ mock_token.execution_arn = "arn:test"
+ mock_token.token_sequence = 1
+ mock_from_str.return_value = mock_token
+
+ processor.process_checkpoint(checkpoint_token, updates, "client-token")
+
+ # Verify execution state was updated
+ assert execution.operations == updated_operations
+ # Check that updates were extended (execution.updates is a real list)
+ assert len(execution.updates) == len(all_updates)
+
+
+def test_get_execution_state():
+ """Test getting execution state."""
+ store = Mock(spec=ExecutionStore)
+ scheduler = Mock(spec=Scheduler)
+ processor = CheckpointProcessor(store, scheduler)
+
+ # Mock execution
+ execution = Mock(spec=Execution)
+ navigable_ops = [Mock()]
+ execution.get_navigable_operations.return_value = navigable_ops
+
+ store.load.return_value = execution
+
+ checkpoint_token = "test-token" # noqa: S105
+
+ with patch.object(CheckpointToken, "from_str") as mock_from_str:
+ mock_token = Mock()
+ mock_token.execution_arn = "arn:test"
+ mock_from_str.return_value = mock_token
+
+ result = processor.get_execution_state(checkpoint_token, "next-marker", 500)
+
+ # Verify calls
+ store.load.assert_called_once_with("arn:test")
+ execution.get_navigable_operations.assert_called_once()
+
+ # Verify result
+ assert isinstance(result, StateOutput)
+ assert result.operations == navigable_ops
+ assert result.next_marker is None
+
+
+def test_get_execution_state_default_max_items():
+ """Test getting execution state with default max_items."""
+ store = Mock(spec=ExecutionStore)
+ scheduler = Mock(spec=Scheduler)
+ processor = CheckpointProcessor(store, scheduler)
+
+ execution = Mock(spec=Execution)
+ execution.get_navigable_operations.return_value = []
+ store.load.return_value = execution
+
+ checkpoint_token = "test-token" # noqa: S105
+
+ with patch.object(CheckpointToken, "from_str") as mock_from_str:
+ mock_token = Mock()
+ mock_token.execution_arn = "arn:test"
+ mock_from_str.return_value = mock_token
+
+ result = processor.get_execution_state(checkpoint_token, "next-marker")
+
+ assert isinstance(result, StateOutput)
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/__init__.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/__init__.py
new file mode 100644
index 0000000..78d8de9
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/__init__.py
@@ -0,0 +1 @@
+"""Test package"""
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/base_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/base_test.py
new file mode 100644
index 0000000..700fd96
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/base_test.py
@@ -0,0 +1,447 @@
+"""Tests for base operation processor."""
+
+import datetime
+from datetime import timedelta
+from unittest.mock import Mock
+
+import pytest
+from aws_durable_execution_sdk_python.lambda_service import (
+ CallbackDetails,
+ ChainedInvokeDetails,
+ ChainedInvokeOptions,
+ ContextDetails,
+ ErrorObject,
+ ExecutionDetails,
+ Operation,
+ OperationAction,
+ OperationStatus,
+ OperationType,
+ OperationUpdate,
+ StepDetails,
+ WaitDetails,
+ WaitOptions,
+ ContextOptions,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.processors.base import (
+ OperationProcessor,
+)
+
+
+def test_process_not_implemented():
+ processor = OperationProcessor()
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ )
+
+ try:
+ processor.process(update, None, Mock(), "test-arn")
+ pytest.fail("Expected NotImplementedError")
+ except NotImplementedError:
+ pass
+
+
+class MockProcessor(OperationProcessor):
+ """Mock processor for testing base functionality."""
+
+ def process(self, update, current_op, notifier, execution_arn):
+ return self._translate_update_to_operation(
+ update, current_op, OperationStatus.STARTED
+ )
+
+ def translate_update(self, update, current_op, status):
+ """Public method to access _translate_update_to_operation for testing."""
+ return self._translate_update_to_operation(update, current_op, status)
+
+ def get_end_time(self, current_op, status):
+ """Public method to access _get_end_time for testing."""
+ return self._get_end_time(current_op, status)
+
+ def create_execution_details(self, update):
+ """Public method to access _create_execution_details for testing."""
+ return self._create_execution_details(update)
+
+ def create_context_details(self, update):
+ """Public method to access _create_context_details for testing."""
+ return self._create_context_details(update)
+
+ def create_step_details(self, update, current_operation):
+ """Public method to access _create_step_details for testing."""
+ return self._create_step_details(update, current_operation)
+
+ def create_callback_details(self, update):
+ """Public method to access _create_callback_details for testing."""
+ return self._create_callback_details(update)
+
+ def create_invoke_details(self, update):
+ """Public method to access _create_invoke_details for testing."""
+ return self._create_invoke_details(update)
+
+ def create_wait_details(self, update, current_op):
+ """Public method to access _create_wait_details for testing."""
+ return self._create_wait_details(update, current_op)
+
+
+def test_get_end_time_with_existing_end_timestamp():
+ processor = MockProcessor()
+ end_time = datetime.datetime.now(tz=datetime.UTC)
+ current_op = Mock()
+ current_op.end_timestamp = end_time
+
+ result = processor.get_end_time(current_op, OperationStatus.STARTED)
+
+ assert result == end_time
+
+
+def test_get_end_time_with_terminal_status():
+ processor = MockProcessor()
+ current_op = Mock()
+ current_op.end_timestamp = None
+
+ result = processor.get_end_time(current_op, OperationStatus.SUCCEEDED)
+
+ assert result is not None
+ assert isinstance(result, datetime.datetime)
+
+
+def test_get_end_time_with_non_terminal_status():
+ processor = MockProcessor()
+ current_op = Mock()
+ current_op.end_timestamp = None
+
+ result = processor.get_end_time(current_op, OperationStatus.STARTED)
+
+ assert result is None
+
+
+def test_create_execution_details():
+ processor = MockProcessor()
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.START,
+ payload="test-payload",
+ )
+
+ result = processor.create_execution_details(update)
+
+ assert isinstance(result, ExecutionDetails)
+ assert result.input_payload == "test-payload"
+
+
+def test_create_execution_details_non_execution_type():
+ processor = MockProcessor()
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ payload="test-payload",
+ )
+
+ result = processor.create_execution_details(update)
+
+ assert result is None
+
+
+def test_create_context_details():
+ processor = MockProcessor()
+ error = ErrorObject.from_message("test error")
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.START,
+ payload="test-payload",
+ error=error,
+ )
+
+ result = processor.create_context_details(update)
+
+ assert isinstance(result, ContextDetails)
+ assert result.result == "test-payload"
+ assert result.error == error
+
+
+def test_create_context_details_non_context_type():
+ processor = MockProcessor()
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ payload="test-payload",
+ )
+
+ result = processor.create_context_details(update)
+
+ assert result is None
+
+
+def test_create_step_details():
+ processor = MockProcessor()
+ error = ErrorObject.from_message("test error")
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ payload="test-payload",
+ error=error,
+ )
+
+ current_op = Mock()
+ current_op.step_details = Mock()
+ current_op.step_details.attempt = Mock()
+
+ result = processor.create_step_details(update, current_op)
+
+ assert isinstance(result, StepDetails)
+ assert result.result == "test-payload"
+ assert result.error == error
+
+
+def test_create_context_details_with_replay_children():
+ processor = MockProcessor()
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.SUCCEED,
+ payload="test-payload",
+ context_options=ContextOptions(replay_children=True),
+ )
+
+ result = processor.create_context_details(update)
+
+ assert isinstance(result, ContextDetails)
+ assert result.result == "test-payload"
+ assert result.replay_children == True
+
+
+def test_create_step_details_non_step_type():
+ processor = MockProcessor()
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.START,
+ payload="test-payload",
+ )
+
+ current_op = Mock()
+ current_op.step_details = Mock()
+ current_op.step_details.attempt = Mock()
+
+ result = processor.create_step_details(update, current_op)
+
+ assert result is None
+
+
+def test_create_step_details_without_current_operation():
+ processor = MockProcessor()
+ error = ErrorObject.from_message("test error")
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ payload="test-payload",
+ error=error,
+ )
+
+ result = processor.create_step_details(update, None)
+
+ assert isinstance(result, StepDetails)
+ assert result.result == "test-payload"
+ assert result.error == error
+ assert result.attempt == 0
+
+
+def test_create_callback_details():
+ processor = MockProcessor()
+ error = ErrorObject.from_message("test error")
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CALLBACK,
+ action=OperationAction.START,
+ payload="test-payload",
+ error=error,
+ )
+
+ result = processor.create_callback_details(update)
+
+ assert isinstance(result, CallbackDetails)
+ assert result.callback_id == "placeholder"
+ assert result.result == "test-payload"
+ assert result.error == error
+
+
+def test_create_callback_details_non_callback_type():
+ processor = MockProcessor()
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ payload="test-payload",
+ )
+
+ result = processor.create_callback_details(update)
+
+ assert result is None
+
+
+def test_create_invoke_details():
+ processor = MockProcessor()
+ error = ErrorObject.from_message("test error")
+ invoke_options = ChainedInvokeOptions(function_name="test-function")
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CHAINED_INVOKE,
+ action=OperationAction.START,
+ payload="test-payload",
+ error=error,
+ chained_invoke_options=invoke_options,
+ )
+
+ result = processor.create_invoke_details(update)
+
+ assert isinstance(result, ChainedInvokeDetails)
+ assert result.result == "test-payload"
+ assert result.error == error
+
+
+def test_create_invoke_details_non_invoke_type():
+ processor = MockProcessor()
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ payload="test-payload",
+ )
+
+ result = processor.create_invoke_details(update)
+
+ assert result is None
+
+
+def test_create_invoke_details_no_options():
+ processor = MockProcessor()
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CHAINED_INVOKE,
+ action=OperationAction.START,
+ payload="test-payload",
+ )
+
+ result = processor.create_invoke_details(update)
+
+ assert result is None
+
+
+def test_create_wait_details_with_current_operation():
+ processor = MockProcessor()
+ scheduled_end_timestamp = datetime.datetime.now(tz=datetime.UTC)
+ current_op = Mock()
+ current_op.wait_details = WaitDetails(
+ scheduled_end_timestamp=scheduled_end_timestamp
+ )
+
+ wait_options = WaitOptions(wait_seconds=30)
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.START,
+ wait_options=wait_options,
+ )
+
+ result = processor.create_wait_details(update, current_op)
+
+ assert isinstance(result, WaitDetails)
+ assert result.scheduled_end_timestamp == scheduled_end_timestamp
+
+
+def test_create_wait_details_without_current_operation():
+ processor = MockProcessor()
+ wait_options = WaitOptions(wait_seconds=30)
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.START,
+ wait_options=wait_options,
+ )
+
+ result = processor.create_wait_details(update, None)
+
+ assert isinstance(result, WaitDetails)
+ assert result.scheduled_end_timestamp > datetime.datetime.now(tz=datetime.UTC)
+
+
+def test_create_wait_details_non_wait_type():
+ processor = MockProcessor()
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ )
+
+ result = processor.create_wait_details(update, None)
+
+ assert result is None
+
+
+def test_translate_update_to_operation_with_current_operation():
+ processor = MockProcessor()
+ start_time = datetime.datetime.now(tz=datetime.UTC) - timedelta(minutes=5)
+ current_op = Mock()
+ current_op.start_timestamp = start_time
+
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ parent_id="parent-id",
+ name="test-operation",
+ sub_type="test-subtype",
+ )
+
+ result = processor.translate_update(update, current_op, OperationStatus.STARTED)
+
+ assert isinstance(result, Operation)
+ assert result.operation_id == "test-id"
+ assert result.parent_id == "parent-id"
+ assert result.name == "test-operation"
+ assert result.start_timestamp == start_time
+ assert result.operation_type == OperationType.STEP
+ assert result.status == OperationStatus.STARTED
+ assert result.sub_type == "test-subtype"
+
+
+def test_translate_update_to_operation_without_current_operation():
+ processor = MockProcessor()
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ parent_id="parent-id",
+ name="test-operation",
+ )
+
+ result = processor.translate_update(update, None, OperationStatus.STARTED)
+
+ assert isinstance(result, Operation)
+ assert result.operation_id == "test-id"
+ assert result.parent_id == "parent-id"
+ assert result.name == "test-operation"
+ assert result.start_timestamp is not None
+ assert result.operation_type == OperationType.STEP
+ assert result.status == OperationStatus.STARTED
+
+
+def test_translate_update_to_operation_with_terminal_status():
+ processor = MockProcessor()
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ )
+
+ result = processor.translate_update(update, None, OperationStatus.SUCCEEDED)
+
+ assert result.end_timestamp is not None
+ assert result.status == OperationStatus.SUCCEEDED
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/callback_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/callback_test.py
new file mode 100644
index 0000000..24bc129
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/callback_test.py
@@ -0,0 +1,259 @@
+"""Tests for callback operation processor."""
+
+from unittest.mock import Mock
+
+import pytest
+from aws_durable_execution_sdk_python.lambda_service import (
+ Operation,
+ OperationAction,
+ OperationStatus,
+ OperationType,
+ OperationUpdate,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.processors.callback import (
+ CallbackProcessor,
+)
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier
+
+
+class MockNotifier(ExecutionNotifier):
+ """Mock notifier for testing."""
+
+ def __init__(self):
+ super().__init__()
+ self.completed_calls = []
+ self.failed_calls = []
+ self.wait_timer_calls = []
+ self.step_retry_calls = []
+
+ def notify_completed(self, execution_arn, result=None):
+ self.completed_calls.append((execution_arn, result))
+
+ def notify_failed(self, execution_arn, error):
+ self.failed_calls.append((execution_arn, error))
+
+ def notify_wait_timer_scheduled(self, execution_arn, operation_id, delay):
+ self.wait_timer_calls.append((execution_arn, operation_id, delay))
+
+ def notify_step_retry_scheduled(self, execution_arn, operation_id, delay):
+ self.step_retry_calls.append((execution_arn, operation_id, delay))
+
+
+def test_process_start_action():
+ processor = CallbackProcessor()
+ notifier = MockNotifier()
+
+ update = OperationUpdate(
+ operation_id="callback-123",
+ operation_type=OperationType.CALLBACK,
+ action=OperationAction.START,
+ name="test-callback",
+ )
+
+ result = processor.process(
+ update, None, notifier, "arn:aws:states:us-east-1:123456789012:execution:test"
+ )
+
+ assert isinstance(result, Operation)
+ assert result.operation_id == "callback-123"
+ assert result.operation_type == OperationType.CALLBACK
+ assert result.status == OperationStatus.STARTED
+ assert result.name == "test-callback"
+ assert result.callback_details is not None
+
+
+def test_process_start_action_with_current_operation():
+ processor = CallbackProcessor()
+ notifier = MockNotifier()
+
+ current_op = Mock()
+ current_op.start_timestamp = Mock()
+
+ update = OperationUpdate(
+ operation_id="callback-123",
+ operation_type=OperationType.CALLBACK,
+ action=OperationAction.START,
+ name="test-callback",
+ )
+
+ result = processor.process(
+ update,
+ current_op,
+ notifier,
+ "arn:aws:states:us-east-1:123456789012:execution:test",
+ )
+
+ assert isinstance(result, Operation)
+ assert result.operation_id == "callback-123"
+ assert result.status == OperationStatus.STARTED
+ assert result.start_timestamp == current_op.start_timestamp
+
+
+def test_process_invalid_action():
+ processor = CallbackProcessor()
+ notifier = MockNotifier()
+
+ update = OperationUpdate(
+ operation_id="callback-123",
+ operation_type=OperationType.CALLBACK,
+ action=OperationAction.SUCCEED,
+ name="test-callback",
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid action for CALLBACK operation"
+ ):
+ processor.process(
+ update,
+ None,
+ notifier,
+ "arn:aws:states:us-east-1:123456789012:execution:test",
+ )
+
+
+def test_process_fail_action():
+ processor = CallbackProcessor()
+ notifier = MockNotifier()
+
+ update = OperationUpdate(
+ operation_id="callback-123",
+ operation_type=OperationType.CALLBACK,
+ action=OperationAction.FAIL,
+ name="test-callback",
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid action for CALLBACK operation"
+ ):
+ processor.process(
+ update,
+ None,
+ notifier,
+ "arn:aws:states:us-east-1:123456789012:execution:test",
+ )
+
+
+def test_process_cancel_action():
+ processor = CallbackProcessor()
+ notifier = MockNotifier()
+
+ update = OperationUpdate(
+ operation_id="callback-123",
+ operation_type=OperationType.CALLBACK,
+ action=OperationAction.CANCEL,
+ name="test-callback",
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid action for CALLBACK operation"
+ ):
+ processor.process(
+ update,
+ None,
+ notifier,
+ "arn:aws:states:us-east-1:123456789012:execution:test",
+ )
+
+
+def test_process_retry_action():
+ processor = CallbackProcessor()
+ notifier = MockNotifier()
+
+ update = OperationUpdate(
+ operation_id="callback-123",
+ operation_type=OperationType.CALLBACK,
+ action=OperationAction.RETRY,
+ name="test-callback",
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid action for CALLBACK operation"
+ ):
+ processor.process(
+ update,
+ None,
+ notifier,
+ "arn:aws:states:us-east-1:123456789012:execution:test",
+ )
+
+
+def test_process_with_payload():
+ processor = CallbackProcessor()
+ notifier = MockNotifier()
+
+ update = OperationUpdate(
+ operation_id="callback-123",
+ operation_type=OperationType.CALLBACK,
+ action=OperationAction.START,
+ name="test-callback",
+ payload="test-payload",
+ )
+
+ result = processor.process(
+ update, None, notifier, "arn:aws:states:us-east-1:123456789012:execution:test"
+ )
+
+ assert result.callback_details.result == "test-payload"
+
+
+def test_process_with_parent_id():
+ processor = CallbackProcessor()
+ notifier = MockNotifier()
+
+ update = OperationUpdate(
+ operation_id="callback-123",
+ operation_type=OperationType.CALLBACK,
+ action=OperationAction.START,
+ name="test-callback",
+ parent_id="parent-456",
+ )
+
+ result = processor.process(
+ update, None, notifier, "arn:aws:states:us-east-1:123456789012:execution:test"
+ )
+
+ assert result.parent_id == "parent-456"
+
+
+def test_process_with_sub_type():
+ processor = CallbackProcessor()
+ notifier = MockNotifier()
+
+ update = OperationUpdate(
+ operation_id="callback-123",
+ operation_type=OperationType.CALLBACK,
+ action=OperationAction.START,
+ name="test-callback",
+ sub_type="activity",
+ )
+
+ result = processor.process(
+ update, None, notifier, "arn:aws:states:us-east-1:123456789012:execution:test"
+ )
+
+ assert result.sub_type == "activity"
+
+
+def test_notifier_not_called_for_start():
+ processor = CallbackProcessor()
+ notifier = MockNotifier()
+
+ update = OperationUpdate(
+ operation_id="callback-123",
+ operation_type=OperationType.CALLBACK,
+ action=OperationAction.START,
+ name="test-callback",
+ )
+
+ processor.process(
+ update, None, notifier, "arn:aws:states:us-east-1:123456789012:execution:test"
+ )
+
+ assert len(notifier.completed_calls) == 0
+ assert len(notifier.failed_calls) == 0
+ assert len(notifier.wait_timer_calls) == 0
+ assert len(notifier.step_retry_calls) == 0
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/context_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/context_test.py
new file mode 100644
index 0000000..a070bc1
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/context_test.py
@@ -0,0 +1,379 @@
+"""Tests for context operation processor."""
+
+from datetime import UTC, datetime
+from unittest.mock import Mock
+
+import pytest
+from aws_durable_execution_sdk_python.lambda_service import (
+ ErrorObject,
+ Operation,
+ OperationAction,
+ OperationStatus,
+ OperationType,
+ OperationUpdate,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.processors.context import (
+ ContextProcessor,
+)
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier
+
+
+class MockNotifier(ExecutionNotifier):
+ """Mock notifier for testing."""
+
+ def __init__(self):
+ super().__init__()
+ self.completed_calls = []
+ self.failed_calls = []
+ self.wait_timer_calls = []
+ self.step_retry_calls = []
+
+ def notify_completed(self, execution_arn, result=None):
+ self.completed_calls.append((execution_arn, result))
+
+ def notify_failed(self, execution_arn, error):
+ self.failed_calls.append((execution_arn, error))
+
+ def notify_wait_timer_scheduled(self, execution_arn, operation_id, delay):
+ self.wait_timer_calls.append((execution_arn, operation_id, delay))
+
+ def notify_step_retry_scheduled(self, execution_arn, operation_id, delay):
+ self.step_retry_calls.append((execution_arn, operation_id, delay))
+
+
+def test_process_start_action():
+ processor = ContextProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="context-123",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.START,
+ name="test-context",
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert isinstance(result, Operation)
+ assert result.operation_id == "context-123"
+ assert result.operation_type == OperationType.CONTEXT
+ assert result.status == OperationStatus.STARTED
+ assert result.name == "test-context"
+ assert result.context_details is not None
+
+
+def test_process_start_action_with_current_operation():
+ processor = ContextProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ current_op = Mock()
+ current_op.start_timestamp = datetime.now(UTC)
+
+ update = OperationUpdate(
+ operation_id="context-123",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.START,
+ name="test-context",
+ )
+
+ result = processor.process(update, current_op, notifier, execution_arn)
+
+ assert result.start_timestamp == current_op.start_timestamp
+ assert result.status == OperationStatus.STARTED
+
+
+def test_process_succeed_action():
+ processor = ContextProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="context-123",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.SUCCEED,
+ name="test-context",
+ payload="success-result",
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert isinstance(result, Operation)
+ assert result.operation_id == "context-123"
+ assert result.status == OperationStatus.SUCCEEDED
+ assert result.context_details.result == "success-result"
+ assert result.context_details.error is None
+
+
+def test_process_succeed_action_with_current_operation():
+ processor = ContextProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ current_op = Mock()
+ current_op.start_timestamp = datetime.now(UTC)
+
+ update = OperationUpdate(
+ operation_id="context-123",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.SUCCEED,
+ name="test-context",
+ payload="success-result",
+ )
+
+ result = processor.process(update, current_op, notifier, execution_arn)
+
+ assert result.start_timestamp == current_op.start_timestamp
+ assert result.status == OperationStatus.SUCCEEDED
+
+
+def test_process_fail_action():
+ processor = ContextProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ error = ErrorObject.from_message("context failed")
+ update = OperationUpdate(
+ operation_id="context-123",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.FAIL,
+ name="test-context",
+ error=error,
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert isinstance(result, Operation)
+ assert result.operation_id == "context-123"
+ assert result.status == OperationStatus.FAILED
+ assert result.context_details.error == error
+ assert result.context_details.result is None
+
+
+def test_process_fail_action_with_current_operation():
+ processor = ContextProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ current_op = Mock()
+ current_op.start_timestamp = datetime.now(UTC)
+
+ error = ErrorObject.from_message("context failed")
+ update = OperationUpdate(
+ operation_id="context-123",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.FAIL,
+ name="test-context",
+ error=error,
+ )
+
+ result = processor.process(update, current_op, notifier, execution_arn)
+
+ assert result.start_timestamp == current_op.start_timestamp
+ assert result.status == OperationStatus.FAILED
+
+
+def test_process_fail_action_with_payload_and_error():
+ processor = ContextProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ error = ErrorObject.from_message("context failed")
+ update = OperationUpdate(
+ operation_id="context-123",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.FAIL,
+ name="test-context",
+ payload="partial-result",
+ error=error,
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert result.context_details.result == "partial-result"
+ assert result.context_details.error == error
+
+
+def test_process_invalid_action():
+ processor = ContextProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="context-123",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.RETRY,
+ name="test-context",
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid action for CONTEXT operation"
+ ):
+ processor.process(update, None, notifier, execution_arn)
+
+
+def test_process_cancel_action():
+ processor = ContextProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="context-123",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.CANCEL,
+ name="test-context",
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid action for CONTEXT operation"
+ ):
+ processor.process(update, None, notifier, execution_arn)
+
+
+def test_process_with_parent_id():
+ processor = ContextProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="context-123",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.START,
+ name="test-context",
+ parent_id="parent-456",
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert result.parent_id == "parent-456"
+
+
+def test_process_with_sub_type():
+ processor = ContextProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="context-123",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.START,
+ name="test-context",
+ sub_type="parallel",
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert result.sub_type == "parallel"
+
+
+def test_process_start_without_payload():
+ processor = ContextProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="context-123",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.START,
+ name="test-context",
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert result.context_details.result is None
+ assert result.context_details.error is None
+
+
+def test_process_succeed_without_payload():
+ processor = ContextProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="context-123",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.SUCCEED,
+ name="test-context",
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert result.context_details.result is None
+ assert result.context_details.error is None
+
+
+def test_process_fail_without_error():
+ processor = ContextProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="context-123",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.FAIL,
+ name="test-context",
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert result.context_details.result is None
+ assert result.context_details.error is None
+
+
+def test_no_notifier_calls():
+ processor = ContextProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="context-123",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.START,
+ name="test-context",
+ )
+
+ processor.process(update, None, notifier, execution_arn)
+
+ assert len(notifier.completed_calls) == 0
+ assert len(notifier.failed_calls) == 0
+ assert len(notifier.wait_timer_calls) == 0
+ assert len(notifier.step_retry_calls) == 0
+
+
+def test_end_timestamp_set_for_terminal_states():
+ processor = ContextProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="context-123",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.SUCCEED,
+ name="test-context",
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert result.end_timestamp is not None
+
+
+def test_end_timestamp_not_set_for_non_terminal_states():
+ processor = ContextProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="context-123",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.START,
+ name="test-context",
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert result.end_timestamp is None
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/execution_processor_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/execution_processor_test.py
new file mode 100644
index 0000000..37c91ea
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/execution_processor_test.py
@@ -0,0 +1,242 @@
+"""Tests for execution operation processor."""
+
+from unittest.mock import Mock
+
+from aws_durable_execution_sdk_python.lambda_service import (
+ ErrorObject,
+ OperationAction,
+ OperationType,
+ OperationUpdate,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.processors.execution import (
+ ExecutionProcessor,
+)
+from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier
+
+
+class MockNotifier(ExecutionNotifier):
+ """Mock notifier for testing."""
+
+ def __init__(self):
+ super().__init__()
+ self.completed_calls = []
+ self.failed_calls = []
+ self.wait_timer_calls = []
+ self.step_retry_calls = []
+
+ def notify_completed(self, execution_arn, result=None):
+ self.completed_calls.append((execution_arn, result))
+
+ def notify_failed(self, execution_arn, error):
+ self.failed_calls.append((execution_arn, error))
+
+ def notify_wait_timer_scheduled(self, execution_arn, operation_id, delay):
+ self.wait_timer_calls.append((execution_arn, operation_id, delay))
+
+ def notify_step_retry_scheduled(self, execution_arn, operation_id, delay):
+ self.step_retry_calls.append((execution_arn, operation_id, delay))
+
+
+def test_process_succeed_action():
+ processor = ExecutionProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="execution-123",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.SUCCEED,
+ payload="success-result",
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert result is None
+ assert len(notifier.completed_calls) == 1
+ assert notifier.completed_calls[0] == (execution_arn, "success-result")
+ assert len(notifier.failed_calls) == 0
+
+
+def test_process_succeed_action_with_current_operation():
+ processor = ExecutionProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ current_op = Mock()
+
+ update = OperationUpdate(
+ operation_id="execution-123",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.SUCCEED,
+ payload="success-result",
+ )
+
+ result = processor.process(update, current_op, notifier, execution_arn)
+
+ assert result is None
+ assert len(notifier.completed_calls) == 1
+ assert notifier.completed_calls[0] == (execution_arn, "success-result")
+
+
+def test_process_succeed_action_without_payload():
+ processor = ExecutionProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="execution-123",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.SUCCEED,
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert result is None
+ assert len(notifier.completed_calls) == 1
+ assert notifier.completed_calls[0] == (execution_arn, None)
+
+
+def test_process_fail_action_with_error():
+ processor = ExecutionProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ error = ErrorObject.from_message("execution failed")
+ update = OperationUpdate(
+ operation_id="execution-123",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.FAIL,
+ error=error,
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert result is None
+ assert len(notifier.failed_calls) == 1
+ assert notifier.failed_calls[0] == (execution_arn, error)
+ assert len(notifier.completed_calls) == 0
+
+
+def test_process_fail_action_without_error():
+ processor = ExecutionProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="execution-123",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.FAIL,
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert result is None
+ assert len(notifier.failed_calls) == 1
+ execution_arn_arg, error_arg = notifier.failed_calls[0]
+ assert execution_arn_arg == execution_arn
+ assert isinstance(error_arg, ErrorObject)
+ assert (
+ "There is no error details but EXECUTION checkpoint action is not SUCCEED"
+ in str(error_arg)
+ )
+
+
+def test_process_start_action():
+ processor = ExecutionProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="execution-123",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.START,
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert result is None
+ assert len(notifier.failed_calls) == 1
+ execution_arn_arg, error_arg = notifier.failed_calls[0]
+ assert execution_arn_arg == execution_arn
+ assert isinstance(error_arg, ErrorObject)
+
+
+def test_process_retry_action():
+ processor = ExecutionProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="execution-123",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.RETRY,
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert result is None
+ assert len(notifier.failed_calls) == 1
+ execution_arn_arg, error_arg = notifier.failed_calls[0]
+ assert execution_arn_arg == execution_arn
+ assert isinstance(error_arg, ErrorObject)
+
+
+def test_process_cancel_action():
+ processor = ExecutionProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="execution-123",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.CANCEL,
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert result is None
+ assert len(notifier.failed_calls) == 1
+ execution_arn_arg, error_arg = notifier.failed_calls[0]
+ assert execution_arn_arg == execution_arn
+ assert isinstance(error_arg, ErrorObject)
+
+
+def test_process_with_current_operation_and_error():
+ processor = ExecutionProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ current_op = Mock()
+ error = ErrorObject.from_message("custom error")
+
+ update = OperationUpdate(
+ operation_id="execution-123",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.FAIL,
+ error=error,
+ )
+
+ result = processor.process(update, current_op, notifier, execution_arn)
+
+ assert result is None
+ assert len(notifier.failed_calls) == 1
+ assert notifier.failed_calls[0] == (execution_arn, error)
+
+
+def test_no_wait_timer_or_step_retry_calls():
+ processor = ExecutionProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="execution-123",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.SUCCEED,
+ payload="result",
+ )
+
+ processor.process(update, None, notifier, execution_arn)
+
+ assert len(notifier.wait_timer_calls) == 0
+ assert len(notifier.step_retry_calls) == 0
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/step_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/step_test.py
new file mode 100644
index 0000000..6ba5cc6
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/step_test.py
@@ -0,0 +1,421 @@
+"""Tests for step operation processor."""
+
+from datetime import UTC, datetime
+from unittest.mock import Mock
+
+import pytest
+from aws_durable_execution_sdk_python.lambda_service import (
+ ErrorObject,
+ Operation,
+ OperationAction,
+ OperationStatus,
+ OperationType,
+ OperationUpdate,
+ StepDetails,
+ StepOptions,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.processors.step import (
+ StepProcessor,
+)
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier
+
+
+class MockNotifier(ExecutionNotifier):
+ """Mock notifier for testing."""
+
+ def __init__(self):
+ super().__init__()
+ self.completed_calls = []
+ self.failed_calls = []
+ self.wait_timer_calls = []
+ self.step_retry_calls = []
+
+ def notify_completed(self, execution_arn, result=None):
+ self.completed_calls.append((execution_arn, result))
+
+ def notify_failed(self, execution_arn, error):
+ self.failed_calls.append((execution_arn, error))
+
+ def notify_wait_timer_scheduled(self, execution_arn, operation_id, delay):
+ self.wait_timer_calls.append((execution_arn, operation_id, delay))
+
+ def notify_step_retry_scheduled(self, execution_arn, operation_id, delay):
+ self.step_retry_calls.append((execution_arn, operation_id, delay))
+
+
+def test_process_start_action():
+ processor = StepProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="step-123",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ name="test-step",
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert isinstance(result, Operation)
+ assert result.operation_id == "step-123"
+ assert result.operation_type == OperationType.STEP
+ assert result.status == OperationStatus.STARTED
+ assert result.name == "test-step"
+ assert result.step_details is not None
+
+
+def test_process_start_action_with_current_operation():
+ processor = StepProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ current_op = Mock()
+ current_op.start_timestamp = datetime.now(UTC)
+
+ update = OperationUpdate(
+ operation_id="step-123",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ name="test-step",
+ )
+
+ result = processor.process(update, current_op, notifier, execution_arn)
+
+ assert result.start_timestamp == current_op.start_timestamp
+
+
+def test_process_retry_action():
+ processor = StepProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ current_op = Mock()
+ current_op.start_timestamp = datetime.now(UTC)
+ current_op.step_details = StepDetails(attempt=1, result="previous-result")
+ current_op.execution_details = None
+ current_op.context_details = None
+ current_op.wait_details = None
+ current_op.callback_details = None
+ current_op.chained_invoke_details = None
+
+ step_options = StepOptions(next_attempt_delay_seconds=30)
+ update = OperationUpdate(
+ operation_id="step-123",
+ operation_type=OperationType.STEP,
+ action=OperationAction.RETRY,
+ name="test-step",
+ step_options=step_options,
+ )
+
+ result = processor.process(update, current_op, notifier, execution_arn)
+
+ assert isinstance(result, Operation)
+ assert result.operation_id == "step-123"
+ assert result.status == OperationStatus.PENDING
+ assert result.step_details.attempt == 2
+ assert result.step_details.result == "previous-result"
+ assert result.step_details.next_attempt_timestamp is not None
+
+ assert len(notifier.step_retry_calls) == 1
+ assert notifier.step_retry_calls[0] == (execution_arn, "step-123", 30)
+
+
+def test_process_retry_action_without_step_options():
+ processor = StepProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ current_op = Mock()
+ current_op.start_timestamp = datetime.now(UTC)
+ current_op.step_details = StepDetails(attempt=0)
+ current_op.execution_details = None
+ current_op.context_details = None
+ current_op.wait_details = None
+ current_op.callback_details = None
+ current_op.chained_invoke_details = None
+
+ update = OperationUpdate(
+ operation_id="step-123",
+ operation_type=OperationType.STEP,
+ action=OperationAction.RETRY,
+ name="test-step",
+ )
+
+ result = processor.process(update, current_op, notifier, execution_arn)
+
+ assert result.step_details.attempt == 1
+ assert len(notifier.step_retry_calls) == 1
+ assert notifier.step_retry_calls[0] == (execution_arn, "step-123", 0)
+
+
+def test_process_retry_action_without_current_operation():
+ processor = StepProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ step_options = StepOptions(next_attempt_delay_seconds=15)
+ update = OperationUpdate(
+ operation_id="step-123",
+ operation_type=OperationType.STEP,
+ action=OperationAction.RETRY,
+ name="test-step",
+ step_options=step_options,
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert result.step_details.attempt == 1
+ assert result.step_details.result is None
+ assert result.step_details.error is None
+
+
+def test_process_retry_action_without_current_step_details():
+ processor = StepProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ current_op = Mock()
+ current_op.start_timestamp = datetime.now(UTC)
+ current_op.step_details = None
+ current_op.execution_details = None
+ current_op.context_details = None
+ current_op.wait_details = None
+ current_op.callback_details = None
+ current_op.chained_invoke_details = None
+
+ step_options = StepOptions(next_attempt_delay_seconds=45)
+ update = OperationUpdate(
+ operation_id="step-123",
+ operation_type=OperationType.STEP,
+ action=OperationAction.RETRY,
+ name="test-step",
+ step_options=step_options,
+ )
+
+ result = processor.process(update, current_op, notifier, execution_arn)
+
+ assert result.step_details.attempt == 1
+
+
+def test_process_succeed_action():
+ processor = StepProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="step-123",
+ operation_type=OperationType.STEP,
+ action=OperationAction.SUCCEED,
+ name="test-step",
+ payload="success-result",
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert isinstance(result, Operation)
+ assert result.operation_id == "step-123"
+ assert result.status == OperationStatus.SUCCEEDED
+ assert result.step_details.result == "success-result"
+
+
+def test_process_succeed_action_with_current_operation():
+ processor = StepProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ current_op = Mock()
+ current_op.start_timestamp = datetime.now(UTC)
+ current_op.step_details = StepDetails()
+
+ update = OperationUpdate(
+ operation_id="step-123",
+ operation_type=OperationType.STEP,
+ action=OperationAction.SUCCEED,
+ name="test-step",
+ payload="success-result",
+ )
+
+ result = processor.process(update, current_op, notifier, execution_arn)
+
+ assert result.start_timestamp == current_op.start_timestamp
+ assert result.status == OperationStatus.SUCCEEDED
+ assert result.step_details.attempt == 1
+
+
+def test_process_fail_action():
+ processor = StepProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ error = ErrorObject.from_message("step failed")
+ update = OperationUpdate(
+ operation_id="step-123",
+ operation_type=OperationType.STEP,
+ action=OperationAction.FAIL,
+ name="test-step",
+ error=error,
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert isinstance(result, Operation)
+ assert result.operation_id == "step-123"
+ assert result.status == OperationStatus.FAILED
+ assert result.step_details.error == error
+
+
+def test_process_fail_action_with_current_operation():
+ processor = StepProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ current_op = Mock()
+ current_op.start_timestamp = datetime.now(UTC)
+ current_op.step_details = StepDetails()
+
+ error = ErrorObject.from_message("step failed")
+ update = OperationUpdate(
+ operation_id="step-123",
+ operation_type=OperationType.STEP,
+ action=OperationAction.FAIL,
+ name="test-step",
+ error=error,
+ )
+
+ result = processor.process(update, current_op, notifier, execution_arn)
+
+ assert result.start_timestamp == current_op.start_timestamp
+ assert result.status == OperationStatus.FAILED
+ assert result.step_details.attempt == 1
+
+
+def test_process_invalid_action():
+ processor = StepProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="step-123",
+ operation_type=OperationType.STEP,
+ action=OperationAction.CANCEL,
+ name="test-step",
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid action for STEP operation"
+ ):
+ processor.process(update, None, notifier, execution_arn)
+
+
+def test_process_with_parent_id():
+ processor = StepProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="step-123",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ name="test-step",
+ parent_id="parent-456",
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert result.parent_id == "parent-456"
+
+
+def test_process_with_sub_type():
+ processor = StepProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="step-123",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ name="test-step",
+ sub_type="lambda",
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert result.sub_type == "lambda"
+
+
+def test_retry_preserves_current_operation_details():
+ processor = StepProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ current_op = Mock()
+ current_op.start_timestamp = datetime.now(UTC)
+ current_op.step_details = StepDetails(
+ attempt=2, result="old-result", error=ErrorObject.from_message("old-error")
+ )
+ current_op.execution_details = Mock()
+ current_op.context_details = Mock()
+ current_op.wait_details = Mock()
+ current_op.callback_details = Mock()
+ current_op.chained_invoke_details = Mock()
+
+ step_options = StepOptions(next_attempt_delay_seconds=60)
+ update = OperationUpdate(
+ operation_id="step-123",
+ operation_type=OperationType.STEP,
+ action=OperationAction.RETRY,
+ name="test-step",
+ step_options=step_options,
+ )
+
+ result = processor.process(update, current_op, notifier, execution_arn)
+
+ assert result.step_details.attempt == 3
+ assert result.step_details.result == "old-result"
+ assert result.step_details.error == current_op.step_details.error
+ assert result.execution_details == current_op.execution_details
+ assert result.context_details == current_op.context_details
+ assert result.wait_details == current_op.wait_details
+ assert result.callback_details == current_op.callback_details
+ assert result.chained_invoke_details == current_op.chained_invoke_details
+
+
+def test_no_completed_or_failed_calls_for_non_execution_actions():
+ processor = StepProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="step-123",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ name="test-step",
+ )
+
+ processor.process(update, None, notifier, execution_arn)
+
+ assert len(notifier.completed_calls) == 0
+ assert len(notifier.failed_calls) == 0
+ assert len(notifier.wait_timer_calls) == 0
+
+
+def test_no_step_retry_calls_for_non_retry_actions():
+ processor = StepProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="step-123",
+ operation_type=OperationType.STEP,
+ action=OperationAction.SUCCEED,
+ name="test-step",
+ )
+
+ processor.process(update, None, notifier, execution_arn)
+
+ assert len(notifier.step_retry_calls) == 0
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/wait_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/wait_test.py
new file mode 100644
index 0000000..42c29aa
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/wait_test.py
@@ -0,0 +1,313 @@
+"""Tests for wait operation processor."""
+
+from datetime import UTC, datetime
+from unittest.mock import Mock
+
+import pytest
+from aws_durable_execution_sdk_python.lambda_service import (
+ Operation,
+ OperationAction,
+ OperationStatus,
+ OperationType,
+ OperationUpdate,
+ WaitOptions,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.processors.wait import (
+ WaitProcessor,
+)
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier
+
+
+class MockNotifier(ExecutionNotifier):
+ """Mock notifier for testing."""
+
+ def __init__(self):
+ super().__init__()
+ self.completed_calls = []
+ self.failed_calls = []
+ self.wait_timer_calls = []
+ self.step_retry_calls = []
+
+ def notify_completed(self, execution_arn, result=None):
+ self.completed_calls.append((execution_arn, result))
+
+ def notify_failed(self, execution_arn, error):
+ self.failed_calls.append((execution_arn, error))
+
+ def notify_wait_timer_scheduled(self, execution_arn, operation_id, delay):
+ self.wait_timer_calls.append((execution_arn, operation_id, delay))
+
+ def notify_step_retry_scheduled(self, execution_arn, operation_id, delay):
+ self.step_retry_calls.append((execution_arn, operation_id, delay))
+
+
+def test_process_start_action():
+ processor = WaitProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ wait_options = WaitOptions(wait_seconds=30)
+ update = OperationUpdate(
+ operation_id="wait-123",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.START,
+ name="test-wait",
+ wait_options=wait_options,
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert isinstance(result, Operation)
+ assert result.operation_id == "wait-123"
+ assert result.operation_type == OperationType.WAIT
+ assert result.status == OperationStatus.STARTED
+ assert result.name == "test-wait"
+ assert result.wait_details is not None
+ assert result.wait_details.scheduled_end_timestamp > datetime.now(UTC)
+
+ assert len(notifier.wait_timer_calls) == 1
+ assert notifier.wait_timer_calls[0] == (execution_arn, "wait-123", 30)
+
+
+def test_process_start_action_without_wait_options():
+ processor = WaitProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="wait-123",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.START,
+ name="test-wait",
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert isinstance(result, Operation)
+ assert result.wait_details is not None
+
+ assert len(notifier.wait_timer_calls) == 1
+ assert notifier.wait_timer_calls[0] == (execution_arn, "wait-123", 0)
+
+
+def test_process_start_action_with_zero_seconds():
+ processor = WaitProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ wait_options = WaitOptions(wait_seconds=0)
+ update = OperationUpdate(
+ operation_id="wait-123",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.START,
+ name="test-wait",
+ wait_options=wait_options,
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert isinstance(result, Operation)
+ assert result.wait_details is not None
+
+ assert len(notifier.wait_timer_calls) == 1
+ assert notifier.wait_timer_calls[0] == (execution_arn, "wait-123", 0)
+
+
+def test_process_start_action_with_parent_id():
+ processor = WaitProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ wait_options = WaitOptions(wait_seconds=15)
+ update = OperationUpdate(
+ operation_id="wait-123",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.START,
+ name="test-wait",
+ parent_id="parent-456",
+ wait_options=wait_options,
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert result.parent_id == "parent-456"
+
+
+def test_process_start_action_with_sub_type():
+ processor = WaitProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ wait_options = WaitOptions(wait_seconds=15)
+ update = OperationUpdate(
+ operation_id="wait-123",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.START,
+ name="test-wait",
+ sub_type="timer",
+ wait_options=wait_options,
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert result.sub_type == "timer"
+
+
+def test_process_cancel_action():
+ processor = WaitProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ current_op = Mock()
+ current_op.start_timestamp = datetime.now(UTC)
+
+ update = OperationUpdate(
+ operation_id="wait-123",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.CANCEL,
+ name="test-wait",
+ )
+
+ result = processor.process(update, current_op, notifier, execution_arn)
+
+ assert isinstance(result, Operation)
+ assert result.operation_id == "wait-123"
+ assert result.status == OperationStatus.CANCELLED
+ assert result.start_timestamp == current_op.start_timestamp
+
+
+def test_process_cancel_action_without_current_operation():
+ processor = WaitProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="wait-123",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.CANCEL,
+ name="test-wait",
+ )
+
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert isinstance(result, Operation)
+ assert result.status == OperationStatus.CANCELLED
+
+
+def test_process_invalid_action():
+ processor = WaitProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="wait-123",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.SUCCEED,
+ name="test-wait",
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid action for WAIT operation"
+ ):
+ processor.process(update, None, notifier, execution_arn)
+
+
+def test_process_fail_action():
+ processor = WaitProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="wait-123",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.FAIL,
+ name="test-wait",
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid action for WAIT operation"
+ ):
+ processor.process(update, None, notifier, execution_arn)
+
+
+def test_process_retry_action():
+ processor = WaitProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ update = OperationUpdate(
+ operation_id="wait-123",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.RETRY,
+ name="test-wait",
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid action for WAIT operation"
+ ):
+ processor.process(update, None, notifier, execution_arn)
+
+
+def test_wait_details_created_correctly():
+ processor = WaitProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ wait_options = WaitOptions(wait_seconds=60)
+ update = OperationUpdate(
+ operation_id="wait-123",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.START,
+ name="test-wait",
+ wait_options=wait_options,
+ )
+
+ before_time = datetime.now(UTC)
+ result = processor.process(update, None, notifier, execution_arn)
+
+ assert result.wait_details.scheduled_end_timestamp > before_time
+
+
+def test_no_completed_or_failed_calls():
+ processor = WaitProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ wait_options = WaitOptions(wait_seconds=30)
+ update = OperationUpdate(
+ operation_id="wait-123",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.START,
+ name="test-wait",
+ wait_options=wait_options,
+ )
+
+ processor.process(update, None, notifier, execution_arn)
+
+ assert len(notifier.completed_calls) == 0
+ assert len(notifier.failed_calls) == 0
+ assert len(notifier.step_retry_calls) == 0
+
+
+def test_cancel_no_timer_scheduled():
+ processor = WaitProcessor()
+ notifier = MockNotifier()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ current_op = Mock()
+ current_op.start_timestamp = datetime.now(UTC)
+
+ update = OperationUpdate(
+ operation_id="wait-123",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.CANCEL,
+ name="test-wait",
+ )
+
+ processor.process(update, current_op, notifier, execution_arn)
+
+ assert len(notifier.wait_timer_calls) == 0
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/transformer_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/transformer_test.py
new file mode 100644
index 0000000..387a96c
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/transformer_test.py
@@ -0,0 +1,394 @@
+"""Unit tests for OperationTransformer."""
+
+from unittest.mock import Mock
+
+import pytest
+from aws_durable_execution_sdk_python.lambda_service import (
+ OperationAction,
+ OperationType,
+ OperationUpdate,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.processors.base import (
+ OperationProcessor,
+)
+from aws_durable_execution_sdk_python_testing.checkpoint.transformer import (
+ OperationTransformer,
+)
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+
+
+class MockProcessor(OperationProcessor):
+ """Mock processor for testing."""
+
+ def __init__(self, return_value=None):
+ self.return_value = return_value
+ self.process_calls = []
+
+ def process(self, update, current_op, notifier, execution_arn):
+ self.process_calls.append((update, current_op, notifier, execution_arn))
+ return self.return_value
+
+
+def test_init_with_default_processors():
+ """Test initialization with default processors."""
+ transformer = OperationTransformer()
+
+ assert OperationType.STEP in transformer.processors
+ assert OperationType.WAIT in transformer.processors
+ assert OperationType.CONTEXT in transformer.processors
+ assert OperationType.CALLBACK in transformer.processors
+ assert OperationType.EXECUTION in transformer.processors
+
+
+def test_init_with_custom_processors():
+ """Test initialization with custom processors."""
+ custom_processors = {OperationType.STEP: MockProcessor()}
+ transformer = OperationTransformer(processors=custom_processors)
+
+ assert transformer.processors == custom_processors
+
+
+def test_process_updates_empty_lists():
+ """Test processing with empty updates and operations."""
+ transformer = OperationTransformer()
+ notifier = Mock()
+
+ operations, updates = transformer.process_updates([], [], notifier, "arn:test")
+
+ assert operations == []
+ assert updates == []
+
+
+def test_process_updates_processor_not_found_raises_error():
+ """Test that missing processor raises InvalidParameterValueException."""
+ transformer = OperationTransformer(processors={OperationType.STEP: MockProcessor()})
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.START,
+ )
+ notifier = Mock()
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Checkpoint for OperationType.WAIT is not implemented yet.",
+ ):
+ transformer.process_updates([update], [], notifier, "arn:test")
+
+
+def test_process_updates_processor_returns_none():
+ """Test processing when processor returns None."""
+ mock_processor = MockProcessor(return_value=None)
+ transformer = OperationTransformer(processors={OperationType.STEP: mock_processor})
+
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ )
+ notifier = Mock()
+
+ operations, updates = transformer.process_updates(
+ [update], [], notifier, "arn:test"
+ )
+
+ assert operations == []
+ assert updates == [update]
+ assert len(mock_processor.process_calls) == 1
+
+
+def test_process_updates_new_operation():
+ """Test processing creates new operation."""
+ new_operation = Mock()
+ new_operation.operation_id = "new-id"
+ mock_processor = MockProcessor(return_value=new_operation)
+ transformer = OperationTransformer(processors={OperationType.STEP: mock_processor})
+
+ update = OperationUpdate(
+ operation_id="new-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ )
+ notifier = Mock()
+
+ operations, updates = transformer.process_updates(
+ [update], [], notifier, "arn:test"
+ )
+
+ assert len(operations) == 1
+ assert operations[0] == new_operation
+ assert updates == [update]
+
+
+def test_process_updates_existing_operation():
+ """Test processing updates existing operation."""
+ existing_operation = Mock()
+ existing_operation.operation_id = "existing-id"
+ updated_operation = Mock()
+ updated_operation.operation_id = "existing-id"
+
+ mock_processor = MockProcessor(return_value=updated_operation)
+ transformer = OperationTransformer(processors={OperationType.STEP: mock_processor})
+
+ update = OperationUpdate(
+ operation_id="existing-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.SUCCEED,
+ )
+ notifier = Mock()
+
+ operations, updates = transformer.process_updates(
+ [update], [existing_operation], notifier, "arn:test"
+ )
+
+ assert len(operations) == 1
+ assert operations[0] == updated_operation
+ assert updates == [update]
+
+
+def test_process_updates_multiple_operations_preserve_order():
+ """Test processing multiple operations preserves order."""
+ op1 = Mock()
+ op1.operation_id = "op1"
+ op2 = Mock()
+ op2.operation_id = "op2"
+ op3 = Mock()
+ op3.operation_id = "op3"
+
+ updated_op2 = Mock()
+ updated_op2.operation_id = "op2"
+ new_op4 = Mock()
+ new_op4.operation_id = "op4"
+
+ mock_processor = MockProcessor()
+ transformer = OperationTransformer(processors={OperationType.STEP: mock_processor})
+
+ mock_processor.return_value = updated_op2
+
+ updates = [
+ OperationUpdate(
+ operation_id="op2",
+ operation_type=OperationType.STEP,
+ action=OperationAction.SUCCEED,
+ ),
+ ]
+ notifier = Mock()
+
+ operations, result_updates = transformer.process_updates(
+ updates, [op1, op2, op3], notifier, "arn:test"
+ )
+
+ assert len(operations) == 3
+ assert operations[0] == op1
+ assert operations[1] == updated_op2
+ assert operations[2] == op3
+ assert result_updates == updates
+
+ mock_processor.return_value = new_op4
+ updates2 = [
+ OperationUpdate(
+ operation_id="op4",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ )
+ ]
+
+ operations2, result_updates2 = transformer.process_updates(
+ updates2, [op1, updated_op2, op3], notifier, "arn:test"
+ )
+
+ assert len(operations2) == 4
+ assert operations2[0] == op1
+ assert operations2[1] == updated_op2
+ assert operations2[2] == op3
+ assert operations2[3] == new_op4
+
+
+def test_process_updates_multiple_processors():
+ """Test processing with multiple processor types."""
+ step_op = Mock()
+ step_op.operation_id = "step-id"
+ wait_op = Mock()
+ wait_op.operation_id = "wait-id"
+
+ step_processor = MockProcessor(return_value=step_op)
+ wait_processor = MockProcessor(return_value=wait_op)
+
+ transformer = OperationTransformer(
+ processors={
+ OperationType.STEP: step_processor,
+ OperationType.WAIT: wait_processor,
+ }
+ )
+
+ updates = [
+ OperationUpdate(
+ operation_id="step-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ ),
+ OperationUpdate(
+ operation_id="wait-id",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.START,
+ ),
+ ]
+ notifier = Mock()
+
+ operations, result_updates = transformer.process_updates(
+ updates, [], notifier, "arn:test"
+ )
+
+ assert len(operations) == 2
+ assert operations[0] == step_op
+ assert operations[1] == wait_op
+ assert len(step_processor.process_calls) == 1
+ assert len(wait_processor.process_calls) == 1
+
+
+def test_process_updates_passes_correct_parameters():
+ """Test that correct parameters are passed to processor."""
+ existing_op = Mock()
+ existing_op.operation_id = "test-id"
+ mock_processor = MockProcessor(return_value=existing_op)
+ transformer = OperationTransformer(processors={OperationType.STEP: mock_processor})
+
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ )
+ notifier = Mock()
+ execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test"
+
+ transformer.process_updates([update], [existing_op], notifier, execution_arn)
+
+ call_args = mock_processor.process_calls[0]
+ assert call_args[0] == update
+ assert call_args[1] == existing_op
+ assert call_args[2] == notifier
+ assert call_args[3] == execution_arn
+
+
+def test_process_updates_new_operation_not_in_map():
+ """Test processing creates new operation when operation_id not in current operations."""
+ new_operation = Mock()
+ new_operation.operation_id = "new-id"
+ mock_processor = MockProcessor(return_value=new_operation)
+ transformer = OperationTransformer(processors={OperationType.STEP: mock_processor})
+
+ # Existing operations with different IDs
+ existing_op = Mock()
+ existing_op.operation_id = "existing-id"
+
+ update = OperationUpdate(
+ operation_id="new-id", # Different from existing operation
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ )
+ notifier = Mock()
+
+ operations, updates = transformer.process_updates(
+ [update], [existing_op], notifier, "arn:test"
+ )
+
+ # Should have both existing and new operation
+ assert len(operations) == 2
+ assert operations[0] == existing_op # Original operation preserved
+ assert operations[1] == new_operation # New operation appended
+ assert updates == [update]
+
+
+def test_process_updates_in_place_update_with_multiple_operations():
+ """Test in-place update when operation exists in middle of operations list."""
+ # Create three operations
+ op1 = Mock()
+ op1.operation_id = "op1"
+ op2 = Mock()
+ op2.operation_id = "op2"
+ op3 = Mock()
+ op3.operation_id = "op3"
+
+ # Updated version of op2
+ updated_op2 = Mock()
+ updated_op2.operation_id = "op2"
+
+ mock_processor = MockProcessor(return_value=updated_op2)
+ transformer = OperationTransformer(processors={OperationType.STEP: mock_processor})
+
+ # Update for op2 (middle operation)
+ update = OperationUpdate(
+ operation_id="op2",
+ operation_type=OperationType.STEP,
+ action=OperationAction.SUCCEED,
+ )
+ notifier = Mock()
+
+ # Process update with op2 in the middle of the list
+ operations, updates = transformer.process_updates(
+ [update], [op1, op2, op3], notifier, "arn:test"
+ )
+
+ # Verify in-place update occurred
+ assert len(operations) == 3
+ assert operations[0] == op1 # First operation unchanged
+ assert operations[1] == updated_op2 # Middle operation updated in-place
+ assert operations[2] == op3 # Last operation unchanged
+ assert updates == [update]
+
+
+def test_process_updates_in_place_update_break_coverage():
+ """Test to ensure break statement in in-place update loop is covered."""
+ # Create operations where target is first in list to ensure break is hit
+ target_op = Mock()
+ target_op.operation_id = "target"
+ other_op = Mock()
+ other_op.operation_id = "other"
+
+ updated_target = Mock()
+ updated_target.operation_id = "target"
+
+ mock_processor = MockProcessor(return_value=updated_target)
+ transformer = OperationTransformer(processors={OperationType.STEP: mock_processor})
+
+ update = OperationUpdate(
+ operation_id="target",
+ operation_type=OperationType.STEP,
+ action=OperationAction.SUCCEED,
+ )
+ notifier = Mock()
+
+ # Target operation is first - should hit break immediately
+ operations, updates = transformer.process_updates(
+ [update], [target_op, other_op], notifier, "arn:test"
+ )
+
+ assert len(operations) == 2
+ assert operations[0] == updated_target
+
+
+def test_process_updates_empty_operations_list():
+ """Test for loop exit when result_operations is empty."""
+ updated_op = Mock()
+ updated_op.operation_id = "test-id"
+
+ mock_processor = MockProcessor(return_value=updated_op)
+ transformer = OperationTransformer(processors={OperationType.STEP: mock_processor})
+
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.SUCCEED,
+ )
+ notifier = Mock()
+
+ # Empty current_operations list - for loop should exit immediately
+ operations, updates = transformer.process_updates(
+ [update], [], notifier, "arn:test"
+ )
+
+ assert len(operations) == 1
+ assert operations[0] == updated_op
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/__init__.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/__init__.py
new file mode 100644
index 0000000..78d8de9
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/__init__.py
@@ -0,0 +1 @@
+"""Test package"""
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/checkpoint_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/checkpoint_test.py
new file mode 100644
index 0000000..548e988
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/checkpoint_test.py
@@ -0,0 +1,688 @@
+"""Unit tests for checkpoint validator."""
+
+import json
+
+import pytest
+from aws_durable_execution_sdk_python.lambda_service import (
+ ErrorObject,
+ Operation,
+ OperationAction,
+ OperationStatus,
+ OperationType,
+ OperationUpdate,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.validators.checkpoint import (
+ MAX_ERROR_PAYLOAD_SIZE_BYTES,
+ CheckpointValidator,
+)
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+from aws_durable_execution_sdk_python_testing.execution import Execution
+from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput
+
+
+def _create_test_execution() -> Execution:
+ """Create a test execution with basic setup."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=900,
+ execution_retention_period_days=7,
+ input=json.dumps({"test": "data"}),
+ invocation_id="test-invocation-id",
+ )
+ execution = Execution.new(start_input)
+ execution.start()
+ return execution
+
+
+def test_validate_input_empty_updates():
+ """Test validation with empty updates list."""
+ execution = _create_test_execution()
+ CheckpointValidator.validate_input([], execution)
+
+
+def test_validate_input_single_valid_update():
+ """Test validation with single valid update."""
+ execution = _create_test_execution()
+ updates = [
+ OperationUpdate(
+ operation_id="test-step-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ )
+ ]
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_conflicting_execution_update_multiple():
+ """Test validation fails with multiple execution updates."""
+ execution = _create_test_execution()
+ updates = [
+ OperationUpdate(
+ operation_id="exec-1",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.SUCCEED,
+ ),
+ OperationUpdate(
+ operation_id="exec-2",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.FAIL,
+ ),
+ ]
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Cannot checkpoint multiple EXECUTION updates",
+ ):
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_conflicting_execution_update_not_last():
+ """Test validation fails when execution update is not last."""
+ execution = _create_test_execution()
+ updates = [
+ OperationUpdate(
+ operation_id="exec-1",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.SUCCEED,
+ ),
+ OperationUpdate(
+ operation_id="step-1",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ ),
+ ]
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="EXECUTION checkpoint must be the last update",
+ ):
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_execution_update_as_last():
+ """Test validation passes when execution update is last."""
+ execution = _create_test_execution()
+ updates = [
+ OperationUpdate(
+ operation_id="step-1",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ ),
+ OperationUpdate(
+ operation_id="exec-1",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.SUCCEED,
+ ),
+ ]
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_payload_sizes_error_too_large():
+ """Test validation fails when error payload is too large."""
+ execution = _create_test_execution()
+
+ large_message = "x" * (MAX_ERROR_PAYLOAD_SIZE_BYTES + 1)
+ large_error = ErrorObject(
+ message=large_message, type="TestError", data=None, stack_trace=None
+ )
+
+ updates = [
+ OperationUpdate(
+ operation_id="step-1",
+ operation_type=OperationType.STEP,
+ action=OperationAction.FAIL,
+ error=large_error,
+ )
+ ]
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match=f"Error object size must be less than {MAX_ERROR_PAYLOAD_SIZE_BYTES} bytes",
+ ):
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_payload_sizes_error_within_limit():
+ """Test validation passes when error payload is within limit."""
+ execution = _create_test_execution()
+
+ small_error = ErrorObject(
+ message="Small error", type="TestError", data=None, stack_trace=None
+ )
+ updates = [
+ OperationUpdate(
+ operation_id="step-1",
+ operation_type=OperationType.STEP,
+ action=OperationAction.FAIL,
+ error=small_error,
+ )
+ ]
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_duplicate_operation_ids():
+ """Test validation allows duplicate operation IDs in same batch.
+
+ With background batching, the SDK can send multiple updates for the same
+ operation in a single batch (e.g., START followed by SUCCEED). This is
+ valid behavior and should be allowed.
+ """
+ execution = _create_test_execution()
+ updates = [
+ OperationUpdate(
+ operation_id="duplicate-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ ),
+ OperationUpdate(
+ operation_id="duplicate-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.SUCCEED,
+ ),
+ ]
+
+ # Should not raise - duplicate operation IDs are allowed in batches
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_valid_parent_id_in_execution():
+ """Test validation passes with valid parent ID from execution."""
+ execution = _create_test_execution()
+
+ context_op = Operation(
+ operation_id="context-1",
+ operation_type=OperationType.CONTEXT,
+ status=OperationStatus.STARTED,
+ )
+ execution.operations.append(context_op)
+
+ updates = [
+ OperationUpdate(
+ operation_id="step-1",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ parent_id="context-1",
+ )
+ ]
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_valid_parent_id_in_updates():
+ """Test validation passes with valid parent ID from updates."""
+ execution = _create_test_execution()
+ updates = [
+ OperationUpdate(
+ operation_id="context-1",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.START,
+ ),
+ OperationUpdate(
+ operation_id="step-1",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ parent_id="context-1",
+ ),
+ ]
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_invalid_parent_id_wrong_type():
+ """Test validation fails with parent ID of wrong operation type."""
+ execution = _create_test_execution()
+
+ step_op = Operation(
+ operation_id="step-parent",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.STARTED,
+ )
+ execution.operations.append(step_op)
+
+ updates = [
+ OperationUpdate(
+ operation_id="step-1",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ parent_id="step-parent",
+ )
+ ]
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid parent operation id"
+ ):
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_invalid_parent_id_not_found():
+ """Test validation fails with parent ID that doesn't exist."""
+ execution = _create_test_execution()
+ updates = [
+ OperationUpdate(
+ operation_id="step-1",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ parent_id="non-existent-parent",
+ )
+ ]
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid parent operation id"
+ ):
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_no_parent_id():
+ """Test validation passes with no parent ID."""
+ execution = _create_test_execution()
+ updates = [
+ OperationUpdate(
+ operation_id="step-1",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ parent_id=None,
+ )
+ ]
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_operation_status_transition_step():
+ """Test validation calls step validator for STEP operations."""
+ execution = _create_test_execution()
+
+ step_op = Operation(
+ operation_id="step-1",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.READY,
+ )
+ execution.operations.append(step_op)
+
+ updates = [
+ OperationUpdate(
+ operation_id="step-1",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ )
+ ]
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_operation_status_transition_context():
+ """Test validation calls context validator for CONTEXT operations."""
+ execution = _create_test_execution()
+
+ context_op = Operation(
+ operation_id="context-1",
+ operation_type=OperationType.CONTEXT,
+ status=OperationStatus.STARTED,
+ )
+ execution.operations.append(context_op)
+
+ updates = [
+ OperationUpdate(
+ operation_id="context-1",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.SUCCEED,
+ )
+ ]
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_operation_status_transition_wait():
+ """Test validation calls wait validator for WAIT operations."""
+ execution = _create_test_execution()
+
+ wait_op = Operation(
+ operation_id="wait-1",
+ operation_type=OperationType.WAIT,
+ status=OperationStatus.STARTED,
+ )
+ execution.operations.append(wait_op)
+
+ updates = [
+ OperationUpdate(
+ operation_id="wait-1",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.CANCEL,
+ )
+ ]
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_operation_status_transition_invoke():
+ """Test validation calls invoke validator for INVOKE operations."""
+ execution = _create_test_execution()
+
+ invoke_op = Operation(
+ operation_id="invoke-1",
+ operation_type=OperationType.CHAINED_INVOKE,
+ status=OperationStatus.STARTED,
+ )
+ execution.operations.append(invoke_op)
+
+ updates = [
+ OperationUpdate(
+ operation_id="invoke-1",
+ operation_type=OperationType.CHAINED_INVOKE,
+ action=OperationAction.CANCEL,
+ )
+ ]
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_operation_status_transition_execution():
+ """Test validation calls execution validator for EXECUTION operations."""
+ execution = _create_test_execution()
+ updates = [
+ OperationUpdate(
+ operation_id="exec-1",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.SUCCEED,
+ )
+ ]
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_inconsistent_operation_type():
+ """Test validation fails when operation type is inconsistent."""
+ execution = _create_test_execution()
+
+ # Add existing operation
+ step_op = Operation(
+ operation_id="op-1",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.STARTED,
+ )
+ execution.operations.append(step_op)
+
+ # Try to update with different type
+ updates = [
+ OperationUpdate(
+ operation_id="op-1",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.SUCCEED,
+ )
+ ]
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Inconsistent operation type"
+ ):
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_inconsistent_operation_subtype():
+ """Test validation fails when operation subtype is inconsistent."""
+ execution = _create_test_execution()
+
+ # Add existing operation with subtype
+ from aws_durable_execution_sdk_python.lambda_service import OperationSubType
+
+ context_op = Operation(
+ operation_id="op-1",
+ operation_type=OperationType.CONTEXT,
+ status=OperationStatus.STARTED,
+ sub_type=OperationSubType.PARALLEL,
+ )
+ execution.operations.append(context_op)
+
+ # Try to update with different subtype
+ updates = [
+ OperationUpdate(
+ operation_id="op-1",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.SUCCEED,
+ sub_type=OperationSubType.MAP,
+ )
+ ]
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Inconsistent operation subtype"
+ ):
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_inconsistent_operation_name():
+ """Test validation fails when operation name is inconsistent."""
+ execution = _create_test_execution()
+
+ # Add existing operation with name
+ step_op = Operation(
+ operation_id="op-1",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.STARTED,
+ name="original_name",
+ )
+ execution.operations.append(step_op)
+
+ # Try to update with different name
+ updates = [
+ OperationUpdate(
+ operation_id="op-1",
+ operation_type=OperationType.STEP,
+ action=OperationAction.SUCCEED,
+ name="different_name",
+ )
+ ]
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Inconsistent operation name"
+ ):
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_inconsistent_parent_operation_id():
+ """Test validation fails when parent operation ID is inconsistent."""
+ execution = _create_test_execution()
+
+ # Add TWO context operations
+ context_op1 = Operation(
+ operation_id="context-1",
+ operation_type=OperationType.CONTEXT,
+ status=OperationStatus.STARTED,
+ )
+ execution.operations.append(context_op1)
+
+ context_op2 = Operation(
+ operation_id="context-2",
+ operation_type=OperationType.CONTEXT,
+ status=OperationStatus.STARTED,
+ )
+ execution.operations.append(context_op2)
+
+ # Add existing step with parent context-1
+ step_op = Operation(
+ operation_id="step-1",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.STARTED,
+ parent_id="context-1",
+ )
+ execution.operations.append(step_op)
+
+ # Try to update with different parent context-2 (which exists, so passes parent validation)
+ updates = [
+ OperationUpdate(
+ operation_id="step-1",
+ operation_type=OperationType.STEP,
+ action=OperationAction.SUCCEED,
+ parent_id="context-2",
+ )
+ ]
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Inconsistent parent operation id"
+ ):
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_invalid_duplicate_wait_operations():
+ """Test validation fails with duplicate WAIT operations."""
+ execution = _create_test_execution()
+
+ # WAIT operations cannot have duplicate updates in same batch
+ updates = [
+ OperationUpdate(
+ operation_id="wait-1",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.START,
+ ),
+ OperationUpdate(
+ operation_id="wait-1",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.CANCEL,
+ ),
+ ]
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Cannot checkpoint multiple operations with the same ID",
+ ):
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_invalid_duplicate_callback_operations():
+ """Test validation fails with duplicate CALLBACK operations."""
+ execution = _create_test_execution()
+
+ # CALLBACK operations cannot have duplicate updates in same batch
+ updates = [
+ OperationUpdate(
+ operation_id="callback-1",
+ operation_type=OperationType.CALLBACK,
+ action=OperationAction.START,
+ ),
+ OperationUpdate(
+ operation_id="callback-1",
+ operation_type=OperationType.CALLBACK,
+ action=OperationAction.SUCCEED,
+ ),
+ ]
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Cannot checkpoint multiple operations with the same ID",
+ ):
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_invalid_duplicate_invoke_operations():
+ """Test validation fails with duplicate CHAINED_INVOKE operations."""
+ execution = _create_test_execution()
+
+ # CHAINED_INVOKE operations cannot have duplicate updates in same batch
+ updates = [
+ OperationUpdate(
+ operation_id="invoke-1",
+ operation_type=OperationType.CHAINED_INVOKE,
+ action=OperationAction.START,
+ ),
+ OperationUpdate(
+ operation_id="invoke-1",
+ operation_type=OperationType.CHAINED_INVOKE,
+ action=OperationAction.SUCCEED,
+ ),
+ ]
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Cannot checkpoint multiple operations with the same ID",
+ ):
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_invalid_duplicate_execution_operations():
+ """Test validation fails with duplicate EXECUTION operations."""
+ execution = _create_test_execution()
+
+ # EXECUTION operations cannot have duplicate updates in same batch
+ # (though this is also caught by _validate_conflicting_execution_update)
+ updates = [
+ OperationUpdate(
+ operation_id="exec-1",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.SUCCEED,
+ ),
+ OperationUpdate(
+ operation_id="exec-1",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.SUCCEED,
+ ),
+ ]
+
+ with pytest.raises(InvalidParameterValueException):
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_duplicate_context_start_then_succeed():
+ """Test validation allows CONTEXT START followed by SUCCEED."""
+ execution = _create_test_execution()
+
+ # CONTEXT operations can have START + non-START in same batch
+ updates = [
+ OperationUpdate(
+ operation_id="context-1",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.START,
+ ),
+ OperationUpdate(
+ operation_id="context-1",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.SUCCEED,
+ ),
+ ]
+
+ # Should not raise
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_invalid_duplicate_context_non_start():
+ """Test validation fails with duplicate CONTEXT non-START operations."""
+ execution = _create_test_execution()
+
+ # CONTEXT operations cannot have duplicate non-START updates
+ updates = [
+ OperationUpdate(
+ operation_id="context-1",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.SUCCEED,
+ ),
+ OperationUpdate(
+ operation_id="context-1",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.SUCCEED,
+ ),
+ ]
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Cannot checkpoint multiple operations with the same ID",
+ ):
+ CheckpointValidator.validate_input(updates, execution)
+
+
+def test_validate_invalid_duplicate_step_non_start():
+ """Test validation fails with duplicate STEP non-START operations."""
+ execution = _create_test_execution()
+
+ # STEP operations cannot have duplicate non-START updates
+ updates = [
+ OperationUpdate(
+ operation_id="step-1",
+ operation_type=OperationType.STEP,
+ action=OperationAction.SUCCEED,
+ ),
+ OperationUpdate(
+ operation_id="step-1",
+ operation_type=OperationType.STEP,
+ action=OperationAction.SUCCEED,
+ ),
+ ]
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Cannot checkpoint multiple operations with the same ID",
+ ):
+ CheckpointValidator.validate_input(updates, execution)
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/__init__.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/__init__.py
new file mode 100644
index 0000000..866c947
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/__init__.py
@@ -0,0 +1 @@
+"""Test package for operation validators."""
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/callback_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/callback_test.py
new file mode 100644
index 0000000..93f6dd3
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/callback_test.py
@@ -0,0 +1,96 @@
+"""Unit tests for callback operation validator."""
+
+import pytest
+from aws_durable_execution_sdk_python.lambda_service import (
+ Operation,
+ OperationAction,
+ OperationStatus,
+ OperationType,
+ OperationUpdate,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.callback import (
+ CallbackOperationValidator,
+)
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+
+
+def test_validate_start_action_with_no_current_state():
+ """Test START action with no current state."""
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CALLBACK,
+ action=OperationAction.START,
+ )
+ CallbackOperationValidator.validate(None, update)
+
+
+def test_validate_start_action_with_existing_state():
+ """Test START action with existing state raises error."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.CALLBACK,
+ status=OperationStatus.STARTED,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CALLBACK,
+ action=OperationAction.START,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Cannot start a CALLBACK that already exist",
+ ):
+ CallbackOperationValidator.validate(current_state, update)
+
+
+def test_validate_cancel_action_with_no_current_state():
+ """Test CANCEL action with no current state raises error."""
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CALLBACK,
+ action=OperationAction.CANCEL,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Invalid action for CALLBACK operation.",
+ ):
+ CallbackOperationValidator.validate(None, update)
+
+
+def test_validate_cancel_action_with_completed_state():
+ """Test CANCEL action with completed state raises error."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.CALLBACK,
+ status=OperationStatus.SUCCEEDED,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CALLBACK,
+ action=OperationAction.CANCEL,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Invalid action for CALLBACK operation.",
+ ):
+ CallbackOperationValidator.validate(current_state, update)
+
+
+def test_validate_invalid_action():
+ """Test invalid action raises error."""
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CALLBACK,
+ action=OperationAction.SUCCEED,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid action for CALLBACK operation."
+ ):
+ CallbackOperationValidator.validate(None, update)
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/context_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/context_test.py
new file mode 100644
index 0000000..4119c17
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/context_test.py
@@ -0,0 +1,257 @@
+"""Tests for context operation validator."""
+
+import pytest
+from aws_durable_execution_sdk_python.lambda_service import (
+ ErrorObject,
+ Operation,
+ OperationAction,
+ OperationStatus,
+ OperationType,
+ OperationUpdate,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.context import (
+ VALID_ACTIONS_FOR_CONTEXT,
+ ContextOperationValidator,
+)
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+
+
+def test_valid_actions_for_context():
+ """Test that VALID_ACTIONS_FOR_CONTEXT contains expected actions."""
+ expected_actions = {
+ OperationAction.START,
+ OperationAction.FAIL,
+ OperationAction.SUCCEED,
+ }
+ assert expected_actions == VALID_ACTIONS_FOR_CONTEXT
+
+
+def test_validate_start_action_with_no_current_state():
+ """Test START action validation when no current state exists."""
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.START,
+ )
+
+ # Should not raise exception
+ ContextOperationValidator.validate(None, update)
+
+
+def test_validate_start_action_with_existing_state():
+ """Test START action validation when current state already exists."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.CONTEXT,
+ status=OperationStatus.STARTED,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.START,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Cannot start a CONTEXT that already exist.",
+ ):
+ ContextOperationValidator.validate(current_state, update)
+
+
+def test_validate_succeed_action_with_started_state():
+ """Test SUCCEED action validation with STARTED state."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.CONTEXT,
+ status=OperationStatus.STARTED,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.SUCCEED,
+ payload="success_payload",
+ )
+
+ # Should not raise exception
+ ContextOperationValidator.validate(current_state, update)
+
+
+def test_validate_fail_action_with_started_state():
+ """Test FAIL action validation with STARTED state."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.CONTEXT,
+ status=OperationStatus.STARTED,
+ )
+ error = ErrorObject(
+ message="test error", type="TestError", data=None, stack_trace=None
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.FAIL,
+ error=error,
+ )
+
+ # Should not raise exception
+ ContextOperationValidator.validate(current_state, update)
+
+
+def test_validate_succeed_action_with_invalid_status():
+ """Test SUCCEED action validation with invalid status."""
+ invalid_statuses = [
+ OperationStatus.PENDING,
+ OperationStatus.READY,
+ OperationStatus.SUCCEEDED,
+ OperationStatus.FAILED,
+ OperationStatus.CANCELLED,
+ OperationStatus.TIMED_OUT,
+ OperationStatus.STOPPED,
+ ]
+
+ for status in invalid_statuses:
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.CONTEXT,
+ status=status,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.SUCCEED,
+ payload="success_payload",
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Invalid current CONTEXT state to close.",
+ ):
+ ContextOperationValidator.validate(current_state, update)
+
+
+def test_validate_fail_action_with_invalid_status():
+ """Test FAIL action validation with invalid status."""
+ invalid_statuses = [
+ OperationStatus.PENDING,
+ OperationStatus.READY,
+ OperationStatus.SUCCEEDED,
+ OperationStatus.FAILED,
+ OperationStatus.CANCELLED,
+ OperationStatus.TIMED_OUT,
+ OperationStatus.STOPPED,
+ ]
+
+ error = ErrorObject(
+ message="test error", type="TestError", data=None, stack_trace=None
+ )
+
+ for status in invalid_statuses:
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.CONTEXT,
+ status=status,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.FAIL,
+ error=error,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Invalid current CONTEXT state to close.",
+ ):
+ ContextOperationValidator.validate(current_state, update)
+
+
+def test_validate_fail_action_with_payload():
+ """Test FAIL action validation when payload is provided."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.CONTEXT,
+ status=OperationStatus.STARTED,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.FAIL,
+ payload="invalid_payload",
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Cannot provide a Payload for FAIL action.",
+ ):
+ ContextOperationValidator.validate(current_state, update)
+
+
+def test_validate_succeed_action_with_error():
+ """Test SUCCEED action validation when error is provided."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.CONTEXT,
+ status=OperationStatus.STARTED,
+ )
+ error = ErrorObject(
+ message="test error", type="TestError", data=None, stack_trace=None
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.SUCCEED,
+ error=error,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Cannot provide an Error for SUCCEED action.",
+ ):
+ ContextOperationValidator.validate(current_state, update)
+
+
+def test_validate_close_actions_with_no_current_state():
+ """Test SUCCEED and FAIL actions validation when no current state exists."""
+ # SUCCEED with no current state should pass
+ succeed_update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.SUCCEED,
+ payload="success_payload",
+ )
+ ContextOperationValidator.validate(None, succeed_update)
+
+ # FAIL with no current state should pass
+ error = ErrorObject(
+ message="test error", type="TestError", data=None, stack_trace=None
+ )
+ fail_update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CONTEXT,
+ action=OperationAction.FAIL,
+ error=error,
+ )
+ ContextOperationValidator.validate(None, fail_update)
+
+
+def test_validate_invalid_action():
+ """Test validation with invalid action."""
+ invalid_actions = [
+ OperationAction.RETRY,
+ OperationAction.CANCEL,
+ ]
+
+ for action in invalid_actions:
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CONTEXT,
+ action=action,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid CONTEXT action."
+ ):
+ ContextOperationValidator.validate(None, update)
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/execution_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/execution_test.py
new file mode 100644
index 0000000..2c0e573
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/execution_test.py
@@ -0,0 +1,107 @@
+"""Unit tests for execution operation validator."""
+
+import pytest
+from aws_durable_execution_sdk_python.lambda_service import (
+ ErrorObject,
+ OperationAction,
+ OperationType,
+ OperationUpdate,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.execution import (
+ ExecutionOperationValidator,
+)
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+
+
+def test_validate_succeed_action():
+ """Test SUCCEED action validation."""
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.SUCCEED,
+ payload="success",
+ )
+ ExecutionOperationValidator.validate(update)
+
+
+def test_validate_fail_action():
+ """Test FAIL action validation."""
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.FAIL,
+ error=ErrorObject(
+ message="Test error", type="TestError", data=None, stack_trace=None
+ ),
+ )
+ ExecutionOperationValidator.validate(update)
+
+
+def test_validate_succeed_action_with_error():
+ """Test SUCCEED action with error raises error."""
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.SUCCEED,
+ error=ErrorObject(
+ message="Test error", type="TestError", data=None, stack_trace=None
+ ),
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Cannot provide an Error for SUCCEED action",
+ ):
+ ExecutionOperationValidator.validate(update)
+
+
+def test_validate_fail_action_with_payload():
+ """Test FAIL action with payload raises error."""
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.FAIL,
+ payload="invalid",
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Cannot provide a Payload for FAIL action"
+ ):
+ ExecutionOperationValidator.validate(update)
+
+
+def test_validate_invalid_action():
+ """Test invalid action raises error."""
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.START,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid EXECUTION action"
+ ):
+ ExecutionOperationValidator.validate(update)
+
+
+def test_validate_fail_action_without_error():
+ """Test FAIL action without error passes validation."""
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.FAIL,
+ )
+ ExecutionOperationValidator.validate(update)
+
+
+def test_validate_succeed_action_without_payload():
+ """Test SUCCEED action without payload passes validation."""
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.EXECUTION,
+ action=OperationAction.SUCCEED,
+ )
+ ExecutionOperationValidator.validate(update)
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/invoke_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/invoke_test.py
new file mode 100644
index 0000000..3300b1a
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/invoke_test.py
@@ -0,0 +1,109 @@
+"""Unit tests for invoke operation validator."""
+
+import pytest
+from aws_durable_execution_sdk_python.lambda_service import (
+ Operation,
+ OperationAction,
+ OperationStatus,
+ OperationType,
+ OperationUpdate,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.invoke import (
+ ChainedInvokeOperationValidator,
+)
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+
+
+def test_validate_start_action_with_no_current_state():
+ """Test START action with no current state."""
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CHAINED_INVOKE,
+ action=OperationAction.START,
+ )
+ ChainedInvokeOperationValidator.validate(None, update)
+
+
+def test_validate_start_action_with_existing_state():
+ """Test START action with existing state raises error."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.CHAINED_INVOKE,
+ status=OperationStatus.STARTED,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CHAINED_INVOKE,
+ action=OperationAction.START,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Cannot start an INVOKE that already exist",
+ ):
+ ChainedInvokeOperationValidator.validate(current_state, update)
+
+
+def test_validate_cancel_action_with_started_state():
+ """Test CANCEL action with STARTED state."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.CHAINED_INVOKE,
+ status=OperationStatus.STARTED,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CHAINED_INVOKE,
+ action=OperationAction.CANCEL,
+ )
+ ChainedInvokeOperationValidator.validate(current_state, update)
+
+
+def test_validate_cancel_action_with_no_current_state():
+ """Test CANCEL action with no current state raises error."""
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CHAINED_INVOKE,
+ action=OperationAction.CANCEL,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Cannot cancel an INVOKE that does not exist or has already completed",
+ ):
+ ChainedInvokeOperationValidator.validate(None, update)
+
+
+def test_validate_cancel_action_with_completed_state():
+ """Test CANCEL action with completed state raises error."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.CHAINED_INVOKE,
+ status=OperationStatus.SUCCEEDED,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CHAINED_INVOKE,
+ action=OperationAction.CANCEL,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Cannot cancel an INVOKE that does not exist or has already completed",
+ ):
+ ChainedInvokeOperationValidator.validate(current_state, update)
+
+
+def test_validate_invalid_action():
+ """Test invalid action raises error."""
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.CHAINED_INVOKE,
+ action=OperationAction.SUCCEED,
+ )
+
+ with pytest.raises(InvalidParameterValueException, match="Invalid INVOKE action"):
+ ChainedInvokeOperationValidator.validate(None, update)
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/step_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/step_test.py
new file mode 100644
index 0000000..7d19132
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/step_test.py
@@ -0,0 +1,272 @@
+"""Unit tests for step operation validator."""
+
+import pytest
+from aws_durable_execution_sdk_python.lambda_service import (
+ ErrorObject,
+ Operation,
+ OperationAction,
+ OperationStatus,
+ OperationType,
+ OperationUpdate,
+ StepOptions,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.step import (
+ StepOperationValidator,
+)
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+
+
+def test_validate_with_no_current_state():
+ """Test validation with no current state."""
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ )
+ StepOperationValidator.validate(None, update)
+
+
+def test_validate_start_action_with_ready_state():
+ """Test START action with READY state."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.READY,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ )
+ StepOperationValidator.validate(current_state, update)
+
+
+def test_validate_start_action_with_invalid_state():
+ """Test START action with invalid state raises error."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid current STEP state to start"
+ ):
+ StepOperationValidator.validate(current_state, update)
+
+
+def test_validate_succeed_action_with_started_state():
+ """Test SUCCEED action with STARTED state."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.STARTED,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.SUCCEED,
+ payload={"result": "success"},
+ )
+ StepOperationValidator.validate(current_state, update)
+
+
+def test_validate_fail_action_with_ready_state():
+ """Test FAIL action with READY state."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.READY,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.FAIL,
+ error=ErrorObject(
+ message="Test error", type="TestError", data=None, stack_trace=None
+ ),
+ )
+ StepOperationValidator.validate(current_state, update)
+
+
+def test_validate_fail_action_with_invalid_state():
+ """Test FAIL action with invalid state raises error."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.FAIL,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid current STEP state to close"
+ ):
+ StepOperationValidator.validate(current_state, update)
+
+
+def test_validate_fail_action_with_payload():
+ """Test FAIL action with payload raises error."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.STARTED,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.FAIL,
+ payload={"invalid": "payload"},
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Cannot provide a Payload for FAIL action"
+ ):
+ StepOperationValidator.validate(current_state, update)
+
+
+def test_validate_succeed_action_with_error():
+ """Test SUCCEED action with error raises error."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.STARTED,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.SUCCEED,
+ error=ErrorObject(
+ message="Test error", type="TestError", data=None, stack_trace=None
+ ),
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Cannot provide an Error for SUCCEED action",
+ ):
+ StepOperationValidator.validate(current_state, update)
+
+
+def test_validate_retry_action_with_started_state():
+ """Test RETRY action with STARTED state."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.STARTED,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.RETRY,
+ step_options=StepOptions(next_attempt_delay_seconds=3),
+ )
+ StepOperationValidator.validate(current_state, update)
+
+
+def test_validate_retry_action_with_ready_state():
+ """Test RETRY action with READY state."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.READY,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.RETRY,
+ step_options=StepOptions(next_attempt_delay_seconds=3),
+ )
+ StepOperationValidator.validate(current_state, update)
+
+
+def test_validate_retry_action_with_invalid_state():
+ """Test RETRY action with invalid state raises error."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.RETRY,
+ step_options=StepOptions(next_attempt_delay_seconds=3),
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid current STEP state to re-attempt"
+ ):
+ StepOperationValidator.validate(current_state, update)
+
+
+def test_validate_retry_action_without_step_options():
+ """Test RETRY action without step options raises error."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.STARTED,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.RETRY,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid StepOptions for the given action"
+ ):
+ StepOperationValidator.validate(current_state, update)
+
+
+def test_validate_retry_action_with_both_error_and_payload():
+ """Test RETRY action with both error and payload raises error."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.STARTED,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.RETRY,
+ step_options=StepOptions(next_attempt_delay_seconds=3),
+ error=ErrorObject(
+ message="Test error", type="TestError", data=None, stack_trace=None
+ ),
+ payload={"result": "success"},
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Cannot provide both error and payload to RETRY a STEP",
+ ):
+ StepOperationValidator.validate(current_state, update)
+
+
+def test_validate_invalid_action():
+ """Test invalid action raises error."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.STARTED,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.CANCEL,
+ )
+
+ with pytest.raises(InvalidParameterValueException, match="Invalid STEP action"):
+ StepOperationValidator.validate(current_state, update)
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/wait_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/wait_test.py
new file mode 100644
index 0000000..77f3536
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/wait_test.py
@@ -0,0 +1,108 @@
+"""Unit tests for wait operation validator."""
+
+import pytest
+from aws_durable_execution_sdk_python.lambda_service import (
+ Operation,
+ OperationAction,
+ OperationStatus,
+ OperationType,
+ OperationUpdate,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.wait import (
+ WaitOperationValidator,
+)
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+
+
+def test_validate_start_action_with_no_current_state():
+ """Test START action with no current state."""
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.START,
+ )
+ WaitOperationValidator.validate(None, update)
+
+
+def test_validate_start_action_with_existing_state():
+ """Test START action with existing state raises error."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.WAIT,
+ status=OperationStatus.STARTED,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.START,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Cannot start a WAIT that already exist"
+ ):
+ WaitOperationValidator.validate(current_state, update)
+
+
+def test_validate_cancel_action_with_started_state():
+ """Test CANCEL action with STARTED state."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.WAIT,
+ status=OperationStatus.STARTED,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.CANCEL,
+ )
+ WaitOperationValidator.validate(current_state, update)
+
+
+def test_validate_cancel_action_with_no_current_state():
+ """Test CANCEL action with no current state raises error."""
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.CANCEL,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Cannot cancel a WAIT that does not exist or has already completed",
+ ):
+ WaitOperationValidator.validate(None, update)
+
+
+def test_validate_cancel_action_with_completed_state():
+ """Test CANCEL action with completed state raises error."""
+ current_state = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.WAIT,
+ status=OperationStatus.SUCCEEDED,
+ )
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.CANCEL,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Cannot cancel a WAIT that does not exist or has already completed",
+ ):
+ WaitOperationValidator.validate(current_state, update)
+
+
+def test_validate_invalid_action():
+ """Test invalid action raises error."""
+ update = OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.WAIT,
+ action=OperationAction.SUCCEED,
+ )
+
+ with pytest.raises(InvalidParameterValueException, match="Invalid WAIT action"):
+ WaitOperationValidator.validate(None, update)
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/transitions_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/transitions_test.py
new file mode 100644
index 0000000..901db3c
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/transitions_test.py
@@ -0,0 +1,150 @@
+"""Unit tests for transitions validator."""
+
+import pytest
+from aws_durable_execution_sdk_python.lambda_service import (
+ OperationAction,
+ OperationType,
+)
+
+from aws_durable_execution_sdk_python_testing.checkpoint.validators.transitions import (
+ ValidActionsByOperationTypeValidator,
+)
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+
+
+def test_validate_step_valid_actions():
+ """Test valid actions for STEP operations."""
+ valid_actions = [
+ OperationAction.START,
+ OperationAction.FAIL,
+ OperationAction.RETRY,
+ OperationAction.SUCCEED,
+ ]
+ for action in valid_actions:
+ ValidActionsByOperationTypeValidator.validate(OperationType.STEP, action)
+
+
+def test_validate_context_valid_actions():
+ """Test valid actions for CONTEXT operations."""
+ valid_actions = [
+ OperationAction.START,
+ OperationAction.FAIL,
+ OperationAction.SUCCEED,
+ ]
+ for action in valid_actions:
+ ValidActionsByOperationTypeValidator.validate(OperationType.CONTEXT, action)
+
+
+def test_validate_wait_valid_actions():
+ """Test valid actions for WAIT operations."""
+ valid_actions = [
+ OperationAction.START,
+ OperationAction.CANCEL,
+ ]
+ for action in valid_actions:
+ ValidActionsByOperationTypeValidator.validate(OperationType.WAIT, action)
+
+
+def test_validate_callback_valid_actions():
+ """Test valid actions for CALLBACK operations."""
+ valid_actions = [
+ OperationAction.START,
+ ]
+ for action in valid_actions:
+ ValidActionsByOperationTypeValidator.validate(OperationType.CALLBACK, action)
+
+
+def test_validate_invoke_valid_actions():
+ """Test valid actions for INVOKE operations."""
+ valid_actions = [
+ OperationAction.START,
+ OperationAction.CANCEL,
+ ]
+ for action in valid_actions:
+ ValidActionsByOperationTypeValidator.validate(
+ OperationType.CHAINED_INVOKE, action
+ )
+
+
+def test_validate_execution_valid_actions():
+ """Test valid actions for EXECUTION operations."""
+ valid_actions = [
+ OperationAction.SUCCEED,
+ OperationAction.FAIL,
+ ]
+ for action in valid_actions:
+ ValidActionsByOperationTypeValidator.validate(OperationType.EXECUTION, action)
+
+
+def test_validate_invalid_action_for_step():
+ """Test invalid action for STEP operation."""
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Invalid action for the given operation type",
+ ):
+ ValidActionsByOperationTypeValidator.validate(
+ OperationType.STEP, OperationAction.CANCEL
+ )
+
+
+def test_validate_invalid_action_for_context():
+ """Test invalid action for CONTEXT operation."""
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Invalid action for the given operation type",
+ ):
+ ValidActionsByOperationTypeValidator.validate(
+ OperationType.CONTEXT, OperationAction.RETRY
+ )
+
+
+def test_validate_invalid_action_for_wait():
+ """Test invalid action for WAIT operation."""
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Invalid action for the given operation type",
+ ):
+ ValidActionsByOperationTypeValidator.validate(
+ OperationType.WAIT, OperationAction.SUCCEED
+ )
+
+
+def test_validate_invalid_action_for_callback():
+ """Test invalid action for CALLBACK operation."""
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Invalid action for the given operation type",
+ ):
+ ValidActionsByOperationTypeValidator.validate(
+ OperationType.CALLBACK, OperationAction.FAIL
+ )
+
+
+def test_validate_invalid_action_for_invoke():
+ """Test invalid action for INVOKE operation."""
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Invalid action for the given operation type",
+ ):
+ ValidActionsByOperationTypeValidator.validate(
+ OperationType.CHAINED_INVOKE, OperationAction.RETRY
+ )
+
+
+def test_validate_invalid_action_for_execution():
+ """Test invalid action for EXECUTION operation."""
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Invalid action for the given operation type",
+ ):
+ ValidActionsByOperationTypeValidator.validate(
+ OperationType.EXECUTION, OperationAction.START
+ )
+
+
+def test_validate_unknown_operation_type():
+ """Test validation with unknown operation type."""
+ with pytest.raises(InvalidParameterValueException, match="Unknown operation type"):
+ ValidActionsByOperationTypeValidator.validate(None, OperationAction.START)
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/cli_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/cli_test.py
new file mode 100644
index 0000000..98b53c6
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/cli_test.py
@@ -0,0 +1,1117 @@
+"""Tests for the CLI module."""
+
+from __future__ import annotations
+
+import argparse
+import json
+import logging
+import os
+import sys
+from http.client import HTTPMessage
+from io import StringIO, BytesIO
+from unittest.mock import Mock, patch
+
+import pytest
+from urllib.error import HTTPError, URLError
+
+from botocore.exceptions import ConnectionError # type: ignore
+
+from aws_durable_execution_sdk_python_testing.cli import CliApp, CliConfig, main
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ DurableFunctionsLocalRunnerError,
+ InvalidParameterValueException,
+ ResourceNotFoundException,
+ ServiceException,
+ TooManyRequestsException,
+)
+
+
+def test_cli_config_has_correct_default_values() -> None:
+ """Test that CliConfig has correct default values."""
+ config = CliConfig()
+
+ assert config.host == "0.0.0.0" # noqa: S104
+ assert config.port == 5000
+ assert config.log_level == logging.INFO
+ assert config.lambda_endpoint == "http://127.0.0.1:3001"
+ assert config.local_runner_endpoint == "http://0.0.0.0:5000"
+ assert config.local_runner_region == "us-west-2"
+ assert config.local_runner_mode == "local"
+
+
+def test_cli_config_from_environment_uses_defaults_when_no_env_vars() -> None:
+ """Test from_environment with no environment variables set."""
+ with patch.dict(os.environ, {}, clear=True):
+ config = CliConfig.from_environment()
+
+ assert config.host == "0.0.0.0" # noqa: S104
+ assert config.port == 5000
+ assert config.log_level == logging.INFO
+ assert config.lambda_endpoint == "http://127.0.0.1:3001"
+ assert config.local_runner_endpoint == "http://0.0.0.0:5000"
+ assert config.local_runner_region == "us-west-2"
+ assert config.local_runner_mode == "local"
+
+
+def test_cli_config_from_environment_uses_all_env_vars_when_set() -> None:
+ """Test from_environment with all environment variables set."""
+ env_vars = {
+ "AWS_DEX_HOST": "127.0.0.1",
+ "AWS_DEX_PORT": "8080",
+ "AWS_DEX_LOG_LEVEL": "DEBUG",
+ "AWS_DEX_LAMBDA_ENDPOINT": "http://localhost:4000",
+ "AWS_DEX_LOCAL_RUNNER_ENDPOINT": "http://localhost:8080",
+ "AWS_DEX_LOCAL_RUNNER_REGION": "us-east-1",
+ "AWS_DEX_LOCAL_RUNNER_MODE": "remote",
+ }
+
+ with patch.dict(os.environ, env_vars, clear=True):
+ config = CliConfig.from_environment()
+
+ assert config.host == "127.0.0.1"
+ assert config.port == 8080
+ assert config.log_level == logging.DEBUG
+ assert config.lambda_endpoint == "http://localhost:4000"
+ assert config.local_runner_endpoint == "http://localhost:8080"
+ assert config.local_runner_region == "us-east-1"
+ assert config.local_runner_mode == "remote"
+
+
+def test_cli_config_from_environment_uses_partial_env_vars_with_defaults() -> None:
+ """Test from_environment with some environment variables set."""
+ env_vars = {
+ "AWS_DEX_HOST": "192.168.1.1",
+ "AWS_DEX_PORT": "9000",
+ }
+
+ with patch.dict(os.environ, env_vars, clear=True):
+ config = CliConfig.from_environment()
+
+ assert config.host == "192.168.1.1"
+ assert config.port == 9000
+ # Other values should be defaults
+ assert config.log_level == logging.INFO
+ assert config.lambda_endpoint == "http://127.0.0.1:3001"
+
+
+def test_cli_app_loads_config_from_environment_on_init() -> None:
+ """Test that CliApp loads configuration from environment on init."""
+ env_vars = {"AWS_DEX_HOST": "test-host", "AWS_DEX_PORT": "7777"}
+
+ with patch.dict(os.environ, env_vars, clear=True):
+ app = CliApp()
+
+ assert app.config.host == "test-host"
+ assert app.config.port == 7777
+
+
+def test_cli_app_shows_help_and_returns_error_when_no_command() -> None:
+ """Test that running with no command shows help and returns error code."""
+ app = CliApp()
+
+ with patch("sys.stderr", new_callable=StringIO) as mock_stderr:
+ exit_code = app.run([])
+
+ assert exit_code == 2 # argparse error code
+ assert "required" in mock_stderr.getvalue().lower()
+
+
+def test_cli_app_shows_usage_information_with_help_flag() -> None:
+ """Test that --help shows usage information."""
+ app = CliApp()
+
+ with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
+ exit_code = app.run(["--help"])
+
+ assert exit_code == 0
+ output = mock_stdout.getvalue()
+ assert "dex-local-runner" in output
+ assert "start-server" in output
+ assert "invoke" in output
+ assert "get-durable-execution" in output
+ assert "get-durable-execution-history" in output
+
+
+def test_cli_app_handles_keyboard_interrupt_gracefully() -> None:
+ """Test that KeyboardInterrupt is handled gracefully."""
+ app = CliApp()
+
+ with patch.object(app, "_create_parsers") as mock_setup:
+ mock_setup.side_effect = KeyboardInterrupt()
+
+ with patch("sys.stderr", new_callable=StringIO) as mock_stderr:
+ exit_code = app.run(["start-server"])
+
+ assert exit_code == 130
+ assert "cancelled by user" in mock_stderr.getvalue()
+
+
+def test_start_server_command_parses_arguments_correctly() -> None:
+ """Test that start-server command parses arguments correctly."""
+ app = CliApp()
+
+ # Test with default values
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.WebRunner"
+ ) as mock_web_runner:
+ # Mock runner context manager
+ mock_runner_instance = mock_web_runner.return_value
+ mock_runner_instance.__enter__.return_value = mock_runner_instance
+ mock_runner_instance.__exit__.return_value = None
+ mock_runner_instance.serve_forever.side_effect = KeyboardInterrupt()
+
+ exit_code = app.run(["start-server"])
+ assert exit_code == 130 # KeyboardInterrupt exit code
+
+ # Test with custom values
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.WebRunner"
+ ) as mock_web_runner:
+ # Mock runner context manager
+ mock_runner_instance = mock_web_runner.return_value
+ mock_runner_instance.__enter__.return_value = mock_runner_instance
+ mock_runner_instance.__exit__.return_value = None
+ mock_runner_instance.serve_forever.side_effect = KeyboardInterrupt()
+
+ exit_code = app.run(
+ [
+ "start-server",
+ "--host",
+ "127.0.0.1",
+ "--port",
+ "8080",
+ "--log-level",
+ "DEBUG",
+ "--lambda-endpoint",
+ "http://localhost:4000",
+ "--local-runner-endpoint",
+ "http://localhost:8080",
+ "--local-runner-region",
+ "us-east-1",
+ "--local-runner-mode",
+ "remote",
+ ]
+ )
+ assert exit_code == 130 # KeyboardInterrupt exit code
+
+
+def test_invoke_command_parses_arguments_correctly() -> None:
+ """Test that invoke command parses arguments correctly."""
+ app = CliApp()
+
+ # Test with required function-name
+ with patch("sys.stdout", new_callable=StringIO):
+ exit_code = app.run(["invoke", "--function-name", "test-function"])
+ assert exit_code == 1 # Not implemented yet
+
+ # Test with all parameters
+ with patch("sys.stdout", new_callable=StringIO):
+ exit_code = app.run(
+ [
+ "invoke",
+ "--function-name",
+ "test-function",
+ "--input",
+ '{"key": "value"}',
+ "--durable-execution-name",
+ "test-execution",
+ ]
+ )
+ assert exit_code == 1 # Not implemented yet
+
+
+def test_invoke_command_requires_function_name_parameter() -> None:
+ """Test that invoke command requires function-name parameter."""
+ app = CliApp()
+
+ with patch("sys.stderr", new_callable=StringIO) as mock_stderr:
+ exit_code = app.run(["invoke"])
+
+ assert exit_code == 2 # argparse error code
+ assert "required" in mock_stderr.getvalue().lower()
+
+
+def test_invoke_command_validates_json_input_format() -> None:
+ """Test that invoke command validates JSON input."""
+ app = CliApp()
+
+ exit_code = app.run(
+ [
+ "invoke",
+ "--function-name",
+ "test-function",
+ "--input",
+ "invalid-json",
+ ]
+ )
+
+ assert exit_code == 1
+
+
+def test_get_durable_execution_command_parses_arguments_correctly() -> None:
+ """Test that get-durable-execution command parses arguments correctly."""
+ app = CliApp()
+
+ with patch.object(app, "_create_boto3_client") as mock_create_client:
+ mock_client = Mock()
+ mock_client.get_durable_execution.side_effect = Exception("Connection refused")
+ mock_create_client.return_value = mock_client
+
+ with patch("sys.stderr", new_callable=StringIO):
+ exit_code = app.run(
+ ["get-durable-execution", "--durable-execution-arn", "test-arn"]
+ )
+ assert exit_code == 1 # Connection error
+
+
+def test_get_durable_execution_command_requires_arn_parameter() -> None:
+ """Test that get-durable-execution command requires ARN parameter."""
+ app = CliApp()
+
+ with patch("sys.stderr", new_callable=StringIO) as mock_stderr:
+ exit_code = app.run(["get-durable-execution"])
+
+ assert exit_code == 2 # argparse error code
+ assert "required" in mock_stderr.getvalue().lower()
+
+
+def test_get_durable_execution_history_command_parses_arguments_correctly() -> None:
+ """Test that get-durable-execution-history command parses arguments correctly."""
+ app = CliApp()
+
+ with patch.object(app, "_create_boto3_client") as mock_create_client:
+ mock_client = Mock()
+ mock_client.get_durable_execution_history.side_effect = Exception(
+ "Connection refused"
+ )
+ mock_create_client.return_value = mock_client
+
+ with patch("sys.stderr", new_callable=StringIO):
+ exit_code = app.run(
+ ["get-durable-execution-history", "--durable-execution-arn", "test-arn"]
+ )
+ assert exit_code == 1 # Connection error
+
+
+def test_get_durable_execution_history_command_requires_arn_parameter() -> None:
+ """Test that get-durable-execution-history command requires ARN parameter."""
+ app = CliApp()
+
+ with patch("sys.stderr", new_callable=StringIO) as mock_stderr:
+ exit_code = app.run(["get-durable-execution-history"])
+
+ assert exit_code == 2 # argparse error code
+ assert "required" in mock_stderr.getvalue().lower()
+
+
+def test_logging_configuration_uses_specified_log_level() -> None:
+ """Test that logging is configured based on log level."""
+ app = CliApp()
+
+ with patch("logging.basicConfig") as mock_basic_config:
+ with patch("sys.stdout", new_callable=StringIO):
+ with patch.object(app, "start_server_command", return_value=0):
+ app.run(["start-server", "--log-level", "DEBUG"])
+
+ mock_basic_config.assert_called_once()
+ call_args = mock_basic_config.call_args
+ assert call_args[1]["level"] == 10
+
+
+def test_parser_creation_includes_all_subcommands() -> None:
+ """Test that parser creation includes all expected subcommands."""
+ app = CliApp()
+
+ # Test that all subcommands are available
+ with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
+ exit_code = app.run(["--help"])
+ assert exit_code == 0
+ output = mock_stdout.getvalue()
+ assert "start-server" in output
+ assert "invoke" in output
+ assert "get-durable-execution" in output
+ assert "get-durable-execution-history" in output
+
+
+def test_start_server_command_works_with_mocked_dependencies() -> None:
+ """Test start-server command with mocked WebRunner."""
+ app = CliApp()
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.WebRunner"
+ ) as mock_web_runner:
+ # Mock runner context manager
+ mock_runner_instance = mock_web_runner.return_value
+ mock_runner_instance.__enter__.return_value = mock_runner_instance
+ mock_runner_instance.__exit__.return_value = None
+
+ # Mock serve_forever to avoid blocking
+ mock_runner_instance.serve_forever.side_effect = KeyboardInterrupt()
+
+ exit_code = app.run(
+ [
+ "start-server",
+ "--host",
+ "127.0.0.1",
+ "--port",
+ "8080",
+ "--log-level",
+ "DEBUG",
+ ]
+ )
+
+ assert exit_code == 130 # KeyboardInterrupt exit code
+ mock_web_runner.assert_called_once()
+
+ # Verify WebRunnerConfig was created with correct values
+ call_args = mock_web_runner.call_args[0][0] # First positional argument
+ assert call_args.web_service.host == "127.0.0.1"
+ assert call_args.web_service.port == 8080
+ assert call_args.web_service.log_level == "DEBUG"
+
+
+def test_start_server_command_handles_server_startup_errors() -> None:
+ """Test start-server command handles server startup errors."""
+ app = CliApp()
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.WebRunner"
+ ) as mock_web_runner:
+ # Make WebRunner constructor raise an exception
+ mock_web_runner.side_effect = Exception("Server startup failed")
+
+ exit_code = app.run(["start-server"])
+
+ assert exit_code == 1
+
+
+def test_start_server_command_creates_correct_web_runner_config() -> None:
+ """Test that start-server command creates WebRunnerConfig with all CLI arguments."""
+ app = CliApp()
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.WebRunner"
+ ) as mock_web_runner:
+ # Mock runner context manager
+ mock_runner_instance = mock_web_runner.return_value
+ mock_runner_instance.__enter__.return_value = mock_runner_instance
+ mock_runner_instance.__exit__.return_value = None
+ mock_runner_instance.serve_forever.side_effect = KeyboardInterrupt()
+
+ exit_code = app.run(
+ [
+ "start-server",
+ "--host",
+ "192.168.1.100",
+ "--port",
+ "9000",
+ "--log-level",
+ "WARNING",
+ "--lambda-endpoint",
+ "http://custom-lambda:4000",
+ "--local-runner-endpoint",
+ "http://custom-runner:9000",
+ "--local-runner-region",
+ "eu-west-1",
+ "--local-runner-mode",
+ "remote",
+ ]
+ )
+
+ assert exit_code == 130 # KeyboardInterrupt exit code
+ mock_web_runner.assert_called_once()
+
+ # Verify WebRunnerConfig was created with all custom values
+ config = mock_web_runner.call_args[0][0] # First positional argument
+
+ # Verify web service configuration
+ assert config.web_service.host == "192.168.1.100"
+ assert config.web_service.port == 9000
+ assert config.web_service.log_level == "WARNING"
+
+ # Verify Lambda service configuration
+ assert config.lambda_endpoint == "http://custom-lambda:4000"
+ assert config.local_runner_endpoint == "http://custom-runner:9000"
+ assert config.local_runner_region == "eu-west-1"
+ assert config.local_runner_mode == "remote"
+
+
+def test_start_server_command_uses_context_manager_properly() -> None:
+ """Test that start-server command uses WebRunner as context manager."""
+ app = CliApp()
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.WebRunner"
+ ) as mock_web_runner:
+ # Mock runner context manager
+ mock_runner_instance = mock_web_runner.return_value
+ mock_runner_instance.__enter__.return_value = mock_runner_instance
+ mock_runner_instance.__exit__.return_value = None
+ mock_runner_instance.serve_forever.return_value = None
+
+ exit_code = app.run(["start-server"])
+
+ assert exit_code == 0
+ mock_web_runner.assert_called_once()
+
+ # Verify context manager methods were called
+ mock_runner_instance.__enter__.assert_called_once()
+ mock_runner_instance.__exit__.assert_called_once()
+ mock_runner_instance.serve_forever.assert_called_once()
+
+
+def test_start_server_command_handles_runtime_error_from_web_runner() -> None:
+ """Test that start-server command handles DurableFunctionsLocalRunnerError from WebRunner."""
+ app = CliApp()
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.WebRunner"
+ ) as mock_web_runner:
+ # Mock runner context manager that raises DurableFunctionsLocalRunnerError
+ mock_runner_instance = mock_web_runner.return_value
+ mock_runner_instance.__enter__.side_effect = DurableFunctionsLocalRunnerError(
+ "Server already running"
+ )
+
+ exit_code = app.run(["start-server"])
+
+ assert exit_code == 1
+
+
+def test_start_server_command_logs_configuration_details() -> None:
+ """Test that start-server command logs configuration details."""
+ app = CliApp()
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.WebRunner"
+ ) as mock_web_runner:
+ # Mock runner context manager
+ mock_runner_instance = mock_web_runner.return_value
+ mock_runner_instance.__enter__.return_value = mock_runner_instance
+ mock_runner_instance.__exit__.return_value = None
+ mock_runner_instance.serve_forever.side_effect = KeyboardInterrupt()
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.logger"
+ ) as mock_logger:
+ exit_code = app.run(
+ [
+ "start-server",
+ "--host",
+ "test-host",
+ "--port",
+ "8888",
+ ]
+ )
+
+ assert exit_code == 130
+
+ # Verify configuration logging
+ mock_logger.info.assert_any_call(
+ "Starting Durable Functions Local Runner on %s:%s",
+ "test-host",
+ 8888,
+ )
+ mock_logger.info.assert_any_call("Configuration:")
+ mock_logger.info.assert_any_call(" Host: %s", "test-host")
+ mock_logger.info.assert_any_call(" Port: %s", 8888)
+
+
+def test_start_server_command_maintains_backward_compatible_logging() -> None:
+ """Test that start-server command maintains backward compatible logging messages."""
+ app = CliApp()
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.WebRunner"
+ ) as mock_web_runner:
+ # Mock runner context manager
+ mock_runner_instance = mock_web_runner.return_value
+ mock_runner_instance.__enter__.return_value = mock_runner_instance
+ mock_runner_instance.__exit__.return_value = None
+ mock_runner_instance.serve_forever.side_effect = KeyboardInterrupt()
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.logger"
+ ) as mock_logger:
+ exit_code = app.run(["start-server"])
+
+ assert exit_code == 130
+
+ # Verify backward compatible logging messages
+ mock_logger.info.assert_any_call(
+ "Server started successfully. Press Ctrl+C to stop."
+ )
+ mock_logger.info.assert_any_call(
+ "Received shutdown signal, stopping server..."
+ )
+
+
+def test_start_server_command_handles_serve_forever_exception() -> None:
+ """Test that start-server command handles exceptions from serve_forever."""
+ app = CliApp()
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.WebRunner"
+ ) as mock_web_runner:
+ # Mock runner context manager
+ mock_runner_instance = mock_web_runner.return_value
+ mock_runner_instance.__enter__.return_value = mock_runner_instance
+ mock_runner_instance.__exit__.return_value = None
+ mock_runner_instance.serve_forever.side_effect = (
+ DurableFunctionsLocalRunnerError("Server error during operation")
+ )
+
+ exit_code = app.run(["start-server"])
+
+ assert exit_code == 1
+
+
+def test_main_function_creates_cli_app_and_runs() -> None:
+ """Test the main function entry point."""
+ with patch("aws_durable_execution_sdk_python_testing.cli.CliApp") as mock_cli_app:
+ mock_app_instance = mock_cli_app.return_value
+ mock_app_instance.run.return_value = 42
+
+ exit_code = main()
+
+ mock_cli_app.assert_called_once()
+ mock_app_instance.run.assert_called_once()
+ assert exit_code == 42
+
+
+def test_main_function_works_when_called_as_script() -> None:
+ """Test that main function works when called as script."""
+ original_argv = sys.argv[:]
+ try:
+ sys.argv = ["dex-local-runner", "--help"]
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.CliApp"
+ ) as mock_cli_app:
+ mock_app_instance = mock_cli_app.return_value
+ mock_app_instance.run.return_value = 0
+
+ exit_code = main()
+
+ assert exit_code == 0
+ mock_app_instance.run.assert_called_once()
+ finally:
+ sys.argv = original_argv
+
+
+# Tests for client operation CLI commands
+
+
+def test_invoke_command_makes_http_request_to_start_execution_endpoint() -> None:
+ """Test that invoke command makes HTTP request to start-durable-execution endpoint."""
+ app = CliApp()
+
+ response_body = json.dumps(
+ {
+ "ExecutionArn": "arn:aws:lambda:us-west-2:123456789012:function:test-function:execution:test-execution"
+ }
+ ).encode("utf-8")
+
+ mock_response = Mock()
+ mock_response.read.return_value = response_body
+ mock_response.__enter__ = Mock(return_value=mock_response)
+ mock_response.__exit__ = Mock(return_value=False)
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.urlopen",
+ return_value=mock_response,
+ ) as mock_urlopen:
+ with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
+ exit_code = app.invoke_command(
+ argparse.Namespace(
+ function_name="test-function",
+ input='{"key": "value"}',
+ durable_execution_name="test-execution",
+ )
+ )
+
+ assert exit_code == 0
+ mock_urlopen.assert_called_once()
+
+ # Verify the request details
+ call_args = mock_urlopen.call_args
+ req = call_args[0][0]
+ assert req.full_url.endswith("/start-durable-execution")
+ assert req.get_header("Content-type") == "application/json"
+ assert call_args[1]["timeout"] == 10
+
+ # Verify payload structure
+ payload = json.loads(req.data.decode("utf-8"))
+ assert payload["FunctionName"] == "test-function"
+ assert payload["Input"] == '{"key": "value"}'
+ assert payload["ExecutionName"] == "test-execution"
+
+ # Verify output
+ output = mock_stdout.getvalue()
+ assert "ExecutionArn" in output
+
+
+def test_invoke_command_uses_default_execution_name_when_not_provided() -> None:
+ """Test that invoke command generates default execution name when not provided."""
+ app = CliApp()
+
+ response_body = json.dumps({"ExecutionArn": "test-arn"}).encode("utf-8")
+ mock_response = Mock()
+ mock_response.read.return_value = response_body
+ mock_response.__enter__ = Mock(return_value=mock_response)
+ mock_response.__exit__ = Mock(return_value=False)
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.urlopen",
+ return_value=mock_response,
+ ) as mock_urlopen:
+ app.invoke_command(
+ argparse.Namespace(
+ function_name="my-function",
+ input="{}",
+ durable_execution_name=None,
+ )
+ )
+
+ # Verify default execution name is generated
+ req = mock_urlopen.call_args[0][0]
+ payload = json.loads(req.data.decode("utf-8"))
+ assert payload["ExecutionName"] == "my-function-execution"
+
+
+def test_invoke_command_handles_connection_error() -> None:
+ """Test that invoke command handles connection errors gracefully."""
+ app = CliApp()
+
+ with patch("aws_durable_execution_sdk_python_testing.cli.urlopen") as mock_urlopen:
+ mock_urlopen.side_effect = URLError("Connection refused")
+
+ exit_code = app.invoke_command(
+ argparse.Namespace(
+ function_name="test-function",
+ input="{}",
+ durable_execution_name=None,
+ )
+ )
+
+ assert exit_code == 1
+
+
+def test_invoke_command_handles_http_error_response() -> None:
+ """Test that invoke command handles HTTP error responses."""
+ app = CliApp()
+
+ error_body = json.dumps(
+ {
+ "ErrorMessage": "Invalid parameter value",
+ "ErrorType": "InvalidParameterValueException",
+ }
+ ).encode("utf-8")
+
+ with patch("aws_durable_execution_sdk_python_testing.cli.urlopen") as mock_urlopen:
+ mock_urlopen.side_effect = HTTPError(
+ url="http://0.0.0.0:5000/start-durable-execution",
+ code=400,
+ msg="Bad Request",
+ hdrs=HTTPMessage(),
+ fp=BytesIO(error_body),
+ )
+
+ with patch("sys.stderr", new_callable=StringIO) as mock_stderr:
+ exit_code = app.invoke_command(
+ argparse.Namespace(
+ function_name="test-function",
+ input="{}",
+ durable_execution_name=None,
+ )
+ )
+
+ assert exit_code == 1
+ assert "Invalid parameter value" in mock_stderr.getvalue()
+
+
+def test_invoke_command_handles_non_json_error_response() -> None:
+ """Test that invoke command handles non-JSON error responses."""
+ app = CliApp()
+
+ with patch("aws_durable_execution_sdk_python_testing.cli.urlopen") as mock_urlopen:
+ mock_urlopen.side_effect = HTTPError(
+ url="http://0.0.0.0:5000/start-durable-execution",
+ code=500,
+ msg="Internal Server Error",
+ hdrs=HTTPMessage(),
+ fp=BytesIO(b"Internal Server Error"),
+ )
+
+ exit_code = app.invoke_command(
+ argparse.Namespace(
+ function_name="test-function",
+ input="{}",
+ durable_execution_name=None,
+ )
+ )
+
+ assert exit_code == 1
+
+
+def test_get_durable_execution_command_uses_boto3_client() -> None:
+ """Test that get-durable-execution command uses boto3 client."""
+ app = CliApp()
+
+ with patch.object(app, "_create_boto3_client") as mock_create_client:
+ mock_client = mock_create_client.return_value
+ mock_client.get_durable_execution.return_value = {
+ "DurableExecutionArn": "test-arn",
+ "Status": "SUCCEEDED",
+ "Result": {"output": "success"},
+ }
+
+ with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
+ exit_code = app.get_durable_execution_command(
+ argparse.Namespace(durable_execution_arn="test-arn")
+ )
+
+ assert exit_code == 0
+ mock_create_client.assert_called_once()
+ mock_client.get_durable_execution.assert_called_once_with(
+ DurableExecutionArn="test-arn"
+ )
+
+ # Verify JSON output
+ output = mock_stdout.getvalue()
+ assert "test-arn" in output
+ assert "SUCCEEDED" in output
+
+
+def test_get_durable_execution_command_handles_resource_not_found() -> None:
+ """Test that get-durable-execution command handles ResourceNotFoundException."""
+ app = CliApp()
+
+ with patch.object(app, "_create_boto3_client") as mock_create_client:
+ mock_client = mock_create_client.return_value
+
+ mock_client.exceptions.ResourceNotFoundException = ResourceNotFoundException
+ mock_client.exceptions.InvalidParameterValueException = (
+ InvalidParameterValueException
+ )
+ mock_client.exceptions.TooManyRequestsException = TooManyRequestsException
+ mock_client.exceptions.ServiceException = ServiceException
+
+ mock_client.get_durable_execution.side_effect = ResourceNotFoundException(
+ "Resource not found"
+ )
+
+ with patch("sys.stderr", new_callable=StringIO) as mock_stderr:
+ exit_code = app.get_durable_execution_command(
+ argparse.Namespace(durable_execution_arn="nonexistent-arn")
+ )
+
+ assert exit_code == 1
+ assert "Error: Execution not found" in mock_stderr.getvalue()
+
+
+def test_get_durable_execution_command_handles_invalid_parameter() -> None:
+ """Test that get-durable-execution command handles InvalidParameterValueException."""
+ app = CliApp()
+
+ with patch.object(app, "_create_boto3_client") as mock_create_client:
+ mock_client = mock_create_client.return_value
+
+ mock_client.exceptions.ResourceNotFoundException = ResourceNotFoundException
+ mock_client.exceptions.InvalidParameterValueException = (
+ InvalidParameterValueException
+ )
+ mock_client.exceptions.TooManyRequestsException = TooManyRequestsException
+ mock_client.exceptions.ServiceException = ServiceException
+
+ mock_client.get_durable_execution.side_effect = InvalidParameterValueException(
+ "Invalid parameters"
+ )
+
+ with patch("sys.stderr", new_callable=StringIO) as mock_stderr:
+ exit_code = app.get_durable_execution_command(
+ argparse.Namespace(durable_execution_arn="invalid-arn")
+ )
+
+ assert exit_code == 1
+ assert "Error: Invalid parameter" in mock_stderr.getvalue()
+
+
+def test_get_durable_execution_command_handles_too_many_requests() -> None:
+ """Test that get-durable-execution command handles InvalidParameterValueException."""
+ app = CliApp()
+
+ with patch.object(app, "_create_boto3_client") as mock_create_client:
+ mock_client = mock_create_client.return_value
+
+ mock_client.exceptions.ResourceNotFoundException = ResourceNotFoundException
+ mock_client.exceptions.InvalidParameterValueException = (
+ InvalidParameterValueException
+ )
+ mock_client.exceptions.TooManyRequestsException = TooManyRequestsException
+ mock_client.exceptions.ServiceException = ServiceException
+
+ mock_client.get_durable_execution.side_effect = TooManyRequestsException(
+ "Too many requests"
+ )
+
+ with patch("sys.stderr", new_callable=StringIO) as mock_stderr:
+ exit_code = app.get_durable_execution_command(
+ argparse.Namespace(durable_execution_arn="my-arn")
+ )
+
+ assert exit_code == 1
+ assert "Error: Too many requests" in mock_stderr.getvalue()
+
+
+def test_get_durable_execution_command_handles_service_exception() -> None:
+ """Test that get-durable-execution command handles InvalidParameterValueException."""
+ app = CliApp()
+
+ with patch.object(app, "_create_boto3_client") as mock_create_client:
+ mock_client = mock_create_client.return_value
+
+ mock_client.exceptions.ResourceNotFoundException = ResourceNotFoundException
+ mock_client.exceptions.InvalidParameterValueException = (
+ InvalidParameterValueException
+ )
+ mock_client.exceptions.TooManyRequestsException = TooManyRequestsException
+ mock_client.exceptions.ServiceException = ServiceException
+
+ mock_client.get_durable_execution.side_effect = ServiceException(
+ "Service exception"
+ )
+
+ with patch("sys.stderr", new_callable=StringIO) as mock_stderr:
+ exit_code = app.get_durable_execution_command(
+ argparse.Namespace(durable_execution_arn="my-arn")
+ )
+
+ assert exit_code == 1
+ assert "Error: Service error" in mock_stderr.getvalue()
+
+
+def test_get_durable_execution_command_handles_connection_error() -> None:
+ """Test that get-durable-execution command handles connection errors."""
+ app = CliApp()
+
+ with patch.object(app, "_create_boto3_client") as mock_create_client:
+ mock_client = mock_create_client.return_value
+
+ mock_client.exceptions.ResourceNotFoundException = ResourceNotFoundException
+ mock_client.exceptions.InvalidParameterValueException = (
+ InvalidParameterValueException
+ )
+ mock_client.exceptions.TooManyRequestsException = TooManyRequestsException
+ mock_client.exceptions.ServiceException = ServiceException
+
+ mock_client.get_durable_execution.side_effect = ConnectionError(
+ error="Mocked connection error"
+ )
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.logger"
+ ) as mock_logger:
+ exit_code = app.get_durable_execution_command(
+ argparse.Namespace(durable_execution_arn="my-arn")
+ )
+
+ assert exit_code == 1
+ mock_logger.exception.assert_called_once_with(
+ "Error: Could not connect to the local runner server. Is it running?"
+ )
+
+
+def test_get_durable_execution_history_command_uses_boto3_client() -> None:
+ """Test that get-durable-execution-history command uses boto3 client."""
+ app = CliApp()
+
+ with patch.object(app, "_create_boto3_client") as mock_create_client:
+ mock_client = mock_create_client.return_value
+ mock_client.get_durable_execution_history.return_value = {
+ "Events": [
+ {
+ "EventType": "ExecutionStarted",
+ "EventTimestamp": "2024-01-01T00:00:00Z",
+ },
+ {
+ "EventType": "ExecutionSucceeded",
+ "EventTimestamp": "2024-01-01T00:01:00Z",
+ },
+ ]
+ }
+
+ with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
+ exit_code = app.get_durable_execution_history_command(
+ argparse.Namespace(durable_execution_arn="test-arn")
+ )
+
+ assert exit_code == 0
+ mock_create_client.assert_called_once()
+ mock_client.get_durable_execution_history.assert_called_once_with(
+ DurableExecutionArn="test-arn"
+ )
+
+ # Verify JSON output
+ output = mock_stdout.getvalue()
+ assert "ExecutionStarted" in output
+ assert "ExecutionSucceeded" in output
+
+
+def test_get_durable_execution_history_command_handles_resource_not_found() -> None:
+ """Test that get-durable-execution-history command handles ResourceNotFoundException."""
+ app = CliApp()
+
+ with patch.object(app, "_create_boto3_client") as mock_create_client:
+ mock_client = mock_create_client.return_value
+ mock_client.get_durable_execution_history.side_effect = Exception(
+ "ResourceNotFoundException: Execution not found"
+ )
+
+ exit_code = app.get_durable_execution_history_command(
+ argparse.Namespace(durable_execution_arn="nonexistent-arn")
+ )
+
+ assert exit_code == 1
+
+
+def test_get_durable_execution_history_command_handles_connection_error() -> None:
+ """Test that get-durable-execution-history command handles connection errors."""
+ app = CliApp()
+
+ with patch.object(app, "_create_boto3_client") as mock_create_client:
+ mock_client = mock_create_client.return_value
+ mock_client.get_durable_execution_history.side_effect = Exception(
+ "Connection refused"
+ )
+
+ exit_code = app.get_durable_execution_history_command(
+ argparse.Namespace(durable_execution_arn="test-arn")
+ )
+
+ assert exit_code == 1
+
+
+def test_create_boto3_client_creates_client_correctly() -> None:
+ """Test that _create_boto3_client creates boto3 client correctly."""
+ app = CliApp()
+
+ with patch("boto3.client") as mock_boto3_client:
+ app._create_boto3_client() # noqa: SLF001
+
+ # Verify boto3 client is created with correct parameters
+ mock_boto3_client.assert_called_once_with(
+ "lambda",
+ endpoint_url=app.config.local_runner_endpoint,
+ region_name=app.config.local_runner_region,
+ )
+
+
+def test_create_boto3_client_handles_creation_failure() -> None:
+ """Test that _create_boto3_client handles client creation failures."""
+ app = CliApp()
+
+ with patch("boto3.client") as mock_boto3_client:
+ mock_boto3_client.side_effect = Exception("Client creation failed")
+
+ with pytest.raises(DurableFunctionsLocalRunnerError) as exc_info:
+ app._create_boto3_client() # noqa: SLF001
+
+ assert "Failed to create boto3 client" in str(exc_info.value)
+ assert "Client creation failed" in str(exc_info.value)
+
+
+def test_cli_app_handles_durable_functions_test_error() -> None:
+ """Test that DurableFunctionsTestError is handled gracefully."""
+ app = CliApp()
+
+ with patch.object(app, "_create_parsers") as mock_setup:
+ from aws_durable_execution_sdk_python_testing.exceptions import (
+ DurableFunctionsTestError,
+ )
+
+ mock_setup.side_effect = DurableFunctionsTestError("Test error")
+
+ exit_code = app.run(["start-server"])
+
+ assert exit_code == 1
+
+
+def test_cli_app_handles_unexpected_exception() -> None:
+ """Test that unexpected exceptions are handled gracefully."""
+ app = CliApp()
+
+ with patch.object(app, "_create_parsers") as mock_setup:
+ mock_setup.side_effect = RuntimeError("Unexpected error")
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.logger"
+ ) as mock_logger:
+ exit_code = app.run(["start-server"])
+
+ assert exit_code == 1
+ mock_logger.exception.assert_called_once_with("Unexpected error.")
+
+
+def test_invoke_command_handles_general_exception() -> None:
+ """Test that invoke command handles general exceptions."""
+ app = CliApp()
+
+ with patch("aws_durable_execution_sdk_python_testing.cli.urlopen") as mock_urlopen:
+ mock_urlopen.side_effect = ValueError("Some unexpected error")
+
+ exit_code = app.invoke_command(
+ argparse.Namespace(
+ function_name="test-function",
+ input="{}",
+ durable_execution_name=None,
+ )
+ )
+
+ assert exit_code == 1
+
+
+def test_get_durable_execution_command_handles_general_exception() -> None:
+ """Test that get-durable-execution command handles general exceptions."""
+ app = CliApp()
+
+ with patch.object(app, "_create_boto3_client") as mock_create_client:
+ mock_client = mock_create_client.return_value
+ mock_client.exceptions.ResourceNotFoundException = ResourceNotFoundException
+ mock_client.exceptions.InvalidParameterValueException = (
+ InvalidParameterValueException
+ )
+ mock_client.exceptions.TooManyRequestsException = TooManyRequestsException
+ mock_client.exceptions.ServiceException = ServiceException
+ mock_client.get_durable_execution.side_effect = ValueError(
+ "Some unexpected error"
+ )
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.logger"
+ ) as mock_logger:
+ exit_code = app.get_durable_execution_command(
+ argparse.Namespace(durable_execution_arn="my-arn")
+ )
+
+ assert exit_code == 1
+ mock_logger.exception.assert_called_once_with(
+ "Unexpected error in get-durable-execution command"
+ )
+
+
+def test_get_durable_execution_history_command_handles_general_exception() -> None:
+ """Test that get-durable-execution-history command handles general exceptions."""
+ app = CliApp()
+
+ with patch.object(app, "_create_boto3_client") as mock_create_client:
+ mock_client = mock_create_client.return_value
+ mock_client.get_durable_execution_history.side_effect = ValueError(
+ "Some unexpected error"
+ )
+
+ exit_code = app.get_durable_execution_history_command(
+ argparse.Namespace(durable_execution_arn="test-arn")
+ )
+
+ assert exit_code == 1
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/client_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/client_test.py
new file mode 100644
index 0000000..15aa366
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/client_test.py
@@ -0,0 +1,103 @@
+"""Unit tests for InMemoryServiceClient."""
+
+import datetime
+from unittest.mock import Mock
+
+from aws_durable_execution_sdk_python.lambda_service import (
+ CheckpointOutput,
+ OperationAction,
+ OperationType,
+ OperationUpdate,
+ StateOutput,
+)
+
+from aws_durable_execution_sdk_python_testing.client import InMemoryServiceClient
+
+
+def test_checkpoint():
+ """Test checkpoint method delegates to processor."""
+ processor = Mock()
+ expected_output = CheckpointOutput(
+ checkpoint_token="new-token", # noqa: S106
+ new_execution_state=Mock(),
+ )
+ processor.process_checkpoint.return_value = expected_output
+
+ client = InMemoryServiceClient(processor)
+
+ updates = [
+ OperationUpdate(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ action=OperationAction.START,
+ )
+ ]
+
+ result = client.checkpoint(
+ "arn:aws:lambda:us-east-1:123456789012:function:test",
+ "token",
+ updates,
+ "client-token",
+ )
+
+ assert result == expected_output
+ processor.process_checkpoint.assert_called_once_with(
+ "token", updates, "client-token"
+ )
+
+
+def test_get_execution_state():
+ """Test get_execution_state method delegates to processor."""
+ processor = Mock()
+ expected_output = StateOutput(operations=[], next_marker="marker")
+ processor.get_execution_state.return_value = expected_output
+
+ client = InMemoryServiceClient(processor)
+
+ result = client.get_execution_state(
+ "arn:aws:lambda:us-east-1:123456789012:function:test", "token", "marker", 500
+ )
+
+ assert result == expected_output
+ processor.get_execution_state.assert_called_once_with("token", "marker", 500)
+
+
+def test_get_execution_state_default_max_items():
+ """Test get_execution_state with default max_items."""
+ processor = Mock()
+ expected_output = StateOutput(operations=[], next_marker="marker")
+ processor.get_execution_state.return_value = expected_output
+
+ client = InMemoryServiceClient(processor)
+
+ result = client.get_execution_state(
+ "arn:aws:lambda:us-east-1:123456789012:function:test", "token", "marker"
+ )
+
+ assert result == expected_output
+ processor.get_execution_state.assert_called_once_with("token", "marker", 1000)
+
+
+def test_stop():
+ """Test stop method returns current datetime."""
+ processor = Mock()
+ client = InMemoryServiceClient(processor)
+
+ before = datetime.datetime.now(tz=datetime.UTC)
+ result = client.stop(
+ "arn:aws:states:us-east-1:123456789012:execution:test", b"payload"
+ )
+ after = datetime.datetime.now(tz=datetime.UTC)
+
+ assert isinstance(result, datetime.datetime)
+ assert before <= result <= after
+
+
+def test_stop_with_none_payload():
+ """Test stop method with None payload."""
+ processor = Mock()
+ client = InMemoryServiceClient(processor)
+
+ result = client.stop("arn:aws:states:us-east-1:123456789012:execution:test", None)
+
+ assert isinstance(result, datetime.datetime)
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/durable_executions_python_testing_library_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/durable_executions_python_testing_library_test.py
new file mode 100644
index 0000000..d827bbf
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/durable_executions_python_testing_library_test.py
@@ -0,0 +1,7 @@
+"""Tests for DurableExecutionsPythonTestingLibrary module."""
+
+import aws_durable_execution_sdk_python_testing # noqa: F401
+
+
+def test_aws_durable_execution_sdk_python_testing_importable():
+ """Test aws_durable_execution_sdk_python_testing is importable."""
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/e2e/__init__.py b/packages/aws-durable-execution-sdk-python-testing/tests/e2e/__init__.py
new file mode 100644
index 0000000..78d8de9
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/e2e/__init__.py
@@ -0,0 +1 @@
+"""Test package"""
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/e2e/basic_success_path_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/e2e/basic_success_path_test.py
new file mode 100644
index 0000000..3e93bcf
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/e2e/basic_success_path_test.py
@@ -0,0 +1,92 @@
+"""Functional tests, covering end-to-end DurableTestRunner."""
+
+import json
+from typing import Any
+
+from aws_durable_execution_sdk_python.context import (
+ DurableContext,
+ durable_step,
+ durable_with_child_context,
+)
+from aws_durable_execution_sdk_python.execution import (
+ InvocationStatus,
+ durable_execution,
+)
+from aws_durable_execution_sdk_python.types import StepContext
+
+from aws_durable_execution_sdk_python_testing.runner import (
+ ContextOperation,
+ DurableFunctionTestResult,
+ DurableFunctionTestRunner,
+ StepOperation,
+)
+from aws_durable_execution_sdk_python.config import Duration
+
+
+# brazil-test-exec pytest test/runner_int_test.py
+def test_basic_durable_function() -> None:
+ @durable_step
+ def one(step_context: StepContext, a: int, b: int) -> str:
+ # print("[DEBUG] one called")
+ return f"{a} {b}"
+
+ @durable_step
+ def two_1(step_context: StepContext, a: int, b: int) -> str:
+ # print("[DEBUG] two_1 called")
+ return f"{a} {b}"
+
+ @durable_step
+ def two_2(step_context: StepContext, a: int, b: int) -> str:
+ # print("[DEBUG] two_2 called")
+ return f"{b} {a}"
+
+ @durable_with_child_context
+ def two(ctx: DurableContext, a: int, b: int) -> str:
+ # print("[DEBUG] two called")
+ two_1_result: str = ctx.step(two_1(a, b))
+ two_2_result: str = ctx.step(two_2(a, b))
+ return f"{two_1_result} {two_2_result}"
+
+ @durable_step
+ def three(step_context: StepContext, a: int, b: int) -> str:
+ # print("[DEBUG] three called")
+ return f"{a} {b}"
+
+ @durable_execution
+ def function_under_test(event: Any, context: DurableContext) -> list[str]:
+ results: list[str] = []
+
+ result_one: str = context.step(one(1, 2))
+ results.append(result_one)
+
+ context.wait(Duration.from_seconds(1))
+
+ result_two: str = context.run_in_child_context(two(3, 4))
+ results.append(result_two)
+
+ result_three: str = context.step(three(5, 6))
+ results.append(result_three)
+
+ return results
+
+ with DurableFunctionTestRunner(handler=function_under_test) as runner:
+ result: DurableFunctionTestResult = runner.run(input="input str", timeout=10)
+
+ assert result.status is InvocationStatus.SUCCEEDED
+ assert result.result == json.dumps(["1 2", "3 4 4 3", "5 6"])
+
+ one_result: StepOperation = result.get_step("one")
+ assert one_result.result == json.dumps("1 2")
+
+ two_result: ContextOperation = result.get_context("two")
+ assert two_result.result == json.dumps("3 4 4 3")
+
+ three_result: StepOperation = result.get_step("three")
+ assert three_result.result == json.dumps("5 6")
+
+ # currently has the optimization where it's not saving child checkpoints after parent done
+ # prob should unpick that for test
+ # two_one_op = cast(StepOperation, two_result_op.get_operation_by_name("two_1"))
+ # assert two_one_op.result == '"3 4"'
+
+ # print("done")
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/event_factory_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/event_factory_test.py
new file mode 100644
index 0000000..1f4238e
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/event_factory_test.py
@@ -0,0 +1,2002 @@
+"""Tests for Event factory methods.
+
+This module tests all the event creation factory methods in the Event class.
+"""
+
+from datetime import UTC, datetime
+from unittest.mock import Mock
+
+import pytest
+from aws_durable_execution_sdk_python.lambda_service import (
+ ErrorObject,
+ OperationStatus,
+ OperationType,
+ StepDetails,
+ OperationUpdate,
+ OperationSubType,
+ OperationAction,
+ StepOptions,
+)
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+from aws_durable_execution_sdk_python_testing.model import (
+ CheckpointDurableExecutionRequest,
+ ErrorResponse,
+ Event,
+ EventCreationContext,
+ EventError,
+ EventInput,
+ EventResult,
+ Execution,
+ ExecutionStartedDetails,
+ LambdaContext,
+ StartDurableExecutionInput,
+)
+
+
+# Helper function to create mock operations
+def create_mock_operation(
+ operation_id: str = "op-1",
+ name: str = "test_op",
+ parent_id=None,
+ status: OperationStatus = OperationStatus.STARTED,
+):
+ from unittest.mock import Mock
+
+ op = Mock()
+ op.operation_id = operation_id
+ op.name = name
+ op.parent_id = parent_id
+ op.status = status
+ return op
+
+
+# region execution-tests
+def test_create_execution_started():
+ from unittest.mock import Mock
+ from aws_durable_execution_sdk_python.lambda_service import ExecutionDetails
+
+ operation = Mock()
+ operation.operation_id = "op-1"
+ operation.name = "test_execution"
+ operation.parent_id = None
+ operation.status = OperationStatus.STARTED
+ operation.start_timestamp = datetime.now(UTC)
+ operation.operation_type = OperationType.EXECUTION
+ operation.sub_type = None
+ operation.execution_details = ExecutionDetails(input_payload='{"test": "data"}')
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ include_execution_data=True,
+ )
+ event = Event.create_execution_event(context)
+
+ assert event.event_type == "ExecutionStarted"
+ assert event.operation_id == "op-1"
+ assert event.name == "test_execution"
+ assert event.execution_started_details.input.payload == '{"test": "data"}'
+ assert event.execution_started_details.execution_timeout == 300
+
+
+def test_create_execution_succeeded():
+ from aws_durable_execution_sdk_python.execution import (
+ DurableExecutionInvocationOutput,
+ InvocationStatus,
+ )
+
+ operation = create_mock_operation("op-1", status=OperationStatus.SUCCEEDED)
+ operation.end_timestamp = datetime.now(UTC)
+
+ result = DurableExecutionInvocationOutput(
+ status=InvocationStatus.SUCCEEDED, result='{"result": "success"}'
+ )
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=2,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ result=result,
+ include_execution_data=True,
+ )
+ event = Event.create_execution_event(context)
+
+ assert event.event_type == "ExecutionSucceeded"
+ assert event.execution_succeeded_details.result.payload == '{"result": "success"}'
+
+
+def test_create_execution_failed():
+ from aws_durable_execution_sdk_python.execution import (
+ DurableExecutionInvocationOutput,
+ InvocationStatus,
+ )
+
+ operation = create_mock_operation("op-1", status=OperationStatus.FAILED)
+ operation.end_timestamp = datetime.now(UTC)
+
+ error_result = DurableExecutionInvocationOutput(
+ status=InvocationStatus.FAILED,
+ error=ErrorObject.from_message("Execution failed"),
+ )
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=3,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ result=error_result,
+ include_execution_data=True,
+ )
+ event = Event.create_execution_event(context)
+
+ assert event.event_type == "ExecutionFailed"
+ assert event.execution_failed_details.error.payload.message == "Execution failed"
+
+
+def test_create_execution_timed_out():
+ from aws_durable_execution_sdk_python.execution import (
+ DurableExecutionInvocationOutput,
+ InvocationStatus,
+ )
+
+ operation = create_mock_operation("op-1", status=OperationStatus.TIMED_OUT)
+ operation.end_timestamp = datetime.now(UTC)
+
+ error_result = DurableExecutionInvocationOutput(
+ status=InvocationStatus.FAILED,
+ error=ErrorObject.from_message("Execution timed out"),
+ )
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=4,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ result=error_result,
+ include_execution_data=True,
+ )
+ event = Event.create_execution_event(context)
+
+ assert event.event_type == "ExecutionTimedOut"
+ assert (
+ event.execution_timed_out_details.error.payload.message == "Execution timed out"
+ )
+
+
+def test_create_execution_stopped():
+ from aws_durable_execution_sdk_python.execution import (
+ DurableExecutionInvocationOutput,
+ InvocationStatus,
+ )
+
+ operation = create_mock_operation("op-1", status=OperationStatus.STOPPED)
+ operation.end_timestamp = datetime.now(UTC)
+
+ error_result = DurableExecutionInvocationOutput(
+ status=InvocationStatus.FAILED,
+ error=ErrorObject.from_message("Execution stopped"),
+ )
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=5,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ result=error_result,
+ include_execution_data=True,
+ )
+ event = Event.create_execution_event(context)
+
+ assert event.event_type == "ExecutionStopped"
+ assert event.execution_stopped_details.error.payload.message == "Execution stopped"
+
+
+def test_create_execution_invalid_status():
+ operation = create_mock_operation("op-1", status=OperationStatus.CANCELLED)
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Operation status .* is not valid for execution operations",
+ ):
+ Event.create_execution_event(context)
+
+
+# endregion execution-tests
+
+
+# region context-tests
+def test_create_context_started():
+ operation = create_mock_operation(
+ "ctx-1", "test_context", status=OperationStatus.STARTED
+ )
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_context_event(context)
+
+ assert event.event_type == "ContextStarted"
+ assert event.operation_id == "ctx-1"
+ assert event.name == "test_context"
+ assert event.context_started_details is not None
+
+
+def test_create_context_succeeded():
+ operation = create_mock_operation("ctx-1", status=OperationStatus.SUCCEEDED)
+ operation.context_details = type(
+ "MockDetails", (), {"result": '{"context": "result"}', "error": None}
+ )()
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=2,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ include_execution_data=True,
+ )
+ event = Event.create_context_event(context)
+
+ assert event.event_type == "ContextSucceeded"
+ assert event.context_succeeded_details.result.payload == '{"context": "result"}'
+
+
+def test_create_context_failed():
+ operation = create_mock_operation("ctx-1", status=OperationStatus.FAILED)
+ error_obj = ErrorObject.from_message("Context failed")
+ operation.context_details = type(
+ "MockDetails", (), {"result": None, "error": error_obj}
+ )()
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=3,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_context_event(context)
+
+ assert event.event_type == "ContextFailed"
+ assert event.context_failed_details.error.payload.message == "Context failed"
+
+
+def test_create_context_invalid_status():
+ operation = create_mock_operation("ctx-1", status=OperationStatus.TIMED_OUT)
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Operation status .* is not valid for context operations",
+ ):
+ Event.create_context_event(context)
+
+
+# endregion context-tests
+
+
+# region wait-tests
+def test_create_wait_started():
+ operation = create_mock_operation("wait-1", status=OperationStatus.STARTED)
+ operation.start_timestamp = datetime.fromisoformat("2024-01-01T12:00:00Z")
+ operation.wait_details = type(
+ "MockDetails",
+ (),
+ {"scheduled_end_timestamp": datetime.fromisoformat("2024-01-01T12:05:00Z")},
+ )()
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_wait_event(context)
+
+ assert event.event_type == "WaitStarted"
+ assert event.wait_started_details.duration == 300
+ assert event.wait_started_details.scheduled_end_timestamp == datetime.fromisoformat(
+ "2024-01-01T12:05:00Z"
+ )
+
+
+def test_create_wait_succeeded():
+ operation = create_mock_operation("wait-1", status=OperationStatus.SUCCEEDED)
+ operation.start_timestamp = datetime.fromisoformat("2024-01-01T12:00:00Z")
+ operation.wait_details = type(
+ "MockDetails",
+ (),
+ {"scheduled_end_timestamp": datetime.fromisoformat("2024-01-01T12:05:00Z")},
+ )()
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=2,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_wait_event(context)
+
+ assert event.event_type == "WaitSucceeded"
+ assert event.wait_succeeded_details.duration == 300
+
+
+def test_create_wait_cancelled():
+ operation = create_mock_operation("wait-1", status=OperationStatus.CANCELLED)
+ operation.wait_details = None
+ mock_operation_update = Mock()
+ mock_operation_update.operation_type = OperationType.WAIT
+ mock_operation_update.operation_update.action = OperationAction.CANCEL
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=3,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ operation_update=mock_operation_update,
+ )
+ event = Event.create_wait_event(context)
+
+ assert event.event_type == "WaitCancelled"
+ assert event.wait_cancelled_details is not None
+
+
+def test_create_wait_invalid_status():
+ operation = create_mock_operation("wait-1", status=OperationStatus.FAILED)
+ operation.wait_details.scheduled_end_timestamp = operation.start_timestamp = (
+ datetime.fromisoformat("2024-01-01T12:00:00Z")
+ )
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Operation status .* is not valid for wait operations",
+ ):
+ Event.create_wait_event(context)
+
+
+# endregion wait-tests
+
+
+# region step-tests
+def test_create_step_started():
+ operation = create_mock_operation(
+ "step-1", "test_step", status=OperationStatus.STARTED
+ )
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_step_event(context)
+
+ assert event.event_type == "StepStarted"
+ assert event.operation_id == "step-1"
+ assert event.name == "test_step"
+ assert event.step_started_details is not None
+
+
+def test_create_step_succeeded():
+ operation = create_mock_operation("step-1", status=OperationStatus.SUCCEEDED)
+ operation.step_details = type(
+ "MockDetails", (), {"result": '{"step": "result"}', "error": None}
+ )()
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=2,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ include_execution_data=True,
+ )
+ event = Event.create_step_event(context)
+
+ assert event.event_type == "StepSucceeded"
+ assert event.step_succeeded_details.result.payload == '{"step": "result"}'
+
+
+def test_create_step_failed():
+ operation = create_mock_operation("step-1", status=OperationStatus.FAILED)
+ error_obj = ErrorObject.from_message("Step failed")
+ operation.step_details = type(
+ "MockDetails", (), {"result": None, "error": error_obj}
+ )()
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=3,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_step_event(context)
+
+ assert event.event_type == "StepFailed"
+ assert event.step_failed_details.error.payload.message == "Step failed"
+
+
+def test_create_step_invalid_status():
+ operation = create_mock_operation("step-1", status=OperationStatus.TIMED_OUT)
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Operation status .* is not valid for step operations",
+ ):
+ Event.create_step_event(context)
+
+
+# endregion step-tests
+
+
+# region chained_invoke
+def test_create_chained_invoke_started():
+ operation = create_mock_operation(
+ "invoke-1", "test_invoke", status=OperationStatus.STARTED
+ )
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_chained_invoke_event(context)
+
+ assert event.event_type == "ChainedInvokeStarted"
+ assert event.operation_id == "invoke-1"
+ assert event.name == "test_invoke"
+ assert event.chained_invoke_started_details is not None
+
+
+# endregion callback
+
+
+# endregion helpers-test
+
+
+def test_create_chained_invoke_succeeded():
+ operation = create_mock_operation("invoke-1", status=OperationStatus.SUCCEEDED)
+ operation.chained_invoke_details = type(
+ "MockDetails", (), {"result": '{"invoke": "result"}', "error": None}
+ )()
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=2,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ include_execution_data=True,
+ )
+ event = Event.create_chained_invoke_event(context)
+
+ assert event.event_type == "ChainedInvokeSucceeded"
+ assert (
+ event.chained_invoke_succeeded_details.result.payload == '{"invoke": "result"}'
+ )
+
+
+def test_create_chained_invoke_failed():
+ operation = create_mock_operation("invoke-1", status=OperationStatus.FAILED)
+ error_obj = ErrorObject.from_message("Invoke failed")
+ operation.chained_invoke_details = type(
+ "MockDetails", (), {"result": None, "error": error_obj}
+ )()
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=3,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_chained_invoke_event(context)
+
+ assert event.event_type == "ChainedInvokeFailed"
+ assert event.chained_invoke_failed_details.error.payload.message == "Invoke failed"
+
+
+def test_create_chained_invoke_timed_out():
+ operation = create_mock_operation("invoke-1", status=OperationStatus.TIMED_OUT)
+ error_obj = ErrorObject.from_message("Invoke timed out")
+ operation.chained_invoke_details = type(
+ "MockDetails", (), {"result": None, "error": error_obj}
+ )()
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=4,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_chained_invoke_event(context)
+
+ assert event.event_type == "ChainedInvokeTimedOut"
+ assert (
+ event.chained_invoke_timed_out_details.error.payload.message
+ == "Invoke timed out"
+ )
+
+
+def test_create_chained_invoke_stopped():
+ operation = create_mock_operation("invoke-1", status=OperationStatus.STOPPED)
+ error_obj = ErrorObject.from_message("Invoke stopped")
+ operation.chained_invoke_details = type(
+ "MockDetails", (), {"result": None, "error": error_obj}
+ )()
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=5,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_chained_invoke_event(context)
+
+ assert event.event_type == "ChainedInvokeStopped"
+ assert (
+ event.chained_invoke_stopped_details.error.payload.message == "Invoke stopped"
+ )
+
+
+def test_create_chained_invoke_invalid_status():
+ operation = create_mock_operation("invoke-1", status=OperationStatus.CANCELLED)
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Operation status .* is not valid for chained invoke operations",
+ ):
+ Event.create_chained_invoke_event(context)
+
+
+# endregion chained_invoke
+
+
+# region callback-tests
+def test_create_callback_started():
+ operation = create_mock_operation(
+ "callback-1", "test_callback", status=OperationStatus.STARTED
+ )
+ operation.callback_details = type(
+ "MockDetails", (), {"callback_id": "cb-123", "result": None, "error": None}
+ )()
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_callback_event(context)
+
+ assert event.event_type == "CallbackStarted"
+ assert event.operation_id == "callback-1"
+ assert event.name == "test_callback"
+ assert event.callback_started_details.callback_id == "cb-123"
+
+
+def test_create_callback_succeeded():
+ operation = create_mock_operation("callback-1", status=OperationStatus.SUCCEEDED)
+ operation.callback_details = type(
+ "MockDetails",
+ (),
+ {"callback_id": None, "result": '{"callback": "result"}', "error": None},
+ )()
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=2,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ include_execution_data=True,
+ )
+ event = Event.create_callback_event(context)
+
+ assert event.event_type == "CallbackSucceeded"
+ assert event.callback_succeeded_details.result.payload == '{"callback": "result"}'
+
+
+def test_create_callback_failed():
+ operation = create_mock_operation("callback-1", status=OperationStatus.FAILED)
+ error_obj = ErrorObject.from_message("Callback failed")
+ operation.callback_details = type(
+ "MockDetails", (), {"callback_id": None, "result": None, "error": error_obj}
+ )()
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=3,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_callback_event(context)
+
+ assert event.event_type == "CallbackFailed"
+ assert event.callback_failed_details.error.payload.message == "Callback failed"
+
+
+def test_create_callback_timed_out():
+ operation = create_mock_operation("callback-1", status=OperationStatus.TIMED_OUT)
+ error_obj = ErrorObject.from_message("Callback timed out")
+ operation.callback_details = type(
+ "MockDetails", (), {"callback_id": None, "result": None, "error": error_obj}
+ )()
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=4,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_callback_event(context)
+
+ assert event.event_type == "CallbackTimedOut"
+ assert (
+ event.callback_timed_out_details.error.payload.message == "Callback timed out"
+ )
+
+
+def test_create_callback_invalid_status():
+ operation = create_mock_operation("callback-1", status=OperationStatus.STOPPED)
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Operation status .* is not valid for callback operations",
+ ):
+ Event.create_callback_event(context)
+
+
+# endregion callback-tests
+
+
+# region model-tests
+def test_lambda_context():
+ context = LambdaContext(aws_request_id="test-123")
+ assert context.aws_request_id == "test-123"
+ assert context.get_remaining_time_in_millis() == 900000
+ context.log("test message") # Should not raise
+
+
+def test_start_durable_execution_input_missing_field():
+ with pytest.raises(
+ InvalidParameterValueException, match="Missing required field: AccountId"
+ ):
+ StartDurableExecutionInput.from_dict({})
+
+
+def test_start_durable_execution_input_to_dict_with_optionals():
+ input_obj = StartDurableExecutionInput(
+ account_id="123456789",
+ function_name="test-func",
+ function_qualifier="$LATEST",
+ execution_name="test-exec",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="inv-123",
+ trace_fields={"key": "value"},
+ tenant_id="tenant-123",
+ input='{"test": "data"}',
+ )
+ result = input_obj.to_dict()
+ assert result["InvocationId"] == "inv-123"
+ assert result["TraceFields"] == {"key": "value"}
+ assert result["TenantId"] == "tenant-123"
+ assert result["Input"] == '{"test": "data"}'
+
+
+def test_execution_from_dict_empty_function_arn():
+ data = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789:function:test",
+ "DurableExecutionName": "test-exec",
+ "Status": "SUCCEEDED",
+ "StartTimestamp": 1640995200.0,
+ }
+ execution = Execution.from_dict(data)
+ assert execution.function_arn == ""
+
+
+def test_execution_to_dict_with_function_arn():
+ execution = Execution(
+ durable_execution_arn="arn:aws:lambda:us-east-1:123456789:function:test",
+ durable_execution_name="test-exec",
+ function_arn="arn:aws:lambda:us-east-1:123456789:function:test",
+ status="SUCCEEDED",
+ start_timestamp=1640995200.0,
+ )
+ result = execution.to_dict()
+ assert "FunctionArn" in result
+
+
+def test_event_input_from_details():
+ from aws_durable_execution_sdk_python.lambda_service import ExecutionDetails
+
+ details = ExecutionDetails(input_payload='{"test": "data"}')
+ event_input = EventInput.from_details(details, include=True)
+ assert event_input.payload == '{"test": "data"}'
+ assert not event_input.truncated
+
+ event_input_truncated = EventInput.from_details(details, include=False)
+ assert event_input_truncated.payload is None
+ assert event_input_truncated.truncated
+
+
+def test_event_result_from_details():
+ from aws_durable_execution_sdk_python.lambda_service import StepDetails
+
+ details = StepDetails(result='{"result": "success"}')
+ event_result = EventResult.from_details(details, include=True)
+ assert event_result.payload == '{"result": "success"}'
+ assert not event_result.truncated
+
+
+def test_event_error_from_details():
+ from aws_durable_execution_sdk_python.lambda_service import StepDetails
+
+ error_obj = ErrorObject.from_message("Test error")
+ details = StepDetails(error=error_obj)
+ event_error = EventError.from_details(details)
+ assert event_error.payload.message == "Test error"
+
+
+def test_event_from_dict_with_all_details():
+ data = {
+ "EventType": "ExecutionStarted",
+ "EventTimestamp": datetime.fromisoformat("2024-01-01T12:00:00Z"),
+ "EventId": 1,
+ "Id": "op-1",
+ "Name": "test",
+ "ParentId": "parent-1",
+ "SubType": "test-subtype",
+ "ExecutionStartedDetails": {
+ "Input": {"Payload": '{"test": "data"}', "Truncated": False},
+ "ExecutionTimeout": 300,
+ },
+ }
+ event = Event.from_dict(data)
+ assert event.sub_type == "test-subtype"
+ assert event.parent_id == "parent-1"
+
+
+def test_event_to_dict_with_all_details():
+ event = Event(
+ event_type="ExecutionStarted",
+ event_timestamp=datetime.fromisoformat("2024-01-01T12:00:00Z"),
+ event_id=1,
+ operation_id="op-1",
+ name="test",
+ parent_id="parent-1",
+ sub_type="test-subtype",
+ execution_started_details=ExecutionStartedDetails(
+ input=EventInput(payload='{"test": "data"}', truncated=False),
+ execution_timeout=300,
+ ),
+ )
+ result = event.to_dict()
+ assert result["SubType"] == "test-subtype"
+ assert result["ParentId"] == "parent-1"
+ assert result["ExecutionStartedDetails"]["ExecutionTimeout"] == 300
+
+
+def test_error_response_from_dict_nested():
+ data = {
+ "error": {
+ "type": "ValidationError",
+ "message": "Invalid input",
+ "code": "400",
+ "requestId": "req-123",
+ }
+ }
+ error_response = ErrorResponse.from_dict(data)
+ assert error_response.error_type == "ValidationError"
+ assert error_response.error_message == "Invalid input"
+ assert error_response.error_code == "400"
+ assert error_response.request_id == "req-123"
+
+
+def test_error_response_from_dict_flat():
+ data = {"type": "ValidationError", "message": "Invalid input"}
+ error_response = ErrorResponse.from_dict(data)
+ assert error_response.error_type == "ValidationError"
+ assert error_response.error_message == "Invalid input"
+
+
+def test_checkpoint_durable_execution_request_from_dict():
+ token: str = "token-123"
+ data = {
+ "CheckpointToken": token,
+ "Updates": [
+ {"Id": "op-1", "Type": "STEP", "Action": "START", "SubType": "Step"}
+ ],
+ }
+ request = CheckpointDurableExecutionRequest.from_dict(data, "arn:test")
+ assert request.checkpoint_token == token
+ assert len(request.updates) == 1
+ assert request.updates[0].operation_id == "op-1"
+
+
+# endregion model-tests
+
+
+# region from_operation_started_tests
+class TestFromOperationStarted:
+ """Tests for Event.from_operation_started method."""
+
+ def test_from_operation_started_execution(self):
+ """Test converting execution operation to started event."""
+ operation = Mock()
+ operation.operation_id = "exec-123"
+ operation.name = "test_execution"
+ operation.parent_id = "parent-123"
+ operation.operation_type = OperationType.EXECUTION
+ operation.start_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC)
+
+ execution_details = Mock()
+ execution_details.input_payload = '{"test": "data"}'
+ operation.execution_details = execution_details
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ include_execution_data=True,
+ )
+ event = Event.create_event_started(context)
+
+ assert event.event_type == "ExecutionStarted"
+ assert event.operation_id == "exec-123"
+ assert event.name == "test_execution"
+ assert event.parent_id == "parent-123"
+ assert event.execution_started_details.input.payload == '{"test": "data"}'
+ assert not event.execution_started_details.input.truncated
+
+ def test_from_operation_started_execution_no_data(self):
+ """Test execution operation with include_execution_data=False."""
+ operation = Mock()
+ operation.operation_id = "exec-123"
+ operation.name = "test_execution"
+ operation.parent_id = "parent-123"
+ operation.operation_type = OperationType.EXECUTION
+ operation.start_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC)
+
+ execution_details = Mock()
+ execution_details.input_payload = '{"test": "data"}'
+ operation.execution_details = execution_details
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ include_execution_data=False,
+ )
+ event = Event.create_event_started(context)
+
+ assert event.event_type == "ExecutionStarted"
+ assert event.execution_started_details.input.payload is None
+ assert event.execution_started_details.input.truncated
+
+ def test_from_operation_started_step(self):
+ """Test converting step operation to started event."""
+ operation = Mock()
+ operation.operation_id = "step-123"
+ operation.name = "test_step"
+ operation.parent_id = "parent-123"
+ operation.operation_type = OperationType.STEP
+ operation.start_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC)
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=2,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_event_started(context)
+
+ assert event.event_type == "StepStarted"
+ assert event.operation_id == "step-123"
+ assert event.name == "test_step"
+ assert event.parent_id == "parent-123"
+ assert event.step_started_details is not None
+
+ def test_from_operation_started_wait(self):
+ """Test converting wait operation to started event."""
+ operation = Mock()
+ operation.operation_id = "wait-123"
+ operation.name = "test_wait"
+ operation.parent_id = "parent-123"
+ operation.operation_type = OperationType.WAIT
+ operation.start_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC)
+
+ wait_details = Mock()
+ wait_details.scheduled_end_timestamp = datetime(
+ 2024, 1, 1, 12, 5, 0, tzinfo=UTC
+ )
+ operation.wait_details = wait_details
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=3,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_event_started(context)
+
+ assert event.event_type == "WaitStarted"
+ assert event.operation_id == "wait-123"
+ assert event.name == "test_wait"
+ assert event.parent_id == "parent-123"
+ assert event.wait_started_details.duration == 300
+ assert (
+ event.wait_started_details.scheduled_end_timestamp
+ == datetime.fromisoformat("2024-01-01T12:05:00+00:00")
+ )
+
+ def test_from_operation_started_callback(self):
+ """Test converting callback operation to started event."""
+ operation = Mock()
+ operation.operation_id = "callback-123"
+ operation.name = "test_callback"
+ operation.parent_id = "parent-123"
+ operation.operation_type = OperationType.CALLBACK
+ operation.start_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC)
+
+ callback_details = Mock()
+ callback_details.callback_id = "cb-456"
+ operation.callback_details = callback_details
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=4,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_event_started(context)
+
+ assert event.event_type == "CallbackStarted"
+ assert event.operation_id == "callback-123"
+ assert event.name == "test_callback"
+ assert event.parent_id == "parent-123"
+ assert event.callback_started_details.callback_id == "cb-456"
+
+ def test_from_operation_started_chained_invoke(self):
+ """Test converting chained invoke operation to started event."""
+ operation = Mock()
+ operation.operation_id = "invoke-123"
+ operation.name = "test_invoke"
+ operation.parent_id = "parent-123"
+ operation.operation_type = OperationType.CHAINED_INVOKE
+ operation.start_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC)
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=5,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_event_started(context)
+
+ assert event.event_type == "ChainedInvokeStarted"
+ assert event.operation_id == "invoke-123"
+ assert event.name == "test_invoke"
+ assert event.parent_id == "parent-123"
+ assert event.chained_invoke_started_details is not None
+
+ def test_from_operation_started_context(self):
+ """Test converting context operation to started event."""
+ operation = Mock()
+ operation.operation_id = "context-123"
+ operation.name = "test_context"
+ operation.parent_id = "parent-123"
+ operation.operation_type = OperationType.CONTEXT
+ operation.start_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC)
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=6,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_event_started(context)
+
+ assert event.event_type == "ContextStarted"
+ assert event.operation_id == "context-123"
+ assert event.name == "test_context"
+ assert event.parent_id == "parent-123"
+ assert event.context_started_details is not None
+
+ def test_from_operation_started_no_timestamp(self):
+ """Test error when operation has no start timestamp."""
+ operation = Mock()
+ operation.start_timestamp = None
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Operation start timestamp cannot be None",
+ ):
+ Event.create_event_started(context)
+
+ def test_from_operation_started_unknown_type(self):
+ """Test error with unknown operation type."""
+ operation = Mock()
+ operation.operation_type = "UNKNOWN_TYPE"
+ operation.start_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC)
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ with pytest.raises(
+ InvalidParameterValueException, match="Unknown operation type: UNKNOWN_TYPE"
+ ):
+ Event.create_event_started(context)
+
+
+# endregion from_operation_started_tests
+
+
+# region from_operation_finished_tests
+class TestFromOperationFinished:
+ """Tests for Event.from_operation_finished method."""
+
+ def test_from_operation_finished_execution_succeeded(self):
+ """Test converting succeeded execution operation to finished event."""
+ operation = Mock()
+ operation.operation_id = "exec-123"
+ operation.name = "test_execution"
+ operation.parent_id = "parent-123"
+ operation.operation_type = OperationType.EXECUTION
+ operation.status = OperationStatus.SUCCEEDED
+ operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC)
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_event_terminated(context)
+
+ assert event.event_type == "ExecutionSucceeded"
+ assert event.operation_id == "exec-123"
+ assert event.name == "test_execution"
+ assert event.parent_id == "parent-123"
+
+ def test_from_operation_finished_execution_failed(self):
+ """Test converting failed execution operation to finished event."""
+ operation = Mock()
+ operation.operation_id = "exec-123"
+ operation.name = "test_execution"
+ operation.parent_id = "parent-123"
+ operation.operation_type = OperationType.EXECUTION
+ operation.status = OperationStatus.FAILED
+ operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC)
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_event_terminated(context)
+
+ assert event.event_type == "ExecutionFailed"
+ assert event.operation_id == "exec-123"
+
+ def test_from_operation_finished_step_with_result(self):
+ """Test converting succeeded step operation with result."""
+ operation = Mock()
+ operation.operation_id = "step-123"
+ operation.name = "test_step"
+ operation.parent_id = "parent-123"
+ operation.operation_type = OperationType.STEP
+ operation.status = OperationStatus.SUCCEEDED
+ operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC)
+
+ step_details = Mock()
+ step_details.result = '{"result": "success"}'
+ step_details.error = None
+ operation.step_details = step_details
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=2,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ include_execution_data=True,
+ )
+ event = Event.create_event_terminated(context)
+
+ assert event.event_type == "StepSucceeded"
+ assert event.operation_id == "step-123"
+ assert event.step_succeeded_details.result.payload == '{"result": "success"}'
+
+ def test_from_operation_finished_step_with_error(self):
+ """Test converting failed step operation with error."""
+ operation = Mock()
+ operation.operation_id = "step-123"
+ operation.name = "test_step"
+ operation.parent_id = "parent-123"
+ operation.operation_type = OperationType.STEP
+ operation.status = OperationStatus.FAILED
+ operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC)
+
+ step_details = Mock()
+ step_details.result = None
+ step_details.error = ErrorObject.from_message("Step failed")
+ operation.step_details = step_details
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=2,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_event_terminated(context)
+
+ assert event.event_type == "StepFailed"
+ assert event.step_failed_details.error.payload.message == "Step failed"
+
+ def test_from_operation_finished_wait_succeeded(self):
+ """Test converting succeeded wait operation."""
+ operation = Mock()
+ operation.operation_id = "wait-123"
+ operation.name = "test_wait"
+ operation.parent_id = "parent-123"
+ operation.operation_type = OperationType.WAIT
+ operation.status = OperationStatus.SUCCEEDED
+ operation.start_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC)
+ operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC)
+
+ wait_details = Mock()
+ wait_details.scheduled_end_timestamp = datetime(
+ 2024, 1, 1, 12, 5, 0, tzinfo=UTC
+ )
+ operation.wait_details = wait_details
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=3,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_event_terminated(context)
+
+ assert event.event_type == "WaitSucceeded"
+ assert event.wait_succeeded_details.duration == 300
+
+ def test_from_operation_finished_wait_cancelled(self):
+ """Test converting cancelled wait operation."""
+ operation = Mock()
+ operation.operation_id = "wait-123"
+ operation.name = "test_wait"
+ operation.parent_id = "parent-123"
+ operation.operation_type = OperationType.WAIT
+ operation.status = OperationStatus.CANCELLED
+ operation.end_timestamp = datetime(2024, 1, 1, 12, 3, 0, tzinfo=UTC)
+ operation.wait_details = None
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=3,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_event_terminated(context)
+
+ assert event.event_type == "WaitCancelled"
+ assert event.wait_cancelled_details is not None
+
+ def test_from_operation_finished_callback_succeeded(self):
+ """Test converting succeeded callback operation."""
+ operation = Mock()
+ operation.operation_id = "callback-123"
+ operation.name = "test_callback"
+ operation.parent_id = "parent-123"
+ operation.operation_type = OperationType.CALLBACK
+ operation.status = OperationStatus.SUCCEEDED
+ operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC)
+
+ callback_details = Mock()
+ callback_details.result = '{"callback": "result"}'
+ callback_details.error = None
+ operation.callback_details = callback_details
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=4,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ include_execution_data=True,
+ )
+ event = Event.create_event_terminated(context)
+
+ assert event.event_type == "CallbackSucceeded"
+ assert (
+ event.callback_succeeded_details.result.payload == '{"callback": "result"}'
+ )
+
+ def test_from_operation_finished_callback_timed_out(self):
+ """Test converting timed out callback operation."""
+ operation = Mock()
+ operation.operation_id = "callback-123"
+ operation.name = "test_callback"
+ operation.parent_id = "parent-123"
+ operation.operation_type = OperationType.CALLBACK
+ operation.status = OperationStatus.TIMED_OUT
+ operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC)
+
+ callback_details = Mock()
+ callback_details.result = None
+ callback_details.error = ErrorObject.from_message("Callback timed out")
+ operation.callback_details = callback_details
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=4,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_event_terminated(context)
+
+ assert event.event_type == "CallbackTimedOut"
+ assert (
+ event.callback_timed_out_details.error.payload.message
+ == "Callback timed out"
+ )
+
+ def test_from_operation_finished_chained_invoke_succeeded(self):
+ """Test converting succeeded chained invoke operation."""
+ operation = Mock()
+ operation.operation_id = "invoke-123"
+ operation.name = "test_invoke"
+ operation.parent_id = "parent-123"
+ operation.operation_type = OperationType.CHAINED_INVOKE
+ operation.status = OperationStatus.SUCCEEDED
+ operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC)
+
+ chained_invoke_details = Mock()
+ chained_invoke_details.result = '{"invoke": "result"}'
+ chained_invoke_details.error = None
+ operation.chained_invoke_details = chained_invoke_details
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=5,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ include_execution_data=True,
+ )
+ event = Event.create_event_terminated(context)
+
+ assert event.event_type == "ChainedInvokeSucceeded"
+ assert (
+ event.chained_invoke_succeeded_details.result.payload
+ == '{"invoke": "result"}'
+ )
+
+ def test_from_operation_finished_chained_invoke_stopped(self):
+ """Test converting stopped chained invoke operation."""
+ operation = Mock()
+ operation.operation_id = "invoke-123"
+ operation.name = "test_invoke"
+ operation.parent_id = "parent-123"
+ operation.operation_type = OperationType.CHAINED_INVOKE
+ operation.status = OperationStatus.STOPPED
+ operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC)
+
+ chained_invoke_details = Mock()
+ chained_invoke_details.result = None
+ chained_invoke_details.error = ErrorObject.from_message("Invoke stopped")
+ operation.chained_invoke_details = chained_invoke_details
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=5,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_event_terminated(context)
+
+ assert event.event_type == "ChainedInvokeStopped"
+ assert (
+ event.chained_invoke_stopped_details.error.payload.message
+ == "Invoke stopped"
+ )
+
+ def test_from_operation_finished_context_succeeded(self):
+ """Test converting succeeded context operation."""
+ operation = Mock()
+ operation.operation_id = "context-123"
+ operation.name = "test_context"
+ operation.parent_id = "parent-123"
+ operation.operation_type = OperationType.CONTEXT
+ operation.status = OperationStatus.SUCCEEDED
+ operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC)
+
+ context_details = Mock()
+ context_details.result = '{"context": "result"}'
+ context_details.error = None
+ operation.context_details = context_details
+ operation.result = None
+ operation.error = None
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=6,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ include_execution_data=True,
+ )
+ event = Event.create_event_terminated(context)
+
+ assert event.event_type == "ContextSucceeded"
+ assert event.context_succeeded_details.result.payload == '{"context": "result"}'
+
+ def test_from_operation_finished_context_failed(self):
+ """Test converting failed context operation."""
+ operation = Mock()
+ operation.operation_id = "context-123"
+ operation.name = "test_context"
+ operation.parent_id = "parent-123"
+ operation.operation_type = OperationType.CONTEXT
+ operation.status = OperationStatus.FAILED
+ operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC)
+
+ context_details = Mock()
+ context_details.result = None
+ context_details.error = ErrorObject.from_message("Context failed")
+ operation.context_details = context_details
+ operation.result = None
+ operation.error = None
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=6,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_event_terminated(context)
+
+ assert event.event_type == "ContextFailed"
+ assert event.context_failed_details.error.payload.message == "Context failed"
+
+ def test_from_operation_finished_no_end_timestamp(self):
+ """Test error when operation has no end timestamp."""
+ operation = Mock()
+ operation.end_timestamp = None
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Operation end timestamp cannot be None",
+ ):
+ Event.create_event_terminated(context)
+
+ def test_from_operation_finished_invalid_status(self):
+ """Test error with invalid operation status."""
+ operation = Mock()
+ operation.status = OperationStatus.STARTED
+ operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC)
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Operation status must be one of SUCCEEDED, FAILED, TIMED_OUT, STOPPED, or CANCELLED",
+ ):
+ Event.create_event_terminated(context)
+
+ def test_from_operation_finished_unknown_type(self):
+ """Test error with unknown operation type."""
+ operation = Mock()
+ operation.operation_type = "UNKNOWN_TYPE"
+ operation.status = OperationStatus.SUCCEEDED
+ operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC)
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ with pytest.raises(
+ InvalidParameterValueException, match="Unknown operation type: UNKNOWN_TYPE"
+ ):
+ Event.create_event_terminated(context)
+
+ def test_from_operation_finished_no_details(self):
+ """Test operations with no detail objects."""
+ operation = Mock()
+ operation.operation_id = "step-123"
+ operation.name = "test_step"
+ operation.parent_id = "parent-123"
+ operation.operation_type = OperationType.STEP
+ operation.status = OperationStatus.SUCCEEDED
+ operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC)
+ operation.step_details = None
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=2,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+ event = Event.create_event_terminated(context)
+
+ assert event.event_type == "StepSucceeded"
+ assert event.step_succeeded_details.result is None
+
+
+# endregion from_operation_finished_tests
+
+
+def test_chained_invoke_pending_details_from_dict():
+ """Test ChainedInvokePendingDetails parsing in Event.from_dict."""
+ data = {
+ "EventType": "ChainedInvokeStarted",
+ "EventTimestamp": datetime.now(UTC),
+ "ChainedInvokePendingDetails": {
+ "Input": {"Payload": "test-input", "Truncated": False},
+ "FunctionName": "test-function",
+ },
+ }
+
+ event = Event.from_dict(data)
+ assert event.chained_invoke_pending_details is not None
+ assert event.chained_invoke_pending_details.input.payload == "test-input"
+ assert event.chained_invoke_pending_details.function_name == "test-function"
+
+
+def test_event_creation_context_sub_type_property():
+ """Test EventCreationContext.sub_type property with and without sub_type."""
+ # Test with sub_type
+ operation = Mock()
+ operation.sub_type = OperationSubType.STEP
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+
+ assert context.sub_type == "Step"
+
+ # Test without sub_type
+ operation.sub_type = None
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+
+ assert context.sub_type is None
+
+
+def test_event_creation_context_get_retry_details():
+ """Test EventCreationContext.get_retry_details method."""
+ operation = Mock()
+ operation.step_details = StepDetails(attempt=2)
+
+ operation_update = OperationUpdate(
+ operation_id="step-1",
+ operation_type=OperationType.STEP,
+ action=OperationAction.SUCCEED,
+ step_options=StepOptions(next_attempt_delay_seconds=30),
+ )
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ operation_update=operation_update,
+ )
+
+ retry_details = context.get_retry_details()
+ assert retry_details is not None
+ assert retry_details.current_attempt == 2
+ assert retry_details.next_attempt_delay_seconds == 30
+
+ # Test with no step_details
+ operation.step_details = None
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ operation_update=operation_update,
+ )
+
+ retry_details = context.get_retry_details()
+ assert retry_details is None
+
+ # Test with no operation_update
+ operation.step_details = StepDetails(attempt=2)
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ )
+
+ retry_details = context.get_retry_details()
+ assert retry_details is None
+
+
+def test_create_chained_invoke_event_pending():
+ """Test Event.create_chained_invoke_event_pending method."""
+ operation = Mock()
+ operation.operation_id = "invoke-1"
+ operation.name = "test_invoke"
+ operation.parent_id = None
+ operation.status = OperationStatus.PENDING
+ operation.start_timestamp = datetime.now(UTC)
+ operation.sub_type = None
+
+ context = EventCreationContext.create(
+ operation=operation,
+ event_id=1,
+ durable_execution_arn="arn:test",
+ start_input=StartDurableExecutionInput(
+ account_id="123",
+ function_name="test",
+ function_qualifier="$LATEST",
+ execution_name="test",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ ),
+ include_execution_data=True,
+ )
+
+ event = Event.create_chained_invoke_event_pending(context)
+
+ assert event.event_type == "ChainedInvokeStarted"
+ assert event.operation_id == "invoke-1"
+ assert event.name == "test_invoke"
+ assert event.chained_invoke_pending_details is not None
+ assert event.chained_invoke_pending_details.function_name == "test"
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/exceptions_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/exceptions_test.py
new file mode 100644
index 0000000..1f1ea55
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/exceptions_test.py
@@ -0,0 +1,953 @@
+"""Tests for AWS-compliant exceptions and their boto3 compatibility.
+
+This module contains comprehensive tests for all exception types used in the
+AWS Durable Execution SDK Python Testing framework, including validation
+of boto3 compatibility for proper AWS service integration.
+"""
+
+import json
+
+import pytest
+
+from aws_durable_execution_sdk_python_testing import exceptions
+
+
+# =============================================================================
+# Base Exception Tests
+# =============================================================================
+
+
+def test_durable_functions_test_error_base_exception() -> None:
+ """Test DurableFunctionsTestError base exception."""
+ error = exceptions.DurableFunctionsTestError("Base error message")
+
+ assert str(error) == "Base error message"
+ assert isinstance(error, Exception)
+
+
+def test_durable_functions_local_runner_error_base_exception() -> None:
+ """Test DurableFunctionsLocalRunnerError base exception."""
+ error = exceptions.DurableFunctionsLocalRunnerError("Local runner error")
+
+ assert str(error) == "Local runner error"
+ assert isinstance(error, Exception)
+
+
+def test_serialization_error() -> None:
+ """Test SerializationError for serialization failures."""
+ error = exceptions.SerializationError("Failed to serialize data")
+
+ assert str(error) == "Failed to serialize data"
+ assert isinstance(error, exceptions.DurableFunctionsLocalRunnerError)
+
+
+def test_unknown_route_error() -> None:
+ """Test UnknownRouteError for unknown HTTP routes."""
+ error = exceptions.UnknownRouteError("POST", "/unknown/path")
+
+ assert str(error) == "Unknown path pattern: POST /unknown/path"
+ assert error.method == "POST"
+ assert error.path == "/unknown/path"
+ assert isinstance(error, exceptions.DurableFunctionsLocalRunnerError)
+
+
+def test_aws_api_exception_base() -> None:
+ """Test AwsApiException base class."""
+ # AwsApiException is abstract, so we test with a concrete implementation
+ error = exceptions.ServiceException("Test service error")
+
+ assert isinstance(error, exceptions.AwsApiException)
+ assert isinstance(error, exceptions.DurableFunctionsLocalRunnerError)
+ assert error.http_status_code == 500
+
+
+def test_exception_hierarchy() -> None:
+ """Test that all custom exceptions inherit from appropriate base exceptions."""
+ # Test AWS API exceptions
+ aws_exceptions = [
+ exceptions.IllegalStateException("test"),
+ exceptions.InvalidParameterValueException("test"),
+ exceptions.ResourceNotFoundException("test"),
+ exceptions.ServiceException("test"),
+ exceptions.CallbackTimeoutException("test"),
+ ]
+
+ for aws_exception in aws_exceptions:
+ assert isinstance(aws_exception, exceptions.AwsApiException)
+ assert isinstance(aws_exception, exceptions.DurableFunctionsLocalRunnerError)
+ assert isinstance(aws_exception, Exception)
+
+ # Test local runner exceptions
+ local_exceptions = [
+ exceptions.SerializationError("test"),
+ exceptions.UnknownRouteError("GET", "/test"),
+ ]
+
+ for local_exception in local_exceptions:
+ assert isinstance(local_exception, exceptions.DurableFunctionsLocalRunnerError)
+ assert isinstance(local_exception, Exception)
+
+ # Test testing exceptions
+ test_error = exceptions.DurableFunctionsTestError("test")
+ assert isinstance(test_error, Exception)
+
+
+def test_illegal_argument_exception() -> None:
+ """Test IllegalArgumentException maps to InvalidParameterValueException."""
+ error = exceptions.IllegalArgumentException("Invalid argument provided")
+
+ assert str(error) == "Invalid argument provided"
+ assert isinstance(error, exceptions.AwsApiException)
+ assert error.http_status_code == 400
+
+ # Test serialization maps to InvalidParameterValueException
+ json_dict = error.to_dict()
+ assert json_dict == {
+ "Type": "InvalidParameterValueException",
+ "message": "Invalid argument provided",
+ }
+
+
+def test_runtime_exception() -> None:
+ """Test RuntimeException maps to ServiceException."""
+ error = exceptions.RuntimeException("Runtime error occurred")
+
+ assert str(error) == "Runtime error occurred"
+ assert isinstance(error, exceptions.AwsApiException)
+ assert error.http_status_code == 500
+
+ # Test serialization maps to ServiceException
+ json_dict = error.to_dict()
+ assert json_dict == {
+ "Type": "ServiceException",
+ "Message": "Runtime error occurred",
+ }
+
+
+def test_illegal_state_exception() -> None:
+ """Test IllegalStateException for invalid state transitions."""
+ error = exceptions.IllegalStateException(
+ "Cannot transition from RUNNING to PENDING"
+ )
+
+ assert str(error) == "Cannot transition from RUNNING to PENDING"
+ assert isinstance(error, exceptions.AwsApiException)
+ assert error.http_status_code == 500
+
+ # Test serialization maps to ServiceException
+ json_dict = error.to_dict()
+ assert json_dict == {
+ "Type": "ServiceException",
+ "Message": "Cannot transition from RUNNING to PENDING",
+ }
+
+
+# =============================================================================
+# Boto3 Compatibility Tests
+# =============================================================================
+
+
+def test_invalid_parameter_value_exception_boto3_format() -> None:
+ """Test InvalidParameterValueException produces correct boto3 format."""
+ exception = exceptions.InvalidParameterValueException("Invalid parameter value")
+
+ # Test serialization
+ json_dict = exception.to_dict()
+
+ # Validate structure matches boto3 expectations
+ assert json_dict == {
+ "Type": "InvalidParameterValueException",
+ "message": "Invalid parameter value",
+ }
+
+ # Test that it can be serialized to JSON and back
+ json_str = json.dumps(json_dict)
+ parsed_back = json.loads(json_str)
+ assert parsed_back == json_dict
+
+ # Validate HTTP status code
+ assert exception.http_status_code == 400
+
+
+def test_resource_not_found_exception_boto3_format() -> None:
+ """Test ResourceNotFoundException produces correct boto3 format."""
+ exception = exceptions.ResourceNotFoundException("Resource not found")
+
+ json_dict = exception.to_dict()
+
+ assert json_dict == {
+ "Type": "ResourceNotFoundException",
+ "Message": "Resource not found", # Capital M per Smithy definition
+ }
+
+ # Test JSON serialization
+ json_str = json.dumps(json_dict)
+ parsed_back = json.loads(json_str)
+ assert parsed_back == json_dict
+
+ assert exception.http_status_code == 404
+
+
+def test_service_exception_boto3_format() -> None:
+ """Test ServiceException produces correct boto3 format."""
+ exception = exceptions.ServiceException("Internal service error")
+
+ json_dict = exception.to_dict()
+
+ assert json_dict == {
+ "Type": "ServiceException",
+ "Message": "Internal service error", # Capital M per Smithy definition
+ }
+
+ # Test JSON serialization
+ json_str = json.dumps(json_dict)
+ parsed_back = json.loads(json_str)
+ assert parsed_back == json_dict
+
+ assert exception.http_status_code == 500
+
+
+def test_callback_timeout_exception_boto3_format() -> None:
+ """Test CallbackTimeoutException produces correct boto3 format."""
+ exception = exceptions.CallbackTimeoutException("Callback timed out")
+
+ json_dict = exception.to_dict()
+
+ assert json_dict == {
+ "Type": "CallbackTimeoutException",
+ "message": "Callback timed out",
+ }
+
+ # Test JSON serialization
+ json_str = json.dumps(json_dict)
+ parsed_back = json.loads(json_str)
+ assert parsed_back == json_dict
+
+ assert exception.http_status_code == 408
+
+
+def test_execution_already_started_exception_special_format() -> None:
+ """Test ExecutionAlreadyStartedException has no Type field (special case)."""
+ exception = exceptions.ExecutionAlreadyStartedException(
+ "Execution already started",
+ "arn:aws:states:us-east-1:123456789012:execution:test",
+ )
+
+ json_dict = exception.to_dict()
+
+ # Special case: no Type field for this exception, includes DurableExecutionArn
+ assert json_dict == {
+ "message": "Execution already started",
+ "DurableExecutionArn": "arn:aws:states:us-east-1:123456789012:execution:test",
+ }
+
+ # Ensure Type field is not present
+ assert "Type" not in json_dict
+
+ # Test JSON serialization
+ json_str = json.dumps(json_dict)
+ parsed_back = json.loads(json_str)
+ assert parsed_back == json_dict
+
+ assert exception.http_status_code == 409
+
+
+def test_all_exceptions_have_correct_type_field_values() -> None:
+ """Test that Type field values match what boto3 expects for exception names."""
+ test_cases = [
+ (
+ exceptions.InvalidParameterValueException("test"),
+ "InvalidParameterValueException",
+ ),
+ (exceptions.ResourceNotFoundException("test"), "ResourceNotFoundException"),
+ (exceptions.ServiceException("test"), "ServiceException"),
+ (exceptions.CallbackTimeoutException("test"), "CallbackTimeoutException"),
+ ]
+
+ for exception, expected_type in test_cases:
+ json_dict = exception.to_dict()
+ assert json_dict["Type"] == expected_type
+
+
+def test_message_field_casing_compatibility() -> None:
+ """Test message field casing matches boto3 deserialization expectations."""
+ # InvalidParameterValueException uses lowercase 'message'
+ exception1 = exceptions.InvalidParameterValueException("Test message")
+ json_dict1 = exception1.to_dict()
+
+ assert "message" in json_dict1
+ assert "Message" not in json_dict1
+ assert json_dict1["message"] == "Test message"
+
+ # ResourceNotFoundException uses capital 'Message'
+ exception2 = exceptions.ResourceNotFoundException("Test message")
+ json_dict2 = exception2.to_dict()
+
+ assert "Message" in json_dict2
+ assert "message" not in json_dict2
+ assert json_dict2["Message"] == "Test message"
+
+
+def test_json_serialization_with_special_characters() -> None:
+ """Test that exceptions with special characters serialize correctly."""
+ special_message = 'Error with "quotes", newlines\n, and unicode: 🚀'
+ exception = exceptions.InvalidParameterValueException(special_message)
+
+ json_dict = exception.to_dict()
+
+ # Test that it can be serialized to JSON
+ json_str = json.dumps(json_dict)
+ parsed_back = json.loads(json_str)
+
+ assert parsed_back["message"] == special_message
+ assert parsed_back["Type"] == "InvalidParameterValueException"
+
+
+def test_empty_message_handling() -> None:
+ """Test that empty messages are handled correctly."""
+ exception = exceptions.InvalidParameterValueException("")
+ json_dict = exception.to_dict()
+
+ assert json_dict == {"Type": "InvalidParameterValueException", "message": ""}
+
+
+def test_none_message_handling() -> None:
+ """Test that None messages are converted to empty strings."""
+ # This tests the edge case where message might be None
+ exception = exceptions.InvalidParameterValueException(None) # type: ignore
+ json_dict = exception.to_dict()
+
+ # Should convert None to string "None" for JSON compatibility
+ assert json_dict["message"] is None or json_dict["message"] == "None"
+
+
+def test_http_status_codes_match_aws_standards() -> None:
+ """Test that HTTP status codes match AWS service standards."""
+ status_code_tests = [
+ (exceptions.InvalidParameterValueException("test"), 400), # Bad Request
+ (exceptions.ResourceNotFoundException("test"), 404), # Not Found
+ (
+ exceptions.ExecutionAlreadyStartedException(
+ "test", "arn:aws:states:us-east-1:123456789012:execution:test"
+ ),
+ 409,
+ ), # Conflict
+ (exceptions.CallbackTimeoutException("test"), 408), # Request Timeout
+ (exceptions.ServiceException("test"), 500), # Internal Server Error
+ ]
+
+ for exception, expected_status in status_code_tests:
+ assert exception.http_status_code == expected_status
+
+
+def test_json_structure_has_no_extra_fields() -> None:
+ """Test that JSON structure only contains expected fields."""
+ exception = exceptions.InvalidParameterValueException("test")
+ json_dict = exception.to_dict()
+
+ # Should only have Type and message fields
+ expected_fields = {"Type", "message"}
+ actual_fields = set(json_dict.keys())
+
+ assert actual_fields == expected_fields
+
+
+def test_execution_already_started_has_only_message_field() -> None:
+ """Test that ExecutionAlreadyStartedException only has message field."""
+ exception = exceptions.ExecutionAlreadyStartedException(
+ "test", "arn:aws:states:us-east-1:123456789012:execution:test"
+ )
+ json_dict = exception.to_dict()
+
+ # Should only have message and DurableExecutionArn fields (no Type)
+ expected_fields = {"message", "DurableExecutionArn"}
+ actual_fields = set(json_dict.keys())
+
+ assert actual_fields == expected_fields
+
+
+def test_large_message_serialization() -> None:
+ """Test that large messages can be serialized correctly."""
+ # Create a large message (but not too large to avoid memory issues in tests)
+ large_message = "Error: " + "x" * 1000
+ exception = exceptions.ServiceException(large_message)
+
+ json_dict = exception.to_dict()
+ json_str = json.dumps(json_dict)
+ parsed_back = json.loads(json_str)
+
+ assert parsed_back["Message"] == large_message # ServiceException uses capital M
+ assert len(parsed_back["Message"]) == len(large_message)
+
+
+def test_all_aws_exceptions_are_json_serializable() -> None:
+ """Test that all AWS exception types can be JSON serialized."""
+ test_exceptions = [
+ exceptions.InvalidParameterValueException("test"),
+ exceptions.ResourceNotFoundException("test"),
+ exceptions.ServiceException("test"),
+ exceptions.CallbackTimeoutException("test"),
+ exceptions.ExecutionAlreadyStartedException(
+ "test", "arn:aws:states:us-east-1:123456789012:execution:test"
+ ),
+ ]
+
+ for exception in test_exceptions:
+ json_dict = exception.to_dict()
+
+ # Should be able to serialize to JSON without errors
+ json_str = json.dumps(json_dict)
+
+ # Should be able to parse back from JSON
+ parsed_back = json.loads(json_str)
+
+ # Should match original structure
+ assert parsed_back == json_dict
+
+
+def test_too_many_requests_exception() -> None:
+ """Test TooManyRequestsException for rate limiting."""
+ exception = exceptions.TooManyRequestsException("Rate limit exceeded")
+
+ assert str(exception) == "Rate limit exceeded"
+ assert isinstance(exception, exceptions.AwsApiException)
+ assert exception.http_status_code == 429
+
+ json_dict = exception.to_dict()
+ assert json_dict == {
+ "Type": "TooManyRequestsException",
+ "message": "Rate limit exceeded",
+ }
+
+
+def test_execution_conflict_exception() -> None:
+ """Test ExecutionConflictException for execution conflicts."""
+ exception = exceptions.ExecutionConflictException("Execution conflict detected")
+
+ assert str(exception) == "Execution conflict detected"
+ assert isinstance(exception, exceptions.AwsApiException)
+ assert exception.http_status_code == 409
+
+ json_dict = exception.to_dict()
+ assert json_dict == {
+ "Type": "ExecutionConflictException",
+ "message": "Execution conflict detected",
+ }
+
+
+# =============================================================================
+# AWS Compliant Exception Tests (Comprehensive)
+# =============================================================================
+
+
+def test_base_exception_hierarchy():
+ """Test that all AWS exceptions inherit from the correct base classes."""
+ # Test base hierarchy
+ assert issubclass(
+ exceptions.AwsApiException, exceptions.DurableFunctionsLocalRunnerError
+ )
+ assert issubclass(exceptions.DurableFunctionsLocalRunnerError, Exception)
+
+ # Test all AWS exceptions inherit from AwsApiException
+ aws_exceptions = [
+ exceptions.InvalidParameterValueException,
+ exceptions.ResourceNotFoundException,
+ exceptions.ServiceException,
+ exceptions.ExecutionAlreadyStartedException,
+ exceptions.ExecutionConflictException,
+ exceptions.CallbackTimeoutException,
+ exceptions.TooManyRequestsException,
+ exceptions.IllegalStateException,
+ exceptions.RuntimeException,
+ exceptions.IllegalArgumentException,
+ ]
+
+ for exception_class in aws_exceptions:
+ assert issubclass(exception_class, exceptions.AwsApiException)
+
+
+def test_aws_api_exception_abstract_to_dict():
+ """Test that AwsApiException.to_dict() raises NotImplementedError."""
+ exception = exceptions.AwsApiException("test message")
+
+ with pytest.raises(NotImplementedError):
+ exception.to_dict()
+
+
+class TestSmithyMappedExceptions:
+ """Test Smithy-mapped exceptions (defined in Smithy models)."""
+
+ def test_invalid_parameter_value_exception(self):
+ """Test InvalidParameterValueException serialization and properties."""
+ message = "Invalid parameter"
+ exception = exceptions.InvalidParameterValueException(message)
+
+ # Test properties
+ assert exception.http_status_code == 400
+ assert exception.message == message
+ assert str(exception) == message
+
+ # Test serialization
+ expected_json = {"Type": "InvalidParameterValueException", "message": message}
+ assert exception.to_dict() == expected_json
+
+ def test_resource_not_found_exception(self):
+ """Test ResourceNotFoundException serialization and properties."""
+ message = "Resource not found"
+ exception = exceptions.ResourceNotFoundException(message)
+
+ # Test properties
+ assert exception.http_status_code == 404
+ assert exception.Message == message # Capital M per Smithy
+ assert str(exception) == message
+
+ # Test serialization
+ expected_json = {"Type": "ResourceNotFoundException", "Message": message}
+ assert exception.to_dict() == expected_json
+
+ def test_service_exception(self):
+ """Test ServiceException serialization and properties."""
+ message = "Service error"
+ exception = exceptions.ServiceException(message)
+
+ # Test properties
+ assert exception.http_status_code == 500
+ assert exception.Message == message # Capital M per Smithy
+ assert str(exception) == message
+
+ # Test serialization
+ expected_json = {"Type": "ServiceException", "Message": message}
+ assert exception.to_dict() == expected_json
+
+ def test_execution_already_started_exception(self):
+ """Test ExecutionAlreadyStartedException serialization and properties."""
+ message = "Execution already started"
+ arn = "arn:aws:lambda:us-east-1:123456789012:function:test"
+ exception = exceptions.ExecutionAlreadyStartedException(message, arn)
+
+ # Test properties
+ assert exception.http_status_code == 409
+ assert exception.message == message
+ assert exception.DurableExecutionArn == arn
+ assert str(exception) == message
+
+ # Test serialization (no Type field per Smithy definition)
+ expected_json = {"message": message, "DurableExecutionArn": arn}
+ assert exception.to_dict() == expected_json
+
+ def test_callback_timeout_exception(self):
+ """Test CallbackTimeoutException serialization and properties."""
+ message = "Callback timed out"
+ exception = exceptions.CallbackTimeoutException(message)
+
+ # Test properties
+ assert exception.http_status_code == 408
+ assert exception.message == message
+ assert str(exception) == message
+
+ # Test serialization
+ expected_json = {"Type": "CallbackTimeoutException", "message": message}
+ assert exception.to_dict() == expected_json
+
+ def test_too_many_requests_exception(self):
+ """Test TooManyRequestsException serialization and properties."""
+ message = "Too many requests"
+ exception = exceptions.TooManyRequestsException(message)
+
+ # Test properties
+ assert exception.http_status_code == 429
+ assert exception.message == message
+ assert str(exception) == message
+
+ # Test serialization
+ expected_json = {"Type": "TooManyRequestsException", "message": message}
+ assert exception.to_dict() == expected_json
+
+ def test_execution_conflict_exception(self):
+ """Test ExecutionConflictException serialization and properties."""
+ message = "Execution conflict"
+ exception = exceptions.ExecutionConflictException(message)
+
+ # Test properties
+ assert exception.http_status_code == 409
+ assert exception.message == message
+ assert str(exception) == message
+
+ # Test serialization
+ expected_json = {"Type": "ExecutionConflictException", "message": message}
+ assert exception.to_dict() == expected_json
+
+
+class TestUnmappedExceptions:
+ """Test unmapped exceptions (thrown by services but not in Smithy)."""
+
+ def test_illegal_state_exception(self):
+ """Test IllegalStateException maps to ServiceException when serialized."""
+ message = "Invalid state"
+ exception = exceptions.IllegalStateException(message)
+
+ # Test properties
+ assert exception.http_status_code == 500
+ assert exception.message == message
+ assert str(exception) == message
+
+ # Test serialization (maps to ServiceException)
+ expected_json = {"Type": "ServiceException", "Message": message}
+ assert exception.to_dict() == expected_json
+
+ def test_runtime_exception(self):
+ """Test RuntimeException maps to ServiceException when serialized."""
+ message = "Runtime error"
+ exception = exceptions.RuntimeException(message)
+
+ # Test properties
+ assert exception.http_status_code == 500
+ assert exception.message == message
+ assert str(exception) == message
+
+ # Test serialization (maps to ServiceException)
+ expected_json = {"Type": "ServiceException", "Message": message}
+ assert exception.to_dict() == expected_json
+
+ def test_illegal_argument_exception(self):
+ """Test IllegalArgumentException maps to InvalidParameterValueException when serialized."""
+ message = "Invalid argument"
+ exception = exceptions.IllegalArgumentException(message)
+
+ # Test properties
+ assert exception.http_status_code == 400
+ assert exception.message == message
+ assert str(exception) == message
+
+ # Test serialization (maps to InvalidParameterValueException)
+ expected_json = {"Type": "InvalidParameterValueException", "message": message}
+ assert exception.to_dict() == expected_json
+
+
+class TestHttpStatusCodes:
+ """Test HTTP status codes match Smithy @httpError annotations."""
+
+ def test_client_error_status_codes(self):
+ """Test client error (4xx) status codes."""
+ assert exceptions.InvalidParameterValueException("test").http_status_code == 400
+ assert exceptions.ResourceNotFoundException("test").http_status_code == 404
+ assert exceptions.CallbackTimeoutException("test").http_status_code == 408
+ assert (
+ exceptions.ExecutionAlreadyStartedException("test", "arn").http_status_code
+ == 409
+ )
+ assert exceptions.ExecutionConflictException("test").http_status_code == 409
+ assert exceptions.TooManyRequestsException("test").http_status_code == 429
+ assert exceptions.IllegalArgumentException("test").http_status_code == 400
+
+ def test_server_error_status_codes(self):
+ """Test server error (5xx) status codes."""
+ assert exceptions.ServiceException("test").http_status_code == 500
+ assert exceptions.IllegalStateException("test").http_status_code == 500
+ assert exceptions.RuntimeException("test").http_status_code == 500
+
+
+class TestFieldNameCasing:
+ """Test field name casing matches Smithy definitions."""
+
+ def test_lowercase_message_fields(self):
+ """Test exceptions that use lowercase 'message' field."""
+ # These use lowercase 'message' per Smithy definitions
+ exceptions_with_lowercase_message = [
+ exceptions.InvalidParameterValueException("test"),
+ exceptions.ExecutionAlreadyStartedException("test", "arn"),
+ exceptions.ExecutionConflictException("test"),
+ exceptions.CallbackTimeoutException("test"),
+ exceptions.TooManyRequestsException("test"),
+ exceptions.IllegalStateException("test"),
+ exceptions.RuntimeException("test"),
+ exceptions.IllegalArgumentException("test"),
+ ]
+
+ for exception in exceptions_with_lowercase_message:
+ if hasattr(exception, "message"):
+ assert exception.message == "test"
+
+ def test_uppercase_message_fields(self):
+ """Test exceptions that use uppercase 'Message' field."""
+ # These use uppercase 'Message' per Smithy definitions
+ exceptions_with_uppercase_message = [
+ exceptions.ResourceNotFoundException("test"),
+ exceptions.ServiceException("test"),
+ ]
+
+ for exception in exceptions_with_uppercase_message:
+ assert exception.Message == "test"
+
+
+class TestBoto3Compatibility:
+ """Test boto3 compatibility and JSON structure validation."""
+
+ def test_json_structure_matches_boto3_expectations(self):
+ """Test that JSON output matches what boto3 error factory expects."""
+ # Test that all exceptions produce valid JSON structures
+ test_cases = [
+ (
+ exceptions.InvalidParameterValueException("test"),
+ {"Type": "InvalidParameterValueException", "message": "test"},
+ ),
+ (
+ exceptions.ResourceNotFoundException("test"),
+ {"Type": "ResourceNotFoundException", "Message": "test"},
+ ),
+ (
+ exceptions.ServiceException("test"),
+ {"Type": "ServiceException", "Message": "test"},
+ ),
+ (
+ exceptions.ExecutionAlreadyStartedException("test", "arn"),
+ {"message": "test", "DurableExecutionArn": "arn"},
+ ),
+ (
+ exceptions.ExecutionConflictException("test"),
+ {"Type": "ExecutionConflictException", "message": "test"},
+ ),
+ (
+ exceptions.CallbackTimeoutException("test"),
+ {"Type": "CallbackTimeoutException", "message": "test"},
+ ),
+ (
+ exceptions.TooManyRequestsException("test"),
+ {"Type": "TooManyRequestsException", "message": "test"},
+ ),
+ ]
+
+ for exception, expected_json in test_cases:
+ actual_json = exception.to_dict()
+ assert actual_json == expected_json
+
+ # Verify JSON is serializable (no complex objects)
+ json_str = json.dumps(actual_json)
+ assert json.loads(json_str) == actual_json
+
+ def test_type_field_values_match_exception_names(self):
+ """Test that Type field values match what boto3 expects for exception names."""
+ type_field_mappings = [
+ (
+ exceptions.InvalidParameterValueException("test"),
+ "InvalidParameterValueException",
+ ),
+ (exceptions.ResourceNotFoundException("test"), "ResourceNotFoundException"),
+ (exceptions.ServiceException("test"), "ServiceException"),
+ (
+ exceptions.ExecutionConflictException("test"),
+ "ExecutionConflictException",
+ ),
+ (exceptions.CallbackTimeoutException("test"), "CallbackTimeoutException"),
+ (exceptions.TooManyRequestsException("test"), "TooManyRequestsException"),
+ # Unmapped exceptions map to different types
+ (exceptions.IllegalStateException("test"), "ServiceException"),
+ (exceptions.RuntimeException("test"), "ServiceException"),
+ (
+ exceptions.IllegalArgumentException("test"),
+ "InvalidParameterValueException",
+ ),
+ ]
+
+ for exception, expected_type in type_field_mappings:
+ json_output = exception.to_dict()
+ if (
+ "Type" in json_output
+ ): # ExecutionAlreadyStartedException doesn't have Type field
+ assert json_output["Type"] == expected_type
+
+ def test_execution_already_started_exception_special_case(self):
+ """Test ExecutionAlreadyStartedException special case (no Type field)."""
+ exception = exceptions.ExecutionAlreadyStartedException(
+ "test message", "test-arn"
+ )
+ json_output = exception.to_dict()
+
+ # Should not have Type field
+ assert "Type" not in json_output
+
+ # Should have required fields
+ assert "message" in json_output
+ assert "DurableExecutionArn" in json_output
+ assert json_output["message"] == "test message"
+ assert json_output["DurableExecutionArn"] == "test-arn"
+
+ def test_message_field_casing_compatibility(self):
+ """Test message field casing compatibility with boto3 deserialization."""
+ # Test lowercase 'message' field exceptions
+ lowercase_exceptions = [
+ exceptions.InvalidParameterValueException("test"),
+ exceptions.ExecutionAlreadyStartedException("test", "arn"),
+ exceptions.ExecutionConflictException("test"),
+ exceptions.CallbackTimeoutException("test"),
+ exceptions.TooManyRequestsException("test"),
+ ]
+
+ for exception in lowercase_exceptions:
+ json_output = exception.to_dict()
+ if "message" in json_output:
+ assert json_output["message"] == "test"
+ # Should not have uppercase Message
+ assert "Message" not in json_output
+
+ # Test uppercase 'Message' field exceptions
+ uppercase_exceptions = [
+ exceptions.ResourceNotFoundException("test"),
+ exceptions.ServiceException("test"),
+ ]
+
+ for exception in uppercase_exceptions:
+ json_output = exception.to_dict()
+ assert json_output["Message"] == "test"
+ # Should not have lowercase message
+ assert "message" not in json_output
+
+
+class TestEdgeCases:
+ """Test edge cases and error conditions."""
+
+ def test_empty_message_handling(self):
+ """Test handling of empty messages."""
+ exceptions_list = [
+ exceptions.InvalidParameterValueException(""),
+ exceptions.ResourceNotFoundException(""),
+ exceptions.ServiceException(""),
+ exceptions.ExecutionConflictException(""),
+ exceptions.CallbackTimeoutException(""),
+ exceptions.TooManyRequestsException(""),
+ exceptions.IllegalStateException(""),
+ exceptions.RuntimeException(""),
+ exceptions.IllegalArgumentException(""),
+ ]
+
+ for exception in exceptions_list:
+ # Should not raise exception during serialization
+ json_output = exception.to_dict()
+ assert isinstance(json_output, dict)
+
+ def test_special_characters_in_messages(self):
+ """Test handling of special characters in messages."""
+ special_message = 'Test with "quotes", newlines\n, and unicode: 🚀'
+
+ exceptions_list = [
+ exceptions.InvalidParameterValueException(special_message),
+ exceptions.ResourceNotFoundException(special_message),
+ exceptions.ServiceException(special_message),
+ ]
+
+ for exception in exceptions_list:
+ json_output = exception.to_dict()
+ # Message should be preserved exactly
+ message_field = "Message" if hasattr(exception, "Message") else "message"
+ assert json_output[message_field] == special_message
+
+ def test_execution_already_started_with_empty_arn(self):
+ """Test ExecutionAlreadyStartedException with empty ARN."""
+ exception = exceptions.ExecutionAlreadyStartedException("test", "")
+ json_output = exception.to_dict()
+
+ assert json_output["DurableExecutionArn"] == ""
+ assert json_output["message"] == "test"
+
+
+def test_exception_test_cases_data_structure():
+ """Test that we can create a comprehensive test data structure for all exceptions."""
+ # This validates the test data structure mentioned in the design document
+ exception_test_cases = [
+ # Smithy-mapped exceptions
+ {
+ "exception_class": exceptions.InvalidParameterValueException,
+ "args": ["Invalid parameter"],
+ "expected_json": {
+ "Type": "InvalidParameterValueException",
+ "message": "Invalid parameter",
+ },
+ "expected_status": 400,
+ },
+ {
+ "exception_class": exceptions.ResourceNotFoundException,
+ "args": ["Resource not found"],
+ "expected_json": {
+ "Type": "ResourceNotFoundException",
+ "Message": "Resource not found",
+ },
+ "expected_status": 404,
+ },
+ {
+ "exception_class": exceptions.ServiceException,
+ "args": ["Service error"],
+ "expected_json": {"Type": "ServiceException", "Message": "Service error"},
+ "expected_status": 500,
+ },
+ {
+ "exception_class": exceptions.ExecutionAlreadyStartedException,
+ "args": [
+ "Already started",
+ "arn:aws:lambda:us-east-1:123456789012:function:test",
+ ],
+ "expected_json": {
+ "message": "Already started",
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test",
+ },
+ "expected_status": 409,
+ },
+ {
+ "exception_class": exceptions.ExecutionConflictException,
+ "args": ["Execution conflict"],
+ "expected_json": {
+ "Type": "ExecutionConflictException",
+ "message": "Execution conflict",
+ },
+ "expected_status": 409,
+ },
+ {
+ "exception_class": exceptions.CallbackTimeoutException,
+ "args": ["Callback timeout"],
+ "expected_json": {
+ "Type": "CallbackTimeoutException",
+ "message": "Callback timeout",
+ },
+ "expected_status": 408,
+ },
+ {
+ "exception_class": exceptions.TooManyRequestsException,
+ "args": ["Too many requests"],
+ "expected_json": {
+ "Type": "TooManyRequestsException",
+ "message": "Too many requests",
+ },
+ "expected_status": 429,
+ },
+ # Unmapped exceptions
+ {
+ "exception_class": exceptions.IllegalStateException,
+ "args": ["Invalid state"],
+ "expected_json": {"Type": "ServiceException", "Message": "Invalid state"},
+ "expected_status": 500,
+ },
+ {
+ "exception_class": exceptions.RuntimeException,
+ "args": ["Runtime error"],
+ "expected_json": {"Type": "ServiceException", "Message": "Runtime error"},
+ "expected_status": 500,
+ },
+ {
+ "exception_class": exceptions.IllegalArgumentException,
+ "args": ["Invalid argument"],
+ "expected_json": {
+ "Type": "InvalidParameterValueException",
+ "message": "Invalid argument",
+ },
+ "expected_status": 400,
+ },
+ ]
+
+ # Test each case
+ for case in exception_test_cases:
+ exception = case["exception_class"](*case["args"])
+
+ # Test status code
+ assert exception.http_status_code == case["expected_status"]
+
+ # Test serialization
+ assert exception.to_dict() == case["expected_json"]
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/execution_concurrent_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/execution_concurrent_test.py
new file mode 100644
index 0000000..6ea2bef
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/execution_concurrent_test.py
@@ -0,0 +1,83 @@
+"""Concurrent access tests for Execution class."""
+
+import threading
+from concurrent.futures import ThreadPoolExecutor, as_completed
+
+from aws_durable_execution_sdk_python_testing.execution import Execution
+from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput
+
+
+def test_concurrent_token_generation():
+ """Test concurrent checkpoint token generation."""
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-inv-id",
+ input='{"test": "data"}',
+ )
+ execution = Execution.new(input_data)
+ tokens = []
+ tokens_lock = threading.Lock()
+
+ def generate_token():
+ token = execution.get_new_checkpoint_token()
+ with tokens_lock:
+ tokens.append(token)
+
+ with ThreadPoolExecutor(max_workers=10) as executor:
+ futures = [executor.submit(generate_token) for _ in range(20)]
+
+ for future in as_completed(futures):
+ future.result()
+
+ # All tokens should be unique and sequential
+ assert len(tokens) == 20
+ assert len(set(tokens)) == 20 # All unique
+ assert execution.token_sequence == 20
+
+
+def test_concurrent_operations_modification():
+ """Test concurrent operations list modifications."""
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-inv-id",
+ input='{"test": "data"}',
+ )
+ execution = Execution.new(input_data)
+ results = []
+ results_lock = threading.Lock()
+
+ def start_execution():
+ execution.start()
+ with results_lock:
+ results.append("started")
+
+ def get_operations():
+ ops = execution.get_navigable_operations()
+ with results_lock:
+ results.append(f"ops-{len(ops)}")
+
+ with ThreadPoolExecutor(max_workers=5) as executor:
+ futures = []
+ # One start operation
+ futures.append(executor.submit(start_execution))
+ # Multiple read operations
+ futures.extend([executor.submit(get_operations) for _ in range(4)])
+
+ for future in as_completed(futures):
+ future.result()
+
+ assert len(results) == 5
+ assert "started" in results
+ # Should have at least one operation after start
+ final_ops = execution.get_navigable_operations()
+ assert len(final_ops) >= 1
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/execution_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/execution_test.py
new file mode 100644
index 0000000..5bde747
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/execution_test.py
@@ -0,0 +1,1014 @@
+"""Unit tests for execution module."""
+
+from datetime import datetime, timezone
+from unittest.mock import patch, Mock
+
+import pytest
+from aws_durable_execution_sdk_python.execution import (
+ InvocationStatus,
+)
+from aws_durable_execution_sdk_python.lambda_service import (
+ ErrorObject,
+ Operation,
+ OperationStatus,
+ OperationType,
+ StepDetails,
+ CallbackDetails,
+)
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ IllegalStateException,
+)
+from aws_durable_execution_sdk_python_testing.execution import Execution
+from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput
+
+
+def test_execution_init():
+ """Test Execution initialization."""
+ arn = "test-arn"
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ operations = []
+
+ execution = Execution(arn, start_input, operations)
+
+ assert execution.durable_execution_arn == arn
+ assert execution.start_input == start_input
+ assert execution.operations == operations
+ assert execution.updates == []
+ assert execution.used_tokens == set()
+ assert execution.token_sequence == 0
+ assert execution.is_complete is False
+ assert execution.consecutive_failed_invocation_attempts == 0
+
+
+@patch("aws_durable_execution_sdk_python_testing.execution.uuid4")
+def test_execution_new(mock_uuid4):
+ """Test Execution.new static method."""
+ mock_uuid = "test-uuid-123"
+ mock_uuid4.return_value = mock_uuid
+
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id-1234",
+ )
+
+ execution = Execution.new(start_input)
+
+ assert (
+ execution.durable_execution_arn == str(mock_uuid) + "/test-invocation-id-1234"
+ )
+ assert execution.start_input == start_input
+ assert execution.operations == []
+
+
+@patch("aws_durable_execution_sdk_python_testing.execution.datetime")
+def test_execution_start(mock_datetime):
+ """Test Execution.start method."""
+ mock_now = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
+ mock_datetime.now.return_value = mock_now
+
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ input='{"key": "value"}',
+ )
+ execution = Execution("test-arn", start_input, [])
+
+ execution.start()
+
+ assert len(execution.operations) == 1
+ operation = execution.operations[0]
+ assert operation.operation_id == "test-invocation-id"
+ assert operation.parent_id is None
+ assert operation.name == "test-execution"
+ assert operation.start_timestamp == mock_now
+ assert operation.operation_type == OperationType.EXECUTION
+ assert operation.status == OperationStatus.STARTED
+ assert operation.execution_details.input_payload == '{"key": "value"}'
+
+
+def test_get_operation_execution_started():
+ """Test get_operation_execution_started method."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ execution = Execution("test-arn", start_input, [])
+ execution.start()
+
+ result = execution.get_operation_execution_started()
+
+ assert result == execution.operations[0]
+ assert result.operation_type == OperationType.EXECUTION
+
+
+def test_get_operation_execution_started_not_started():
+ """Test get_operation_execution_started raises error when not started."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ execution = Execution("test-arn", start_input, [])
+
+ with pytest.raises(IllegalStateException, match="execution not started"):
+ execution.get_operation_execution_started()
+
+
+def test_get_new_checkpoint_token():
+ """Test get_new_checkpoint_token method."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-id",
+ )
+ execution = Execution("test-arn", start_input, [])
+
+ token1 = execution.get_new_checkpoint_token()
+ token2 = execution.get_new_checkpoint_token()
+
+ assert execution.token_sequence == 2
+ assert token1 in execution.used_tokens
+ assert token2 in execution.used_tokens
+ assert token1 != token2
+
+
+def test_get_navigable_operations():
+ """Test get_navigable_operations method."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ operations = [
+ Operation(
+ operation_id="op1",
+ parent_id=None,
+ name="test",
+ start_timestamp=datetime.now(timezone.utc),
+ operation_type=OperationType.EXECUTION,
+ status=OperationStatus.STARTED,
+ )
+ ]
+ execution = Execution("test-arn", start_input, operations)
+
+ result = execution.get_navigable_operations()
+
+ assert result == operations
+
+
+def test_get_assertable_operations():
+ """Test get_assertable_operations method."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ execution_op = Operation(
+ operation_id="exec-op",
+ parent_id=None,
+ name="execution",
+ start_timestamp=datetime.now(timezone.utc),
+ operation_type=OperationType.EXECUTION,
+ status=OperationStatus.STARTED,
+ )
+ step_op = Operation(
+ operation_id="step-op",
+ parent_id=None,
+ name="step",
+ start_timestamp=datetime.now(timezone.utc),
+ operation_type=OperationType.STEP,
+ status=OperationStatus.STARTED,
+ )
+ operations = [execution_op, step_op]
+ execution = Execution("test-arn", start_input, operations)
+
+ result = execution.get_assertable_operations()
+
+ assert len(result) == 1
+ assert result[0] == step_op
+
+
+def test_has_pending_operations_with_pending_step():
+ """Test has_pending_operations returns True for pending STEP operations."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ operations = [
+ Operation(
+ operation_id="op1",
+ parent_id=None,
+ name="test",
+ start_timestamp=datetime.now(timezone.utc),
+ operation_type=OperationType.STEP,
+ status=OperationStatus.PENDING,
+ )
+ ]
+ execution = Execution("test-arn", start_input, operations)
+
+ result = execution.has_pending_operations(execution)
+
+ assert result is True
+
+
+def test_has_pending_operations_with_started_wait():
+ """Test has_pending_operations returns True for started WAIT operations."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ operations = [
+ Operation(
+ operation_id="op1",
+ parent_id=None,
+ name="test",
+ start_timestamp=datetime.now(timezone.utc),
+ operation_type=OperationType.WAIT,
+ status=OperationStatus.STARTED,
+ )
+ ]
+ execution = Execution("test-arn", start_input, operations)
+
+ result = execution.has_pending_operations(execution)
+
+ assert result is True
+
+
+def test_has_pending_operations_with_started_callback():
+ """Test has_pending_operations returns True for started CALLBACK operations."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ operations = [
+ Operation(
+ operation_id="op1",
+ parent_id=None,
+ name="test",
+ start_timestamp=datetime.now(timezone.utc),
+ operation_type=OperationType.CALLBACK,
+ status=OperationStatus.STARTED,
+ )
+ ]
+ execution = Execution("test-arn", start_input, operations)
+
+ result = execution.has_pending_operations(execution)
+
+ assert result is True
+
+
+def test_has_pending_operations_with_started_invoke():
+ """Test has_pending_operations returns True for started INVOKE operations."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ operations = [
+ Operation(
+ operation_id="op1",
+ parent_id=None,
+ name="test",
+ start_timestamp=datetime.now(timezone.utc),
+ operation_type=OperationType.CHAINED_INVOKE,
+ status=OperationStatus.STARTED,
+ )
+ ]
+ execution = Execution("test-arn", start_input, operations)
+
+ result = execution.has_pending_operations(execution)
+
+ assert result is True
+
+
+def test_has_pending_operations_no_pending():
+ """Test has_pending_operations returns False when no pending operations."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ operations = [
+ Operation(
+ operation_id="op1",
+ parent_id=None,
+ name="test",
+ start_timestamp=datetime.now(timezone.utc),
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ )
+ ]
+ execution = Execution("test-arn", start_input, operations)
+
+ result = execution.has_pending_operations(execution)
+
+ assert result is False
+
+
+def test_complete_success_with_string_result():
+ """Test complete_success method with string result."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ execution = Execution("test-arn", start_input, [Mock()])
+
+ execution.complete_success("success result")
+
+ assert execution.is_complete is True
+ assert execution.result.status == InvocationStatus.SUCCEEDED
+ assert execution.result.result == "success result"
+
+
+def test_complete_success_with_none_result():
+ """Test complete_success method with None result."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ execution = Execution("test-arn", start_input, [Mock()])
+
+ execution.complete_success(None)
+
+ assert execution.is_complete is True
+ assert execution.result.status == InvocationStatus.SUCCEEDED
+ assert execution.result.result is None
+
+
+def test_complete_fail():
+ """Test complete_fail method."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ execution = Execution("test-arn", start_input, [Mock()])
+ error = ErrorObject.from_message("Test error message")
+
+ execution.complete_fail(error)
+
+ assert execution.is_complete is True
+ assert execution.result.status == InvocationStatus.FAILED
+ assert execution.result.error == error
+
+
+def test_find_operation_exists():
+ """Test find_operation method when operation exists."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ operation = Operation(
+ operation_id="test-op-id",
+ parent_id=None,
+ name="test",
+ start_timestamp=datetime.now(timezone.utc),
+ operation_type=OperationType.STEP,
+ status=OperationStatus.STARTED,
+ )
+ execution = Execution("test-arn", start_input, [operation])
+
+ index, found_operation = execution.find_operation("test-op-id")
+
+ assert index == 0
+ assert found_operation == operation
+
+
+def test_find_operation_not_exists():
+ """Test find_operation method when operation doesn't exist."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ execution = Execution("test-arn", start_input, [])
+
+ with pytest.raises(
+ IllegalStateException, match="Attempting to update state of an Operation"
+ ):
+ execution.find_operation("non-existent-id")
+
+
+@patch("aws_durable_execution_sdk_python_testing.execution.datetime")
+def test_complete_wait_success(mock_datetime):
+ """Test complete_wait method successful completion."""
+ mock_now = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
+ mock_datetime.now.return_value = mock_now
+
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ operation = Operation(
+ operation_id="wait-op-id",
+ parent_id=None,
+ name="test-wait",
+ start_timestamp=datetime.now(timezone.utc),
+ operation_type=OperationType.WAIT,
+ status=OperationStatus.STARTED,
+ )
+ execution = Execution("test-arn", start_input, [operation])
+
+ result = execution.complete_wait("wait-op-id")
+
+ assert result.status == OperationStatus.SUCCEEDED
+ assert result.end_timestamp == mock_now
+ assert execution.token_sequence == 1
+ assert execution.operations[0] == result
+
+
+def test_complete_wait_wrong_status():
+ """Test complete_wait method with wrong operation status."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ operation = Operation(
+ operation_id="wait-op-id",
+ parent_id=None,
+ name="test-wait",
+ start_timestamp=datetime.now(timezone.utc),
+ operation_type=OperationType.WAIT,
+ status=OperationStatus.SUCCEEDED,
+ )
+ execution = Execution("test-arn", start_input, [operation])
+
+ with pytest.raises(
+ IllegalStateException, match="Attempting to transition a Wait Operation"
+ ):
+ execution.complete_wait("wait-op-id")
+
+
+def test_complete_wait_wrong_type():
+ """Test complete_wait method with wrong operation type."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ operation = Operation(
+ operation_id="step-op-id",
+ parent_id=None,
+ name="test-step",
+ start_timestamp=datetime.now(timezone.utc),
+ operation_type=OperationType.STEP,
+ status=OperationStatus.STARTED,
+ )
+ execution = Execution("test-arn", start_input, [operation])
+
+ with pytest.raises(IllegalStateException, match="Expected WAIT operation"):
+ execution.complete_wait("step-op-id")
+
+
+def test_complete_retry_success():
+ """Test complete_retry method successful completion."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ step_details = StepDetails(
+ next_attempt_timestamp=str(datetime.now(timezone.utc)),
+ attempt=1,
+ )
+ operation = Operation(
+ operation_id="step-op-id",
+ parent_id=None,
+ name="test-step",
+ start_timestamp=datetime.now(timezone.utc),
+ operation_type=OperationType.STEP,
+ status=OperationStatus.PENDING,
+ step_details=step_details,
+ )
+ execution = Execution("test-arn", start_input, [operation])
+
+ result = execution.complete_retry("step-op-id")
+
+ assert result.status == OperationStatus.READY
+ assert result.step_details.next_attempt_timestamp is None
+ assert execution.token_sequence == 1
+ assert execution.operations[0] == result
+
+
+def test_complete_retry_no_step_details():
+ """Test complete_retry method with no step_details."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ operation = Operation(
+ operation_id="step-op-id",
+ parent_id=None,
+ name="test-step",
+ start_timestamp=datetime.now(timezone.utc),
+ operation_type=OperationType.STEP,
+ status=OperationStatus.PENDING,
+ )
+ execution = Execution("test-arn", start_input, [operation])
+
+ result = execution.complete_retry("step-op-id")
+
+ assert result.status == OperationStatus.READY
+ assert result.step_details is None
+ assert execution.token_sequence == 1
+
+
+def test_complete_retry_wrong_status():
+ """Test complete_retry method with wrong operation status."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ operation = Operation(
+ operation_id="step-op-id",
+ parent_id=None,
+ name="test-step",
+ start_timestamp=datetime.now(timezone.utc),
+ operation_type=OperationType.STEP,
+ status=OperationStatus.STARTED,
+ )
+ execution = Execution("test-arn", start_input, [operation])
+
+ with pytest.raises(
+ IllegalStateException, match="Attempting to transition a Step Operation"
+ ):
+ execution.complete_retry("step-op-id")
+
+
+def test_complete_retry_wrong_type():
+ """Test complete_retry method with wrong operation type."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ operation = Operation(
+ operation_id="wait-op-id",
+ parent_id=None,
+ name="test-wait",
+ start_timestamp=datetime.now(timezone.utc),
+ operation_type=OperationType.WAIT,
+ status=OperationStatus.PENDING,
+ )
+ execution = Execution("test-arn", start_input, [operation])
+
+ with pytest.raises(IllegalStateException, match="Expected STEP operation"):
+ execution.complete_retry("wait-op-id")
+
+
+def test_status_running():
+ """Test status property returns RUNNING for incomplete execution."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ execution = Execution("test-arn", start_input, [])
+
+ assert execution.current_status().value == "RUNNING"
+
+
+def test_status_succeeded():
+ """Test status property returns SUCCEEDED for successful execution."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ execution = Execution("test-arn", start_input, [Mock()])
+ execution.complete_success("success result")
+
+ assert execution.current_status().value == "SUCCEEDED"
+
+
+def test_status_failed():
+ """Test status property returns FAILED for failed execution."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ execution = Execution("test-arn", start_input, [Mock()])
+ error = ErrorObject.from_message("Test error")
+ execution.complete_fail(error)
+
+ assert execution.current_status().value == "FAILED"
+
+
+def test_status_timed_out():
+ """Test status property returns TIMED_OUT for timeout errors."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-id",
+ )
+ execution = Execution("test-arn", start_input, [Mock()])
+ error = ErrorObject(
+ message="Execution timed out", type="TimeoutError", data=None, stack_trace=None
+ )
+ execution.complete_timeout(error)
+
+ assert execution.current_status().value == "TIMED_OUT"
+
+
+def test_status_stopped():
+ """Test status property returns STOPPED for stop errors."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-id",
+ )
+ execution = Execution("test-arn", start_input, [Mock()])
+ error = ErrorObject(
+ message="Execution stopped", type="StopError", data=None, stack_trace=None
+ )
+ execution.complete_stopped(error)
+
+ assert execution.current_status().value == "STOPPED"
+
+
+def test_status_no_result():
+ """Test status property returns FAILED for completed execution with no result."""
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-id",
+ )
+ execution = Execution("test-arn", start_input, [])
+ execution.is_complete = True
+ execution.result = None
+ with pytest.raises(
+ IllegalStateException,
+ match="close_status must be set when execution is complete",
+ ):
+ execution.current_status()
+
+
+def test_complete_retry_with_step_details():
+ """Test complete_retry with operation that has step_details."""
+ step_details = StepDetails(
+ attempt=1, next_attempt_timestamp=datetime.now(timezone.utc)
+ )
+ step_op = Operation(
+ operation_id="op-1",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.PENDING,
+ step_details=step_details,
+ )
+
+ execution = Execution("test-arn", Mock(), [step_op])
+
+ result = execution.complete_retry("op-1")
+ assert result.status == OperationStatus.READY
+ assert result.step_details.next_attempt_timestamp is None
+
+
+def test_complete_retry_without_step_details():
+ """Test complete_retry with operation that has no step_details."""
+ step_op = Operation(
+ operation_id="op-1",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.PENDING,
+ step_details=None, # No step details
+ )
+
+ execution = Execution("test-arn", Mock(), [step_op])
+
+ result = execution.complete_retry("op-1")
+ assert result.status == OperationStatus.READY
+ assert result.step_details is None
+
+
+# endregion retry
+
+
+def test_from_dict_with_none_result():
+ """Test from_dict with None result."""
+ data = {
+ "DurableExecutionArn": "test-arn",
+ "StartInput": {"function_name": "test"},
+ "Operations": [],
+ "Updates": [],
+ "UsedTokens": [],
+ "TokenSequence": 0,
+ "IsComplete": False,
+ "Result": None, # None result
+ "ConsecutiveFailedInvocationAttempts": 0,
+ "CloseStatus": None,
+ }
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.model.StartDurableExecutionInput.from_dict"
+ ) as mock_from_dict:
+ mock_from_dict.return_value = Mock()
+ execution = Execution.from_json_dict(data)
+ assert execution.result is None
+
+
+# region callback
+def test_find_callback_operation_not_found():
+ """Test find_callback_operation raises exception when callback not found."""
+ execution = Execution("test-arn", Mock(), [])
+
+ with pytest.raises(
+ IllegalStateException,
+ match="Callback operation with callback_id \\[nonexistent\\] not found",
+ ):
+ execution.find_callback_operation("nonexistent")
+
+
+def test_complete_callback_success_not_started():
+ """Test complete_callback_success raises exception when callback not in STARTED state."""
+ # Create callback operation in wrong state
+ callback_op = Operation(
+ operation_id="op-1",
+ operation_type=OperationType.CALLBACK,
+ status=OperationStatus.SUCCEEDED, # Wrong state
+ callback_details=CallbackDetails(callback_id="test-id"),
+ )
+
+ execution = Execution("test-arn", Mock(), [callback_op])
+
+ with pytest.raises(
+ IllegalStateException,
+ match="Callback operation \\[test-id\\] is not in STARTED state",
+ ):
+ execution.complete_callback_success("test-id")
+
+
+def test_complete_callback_failure_not_started():
+ """Test complete_callback_failure raises exception when callback not in STARTED state."""
+ # Create callback operation in wrong state
+ callback_op = Operation(
+ operation_id="op-1",
+ operation_type=OperationType.CALLBACK,
+ status=OperationStatus.FAILED, # Wrong state
+ callback_details=CallbackDetails(callback_id="test-id"),
+ )
+
+ execution = Execution("test-arn", Mock(), [callback_op])
+ error = ErrorObject.from_message("test error")
+
+ with pytest.raises(
+ IllegalStateException,
+ match="Callback operation \\[test-id\\] is not in STARTED state",
+ ):
+ execution.complete_callback_failure("test-id", error)
+
+
+def test_complete_callback_success_no_callback_details():
+ """Test complete_callback_success with operation that has no callback_details."""
+ callback_details = CallbackDetails(callback_id="test-id")
+ callback_op = Operation(
+ operation_id="op-1",
+ operation_type=OperationType.CALLBACK,
+ status=OperationStatus.STARTED,
+ callback_details=callback_details,
+ )
+
+ execution = Execution("test-arn", Mock(), [callback_op])
+
+ # Test with None result
+ result = execution.complete_callback_success("test-id", None)
+ assert result.status == OperationStatus.SUCCEEDED
+
+
+def test_complete_callback_failure_no_callback_details():
+ """Test complete_callback_failure with operation that has no callback_details."""
+ callback_details = CallbackDetails(callback_id="test-id")
+ callback_op = Operation(
+ operation_id="op-1",
+ operation_type=OperationType.CALLBACK,
+ status=OperationStatus.STARTED,
+ callback_details=callback_details,
+ )
+
+ execution = Execution("test-arn", Mock(), [callback_op])
+ error = ErrorObject.from_message("test error")
+
+ # Test with actual callback details
+ result = execution.complete_callback_failure("test-id", error)
+ assert result.status == OperationStatus.FAILED
+
+
+# region callback - details
+
+
+def test_complete_callback_success_with_none_callback_details():
+ """Test complete_callback_success when operation has None callback_details."""
+ callback_op = Operation(
+ operation_id="op-1",
+ operation_type=OperationType.CALLBACK,
+ status=OperationStatus.STARTED,
+ callback_details=None, # None callback details
+ )
+
+ execution = Execution("test-arn", Mock(), [callback_op])
+
+ # Mock find_callback_operation to return this operation
+ execution.find_callback_operation = Mock(return_value=(0, callback_op))
+
+ result = execution.complete_callback_success("test-id", b"result")
+ assert result.status == OperationStatus.SUCCEEDED
+ assert result.callback_details is None
+
+
+def test_complete_callback_failure_with_none_callback_details():
+ """Test complete_callback_failure when operation has None callback_details."""
+ callback_op = Operation(
+ operation_id="op-1",
+ operation_type=OperationType.CALLBACK,
+ status=OperationStatus.STARTED,
+ callback_details=None, # None callback details
+ )
+
+ execution = Execution("test-arn", Mock(), [callback_op])
+ error = ErrorObject.from_message("test error")
+
+ # Mock find_callback_operation to return this operation
+ execution.find_callback_operation = Mock(return_value=(0, callback_op))
+
+ result = execution.complete_callback_failure("test-id", error)
+ assert result.status == OperationStatus.FAILED
+ assert result.callback_details is None
+
+
+def test_complete_callback_success_with_bytes_result():
+ """Test complete_callback_success with bytes result that gets decoded."""
+ callback_details = CallbackDetails(callback_id="test-id")
+ callback_op = Operation(
+ operation_id="op-1",
+ operation_type=OperationType.CALLBACK,
+ status=OperationStatus.STARTED,
+ callback_details=callback_details,
+ )
+
+ execution = Execution("test-arn", Mock(), [callback_op])
+
+ result = execution.complete_callback_success("test-id", b"test result")
+ assert result.status == OperationStatus.SUCCEEDED
+ assert result.callback_details.result == "test result"
+
+
+def test_complete_callback_success_with_none_result():
+ """Test complete_callback_success with None result."""
+ callback_details = CallbackDetails(callback_id="test-id")
+ callback_op = Operation(
+ operation_id="op-1",
+ operation_type=OperationType.CALLBACK,
+ status=OperationStatus.STARTED,
+ callback_details=callback_details,
+ )
+
+ execution = Execution("test-arn", Mock(), [callback_op])
+
+ result = execution.complete_callback_success("test-id", None)
+ assert result.status == OperationStatus.SUCCEEDED
+ assert result.callback_details.result is None
+
+
+# endregion callback -details
+
+# endregion callback
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/execution_wait_retry_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/execution_wait_retry_test.py
new file mode 100644
index 0000000..b0c9db3
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/execution_wait_retry_test.py
@@ -0,0 +1,80 @@
+"""Additional concurrent tests for wait and retry operations."""
+
+import threading
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from datetime import UTC, datetime
+
+from aws_durable_execution_sdk_python.lambda_service import (
+ Operation,
+ OperationStatus,
+ OperationType,
+ StepDetails,
+)
+
+from aws_durable_execution_sdk_python_testing.execution import Execution
+from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput
+
+
+def test_concurrent_wait_and_retry_completion():
+ """Test concurrent complete_wait and complete_retry operations."""
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-inv-id",
+ input='{"test": "data"}',
+ )
+ execution = Execution.new(input_data)
+
+ # Add WAIT and STEP operations
+ wait_op = Operation(
+ operation_id="wait-1",
+ parent_id=None,
+ name="test-wait",
+ start_timestamp=datetime.now(UTC),
+ operation_type=OperationType.WAIT,
+ status=OperationStatus.STARTED,
+ )
+
+ step_op = Operation(
+ operation_id="step-1",
+ parent_id=None,
+ name="test-step",
+ start_timestamp=datetime.now(UTC),
+ operation_type=OperationType.STEP,
+ status=OperationStatus.PENDING,
+ step_details=StepDetails(),
+ )
+
+ execution.operations.extend([wait_op, step_op])
+
+ results = []
+ results_lock = threading.Lock()
+
+ def complete_wait():
+ result = execution.complete_wait("wait-1")
+ with results_lock:
+ results.append(f"wait-completed-{result.status.value}")
+
+ def complete_retry():
+ result = execution.complete_retry("step-1")
+ with results_lock:
+ results.append(f"retry-completed-{result.status.value}")
+
+ with ThreadPoolExecutor(max_workers=2) as executor:
+ futures = []
+ futures.append(executor.submit(complete_wait))
+ futures.append(executor.submit(complete_retry))
+
+ for future in as_completed(futures):
+ future.result()
+
+ assert len(results) == 2
+ assert "wait-completed-SUCCEEDED" in results
+ assert "retry-completed-READY" in results
+
+ # Verify token sequence was incremented twice
+ assert execution.token_sequence == 2
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/executor_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/executor_test.py
new file mode 100644
index 0000000..e228a0d
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/executor_test.py
@@ -0,0 +1,2881 @@
+"""Unit tests for executor module."""
+
+import asyncio
+from datetime import UTC, datetime
+from unittest.mock import Mock, patch
+
+import pytest
+
+from aws_durable_execution_sdk_python.execution import (
+ DurableExecutionInvocationOutput,
+ InvocationStatus,
+)
+from aws_durable_execution_sdk_python.lambda_service import (
+ CallbackOptions,
+ OperationUpdate,
+ OperationAction,
+ OperationType,
+ Operation,
+ OperationStatus,
+ CallbackDetails,
+)
+from aws_durable_execution_sdk_python.lambda_service import (
+ ErrorObject,
+ ExecutionDetails,
+)
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ ExecutionAlreadyStartedException,
+ IllegalStateException,
+ InvalidParameterValueException,
+ ResourceNotFoundException,
+)
+from aws_durable_execution_sdk_python_testing.execution import (
+ ExecutionStatus,
+ Execution,
+)
+from aws_durable_execution_sdk_python_testing.executor import Executor
+from aws_durable_execution_sdk_python_testing.invoker import InvokeResponse
+from aws_durable_execution_sdk_python_testing.model import (
+ ListDurableExecutionsResponse,
+ SendDurableExecutionCallbackFailureResponse,
+ SendDurableExecutionCallbackHeartbeatResponse,
+ SendDurableExecutionCallbackSuccessResponse,
+ StartDurableExecutionInput,
+ StopDurableExecutionResponse,
+)
+from aws_durable_execution_sdk_python_testing.observer import (
+ ExecutionNotifier,
+ ExecutionObserver,
+)
+from aws_durable_execution_sdk_python_testing.token import (
+ CallbackToken,
+)
+
+
+class MockExecutionObserver(ExecutionObserver):
+ """Mock observer to capture execution events through public callbacks."""
+
+ def __init__(self):
+ self.completed_executions = {}
+ self.failed_executions = {}
+ self.wait_timers = {}
+ self.retry_schedules = {}
+ self.callback_creations = {}
+
+ def on_completed(self, execution_arn: str, result: str | None = None) -> None:
+ """Capture completion events."""
+ self.completed_executions[execution_arn] = result
+
+ def on_failed(self, execution_arn: str, error: ErrorObject) -> None:
+ """Capture failure events."""
+ self.failed_executions[execution_arn] = error
+
+ def on_wait_timer_scheduled(
+ self, execution_arn: str, operation_id: str, delay: float
+ ) -> None:
+ """Capture wait timer scheduling events."""
+ self.wait_timers[execution_arn] = {"operation_id": operation_id, "delay": delay}
+
+ def on_step_retry_scheduled(
+ self, execution_arn: str, operation_id: str, delay: float
+ ) -> None:
+ """Capture retry scheduling events."""
+ self.retry_schedules[execution_arn] = {
+ "operation_id": operation_id,
+ "delay": delay,
+ }
+
+ def on_callback_created(
+ self,
+ execution_arn: str,
+ operation_id: str,
+ callback_options: CallbackOptions | None,
+ callback_token: CallbackToken,
+ ) -> None:
+ """Capture callback creation events."""
+ self.callback_creations[execution_arn] = {
+ "operation_id": operation_id,
+ "callback_id": callback_token.to_str(),
+ }
+
+ def on_callback_completed(
+ self, execution_arn: str, operation_id: str, callback_id: str
+ ) -> None:
+ """Capture callback completion events."""
+ pass # Not needed for current tests
+
+ def on_timed_out(self, execution_arn: str, error: ErrorObject) -> None:
+ """Capture timeout events."""
+ pass # Not needed for current tests
+
+ def on_stopped(self, execution_arn: str, error: ErrorObject) -> None:
+ """Capture stop events."""
+ pass # Not needed for current tests
+
+
+@pytest.fixture
+def test_observer():
+ return MockExecutionObserver()
+
+
+@pytest.fixture
+def mock_store():
+ return Mock()
+
+
+@pytest.fixture
+def mock_scheduler():
+ return Mock()
+
+
+@pytest.fixture
+def mock_invoker():
+ return Mock()
+
+
+@pytest.fixture
+def mock_checkpoint_processor():
+ return Mock()
+
+
+@pytest.fixture
+def executor(mock_store, mock_scheduler, mock_invoker, mock_checkpoint_processor):
+ return Executor(mock_store, mock_scheduler, mock_invoker, mock_checkpoint_processor)
+
+
+@pytest.fixture
+def start_input():
+ return StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ )
+
+
+@pytest.fixture
+def mock_execution():
+ execution = Mock(spec=Execution)
+ execution.durable_execution_arn = "arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:test-execution"
+ execution.is_complete = False
+ execution.consecutive_failed_invocation_attempts = 0
+ execution.start_input = Mock()
+ execution.start_input.function_name = "test-function"
+ return execution
+
+
+def test_init(mock_store, mock_scheduler, mock_invoker, mock_checkpoint_processor):
+ # Test that Executor can be constructed with dependencies
+ # Dependency injection is implementation detail - test behavior instead
+ executor = Executor(
+ mock_store, mock_scheduler, mock_invoker, mock_checkpoint_processor
+ )
+
+ # Verify executor is properly initialized by testing it can perform basic operations
+ assert executor is not None
+
+ # Test that the executor uses the injected dependencies by verifying behavior
+ # This will be covered by other tests that exercise the executor's functionality
+
+
+@patch("aws_durable_execution_sdk_python_testing.executor.Execution")
+def test_start_execution(
+ mock_execution_class, executor, start_input, mock_store, mock_scheduler
+):
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution_class.new.return_value = mock_execution
+ mock_event = Mock()
+ mock_scheduler.create_event.return_value = mock_event
+
+ with patch.object(executor, "_invoke_execution") as mock_invoke:
+ result = executor.start_execution(start_input)
+
+ # Test observable behavior through public API
+ # The executor should generate an invocation_id if not provided
+ call_args = mock_execution_class.new.call_args
+ actual_input = call_args.kwargs["input"]
+
+ # Verify all fields match except invocation_id should be generated
+ assert actual_input.account_id == start_input.account_id
+ assert actual_input.function_name == start_input.function_name
+ assert actual_input.function_qualifier == start_input.function_qualifier
+ assert actual_input.execution_name == start_input.execution_name
+ assert (
+ actual_input.execution_timeout_seconds == start_input.execution_timeout_seconds
+ )
+ assert (
+ actual_input.execution_retention_period_days
+ == start_input.execution_retention_period_days
+ )
+ assert actual_input.invocation_id is not None # Should be generated
+ assert actual_input.trace_fields == start_input.trace_fields
+ assert actual_input.tenant_id == start_input.tenant_id
+ assert actual_input.input == start_input.input
+ mock_execution.start.assert_called_once()
+ mock_store.save.assert_called_once_with(mock_execution)
+ mock_scheduler.create_event.assert_called_once()
+
+ # Verify execution timeout was scheduled
+ assert mock_scheduler.call_later.called
+ timeout_call = mock_scheduler.call_later.call_args
+ assert timeout_call.kwargs["delay"] == start_input.execution_timeout_seconds
+ assert timeout_call.kwargs["completion_event"] == mock_event
+
+ mock_invoke.assert_called_once_with("test-arn")
+ assert result.execution_arn == "test-arn"
+
+ # Test that completion event was created by verifying wait_until_complete works
+ # This tests the same functionality without accessing private members
+ mock_event.wait.return_value = True
+ wait_result = executor.wait_until_complete("test-arn", timeout=1)
+ assert wait_result is True
+ mock_event.wait.assert_called_once_with(1)
+
+
+@patch("aws_durable_execution_sdk_python_testing.executor.Execution")
+def test_start_execution_with_provided_invocation_id(
+ mock_execution_class, executor, mock_store, mock_scheduler
+):
+ # Create input with invocation_id already provided
+ provided_invocation_id = "user-provided-id-123"
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id=provided_invocation_id,
+ )
+
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution_class.new.return_value = mock_execution
+ mock_event = Mock()
+ mock_scheduler.create_event.return_value = mock_event
+
+ with patch.object(executor, "_invoke_execution") as mock_invoke:
+ result = executor.start_execution(start_input)
+
+ # Should use the provided invocation_id unchanged
+ mock_execution_class.new.assert_called_once_with(input=start_input)
+ mock_execution.start.assert_called_once()
+ mock_store.save.assert_called_once_with(mock_execution)
+ mock_scheduler.create_event.assert_called_once()
+ mock_invoke.assert_called_once_with("test-arn")
+ assert result.execution_arn == "test-arn"
+
+ mock_execution = Mock()
+ mock_store.load.return_value = mock_execution
+
+ result = executor.get_execution("test-arn")
+
+ mock_store.load.assert_called_once_with("test-arn")
+ assert result == mock_execution
+
+
+def test_should_complete_workflow_with_error_when_invocation_fails(
+ executor, mock_store, mock_scheduler, mock_invoker, start_input
+):
+ """Test that failed invocation responses trigger workflow completion with error."""
+ # Arrange
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution.is_complete = False
+ mock_execution.start_input = start_input
+ mock_execution.consecutive_failed_invocation_attempts = 0
+
+ # Mock invoker to return failed response
+ mock_invocation_input = Mock()
+ mock_invoker.create_invocation_input.return_value = mock_invocation_input
+ failed_response = DurableExecutionInvocationOutput(
+ status=InvocationStatus.FAILED, error=ErrorObject.from_message("Test error")
+ )
+ mock_invoker.invoke.return_value = InvokeResponse(
+ invocation_output=failed_response, request_id="test-request-id"
+ )
+
+ # Mock execution creation and store behavior
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution_class.new.return_value = mock_execution
+ mock_store.load.return_value = mock_execution
+
+ # Mock the workflow completion methods
+ with patch.object(executor, "fail_execution") as mock_fail:
+ # Act - trigger invocation through public start_execution method
+ executor.start_execution(start_input)
+
+ # Get the handler that was passed to the scheduler and execute it manually
+ assert mock_scheduler.call_later.call_count >= 1
+ handler = mock_scheduler.call_later.call_args_list[-1][0][0]
+
+ # Execute the handler to trigger the invocation logic
+ import asyncio
+
+ asyncio.run(handler())
+
+ # Assert - verify workflow was completed with error
+ mock_fail.assert_called_once_with("test-arn", failed_response.error)
+
+
+def test_should_complete_workflow_with_result_when_invocation_succeeds(
+ executor, mock_store, mock_scheduler, mock_invoker, start_input
+):
+ """Test that successful invocation responses trigger workflow completion with result."""
+ # Arrange
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution.is_complete = False
+ mock_execution.start_input = start_input
+ mock_execution.consecutive_failed_invocation_attempts = 0
+
+ # Mock invoker to return successful response
+ mock_invocation_input = Mock()
+ mock_invoker.create_invocation_input.return_value = mock_invocation_input
+ success_response = DurableExecutionInvocationOutput(
+ status=InvocationStatus.SUCCEEDED, result="success result"
+ )
+ mock_invoker.invoke.return_value = InvokeResponse(
+ invocation_output=success_response, request_id="test-request-id"
+ )
+
+ # Mock execution creation and store behavior
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution_class.new.return_value = mock_execution
+ mock_store.load.return_value = mock_execution
+
+ # Mock the workflow completion methods
+ with patch.object(executor, "complete_execution") as mock_complete:
+ # Act - trigger invocation through public start_execution method
+ executor.start_execution(start_input)
+
+ # Get the handler that was passed to the scheduler and execute it manually
+ assert mock_scheduler.call_later.call_count >= 1
+ handler = mock_scheduler.call_later.call_args_list[-1][0][0]
+
+ # Execute the handler to trigger the invocation logic
+ import asyncio
+
+ asyncio.run(handler())
+
+ # Assert - verify workflow was completed with result
+ mock_complete.assert_called_once_with("test-arn", "success result")
+
+
+def test_should_handle_pending_status_when_operations_exist(
+ executor, mock_store, mock_scheduler, mock_invoker, start_input
+):
+ """Test that pending invocation responses are handled when operations exist."""
+ # Arrange
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution.is_complete = False
+ mock_execution.start_input = start_input
+ mock_execution.consecutive_failed_invocation_attempts = 0
+ mock_execution.has_pending_operations.return_value = True
+
+ # Mock invoker to return pending response
+ mock_invocation_input = Mock()
+ mock_invoker.create_invocation_input.return_value = mock_invocation_input
+ pending_response = DurableExecutionInvocationOutput(status=InvocationStatus.PENDING)
+ mock_invoker.invoke.return_value = InvokeResponse(
+ invocation_output=pending_response, request_id="test-request-id"
+ )
+
+ # Mock execution creation and store behavior
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution_class.new.return_value = mock_execution
+ mock_store.load.return_value = mock_execution
+
+ # Act - trigger invocation through public start_execution method
+ executor.start_execution(start_input)
+
+ # Get the handler that was passed to the scheduler and execute it manually
+ assert mock_scheduler.call_later.call_count >= 1
+ handler = mock_scheduler.call_later.call_args_list[-1][0][0]
+
+ # Execute the handler to trigger the invocation logic
+ import asyncio
+
+ asyncio.run(handler())
+
+ # Assert - verify pending operations were checked
+ mock_execution.has_pending_operations.assert_called_once_with(mock_execution)
+
+
+def test_should_ignore_response_when_execution_already_complete(
+ executor, mock_store, mock_scheduler, mock_invoker, start_input
+):
+ """Test that responses are ignored when execution is already complete."""
+ # Arrange
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution.is_complete = True # Already complete
+ mock_execution.start_input = start_input
+
+ # Mock invoker - this shouldn't be called since execution is complete
+ mock_invoker.create_invocation_input.return_value = Mock()
+ mock_invoker.invoke.return_value = (
+ DurableExecutionInvocationOutput(status=InvocationStatus.SUCCEEDED),
+ "test-request-id",
+ )
+
+ # Mock execution creation and store behavior
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution_class.new.return_value = mock_execution
+ mock_store.load.return_value = mock_execution
+
+ # Act - trigger invocation through public start_execution method
+ executor.start_execution(start_input)
+
+ # Get the handler that was passed to the scheduler and execute it manually
+ assert mock_scheduler.call_later.call_count >= 1
+ handler = mock_scheduler.call_later.call_args_list[-1][0][0]
+
+ # Execute the handler to trigger the invocation logic
+ import asyncio
+
+ asyncio.run(handler())
+
+ # Assert - verify invoker was not called since execution was already complete
+ mock_invoker.create_invocation_input.assert_not_called()
+ mock_invoker.invoke.assert_not_called()
+
+
+def test_should_retry_when_response_has_no_status(
+ executor, mock_store, mock_scheduler, mock_invoker, start_input
+):
+ """Test that invocation responses without status trigger retry logic."""
+ # Arrange
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution.is_complete = False
+ mock_execution.start_input = start_input
+ mock_execution.consecutive_failed_invocation_attempts = 0
+
+ # Mock invoker to return response without status
+ mock_invocation_input = Mock()
+ mock_invoker.create_invocation_input.return_value = mock_invocation_input
+ no_status_response = DurableExecutionInvocationOutput(status=None)
+ mock_invoker.invoke.return_value = InvokeResponse(
+ invocation_output=no_status_response, request_id="test-request-id"
+ )
+
+ # Mock execution creation and store behavior
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution_class.new.return_value = mock_execution
+ mock_store.load.return_value = mock_execution
+
+ # Act - trigger invocation through public start_execution method
+ executor.start_execution(start_input)
+
+ # Get the handler that was passed to the scheduler and execute it manually
+ assert mock_scheduler.call_later.call_count >= 1
+ handler = mock_scheduler.call_later.call_args_list[-1][0][0]
+
+ # Execute the handler to trigger the invocation logic
+ asyncio.run(handler())
+
+ # Assert - verify retry was triggered due to validation error
+ assert mock_execution.consecutive_failed_invocation_attempts == 1
+ mock_store.save.assert_called_with(mock_execution)
+ # Verify retry was scheduled (call_later should be called 3 times: timeout + initial + retry)
+ assert mock_scheduler.call_later.call_count == 3
+
+
+def test_should_retry_when_failed_response_has_result(
+ executor, mock_store, mock_scheduler, mock_invoker, start_input
+):
+ """Test that failed responses with result trigger retry logic."""
+ # Arrange
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution.is_complete = False
+ mock_execution.start_input = start_input
+ mock_execution.consecutive_failed_invocation_attempts = 0
+
+ # Mock invoker to return invalid failed response (with result)
+ mock_invocation_input = Mock()
+ mock_invoker.create_invocation_input.return_value = mock_invocation_input
+ invalid_response = DurableExecutionInvocationOutput(
+ status=InvocationStatus.FAILED, result="should not have result"
+ )
+ mock_invoker.invoke.return_value = InvokeResponse(
+ invocation_output=invalid_response, request_id="test-request-id"
+ )
+
+ # Mock execution creation and store behavior
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution_class.new.return_value = mock_execution
+ mock_store.load.return_value = mock_execution
+
+ # Act - trigger invocation through public start_execution method
+ executor.start_execution(start_input)
+
+ # Get the handler that was passed to the scheduler and execute it manually
+ assert mock_scheduler.call_later.call_count >= 1
+ handler = mock_scheduler.call_later.call_args_list[-1][0][0]
+
+ # Execute the handler to trigger the invocation logic
+ asyncio.run(handler())
+
+ # Assert - verify retry was triggered due to validation error
+ assert mock_execution.consecutive_failed_invocation_attempts == 1
+ mock_store.save.assert_called_with(mock_execution)
+ # Verify retry was scheduled (call_later should be called 3 times: timeout + initial + retry)
+ assert mock_scheduler.call_later.call_count == 3
+
+
+def test_should_retry_when_success_response_has_error(
+ executor, mock_store, mock_scheduler, mock_invoker, start_input
+):
+ """Test that successful responses with error trigger retry logic."""
+ # Arrange
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution.is_complete = False
+ mock_execution.start_input = start_input
+ mock_execution.consecutive_failed_invocation_attempts = 0
+
+ # Mock invoker to return invalid success response (with error)
+ mock_invocation_input = Mock()
+ mock_invoker.create_invocation_input.return_value = mock_invocation_input
+ invalid_response = DurableExecutionInvocationOutput(
+ status=InvocationStatus.SUCCEEDED,
+ error=ErrorObject.from_message("should not have error"),
+ )
+ mock_invoker.invoke.return_value = InvokeResponse(
+ invocation_output=invalid_response, request_id="test-request-id"
+ )
+
+ # Mock execution creation and store behavior
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution_class.new.return_value = mock_execution
+ mock_store.load.return_value = mock_execution
+
+ # Act - trigger invocation through public start_execution method
+ executor.start_execution(start_input)
+
+ # Get the handler that was passed to the scheduler and execute it manually
+ assert mock_scheduler.call_later.call_count >= 1
+ handler = mock_scheduler.call_later.call_args_list[-1][0][0]
+
+ # Execute the handler to trigger the invocation logic
+ asyncio.run(handler())
+
+ # Assert - verify retry was triggered due to validation error
+ assert mock_execution.consecutive_failed_invocation_attempts == 1
+ mock_store.save.assert_called_with(mock_execution)
+ # Verify retry was scheduled (call_later should be called 3 times: timeout + initial + retry)
+ assert mock_scheduler.call_later.call_count == 3
+
+
+def test_should_retry_when_pending_response_has_no_operations(
+ executor, mock_store, mock_scheduler, mock_invoker, start_input
+):
+ """Test that pending responses without operations trigger retry logic."""
+ # Arrange
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution.is_complete = False
+ mock_execution.start_input = start_input
+ mock_execution.consecutive_failed_invocation_attempts = 0
+ mock_execution.has_pending_operations.return_value = False # No pending operations
+
+ # Mock invoker to return pending response
+ mock_invocation_input = Mock()
+ mock_invoker.create_invocation_input.return_value = mock_invocation_input
+ pending_response = DurableExecutionInvocationOutput(status=InvocationStatus.PENDING)
+ mock_invoker.invoke.return_value = InvokeResponse(
+ invocation_output=pending_response, request_id="test-request-id"
+ )
+
+ # Mock execution creation and store behavior
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution_class.new.return_value = mock_execution
+ mock_store.load.return_value = mock_execution
+
+ # Act - trigger invocation through public start_execution method
+ executor.start_execution(start_input)
+
+ # Get the handler that was passed to the scheduler and execute it manually
+ assert mock_scheduler.call_later.call_count >= 1
+ handler = mock_scheduler.call_later.call_args_list[-1][0][0]
+
+ # Execute the handler to trigger the invocation logic
+ asyncio.run(handler())
+
+ # Assert - verify retry was triggered due to validation error
+ assert mock_execution.consecutive_failed_invocation_attempts == 1
+ mock_store.save.assert_called_with(mock_execution)
+ # Verify retry was scheduled (call_later should be called 3 times: timeout + initial + retry)
+ assert mock_scheduler.call_later.call_count == 3
+
+
+def test_invoke_handler_success(
+ executor, mock_store, mock_scheduler, mock_invoker, start_input
+):
+ """Test successful invocation through public API."""
+ # Arrange
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution.is_complete = False
+ mock_execution.start_input = start_input
+
+ mock_invocation_input = Mock()
+ mock_invoker.create_invocation_input.return_value = mock_invocation_input
+ mock_response = DurableExecutionInvocationOutput(
+ status=InvocationStatus.SUCCEEDED, result="test"
+ )
+ mock_invoker.invoke.return_value = InvokeResponse(
+ invocation_output=mock_response, request_id="test-request-id"
+ )
+
+ # Mock execution creation and store behavior
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution_class.new.return_value = mock_execution
+ mock_store.load.return_value = mock_execution
+
+ # Act - trigger invocation through public start_execution method
+ executor.start_execution(start_input)
+
+ # Get the handler that was passed to the scheduler and execute it manually
+ assert mock_scheduler.call_later.call_count >= 1
+ handler = mock_scheduler.call_later.call_args_list[-1][0][0]
+
+ # Execute the handler to trigger the invocation logic
+ asyncio.run(handler())
+
+ # Verify the invocation process was executed
+ mock_invoker.create_invocation_input.assert_called_once_with(
+ execution=mock_execution
+ )
+ mock_invoker.invoke.assert_called_once_with(
+ "test-function", mock_invocation_input, None
+ )
+
+
+def test_invoke_handler_execution_already_complete(
+ executor, mock_store, mock_scheduler, mock_invoker, start_input
+):
+ """Test that completed executions are handled properly through public API."""
+ # Arrange
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution.is_complete = True
+ mock_execution.start_input = start_input
+
+ # Mock execution creation and store behavior
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution_class.new.return_value = mock_execution
+ mock_store.load.return_value = mock_execution
+
+ # Act - trigger invocation through public start_execution method
+ executor.start_execution(start_input)
+
+ # Get the handler that was passed to the scheduler and execute it manually
+ assert mock_scheduler.call_later.call_count >= 1
+ handler = mock_scheduler.call_later.call_args_list[-1][0][0]
+
+ # Execute the handler to trigger the invocation logic
+ asyncio.run(handler())
+
+ # Verify store was accessed to check execution status
+ mock_store.load.assert_called_with("test-arn")
+
+
+def test_invoke_handler_execution_completed_during_invocation(
+ executor, mock_store, mock_scheduler, mock_invoker, start_input
+):
+ """Test execution completing during invocation through public API."""
+ # Arrange
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution.is_complete = False
+ mock_execution.start_input = start_input
+
+ mock_invocation_input = Mock()
+ mock_invoker.create_invocation_input.return_value = mock_invocation_input
+ mock_response = Mock()
+ mock_invoker.invoke.return_value = InvokeResponse(
+ invocation_output=mock_response, request_id="test-request-id"
+ )
+
+ # Create a completed execution mock
+ completed_execution = Mock()
+ completed_execution.durable_execution_arn = "test-arn"
+ completed_execution.is_complete = True
+ completed_execution.start_input = start_input
+
+ # First call returns incomplete execution, second call returns completed execution
+ mock_store.load.side_effect = [mock_execution, completed_execution]
+
+ # Mock execution creation
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution_class.new.return_value = mock_execution
+
+ # Act - trigger invocation through public start_execution method
+ executor.start_execution(start_input)
+
+ # Get the handler that was passed to the scheduler and execute it manually
+ assert mock_scheduler.call_later.call_count >= 1
+ handler = mock_scheduler.call_later.call_args_list[-1][0][0]
+
+ # Execute the handler to trigger the invocation logic
+ asyncio.run(handler())
+
+ # Verify the execution was checked for completion
+ assert mock_store.load.call_count >= 2
+
+
+def test_invoke_handler_resource_not_found(
+ executor, mock_store, mock_scheduler, mock_invoker, start_input
+):
+ """Test resource not found handling causes workflow failure through public API."""
+ # Arrange
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution.is_complete = False
+ mock_execution.start_input = start_input
+
+ mock_invoker.create_invocation_input.side_effect = ResourceNotFoundException(
+ "Function not found"
+ )
+
+ # Mock execution creation and store behavior
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution_class.new.return_value = mock_execution
+ mock_store.load.return_value = mock_execution
+
+ # Mock the public fail_execution method to verify it gets called
+ with patch.object(executor, "fail_execution") as mock_fail:
+ # Act - trigger invocation through public start_execution method
+ executor.start_execution(start_input)
+
+ # Get the handler that was passed to the scheduler and execute it manually
+ assert mock_scheduler.call_later.call_count >= 1
+ handler = mock_scheduler.call_later.call_args_list[-1][0][0]
+
+ # Execute the handler to trigger the invocation logic
+ asyncio.run(handler())
+
+ # Assert - verify workflow failure was triggered through public API
+ mock_fail.assert_called_once()
+ # Verify the error contains the expected message
+ call_args = mock_fail.call_args
+ assert call_args[0][0] == "test-arn" # execution_arn is first positional arg
+ assert "Function not found" in str(
+ call_args[0][1]
+ ) # error is second positional arg
+
+
+def test_invoke_handler_general_exception(
+ executor, mock_store, mock_scheduler, mock_invoker, start_input
+):
+ """Test general exception handling triggers retry through public API."""
+ # Arrange
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution.is_complete = False
+ mock_execution.start_input = start_input
+ mock_execution.consecutive_failed_invocation_attempts = 0
+
+ # Configure invoker to fail
+ mock_invoker.create_invocation_input.side_effect = Exception("General error")
+
+ # Mock execution creation and store behavior
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution_class.new.return_value = mock_execution
+ mock_store.load.return_value = mock_execution
+
+ # Act - trigger invocation through public start_execution method
+ executor.start_execution(start_input)
+
+ # Get the handler that was passed to the scheduler and execute it manually
+ assert mock_scheduler.call_later.call_count >= 1
+ handler = mock_scheduler.call_later.call_args_list[-1][0][0]
+
+ # Execute the handler to trigger the invocation logic
+ asyncio.run(handler())
+
+ # Assert - verify retry was scheduled through observable behavior
+ assert mock_execution.consecutive_failed_invocation_attempts == 1
+ mock_store.save.assert_called_with(mock_execution)
+ # Verify retry was scheduled (call_later should be called 3 times: timeout + initial + retry)
+ assert mock_scheduler.call_later.call_count == 3
+
+
+def test_invoke_execution_through_start_execution(
+ executor, mock_scheduler, start_input
+):
+ """Test execution invocation behavior through public start_execution method."""
+ mock_event = Mock()
+ mock_scheduler.create_event.return_value = mock_event
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution_class.new.return_value = mock_execution
+
+ # Start execution which internally calls _invoke_execution
+ executor.start_execution(start_input)
+
+ # Verify scheduler was called with the completion event
+ mock_scheduler.call_later.assert_called()
+ args = mock_scheduler.call_later.call_args
+ assert args[1]["delay"] == 0 # Initial invocation has no delay
+ assert args[1]["completion_event"] == mock_event
+
+
+def test_should_complete_workflow_successfully_through_public_api(
+ executor, mock_store, mock_execution
+):
+ """Test workflow completion through public complete_execution method."""
+ # Arrange
+ mock_execution.result = "test result" # Mock result after completion
+ mock_store.load.return_value = mock_execution
+
+ with patch.object(executor, "_complete_events") as mock_complete_events:
+ # Act - Use public API to complete workflow
+ executor.complete_execution("test-arn", "result")
+
+ # Assert - Verify final execution status and stored results
+ mock_store.load.assert_called_once_with(execution_arn="test-arn")
+ mock_execution.complete_success.assert_called_once_with(result="result")
+ mock_store.update.assert_called_once_with(mock_execution)
+ mock_complete_events.assert_called_once_with(execution_arn="test-arn")
+
+
+def test_should_complete_workflow_with_failure_through_public_api(
+ executor, mock_store, mock_execution
+):
+ """Test workflow failure completion through public fail_execution method."""
+ # Arrange
+ error = ErrorObject.from_message("test error")
+ mock_execution.result = "error result" # Mock result after failure
+ mock_store.load.return_value = mock_execution
+
+ with patch.object(executor, "_complete_events") as mock_complete_events:
+ # Act - Use public API to fail workflow
+ executor.fail_execution("test-arn", error)
+
+ # Assert - Verify final execution status and stored error
+ mock_store.load.assert_called_once_with(execution_arn="test-arn")
+ mock_execution.complete_fail.assert_called_once_with(error=error)
+ mock_store.update.assert_called_once_with(mock_execution)
+ mock_complete_events.assert_called_once_with(execution_arn="test-arn")
+
+
+def test_should_handle_workflow_completion_state_through_public_api(
+ executor, mock_store, mock_execution
+):
+ """Test workflow completion behavior and state management through public API."""
+ # Arrange
+ mock_execution.result = "final result" # Mock result after completion
+ mock_store.load.return_value = mock_execution
+
+ with patch.object(executor, "_complete_events") as mock_complete_events:
+ # Act - Complete workflow through public API
+ executor.complete_execution("test-arn", "result")
+
+ # Assert - Verify completion was processed and observer notifications sent
+ mock_store.load.assert_called_once_with(execution_arn="test-arn")
+ mock_execution.complete_success.assert_called_once_with(result="result")
+ mock_store.update.assert_called_once_with(mock_execution)
+ mock_complete_events.assert_called_once_with(execution_arn="test-arn")
+
+
+def test_should_fail_execution_when_function_not_found(
+ executor, mock_store, mock_scheduler, mock_invoker, start_input
+):
+ """Test that workflow fails when function is not found during invocation."""
+ # Arrange
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution.is_complete = False
+ mock_execution.start_input = start_input
+ mock_execution.consecutive_failed_invocation_attempts = 0
+
+ # Mock invoker to raise function not found error
+ mock_invoker.create_invocation_input.side_effect = ResourceNotFoundException(
+ "Function not found: test_function"
+ )
+
+ # Mock execution creation and store behavior
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution_class.new.return_value = mock_execution
+ mock_store.load.return_value = mock_execution
+
+ with patch.object(executor, "fail_execution") as mock_fail:
+ # Act - trigger invocation through public start_execution method
+ executor.start_execution(start_input)
+
+ # Get the handler that was passed to the scheduler and execute it manually
+ assert mock_scheduler.call_later.call_count >= 1
+ handler = mock_scheduler.call_later.call_args_list[-1][0][0]
+
+ # Execute the handler to trigger the invocation logic
+ import asyncio
+
+ asyncio.run(handler())
+
+ # Assert - verify failure was triggered with correct error
+ mock_fail.assert_called_once()
+ call_args = mock_fail.call_args
+ assert call_args[0][0] == "test-arn" # execution_arn
+ assert "Function not found" in call_args[0][1].message # error message
+
+
+def test_should_fail_execution_when_retries_exhausted(
+ executor, mock_store, mock_scheduler, mock_invoker, start_input
+):
+ """Test that workflow fails when maximum retry attempts are exhausted."""
+ # Arrange
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution.is_complete = False
+ mock_execution.start_input = start_input
+ mock_execution.consecutive_failed_invocation_attempts = (
+ executor.MAX_CONSECUTIVE_FAILED_ATTEMPTS + 1
+ )
+
+ # Mock invoker to raise exception (simulating network/invocation failure)
+ mock_invoker.create_invocation_input.side_effect = Exception("Network error")
+
+ # Mock execution creation and store behavior
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution_class.new.return_value = mock_execution
+ mock_store.load.return_value = mock_execution
+
+ with patch.object(executor, "fail_execution") as mock_fail:
+ # Act - trigger invocation through public start_execution method
+ # This will cause an exception during invocation, which triggers retry logic
+ executor.start_execution(start_input)
+
+ # Get the handler that was passed to the scheduler and execute it manually
+ assert mock_scheduler.call_later.call_count >= 1
+ handler = mock_scheduler.call_later.call_args_list[-1][0][0]
+
+ # Execute the handler to trigger the invocation logic
+ import asyncio
+
+ asyncio.run(handler())
+
+ # Assert - verify failure was triggered when retries exhausted
+ mock_fail.assert_called_once()
+ call_args = mock_fail.call_args
+ assert call_args[0][0] == "test-arn" # execution_arn
+
+
+def test_should_prevent_multiple_workflow_failures_on_complete_execution(
+ executor, mock_store, mock_scheduler, mock_invoker, start_input
+):
+ """Test that attempting to fail an already completed execution raises an exception."""
+ # Arrange - execution starts incomplete but becomes complete during processing
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution.is_complete = False # Initially incomplete
+ mock_execution.start_input = start_input
+ mock_execution.consecutive_failed_invocation_attempts = 0
+
+ # Create a completed execution for the _fail_workflow call
+ completed_execution = Mock()
+ completed_execution.is_complete = True
+
+ # Mock invoker to raise ResourceNotFoundException (triggers _fail_workflow)
+ mock_invoker.create_invocation_input.side_effect = ResourceNotFoundException(
+ "Function not found"
+ )
+
+ # Mock execution creation and store behavior
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution_class.new.return_value = mock_execution
+ # First load returns incomplete, second load (in _fail_workflow) returns complete
+ mock_store.load.side_effect = [mock_execution, completed_execution]
+
+ # Act & Assert - triggering workflow failure on completed execution should raise exception
+ executor.start_execution(start_input)
+ handler = mock_scheduler.call_later.call_args_list[-1][0][0]
+
+ # Execute the handler to trigger the invocation logic - this should raise the exception
+ with pytest.raises(
+ IllegalStateException, match="Cannot make multiple close workflow decisions"
+ ):
+ asyncio.run(handler())
+
+
+def test_should_retry_invocation_when_under_limit_through_public_api(
+ executor, mock_store, mock_scheduler, mock_invoker, start_input
+):
+ """Test that invocation retries when under limit through public API with final outcome verification."""
+ # Arrange - Set up execution that will trigger retry logic
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution.is_complete = False
+ mock_execution.start_input = start_input
+ mock_execution.consecutive_failed_invocation_attempts = 3 # Under limit (5 is max)
+
+ # Configure invoker to fail initially with validation error, then succeed on retry
+ mock_invocation_input = Mock()
+ mock_invoker.create_invocation_input.return_value = mock_invocation_input
+
+ # First invocation: invalid response triggers retry
+ invalid_response = DurableExecutionInvocationOutput(
+ status=InvocationStatus.FAILED,
+ result="should not have result", # Invalid: failed response with result
+ )
+ # Second invocation: valid success response
+ success_response = DurableExecutionInvocationOutput(
+ status=InvocationStatus.SUCCEEDED, result="final success"
+ )
+ mock_invoker.invoke.side_effect = [
+ InvokeResponse(
+ invocation_output=invalid_response, request_id="test-request-id-1"
+ ),
+ InvokeResponse(
+ invocation_output=success_response, request_id="test-request-id-2"
+ ),
+ ]
+
+ # Mock execution creation and store behavior
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution_class.new.return_value = mock_execution
+ mock_store.load.return_value = mock_execution
+
+ # Act - trigger the retry scenario through public API
+ executor.start_execution(start_input)
+
+ # Simulate scheduler executing the initial invocation handler
+ initial_handler = mock_scheduler.call_later.call_args_list[-1][0][0]
+ import asyncio
+
+ asyncio.run(initial_handler())
+
+ # Verify retry was scheduled due to validation error
+ assert mock_scheduler.call_later.call_count == 3 # timeout + initial + retry
+ retry_call = mock_scheduler.call_later.call_args_list[
+ 2
+ ] # Third call is the retry
+ retry_handler = retry_call[0][0]
+ retry_delay = retry_call[1]["delay"]
+
+ # Execute the retry handler to complete the scenario
+ asyncio.run(retry_handler())
+
+ # Assert - verify final outcome after retry sequence
+ assert (
+ mock_execution.consecutive_failed_invocation_attempts == 4
+ ) # Incremented from 3 to 4
+ assert retry_delay == Executor.RETRY_BACKOFF_SECONDS # Correct backoff delay used
+ mock_store.save.assert_called_with(mock_execution) # Execution state saved
+ assert mock_invoker.invoke.call_count == 2 # Initial + retry invocation
+
+
+def test_should_fail_workflow_when_retry_limit_exceeded(
+ executor, mock_store, mock_scheduler, mock_invoker, start_input
+):
+ """Test that workflow fails when retry limit is exceeded through public API."""
+ # Arrange
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution.is_complete = False
+ mock_execution.start_input = start_input
+ mock_execution.consecutive_failed_invocation_attempts = 6 # Over limit
+
+ # Mock invoker to consistently fail
+ mock_invoker.create_invocation_input.side_effect = Exception("Persistent error")
+ mock_store.load.return_value = mock_execution
+
+ # Mock execution creation
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution_class.new.return_value = mock_execution
+
+ # Mock the public fail_execution method to verify it gets called
+ with patch.object(executor, "fail_execution") as mock_fail:
+ # Act - trigger execution that will exceed retry limit
+ executor.start_execution(start_input)
+
+ # Get the handler that was passed to the scheduler and execute it manually
+ assert mock_scheduler.call_later.call_count >= 1
+ handler = mock_scheduler.call_later.call_args_list[-1][0][0]
+
+ # Execute the handler to trigger the invocation logic
+ asyncio.run(handler())
+
+ # Assert - verify workflow failed due to retry limit exceeded
+ mock_fail.assert_called_once()
+ # Verify the error contains the expected message
+ call_args = mock_fail.call_args
+ assert call_args[0][0] == "test-arn" # execution_arn is first positional arg
+ assert "Persistent error" in str(
+ call_args[0][1]
+ ) # error is second positional arg
+
+
+def test_complete_events_through_complete_execution(
+ executor, mock_store, mock_scheduler
+):
+ """Test completion event behavior through public complete_execution method."""
+ mock_execution = Mock()
+ mock_execution.result = "test result"
+ mock_store.load.return_value = mock_execution
+
+ # Set up completion event through start_execution
+ mock_event = Mock()
+ mock_scheduler.create_event.return_value = mock_event
+
+ # Mock the timeout future that will be created
+ mock_timeout_future = Mock()
+ mock_scheduler.call_later.return_value = mock_timeout_future
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_exec = Mock()
+ mock_exec.durable_execution_arn = "test-arn"
+ mock_execution_class.new.return_value = mock_exec
+
+ start_input = Mock()
+ start_input.execution_timeout_seconds = 300
+ executor.start_execution(start_input)
+
+ # Now complete the execution - this should trigger event.set() and cancel timeout
+ executor.complete_execution("test-arn", "result")
+
+ # Verify the event was set and timeout was cancelled
+ mock_event.set.assert_called_once()
+ mock_timeout_future.cancel.assert_called_once()
+
+
+def test_complete_events_no_event_through_public_api(executor, mock_store):
+ """Test that completing non-existent execution handles missing events gracefully."""
+ mock_execution = Mock()
+ mock_execution.result = "test result"
+ mock_store.load.return_value = mock_execution
+
+ # Complete execution without setting up completion event first
+ # Should not raise exception when event doesn't exist
+ executor.complete_execution("nonexistent-arn", "result")
+
+
+def test_wait_until_complete_success(executor, mock_scheduler):
+ """Test wait until complete success through public API."""
+ mock_event = Mock()
+ mock_event.wait.return_value = True
+ mock_scheduler.create_event.return_value = mock_event
+
+ # Set up completion event through start_execution
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution_class.new.return_value = mock_execution
+
+ start_input = Mock()
+ start_input.execution_timeout_seconds = 0
+ executor.start_execution(start_input)
+
+ result = executor.wait_until_complete("test-arn", timeout=10)
+
+ assert result is True
+ mock_event.wait.assert_called_once_with(10)
+
+
+def test_wait_until_complete_timeout(executor, mock_scheduler):
+ """Test wait until complete timeout through public API."""
+ mock_event = Mock()
+ mock_event.wait.return_value = False
+ mock_scheduler.create_event.return_value = mock_event
+
+ # Set up completion event through start_execution
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution_class.new.return_value = mock_execution
+
+ start_input = Mock()
+ start_input.execution_timeout_seconds = 0
+ executor.start_execution(start_input)
+
+ result = executor.wait_until_complete("test-arn", timeout=10)
+
+ assert result is False
+
+
+def test_wait_until_complete_no_event(executor):
+ with pytest.raises(ResourceNotFoundException, match="execution does not exist"):
+ executor.wait_until_complete("nonexistent-arn")
+
+
+def test_complete_execution(executor, mock_store, mock_execution):
+ mock_execution.result = "test result"
+ mock_store.load.return_value = mock_execution
+
+ with patch.object(executor, "_complete_events") as mock_complete_events:
+ executor.complete_execution("test-arn", "result")
+
+ mock_store.load.assert_called_once_with(execution_arn="test-arn")
+ mock_execution.complete_success.assert_called_once_with(result="result")
+ mock_store.update.assert_called_once_with(mock_execution)
+ mock_complete_events.assert_called_once_with(execution_arn="test-arn")
+
+
+def test_fail_execution(executor, mock_store, mock_execution):
+ error = ErrorObject.from_message("test error")
+ mock_execution.result = "error result"
+ mock_store.load.return_value = mock_execution
+
+ with patch.object(executor, "_complete_events") as mock_complete_events:
+ executor.fail_execution("test-arn", error)
+
+ mock_store.load.assert_called_once_with(execution_arn="test-arn")
+ mock_execution.complete_fail.assert_called_once_with(error=error)
+ mock_store.update.assert_called_once_with(mock_execution)
+ mock_complete_events.assert_called_once_with(execution_arn="test-arn")
+
+
+def test_should_schedule_wait_timer_correctly(executor, mock_scheduler):
+ """Test that wait timer is scheduled correctly through public method."""
+ # Arrange
+ mock_event = Mock()
+ mock_scheduler.create_event.return_value = mock_event
+
+ # Set up completion event through start_execution
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution_class.new.return_value = mock_execution
+
+ start_input = Mock()
+ start_input.execution_timeout_seconds = 0
+ executor.start_execution(start_input)
+
+ # Act - schedule wait timer through public method
+ executor.on_wait_timer_scheduled("test-arn", "op-123", delay=5.0)
+
+ # Assert - verify scheduler was called correctly
+ assert mock_scheduler.call_later.call_count == 2 # start_execution + wait timer
+ wait_call = mock_scheduler.call_later.call_args_list[1] # Second call is wait timer
+ assert wait_call[1]["delay"] == 5.0
+ assert wait_call[1]["completion_event"] == mock_event
+
+
+def test_should_ignore_wait_completion_for_completed_execution(
+ executor, mock_store, mock_execution
+):
+ """Test that wait completion logic correctly handles completed executions."""
+ # Arrange
+ mock_execution.is_complete = True
+ mock_store.load.return_value = mock_execution
+
+ # Act - simulate the wait completion logic for a completed execution
+ execution = mock_store.load("test-arn")
+
+ # The logic should check if execution is complete before attempting to complete wait
+ if not execution.is_complete:
+ execution.complete_wait(operation_id="op-123")
+ mock_store.update(execution)
+
+ # Assert - verify that complete_wait was not called for completed execution
+ mock_execution.complete_wait.assert_not_called()
+ mock_store.update.assert_not_called()
+
+
+def test_should_handle_wait_completion_exception_gracefully(
+ executor, mock_store, mock_execution
+):
+ """Test that wait completion exceptions are handled through error handling."""
+ # Arrange
+ mock_store.load.return_value = mock_execution
+ mock_execution.is_complete = False
+ mock_execution.complete_wait.side_effect = Exception("test error")
+
+ # Act & Assert - test that exception handling works correctly
+ # This tests the error handling logic without scheduler timing dependencies
+ execution = mock_store.load("test-arn")
+
+ with pytest.raises(Exception, match="test error"):
+ execution.complete_wait(operation_id="op-123")
+
+
+def test_should_complete_retry_when_retry_scheduled(
+ executor, mock_store, mock_scheduler, mock_execution
+):
+ """Test retry completion through public scheduler callback API."""
+ # Arrange
+ mock_store.load.return_value = mock_execution
+
+ # Configure scheduler to immediately execute the callback
+ def immediate_callback(func, delay=0, count=1, completion_event=None):
+ func() # Execute the retry handler immediately
+ return Mock()
+
+ mock_scheduler.call_later.side_effect = immediate_callback
+
+ # Mock _invoke_execution to prevent async warnings
+ with patch.object(executor, "_invoke_execution"):
+ # Act - trigger retry through public API
+ executor.on_step_retry_scheduled("test-arn", "op-123", 10.0)
+
+ # Assert - verify observable behavior
+ mock_store.load.assert_called_with("test-arn")
+ mock_execution.complete_retry.assert_called_once_with(operation_id="op-123")
+ mock_store.update.assert_called_with(mock_execution)
+
+
+def test_should_ignore_retry_when_execution_complete(
+ executor, mock_store, mock_scheduler, mock_execution
+):
+ """Test that completed executions ignore retry events through public API."""
+ # Arrange
+ mock_execution.is_complete = True
+ mock_store.load.return_value = mock_execution
+
+ # Configure scheduler to immediately execute the callback
+ def immediate_callback(func, delay=0, count=1, completion_event=None):
+ func() # Execute the retry handler immediately
+ return Mock()
+
+ mock_scheduler.call_later.side_effect = immediate_callback
+
+ # Mock _invoke_execution to prevent async warnings
+ with patch.object(executor, "_invoke_execution"):
+ # Act - trigger retry through public API
+ executor.on_step_retry_scheduled("test-arn", "op-123", 10.0)
+
+ # Assert - verify no retry processing occurs
+ mock_execution.complete_retry.assert_not_called()
+ mock_store.update.assert_not_called()
+
+
+def test_should_handle_retry_exception_gracefully(
+ executor, mock_store, mock_scheduler, mock_execution
+):
+ """Test that retry exceptions are handled gracefully through public API."""
+ # Arrange
+ mock_store.load.return_value = mock_execution
+ mock_execution.complete_retry.side_effect = Exception("test error")
+
+ # Configure scheduler to immediately execute the callback
+ def immediate_callback(func, delay=0, count=1, completion_event=None):
+ func() # Execute the retry handler immediately
+ return Mock()
+
+ mock_scheduler.call_later.side_effect = immediate_callback
+
+ # Mock _invoke_execution to prevent async warnings
+ with patch.object(executor, "_invoke_execution"):
+ # Act - should not raise exception
+ executor.on_step_retry_scheduled("test-arn", "op-123", 10.0)
+
+ # Assert - verify the retry was attempted but exception was caught
+ mock_execution.complete_retry.assert_called_once_with(operation_id="op-123")
+
+
+def test_on_completed(executor):
+ with patch.object(executor, "complete_execution") as mock_complete:
+ executor.on_completed("test-arn", "result")
+
+ mock_complete.assert_called_once_with("test-arn", "result")
+
+
+def test_on_failed(executor):
+ error = ErrorObject.from_message("test error")
+
+ with patch.object(executor, "fail_execution") as mock_fail:
+ executor.on_failed("test-arn", error)
+
+ mock_fail.assert_called_once_with("test-arn", error)
+
+
+def test_on_wait_timer_scheduled(executor, mock_scheduler):
+ """Test wait timer scheduling through public observer method."""
+ mock_event = Mock()
+ mock_scheduler.create_event.return_value = mock_event
+
+ # Set up completion event through start_execution
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution_class.new.return_value = mock_execution
+
+ start_input = Mock()
+ start_input.execution_timeout_seconds = 0
+ executor.start_execution(start_input)
+
+ with patch.object(executor, "_on_wait_succeeded"):
+ with patch.object(executor, "_invoke_execution"):
+ executor.on_wait_timer_scheduled("test-arn", "op-123", 10.0)
+
+ # Verify scheduler was called with correct parameters
+ assert (
+ mock_scheduler.call_later.call_count == 2
+ ) # Once for start_execution, once for wait timer
+ wait_timer_call = mock_scheduler.call_later.call_args_list[
+ 1
+ ] # Second call is for wait timer
+ assert wait_timer_call[1]["delay"] == 10.0
+ assert wait_timer_call[1]["completion_event"] == mock_event
+
+
+def test_should_retry_when_response_has_unexpected_status(
+ executor, mock_store, mock_scheduler, mock_invoker, start_input
+):
+ """Test that responses with unexpected status trigger retry logic."""
+ # Arrange
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution.is_complete = False
+ mock_execution.start_input = start_input
+ mock_execution.consecutive_failed_invocation_attempts = 0
+
+ # Mock invoker to return response with unexpected status
+ mock_invocation_input = Mock()
+ mock_invoker.create_invocation_input.return_value = mock_invocation_input
+ unexpected_response = Mock()
+ unexpected_response.status = "UNKNOWN_STATUS"
+ mock_invoker.invoke.return_value = InvokeResponse(
+ invocation_output=unexpected_response, request_id="test-request-id"
+ )
+
+ # Mock execution creation and store behavior
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution_class.new.return_value = mock_execution
+ mock_store.load.return_value = mock_execution
+
+ # Act - trigger invocation through public start_execution method
+ executor.start_execution(start_input)
+
+ # Get the handler that was passed to the scheduler and execute it manually
+ assert mock_scheduler.call_later.call_count >= 1
+ handler = mock_scheduler.call_later.call_args_list[-1][0][0]
+
+ # Execute the handler to trigger the invocation logic
+ asyncio.run(handler())
+
+ # Assert - verify retry was triggered due to validation error
+ assert mock_execution.consecutive_failed_invocation_attempts == 1
+ mock_store.save.assert_called_with(mock_execution)
+ # Verify retry was scheduled (call_later should be called 3 times: timeout + initial + retry)
+ assert mock_scheduler.call_later.call_count == 3
+
+
+def test_invoke_handler_execution_completed_during_invocation_async(
+ executor, mock_store, mock_scheduler, mock_invoker, start_input
+):
+ """Test execution completing during invocation through public API."""
+ # First call returns incomplete execution, second call returns completed execution
+ incomplete_execution = Mock(spec=Execution)
+ incomplete_execution.is_complete = False
+ incomplete_execution.start_input = start_input
+ incomplete_execution.consecutive_failed_invocation_attempts = 0
+ incomplete_execution.durable_execution_arn = "test-arn"
+
+ completed_execution = Mock(spec=Execution)
+ completed_execution.is_complete = True
+
+ mock_store.load.side_effect = [incomplete_execution, completed_execution]
+
+ mock_invocation_input = Mock()
+ mock_invoker.create_invocation_input.return_value = mock_invocation_input
+ mock_response = Mock()
+ mock_invoker.invoke.return_value = InvokeResponse(
+ invocation_output=mock_response, request_id="test-request-id"
+ )
+
+ # Mock execution creation
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution_class.new.return_value = incomplete_execution
+
+ # Act - trigger invocation through public start_execution method
+ executor.start_execution(start_input)
+
+ # Get the handler that was passed to the scheduler and execute it manually
+ assert mock_scheduler.call_later.call_count >= 1
+ handler = mock_scheduler.call_later.call_args_list[-1][0][0]
+
+ # Execute the handler to trigger the invocation logic
+ asyncio.run(handler())
+
+ # Verify the execution was loaded multiple times (before and after invocation)
+ assert mock_store.load.call_count >= 2
+
+
+def test_invoke_handler_resource_not_found_async(
+ executor, mock_store, mock_scheduler, mock_invoker, start_input
+):
+ """Test resource not found handling causes workflow failure through public API (async version)."""
+ # Arrange
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution.is_complete = False
+ mock_execution.start_input = start_input
+
+ mock_invoker.create_invocation_input.side_effect = ResourceNotFoundException(
+ "Function not found"
+ )
+
+ # Mock execution creation and store behavior
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution_class.new.return_value = mock_execution
+ mock_store.load.return_value = mock_execution
+
+ # Mock the public fail_execution method to verify it gets called
+ with patch.object(executor, "fail_execution") as mock_fail:
+ # Act - trigger invocation through public start_execution method
+ executor.start_execution(start_input)
+
+ # Get the handler that was passed to the scheduler and execute it manually
+ assert mock_scheduler.call_later.call_count >= 1
+ handler = mock_scheduler.call_later.call_args_list[-1][0][0]
+
+ # Execute the handler to trigger the invocation logic
+ asyncio.run(handler())
+
+ # Assert - verify workflow failure was triggered through public API
+ mock_fail.assert_called_once()
+ # Verify the error contains the expected message
+ call_args = mock_fail.call_args
+ assert call_args[0][0] == "test-arn" # execution_arn is first positional arg
+ assert "Function not found" in str(
+ call_args[0][1]
+ ) # error is second positional arg
+
+
+def test_invoke_handler_general_exception_async(
+ executor, mock_store, mock_scheduler, mock_invoker, start_input
+):
+ """Test general exception handling triggers retry through public API (async version)."""
+ # Arrange
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution.is_complete = False
+ mock_execution.start_input = start_input
+ mock_execution.consecutive_failed_invocation_attempts = 0
+
+ # Configure invoker to fail initially, then succeed on retry
+ mock_invoker.create_invocation_input.side_effect = [
+ Exception("General error"), # First call fails
+ Mock(), # Second call succeeds (returns invocation input)
+ ]
+
+ # Mock successful response for retry
+ success_response = DurableExecutionInvocationOutput(
+ status=InvocationStatus.SUCCEEDED, result="success"
+ )
+ mock_invoker.invoke.return_value = InvokeResponse(
+ invocation_output=success_response, request_id="test-request-id"
+ )
+
+ # Mock execution creation and store behavior
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution_class.new.return_value = mock_execution
+ mock_store.load.return_value = mock_execution
+
+ # Act - trigger invocation through public start_execution method
+ executor.start_execution(start_input)
+
+ # Get the handler that was passed to the scheduler and execute it manually
+ assert mock_scheduler.call_later.call_count >= 1
+ handler = mock_scheduler.call_later.call_args_list[-1][0][0]
+
+ # Execute the handler to trigger the invocation logic
+ asyncio.run(handler())
+
+ # Assert - verify retry was scheduled through observable behavior
+ assert mock_execution.consecutive_failed_invocation_attempts == 1
+ mock_store.save.assert_called_with(mock_execution)
+ # Verify retry was scheduled (call_later should be called 3 times: timeout + initial + retry)
+ assert mock_scheduler.call_later.call_count == 3
+
+
+def test_invoke_execution_with_delay_through_wait_timer(executor, mock_scheduler):
+ """Test execution invocation with delay through wait timer scheduling."""
+ mock_event = Mock()
+ mock_scheduler.create_event.return_value = mock_event
+
+ # Set up completion event through start_execution
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution_class.new.return_value = mock_execution
+
+ start_input = Mock()
+ start_input.execution_timeout_seconds = 0
+ executor.start_execution(start_input)
+
+ # Test delay behavior through wait timer scheduling
+ with patch.object(executor, "_on_wait_succeeded"):
+ executor.on_wait_timer_scheduled("test-arn", "op-123", 10.0)
+
+ # Verify scheduler was called with delay for wait timer
+ wait_timer_call = mock_scheduler.call_later.call_args_list[
+ 1
+ ] # Second call is for wait timer
+ assert wait_timer_call[1]["delay"] == 10.0
+
+
+def test_invoke_execution_no_delay_through_start_execution(executor, mock_scheduler):
+ """Test execution invocation with no delay through start_execution."""
+ mock_event = Mock()
+ mock_scheduler.create_event.return_value = mock_event
+
+ # Test no delay behavior through start_execution
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution_class.new.return_value = mock_execution
+
+ start_input = Mock()
+ start_input.execution_timeout_seconds = 0
+ executor.start_execution(start_input)
+
+ # Verify scheduler was called with no delay for initial execution
+ initial_call = mock_scheduler.call_later.call_args_list[
+ 0
+ ] # First call is for initial execution
+ assert initial_call[1]["delay"] == 0
+
+
+def test_on_step_retry_scheduled(executor, mock_scheduler):
+ """Test step retry scheduling through public observer method."""
+ mock_event = Mock()
+ mock_scheduler.create_event.return_value = mock_event
+
+ # Set up completion event through start_execution
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution_class.new.return_value = mock_execution
+
+ start_input = Mock()
+ start_input.execution_timeout_seconds = 0
+ executor.start_execution(start_input)
+
+ with patch.object(executor, "_on_retry_ready"):
+ with patch.object(executor, "_invoke_execution"):
+ executor.on_step_retry_scheduled("test-arn", "op-123", 10.0)
+
+ # Verify scheduler was called with correct parameters
+ assert (
+ mock_scheduler.call_later.call_count == 2
+ ) # Once for start_execution, once for retry
+ retry_call = mock_scheduler.call_later.call_args_list[1] # Second call is for retry
+ assert retry_call[1]["delay"] == 10.0
+ assert retry_call[1]["completion_event"] == mock_event
+
+
+def test_wait_handler_execution(executor, mock_scheduler):
+ """Test wait handler execution through public observer method."""
+ mock_event = Mock()
+ mock_scheduler.create_event.return_value = mock_event
+
+ # Set up completion event through start_execution
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution_class.new.return_value = mock_execution
+
+ start_input = Mock()
+ start_input.execution_timeout_seconds = 0
+ executor.start_execution(start_input)
+
+ with patch.object(executor, "_on_wait_succeeded") as mock_wait:
+ with patch.object(executor, "_invoke_execution") as mock_invoke:
+ executor.on_wait_timer_scheduled("test-arn", "op-123", 10.0)
+
+ # Get the handler that was passed to call_later (second call for wait timer)
+ wait_timer_call = mock_scheduler.call_later.call_args_list[1]
+ wait_handler = wait_timer_call[0][0]
+
+ # Execute the handler to test the inner function
+ wait_handler()
+
+ mock_wait.assert_called_once_with("test-arn", "op-123")
+ mock_invoke.assert_called_once_with("test-arn", delay=0)
+
+
+def test_retry_handler_execution(executor, mock_scheduler):
+ """Test retry handler execution through public observer method."""
+ mock_event = Mock()
+ mock_scheduler.create_event.return_value = mock_event
+
+ # Set up completion event through start_execution
+ with patch(
+ "aws_durable_execution_sdk_python_testing.executor.Execution"
+ ) as mock_execution_class:
+ mock_execution = Mock()
+ mock_execution.durable_execution_arn = "test-arn"
+ mock_execution_class.new.return_value = mock_execution
+
+ start_input = Mock()
+ start_input.execution_timeout_seconds = 0
+ executor.start_execution(start_input)
+
+ with patch.object(executor, "_on_retry_ready") as mock_retry:
+ with patch.object(executor, "_invoke_execution") as mock_invoke:
+ executor.on_step_retry_scheduled("test-arn", "op-123", 10.0)
+
+ # Get the handler that was passed to call_later (second call for retry)
+ retry_call = mock_scheduler.call_later.call_args_list[1]
+ retry_handler = retry_call[0][0]
+
+ # Execute the handler to test the inner function
+ retry_handler()
+
+ mock_retry.assert_called_once_with("test-arn", "op-123")
+ mock_invoke.assert_called_once_with("test-arn", delay=0)
+
+
+# Tests for new web handler methods
+
+
+def test_get_execution_details(executor, mock_store):
+ """Test get_execution_details method."""
+
+ # Create real execution instance with mocked start_input
+ mock_start_input = Mock()
+ mock_start_input.execution_name = "test-execution"
+ mock_start_input.function_name = "test-function"
+
+ execution = Execution(
+ durable_execution_arn="test-arn", start_input=mock_start_input, operations=[]
+ )
+ execution.is_complete = True
+
+ # Create mock result
+ mock_result = DurableExecutionInvocationOutput(
+ status=InvocationStatus.SUCCEEDED, result="test-result"
+ )
+ execution.result = mock_result
+ execution.close_status = ExecutionStatus.SUCCEEDED
+
+ # Create mock operation and add to execution
+ mock_operation = Operation(
+ operation_id="op-1",
+ parent_id=None,
+ name="test-execution",
+ start_timestamp=datetime.now(UTC),
+ end_timestamp=datetime.now(UTC),
+ operation_type=OperationType.EXECUTION,
+ status=OperationStatus.SUCCEEDED,
+ execution_details=ExecutionDetails(input_payload='{"test": "data"}'),
+ )
+ execution.operations = [mock_operation]
+
+ mock_store.load.return_value = execution
+
+ result = executor.get_execution_details("test-arn")
+
+ assert result.durable_execution_arn == "test-arn"
+ assert result.durable_execution_name == "test-execution"
+ assert result.status == "SUCCEEDED"
+ assert result.result == "test-result"
+ assert result.error is None
+ mock_store.load.assert_called_once_with("test-arn")
+
+
+def test_get_execution_details_not_found(executor, mock_store):
+ """Test get_execution_details with non-existent execution."""
+ mock_store.load.side_effect = KeyError("Execution not found")
+
+ with pytest.raises(ResourceNotFoundException, match="Execution test-arn not found"):
+ executor.get_execution_details("test-arn")
+
+
+def test_get_execution_details_failed_execution(executor, mock_store):
+ """Test get_execution_details with failed execution."""
+
+ # Create real execution instance with mocked start_input
+ mock_start_input = Mock()
+ mock_start_input.execution_name = "test-execution"
+ mock_start_input.function_name = "test-function"
+
+ execution = Execution(
+ durable_execution_arn="test-arn", start_input=mock_start_input, operations=[]
+ )
+ execution.is_complete = True
+
+ error = ErrorObject.from_message("Test error")
+ mock_result = DurableExecutionInvocationOutput(
+ status=InvocationStatus.FAILED, error=error
+ )
+ execution.result = mock_result
+
+ # Create mock operation and add to execution
+ mock_operation = Operation(
+ operation_id="op-1",
+ parent_id=None,
+ name="test-execution",
+ start_timestamp=datetime.now(UTC),
+ operation_type=OperationType.EXECUTION,
+ status=OperationStatus.FAILED,
+ execution_details=ExecutionDetails(input_payload='{"test": "data"}'),
+ )
+ execution.operations = [mock_operation]
+
+ mock_store.load.return_value = execution
+ with pytest.raises(
+ IllegalStateException,
+ match="close_status must be set when execution is complete",
+ ):
+ executor.get_execution_details("test-arn")
+ execution.close_status = ExecutionStatus.FAILED
+ result = executor.get_execution_details("test-arn")
+ assert result.status == "FAILED"
+ assert result.result is None
+ assert result.error == error
+
+
+def test_list_executions_empty(executor, mock_store):
+ """Test list_executions with no executions."""
+ query_result = ([], None)
+ mock_store.query.return_value = query_result
+
+ result = executor.list_executions()
+
+ assert result.durable_executions == []
+ assert result.next_marker is None
+ mock_store.query.assert_called_once()
+
+
+def test_list_executions_with_filtering(executor, mock_store):
+ """Test list_executions with function name filtering."""
+ # Create real execution instance
+ mock_start_input = Mock()
+ mock_start_input.execution_name = "exec1"
+ mock_start_input.function_name = "function1"
+
+ execution1 = Execution(
+ durable_execution_arn="arn1", start_input=mock_start_input, operations=[]
+ )
+ execution1.is_complete = False
+ execution1.result = None
+
+ # Create mock operations
+ op1 = Operation(
+ operation_id="op-1",
+ parent_id=None,
+ name="exec1",
+ start_timestamp=datetime.now(UTC),
+ operation_type=OperationType.EXECUTION,
+ status=OperationStatus.STARTED,
+ execution_details=ExecutionDetails(input_payload="{}"),
+ )
+ execution1.operations = [op1]
+
+ # Mock the query method to return filtered results
+ query_result = ([execution1], "1")
+ mock_store.query.return_value = query_result
+
+ # Test filtering by function name
+ result = executor.list_executions(function_name="function1")
+
+ assert len(result.durable_executions) == 1
+ assert result.durable_executions[0].durable_execution_arn == "arn1"
+ assert result.durable_executions[0].status == "RUNNING"
+
+
+def test_list_executions_with_pagination(executor, mock_store):
+ """Test list_executions with pagination."""
+ # Create multiple mock executions for first page
+ executions_page1 = []
+ for i in range(2):
+ execution = Mock()
+ execution.durable_execution_arn = f"arn{i}"
+ execution.start_input.execution_name = f"exec{i}"
+ execution.start_input.function_name = "test-function"
+ execution.is_complete = False
+ execution.result = None
+
+ op = Operation(
+ operation_id=f"op-{i}",
+ parent_id=None,
+ name=f"exec{i}",
+ start_timestamp=datetime.now(UTC),
+ operation_type=OperationType.EXECUTION,
+ status=OperationStatus.STARTED,
+ execution_details=ExecutionDetails(input_payload="{}"),
+ )
+ execution.get_operation_execution_started.return_value = op
+ executions_page1.append(execution)
+
+ # Create executions for second page
+ executions_page2 = []
+ for i in range(2, 4):
+ execution = Mock()
+ execution.durable_execution_arn = f"arn{i}"
+ execution.start_input.execution_name = f"exec{i}"
+ execution.start_input.function_name = "test-function"
+ execution.is_complete = False
+ execution.result = None
+
+ op = Operation(
+ operation_id=f"op-{i}",
+ parent_id=None,
+ name=f"exec{i}",
+ start_timestamp=datetime.now(UTC),
+ operation_type=OperationType.EXECUTION,
+ status=OperationStatus.STARTED,
+ execution_details=ExecutionDetails(input_payload="{}"),
+ )
+ execution.get_operation_execution_started.return_value = op
+ executions_page2.append(execution)
+
+ # Mock query responses for pagination
+ query_result1 = (executions_page1, "2")
+
+ query_result2 = (executions_page2, "4")
+
+ mock_store.query.side_effect = [query_result1, query_result2]
+
+ # Test pagination with max_items=2
+ result = executor.list_executions(max_items=2)
+
+ assert len(result.durable_executions) == 2
+ assert result.next_marker == "2"
+
+ # Test second page
+ result2 = executor.list_executions(max_items=2, marker="2")
+
+ assert len(result2.durable_executions) == 2
+ assert result2.next_marker == "4"
+
+
+def test_list_executions_by_function(executor):
+ """Test list_executions_by_function delegates to list_executions."""
+ with patch.object(executor, "list_executions") as mock_list:
+ mock_response = ListDurableExecutionsResponse(
+ durable_executions=[], next_marker=None
+ )
+ mock_list.return_value = mock_response
+
+ result = executor.list_executions_by_function(
+ "test-function", status_filter="RUNNING"
+ )
+
+ mock_list.assert_called_once_with(
+ function_name="test-function",
+ execution_name=None,
+ status_filter="RUNNING",
+ started_after=None,
+ started_before=None,
+ marker=None,
+ max_items=None,
+ reverse_order=False,
+ )
+ assert result.durable_executions == []
+ assert result.next_marker is None
+
+
+def test_stop_execution(executor, mock_store):
+ """Test stop_execution method."""
+ # Create real execution instance with mocked start_input
+ mock_start_input = Mock()
+ mock_start_input.execution_name = "test-execution"
+ mock_start_input.function_name = "test-function"
+
+ execution = Execution(
+ durable_execution_arn="test-arn",
+ start_input=mock_start_input,
+ operations=[Mock()],
+ )
+ execution.is_complete = False
+ mock_store.load.return_value = execution
+
+ result = executor.stop_execution("test-arn")
+
+ mock_store.load.assert_called_once_with("test-arn")
+ mock_store.update.assert_called_once_with(execution)
+ assert result.stop_timestamp is not None
+ assert execution.is_complete is True
+ assert execution.close_status == ExecutionStatus.STOPPED
+
+
+def test_stop_execution_already_complete(executor, mock_store):
+ """Test stop_execution with already completed execution returns idempotent response."""
+ mock_execution = Mock()
+ mock_execution.is_complete = True
+ mock_execution.durable_execution_arn = "test-arn"
+
+ # Mock the execution operation with end_timestamp
+ mock_execution_op = Mock()
+ mock_execution_op.end_timestamp = datetime(2023, 1, 1, 0, 1, 0, tzinfo=UTC)
+ mock_execution.get_operation_execution_started.return_value = mock_execution_op
+
+ mock_store.load.return_value = mock_execution
+
+ result = executor.stop_execution("test-arn")
+
+ assert isinstance(result, StopDurableExecutionResponse)
+ assert result.stop_timestamp == datetime(2023, 1, 1, 0, 1, 0, tzinfo=UTC)
+
+
+def test_stop_execution_with_custom_error(executor, mock_store):
+ """Test stop_execution with custom error."""
+ # Create real execution instance with mocked start_input
+ mock_start_input = Mock()
+ mock_start_input.execution_name = "test-execution"
+ mock_start_input.function_name = "test-function"
+
+ execution = Execution(
+ durable_execution_arn="test-arn",
+ start_input=mock_start_input,
+ operations=[Mock()],
+ )
+ execution.is_complete = False
+ mock_store.load.return_value = execution
+
+ custom_error = ErrorObject.from_message("Custom stop error")
+
+ executor.stop_execution("test-arn", error=custom_error)
+
+ mock_store.load.assert_called_once_with("test-arn")
+ mock_store.update.assert_called_once_with(execution)
+ assert execution.is_complete is True
+ assert execution.close_status == ExecutionStatus.STOPPED
+ assert execution.result.error == custom_error
+
+
+def test_get_execution_not_found(executor, mock_store):
+ mock_store.load.side_effect = KeyError("not found")
+
+ with pytest.raises(ResourceNotFoundException):
+ executor.get_execution("test-arn")
+
+
+def test_get_execution_state(executor, mock_store):
+ """Test get_execution_state method."""
+
+ mock_execution = Mock()
+ mock_execution.used_tokens = {"token1", "token2"}
+
+ # Create mock operations
+ operations = [
+ Operation(
+ operation_id="op-1",
+ parent_id=None,
+ name="step1",
+ start_timestamp=datetime.now(UTC),
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ ),
+ Operation(
+ operation_id="op-2",
+ parent_id=None,
+ name="step2",
+ start_timestamp=datetime.now(UTC),
+ operation_type=OperationType.STEP,
+ status=OperationStatus.STARTED,
+ ),
+ ]
+ mock_execution.get_assertable_operations.return_value = operations
+
+ mock_store.load.return_value = mock_execution
+
+ result = executor.get_execution_state("test-arn", checkpoint_token="token1") # noqa: S106
+
+ assert len(result.operations) == 2
+ assert result.next_marker is None
+ mock_store.load.assert_called_once_with("test-arn")
+
+
+def test_get_execution_state_invalid_token(executor, mock_store):
+ """Test get_execution_state with invalid checkpoint token."""
+ mock_execution = Mock()
+ mock_execution.used_tokens = {"token1", "token2"}
+ mock_store.load.return_value = mock_execution
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid checkpoint token"
+ ):
+ executor.get_execution_state("test-arn", checkpoint_token="invalid-token") # noqa: S106
+
+
+def test_get_execution_history(executor, mock_store):
+ """Test get_execution_history method."""
+ mock_execution = Mock()
+ mock_execution.operations = [] # Empty operations list
+ mock_execution.updates = []
+ mock_execution.invocation_completions = []
+ mock_execution.durable_execution_arn = ""
+ mock_execution.start_input = Mock()
+ mock_execution.result = Mock()
+
+ mock_store.load.return_value = mock_execution
+
+ result = executor.get_execution_history("test-arn")
+
+ assert result.events == []
+ assert result.next_marker is None
+ mock_store.load.assert_called_once_with("test-arn")
+
+
+def test_get_execution_history_with_events(executor, mock_store):
+ """Test get_execution_history with actual events."""
+ from aws_durable_execution_sdk_python.lambda_service import StepDetails
+
+ # Create operations that will generate events
+ op1 = Operation(
+ operation_id="op-1",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ start_timestamp=datetime.now(UTC),
+ end_timestamp=datetime.now(UTC),
+ step_details=StepDetails(result="test_result"),
+ )
+ mock_execution = Mock()
+ mock_execution.operations = [op1]
+ mock_execution.updates = []
+ mock_execution.invocation_completions = []
+ mock_execution.durable_execution_arn = ""
+ mock_execution.start_input = Mock()
+ mock_execution.result = Mock()
+ mock_store.load.return_value = mock_execution
+
+ result = executor.get_execution_history("test-arn", include_execution_data=True)
+
+ assert len(result.events) == 2 # Started + Succeeded events
+ assert result.events[0].event_type == "StepStarted"
+ assert result.events[1].event_type == "StepSucceeded"
+
+
+def test_get_execution_history_reverse_order(executor, mock_store):
+ """Test get_execution_history with reverse order."""
+ op1 = Operation(
+ operation_id="op-1",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ start_timestamp=datetime.now(UTC),
+ end_timestamp=datetime.now(UTC),
+ )
+
+ mock_execution = Mock()
+ mock_execution.operations = [op1]
+ mock_execution.updates = []
+ mock_execution.invocation_completions = []
+ mock_execution.durable_execution_arn = ""
+ mock_execution.start_input = Mock()
+ mock_execution.result = Mock()
+ mock_store.load.return_value = mock_execution
+
+ result = executor.get_execution_history("test-arn", reverse_order=True)
+
+ assert len(result.events) == 2
+ # In reverse order, succeeded event should come first
+ assert result.events[0].event_type == "StepSucceeded"
+ assert result.events[1].event_type == "StepStarted"
+
+
+def test_get_execution_history_pagination(executor, mock_store):
+ """Test get_execution_history with pagination."""
+ # Create multiple operations to generate many events
+ operations = []
+ for i in range(3):
+ op = Operation(
+ operation_id=f"op-{i}",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ start_timestamp=datetime.now(UTC),
+ end_timestamp=datetime.now(UTC),
+ )
+ operations.append(op)
+
+ mock_execution = Mock()
+ mock_execution.operations = operations
+ mock_execution.updates = []
+ mock_execution.invocation_completions = []
+ mock_execution.durable_execution_arn = ""
+ mock_execution.start_input = Mock()
+ mock_execution.result = Mock()
+ mock_store.load.return_value = mock_execution
+
+ # Test with max_items=2
+ result = executor.get_execution_history("test-arn", max_items=2)
+
+ assert len(result.events) == 2
+ assert result.next_marker == "3" # Next event_id
+
+
+def test_get_execution_history_pagination_with_marker(executor, mock_store):
+ """Test get_execution_history pagination with marker."""
+ operations = []
+ for i in range(3):
+ op = Operation(
+ operation_id=f"op-{i}",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ start_timestamp=datetime.now(UTC),
+ end_timestamp=datetime.now(UTC),
+ )
+ operations.append(op)
+
+ mock_execution = Mock()
+ mock_execution.operations = operations
+ mock_execution.updates = []
+ mock_execution.invocation_completions = []
+ mock_execution.durable_execution_arn = ""
+ mock_execution.start_input = Mock()
+ mock_execution.result = Mock()
+ mock_store.load.return_value = mock_execution
+
+ # Test with marker (start from event_id 3)
+ result = executor.get_execution_history("test-arn", marker="3", max_items=2)
+
+ assert len(result.events) == 2
+ # Should get events with event_id >= 3
+
+
+def test_get_execution_history_invalid_marker(executor, mock_store):
+ """Test get_execution_history with invalid marker."""
+ mock_execution = Mock()
+ mock_execution.operations = []
+ mock_execution.updates = []
+ mock_execution.invocation_completions = []
+ mock_execution.durable_execution_arn = ""
+ mock_execution.start_input = Mock()
+ mock_execution.result = Mock()
+ mock_store.load.return_value = mock_execution
+
+ # Invalid marker should default to 1
+ result = executor.get_execution_history("test-arn", marker="invalid")
+
+ assert result.events == []
+ assert result.next_marker is None
+
+
+def test_checkpoint_execution(executor, mock_store):
+ """Test checkpoint_execution method."""
+ mock_execution = Mock()
+ mock_execution.used_tokens = {"token1", "token2"}
+ mock_execution.get_new_checkpoint_token.return_value = "new-token"
+ mock_store.load.return_value = mock_execution
+
+ result = executor.checkpoint_execution("test-arn", "token1")
+
+ assert result.checkpoint_token == "new-token" # noqa: S105
+ assert result.new_execution_state is None
+ mock_store.load.assert_called_once_with("test-arn")
+ mock_execution.get_new_checkpoint_token.assert_called_once()
+
+
+def test_checkpoint_execution_invalid_token(executor, mock_store):
+ """Test checkpoint_execution with invalid checkpoint token."""
+ mock_execution = Mock()
+ mock_execution.used_tokens = {"token1", "token2"}
+ mock_store.load.return_value = mock_execution
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid checkpoint token"
+ ):
+ executor.checkpoint_execution("test-arn", "invalid-token")
+
+
+# Callback method tests
+
+
+def test_send_callback_success(executor, mock_store):
+ """Test send_callback_success method."""
+ from aws_durable_execution_sdk_python_testing.token import CallbackToken
+
+ # Create valid callback token
+ callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123")
+ callback_id = callback_token.to_str()
+
+ # Create mock execution with callback operation
+ mock_execution = Mock()
+ mock_execution.find_callback_operation.return_value = (0, Mock())
+ mock_execution.complete_callback_success.return_value = Mock()
+ mock_store.load.return_value = mock_execution
+
+ with patch.object(executor, "_invoke_execution") as mock_invoke:
+ result = executor.send_callback_success(callback_id, b"success-result")
+
+ assert isinstance(result, SendDurableExecutionCallbackSuccessResponse)
+ mock_store.load.assert_called_once_with("test-arn")
+ mock_execution.complete_callback_success.assert_called_once_with(
+ callback_id, b"success-result"
+ )
+ mock_store.update.assert_called_once_with(mock_execution)
+ # Verify execution is invoked after callback success
+ mock_invoke.assert_called_once_with("test-arn")
+
+
+def test_send_callback_success_empty_callback_id(executor):
+ """Test send_callback_success with empty callback_id."""
+ with pytest.raises(InvalidParameterValueException, match="callback_id is required"):
+ executor.send_callback_success("")
+
+
+def test_send_callback_success_none_callback_id(executor):
+ """Test send_callback_success with None callback_id."""
+ with pytest.raises(InvalidParameterValueException, match="callback_id is required"):
+ executor.send_callback_success(None)
+
+
+def test_send_callback_success_with_result(executor, mock_store):
+ """Test send_callback_success with result data."""
+ from aws_durable_execution_sdk_python_testing.token import CallbackToken
+
+ # Create valid callback token
+ callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123")
+ callback_id = callback_token.to_str()
+
+ # Create mock execution with callback operation
+ mock_execution = Mock()
+ mock_execution.find_callback_operation.return_value = (0, Mock())
+ mock_execution.complete_callback_success.return_value = Mock()
+ mock_store.load.return_value = mock_execution
+
+ with patch.object(executor, "_invoke_execution") as mock_invoke:
+ result = executor.send_callback_success(callback_id, b"test-result")
+
+ assert isinstance(result, SendDurableExecutionCallbackSuccessResponse)
+ mock_execution.complete_callback_success.assert_called_once_with(
+ callback_id, b"test-result"
+ )
+ # Verify execution is invoked after callback success
+ mock_invoke.assert_called_once_with("test-arn")
+
+
+def test_send_callback_failure(executor, mock_store):
+ """Test send_callback_failure method."""
+ from aws_durable_execution_sdk_python_testing.token import CallbackToken
+
+ # Create valid callback token
+ callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123")
+ callback_id = callback_token.to_str()
+
+ # Create mock execution with callback operation
+ mock_execution = Mock()
+ mock_execution.find_callback_operation.return_value = (0, Mock())
+ mock_execution.complete_callback_failure.return_value = Mock()
+ mock_store.load.return_value = mock_execution
+
+ with patch.object(executor, "_invoke_execution") as mock_invoke:
+ result = executor.send_callback_failure(callback_id)
+
+ assert isinstance(result, SendDurableExecutionCallbackFailureResponse)
+ mock_store.load.assert_called_once_with("test-arn")
+ mock_store.update.assert_called_once_with(mock_execution)
+ # Verify execution is invoked after callback failure
+ mock_invoke.assert_called_once_with("test-arn")
+
+
+def test_send_callback_failure_empty_callback_id(executor):
+ """Test send_callback_failure with empty callback_id."""
+ with pytest.raises(InvalidParameterValueException, match="callback_id is required"):
+ executor.send_callback_failure("")
+
+
+def test_send_callback_failure_none_callback_id(executor):
+ """Test send_callback_failure with None callback_id."""
+ with pytest.raises(InvalidParameterValueException, match="callback_id is required"):
+ executor.send_callback_failure(None)
+
+
+def test_send_callback_failure_with_error(executor, mock_store):
+ """Test send_callback_failure with error object."""
+ # Create valid callback token
+ callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123")
+ callback_id = callback_token.to_str()
+
+ # Create mock execution with callback operation
+ mock_execution = Mock()
+ mock_execution.find_callback_operation.return_value = (0, Mock())
+ mock_execution.complete_callback_failure.return_value = Mock()
+ mock_store.load.return_value = mock_execution
+
+ error = ErrorObject.from_message("Test callback error")
+ with patch.object(executor, "_invoke_execution") as mock_invoke:
+ result = executor.send_callback_failure(callback_id, error)
+
+ assert isinstance(result, SendDurableExecutionCallbackFailureResponse)
+ mock_execution.complete_callback_failure.assert_called_once_with(callback_id, error)
+ # Verify execution is invoked after callback failure
+ mock_invoke.assert_called_once_with("test-arn")
+
+
+def test_send_callback_heartbeat(executor, mock_store):
+ """Test send_callback_heartbeat method."""
+ # Create valid callback token
+ callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123")
+ callback_id = callback_token.to_str()
+
+ # Create mock execution with callback operation
+ mock_execution = Mock()
+ mock_operation = Mock()
+ mock_operation.status = OperationStatus.STARTED
+ mock_execution.find_callback_operation.return_value = (0, mock_operation)
+ mock_execution.updates = [] # No callback options to reset timeout
+ mock_execution.invocation_completions = []
+ mock_store.load.return_value = mock_execution
+
+ result = executor.send_callback_heartbeat(callback_id)
+
+ assert isinstance(result, SendDurableExecutionCallbackHeartbeatResponse)
+ # Called twice: once in get_execution, once in _reset_callback_heartbeat_timeout
+ assert mock_store.load.call_count == 2
+ mock_execution.find_callback_operation.assert_called_once_with(callback_id)
+
+
+def test_send_callback_heartbeat_empty_callback_id(executor):
+ """Test send_callback_heartbeat with empty callback_id."""
+ with pytest.raises(InvalidParameterValueException, match="callback_id is required"):
+ executor.send_callback_heartbeat("")
+
+
+def test_send_callback_heartbeat_none_callback_id(executor):
+ """Test send_callback_heartbeat with None callback_id."""
+ with pytest.raises(InvalidParameterValueException, match="callback_id is required"):
+ executor.send_callback_heartbeat(None)
+
+
+def test_complete_execution_no_result(mock_store, executor):
+ """Test complete_execution when execution has no result after completion."""
+ mock_execution = Mock()
+ mock_execution.result = None # No result after completion
+ mock_store.load.return_value = mock_execution
+
+ with patch.object(executor, "_complete_events"):
+ with pytest.raises(IllegalStateException, match="Execution result is required"):
+ executor.complete_execution("test-arn", "result")
+
+
+def test_fail_execution_no_result(mock_store, executor):
+ """Test fail_execution when execution has no result after failure."""
+ mock_execution = Mock()
+ mock_execution.result = None # No result after failure
+ mock_store.load.return_value = mock_execution
+ error = ErrorObject.from_message("test error")
+
+ with patch.object(executor, "_complete_events"):
+ with pytest.raises(IllegalStateException, match="Execution result is required"):
+ executor.fail_execution("test-arn", error)
+
+
+def test_send_callback_heartbeat_inactive_callback(mock_store, executor):
+ """Test send_callback_heartbeat with inactive callback."""
+
+ # Create valid callback token
+ callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123")
+ callback_id = callback_token.to_str()
+
+ # Create mock execution with inactive callback operation
+ mock_execution = Mock()
+ mock_operation = Mock()
+ mock_operation.status = OperationStatus.SUCCEEDED # Not STARTED
+ mock_execution.find_callback_operation.return_value = (0, mock_operation)
+ mock_store.load.return_value = mock_execution
+
+ with pytest.raises(ResourceNotFoundException, match="Callback .* is not active"):
+ executor.send_callback_heartbeat(callback_id)
+
+
+def test_send_callback_success_invalid_token(executor):
+ """Test send_callback_success with invalid token format."""
+ with pytest.raises(
+ ResourceNotFoundException, match="Failed to process callback success"
+ ):
+ executor.send_callback_success("invalid-token")
+
+
+def test_send_callback_failure_invalid_token(executor):
+ """Test send_callback_failure with invalid token format."""
+ with pytest.raises(
+ ResourceNotFoundException, match="Failed to process callback failure"
+ ):
+ executor.send_callback_failure("invalid-token")
+
+
+def test_send_callback_heartbeat_invalid_token(executor):
+ """Test send_callback_heartbeat with invalid token format."""
+ with pytest.raises(
+ ResourceNotFoundException, match="Failed to process callback heartbeat"
+ ):
+ executor.send_callback_heartbeat("invalid-token")
+
+
+def test_complete_events_no_event(executor):
+ """Test _complete_events when no event exists."""
+ # Should not raise exception when event doesn't exist
+ executor._complete_events("nonexistent-arn") # Should handle gracefully
+
+
+# Tests for callback timeout functionality
+
+
+def test_callback_timeout_scheduling(executor, mock_store, mock_scheduler):
+ """Test that callback timeouts are scheduled when callback is created."""
+ # Create callback options with both timeouts
+ callback_options = CallbackOptions(timeout_seconds=60, heartbeat_timeout_seconds=30)
+
+ # Set up completion event
+ executor._completion_events["test-arn"] = Mock()
+
+ # Test the timeout scheduling directly with correct parameters
+ executor._schedule_callback_timeouts("test-arn", callback_options, "callback-id")
+
+ # Verify scheduler was called for both timeouts
+ assert mock_scheduler.call_later.call_count == 2 # main timeout + heartbeat timeout
+
+
+def test_callback_timeout_cleanup(executor, mock_store):
+ """Test that callback timeouts are cleaned up when callback completes."""
+ # Create mock timeout events
+ timeout_event = Mock()
+ heartbeat_event = Mock()
+
+ executor._callback_timeouts["callback-id"] = timeout_event
+ executor._callback_heartbeats["callback-id"] = heartbeat_event
+
+ # Trigger cleanup
+ executor._cleanup_callback_timeouts("callback-id")
+
+ # Verify events were cancelled and removed
+ timeout_event.cancel.assert_called_once()
+ heartbeat_event.cancel.assert_called_once()
+ assert "callback-id" not in executor._callback_timeouts
+ assert "callback-id" not in executor._callback_heartbeats
+
+
+def test_callback_heartbeat_timeout_reset(executor, mock_store, mock_scheduler):
+ """Test that heartbeat timeout is reset when heartbeat is received."""
+
+ # Create callback token
+ callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123")
+ callback_id = callback_token.to_str()
+
+ # Create mock execution with callback options
+ mock_execution = Mock()
+ callback_options = CallbackOptions(heartbeat_timeout_seconds=30)
+ update = OperationUpdate(
+ operation_id="op-123",
+ operation_type=OperationType.CALLBACK,
+ action=OperationAction.START,
+ callback_options=callback_options,
+ )
+ mock_execution.updates = [update]
+
+ mock_store.load.return_value = mock_execution
+ mock_scheduler.create_event.return_value = Mock()
+
+ # Set up existing heartbeat event
+ old_event = Mock()
+ executor._callback_heartbeats[callback_id] = old_event
+
+ # Reset heartbeat timeout
+ executor._reset_callback_heartbeat_timeout(callback_id, "test-arn")
+
+ # Verify old event was cancelled and new one scheduled
+ old_event.cancel.assert_called_once()
+ mock_scheduler.call_later.assert_called()
+
+
+def test_callback_timeout_handlers(executor, mock_store):
+ """Test callback timeout and heartbeat timeout handlers."""
+ # Create callback token
+ callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123")
+ callback_id = callback_token.to_str()
+
+ # Create mock execution
+ mock_execution = Mock()
+ mock_execution.is_complete = False
+ mock_store.load.return_value = mock_execution
+
+ # Test main timeout handler
+ executor._on_callback_timeout("test-arn", callback_id)
+
+ # Verify callback was failed with timeout error
+ mock_execution.complete_callback_timeout.assert_called()
+ timeout_error = mock_execution.complete_callback_timeout.call_args[0][1]
+ assert "Callback timed out" in str(timeout_error.message)
+
+ # Reset mocks for heartbeat test
+ mock_execution.reset_mock()
+
+ # Test heartbeat timeout handler
+ executor._on_callback_heartbeat_timeout("test-arn", callback_id)
+
+ # Verify callback was failed with heartbeat timeout error
+ mock_execution.complete_callback_timeout.assert_called()
+ heartbeat_error = mock_execution.complete_callback_timeout.call_args[0][1]
+ assert "Callback heartbeat timed out" in str(heartbeat_error.message)
+
+
+def test_callback_timeout_completed_execution(executor, mock_store):
+ """Test that timeout handlers ignore completed executions."""
+
+ # Create callback token
+ callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123")
+ callback_id = callback_token.to_str()
+
+ # Create completed execution
+ mock_execution = Mock()
+ mock_execution.is_complete = True
+ mock_store.load.return_value = mock_execution
+
+ # Test timeout handlers with completed execution
+ executor._on_callback_timeout("test-arn", callback_id)
+ executor._on_callback_heartbeat_timeout("test-arn", callback_id)
+
+ # Verify no callback operations were performed
+ mock_execution.complete_callback_timeout.assert_not_called()
+ mock_store.update.assert_not_called()
+
+
+def test_schedule_callback_timeouts_no_callback_details(executor, mock_store):
+ """Test _schedule_callback_timeouts when operation has no callback details."""
+
+ # Create operation without callback details
+ operation = Operation(
+ operation_id="op-123",
+ operation_type=OperationType.CALLBACK,
+ status=OperationStatus.STARTED,
+ callback_details=None,
+ )
+
+ mock_execution = Mock()
+ mock_execution.find_operation.return_value = (0, operation)
+ mock_store.load.return_value = mock_execution
+
+ # Should return early without scheduling
+ executor._schedule_callback_timeouts("test-arn", "op-123", "callback-id")
+
+ # No scheduler calls should be made
+ assert len(executor._callback_timeouts) == 0
+ assert len(executor._callback_heartbeats) == 0
+
+
+def test_schedule_callback_timeouts_no_callback_options(executor, mock_store):
+ """Test _schedule_callback_timeouts when no callback options are found."""
+
+ # Create operation with callback details but no matching updates
+ operation = Operation(
+ operation_id="op-123",
+ operation_type=OperationType.CALLBACK,
+ status=OperationStatus.STARTED,
+ callback_details=CallbackDetails(callback_id="callback-id"),
+ )
+
+ mock_execution = Mock()
+ mock_execution.find_operation.return_value = (0, operation)
+ mock_execution.updates = [] # No updates with callback options
+ mock_execution.invocation_completions = []
+ mock_store.load.return_value = mock_execution
+
+ # Should return early without scheduling
+ executor._schedule_callback_timeouts("test-arn", "op-123", "callback-id")
+
+ # No scheduler calls should be made
+ assert len(executor._callback_timeouts) == 0
+ assert len(executor._callback_heartbeats) == 0
+
+
+def test_schedule_callback_timeouts_zero_timeouts(executor, mock_store, mock_scheduler):
+ """Test _schedule_callback_timeouts with zero timeout values."""
+ # Create operation with callback details
+ operation = Operation(
+ operation_id="op-123",
+ operation_type=OperationType.CALLBACK,
+ status=OperationStatus.STARTED,
+ callback_details=CallbackDetails(callback_id="callback-id"),
+ )
+
+ mock_execution = Mock()
+ mock_execution.find_operation.return_value = (0, operation)
+
+ # Create update with zero timeouts (disabled)
+ callback_options = CallbackOptions(timeout_seconds=0, heartbeat_timeout_seconds=0)
+ update = OperationUpdate(
+ operation_id="op-123",
+ operation_type=OperationType.CALLBACK,
+ action=OperationAction.START,
+ callback_options=callback_options,
+ )
+ mock_execution.updates = [update]
+
+ mock_store.load.return_value = mock_execution
+ executor._completion_events["test-arn"] = Mock()
+
+ # Should not schedule any timeouts
+ executor._schedule_callback_timeouts("test-arn", "op-123", "callback-id")
+
+ # No scheduler calls should be made
+ mock_scheduler.call_later.assert_not_called()
+ assert len(executor._callback_timeouts) == 0
+ assert len(executor._callback_heartbeats) == 0
+
+
+def test_schedule_callback_timeouts_only_main_timeout(
+ executor, mock_store, mock_scheduler
+):
+ """Test _schedule_callback_timeouts with only main timeout configured."""
+
+ # Create callback options with only main timeout
+ callback_options = CallbackOptions(timeout_seconds=60, heartbeat_timeout_seconds=0)
+
+ executor._completion_events["test-arn"] = Mock()
+
+ executor._schedule_callback_timeouts("test-arn", callback_options, "callback-id")
+
+ # Only main timeout should be scheduled
+ assert mock_scheduler.call_later.call_count == 1
+ assert len(executor._callback_timeouts) == 1
+ assert len(executor._callback_heartbeats) == 0
+
+
+def test_schedule_callback_timeouts_only_heartbeat_timeout(
+ executor, mock_store, mock_scheduler
+):
+ """Test _schedule_callback_timeouts with only heartbeat timeout configured."""
+ # Create callback options with only heartbeat timeout
+ callback_options = CallbackOptions(timeout_seconds=0, heartbeat_timeout_seconds=30)
+
+ executor._completion_events["test-arn"] = Mock()
+
+ executor._schedule_callback_timeouts("test-arn", callback_options, "callback-id")
+
+ # Only heartbeat timeout should be scheduled
+ assert mock_scheduler.call_later.call_count == 1
+ assert len(executor._callback_timeouts) == 0
+ assert len(executor._callback_heartbeats) == 1
+
+
+def test_schedule_callback_timeouts_exception_handling(executor, mock_store):
+ """Test _schedule_callback_timeouts handles exceptions gracefully."""
+ # Make get_execution raise an exception
+ mock_store.load.side_effect = Exception("Test error")
+
+ # Should not raise exception
+ executor._schedule_callback_timeouts("test-arn", "op-123", "callback-id")
+
+ # No timeouts should be scheduled
+ assert len(executor._callback_timeouts) == 0
+ assert len(executor._callback_heartbeats) == 0
+
+
+def test_on_timed_out(executor, mock_store):
+ """Test on_timed_out method."""
+ # Create real execution instance
+ mock_start_input = Mock()
+ mock_start_input.execution_name = "test-execution"
+ mock_start_input.function_name = "test-function"
+
+ execution = Execution(
+ durable_execution_arn="test-arn",
+ start_input=mock_start_input,
+ operations=[Mock()],
+ )
+ execution.is_complete = False
+ mock_store.load.return_value = execution
+
+ error = ErrorObject.from_message("Execution timeout")
+
+ with patch.object(executor, "_complete_events") as mock_complete_events:
+ executor.on_timed_out("test-arn", error)
+
+ mock_store.load.assert_called_once_with(execution_arn="test-arn")
+ mock_store.update.assert_called_once_with(execution)
+ mock_complete_events.assert_called_once_with(execution_arn="test-arn")
+ assert execution.is_complete is True
+ assert execution.close_status == ExecutionStatus.TIMED_OUT
+ assert execution.result.error == error
+
+
+def test_on_stopped(executor):
+ """Test on_stopped method."""
+ error = ErrorObject.from_message("Execution stopped")
+
+ with patch.object(executor, "fail_execution") as mock_fail:
+ executor.on_stopped("test-arn", error)
+
+ mock_fail.assert_called_once_with("test-arn", error)
+
+
+def test_notify_timed_out():
+ """Test notify_timed_out method."""
+ notifier = ExecutionNotifier()
+ observer = Mock()
+ notifier.add_observer(observer)
+
+ error = ErrorObject.from_message("Timeout error")
+ notifier.notify_timed_out("test-arn", error)
+
+ observer.on_timed_out.assert_called_once_with(execution_arn="test-arn", error=error)
+
+
+def test_notify_stopped():
+ """Test notify_stopped method."""
+ notifier = ExecutionNotifier()
+ observer = Mock()
+ notifier.add_observer(observer)
+
+ error = ErrorObject.from_message("Stop error")
+ notifier.notify_stopped("test-arn", error)
+
+ observer.on_stopped.assert_called_once_with(execution_arn="test-arn", error=error)
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/invoker_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/invoker_test.py
new file mode 100644
index 0000000..1270f50
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/invoker_test.py
@@ -0,0 +1,652 @@
+"""Tests for invoker module."""
+
+import json
+from unittest.mock import Mock, patch
+
+import pytest
+from aws_durable_execution_sdk_python.execution import (
+ DurableExecutionInvocationInput,
+ DurableExecutionInvocationInputWithClient,
+ DurableExecutionInvocationOutput,
+ InitialExecutionState,
+ InvocationStatus,
+)
+
+from aws_durable_execution_sdk_python.lambda_service import (
+ ExecutionDetails,
+ Operation,
+ OperationStatus,
+ OperationType,
+)
+
+from datetime import datetime, UTC
+
+from aws_durable_execution_sdk_python_testing.execution import Execution
+from aws_durable_execution_sdk_python_testing.invoker import (
+ InProcessInvoker,
+ LambdaInvoker,
+ _LAMBDA_CLIENT_CONFIG,
+ create_lambda_client,
+ create_test_lambda_context,
+)
+from aws_durable_execution_sdk_python_testing.model import (
+ LambdaContext,
+ StartDurableExecutionInput,
+)
+
+
+def test_create_test_lambda_context():
+ """Test creating a test lambda context."""
+ context = create_test_lambda_context()
+
+ assert (
+ context.invoked_function_arn
+ == "arn:aws:lambda:us-west-2:123456789012:function:test-function"
+ )
+ assert context.tenant_id == "test-tenant-789"
+ assert context.client_context is not None
+
+
+def test_in_process_invoker_init():
+ """Test InProcessInvoker initialization."""
+ handler = Mock()
+ service_client = Mock()
+
+ invoker = InProcessInvoker(handler, service_client)
+
+ assert invoker.handler is handler
+ assert invoker.service_client is service_client
+
+
+def test_in_process_invoker_create_invocation_input():
+ """Test creating invocation input for in-process invoker."""
+ handler = Mock()
+ service_client = Mock()
+ invoker = InProcessInvoker(handler, service_client)
+
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ execution = Execution.new(input_data)
+
+ invocation_input = invoker.create_invocation_input(execution)
+
+ assert isinstance(invocation_input, DurableExecutionInvocationInputWithClient)
+ assert invocation_input.durable_execution_arn == execution.durable_execution_arn
+ assert invocation_input.checkpoint_token is not None
+ assert isinstance(invocation_input.initial_execution_state, InitialExecutionState)
+ assert invocation_input.service_client is service_client
+
+
+def test_in_process_invoker_invoke():
+ """Test invoking function with in-process invoker."""
+ # Mock handler that returns a valid response
+ handler = Mock()
+ handler.return_value = {"Status": "SUCCEEDED", "Result": "test-result"}
+
+ service_client = Mock()
+ invoker = InProcessInvoker(handler, service_client)
+
+ input_data = DurableExecutionInvocationInput(
+ durable_execution_arn="test-arn",
+ checkpoint_token="test-token", # noqa: S106
+ initial_execution_state=InitialExecutionState(operations=[], next_marker=""),
+ )
+
+ response = invoker.invoke("test-function", input_data)
+
+ assert isinstance(response.invocation_output, DurableExecutionInvocationOutput)
+ assert response.invocation_output.status == InvocationStatus.SUCCEEDED
+ assert response.invocation_output.result == "test-result"
+ assert isinstance(response.request_id, str)
+
+ # Verify handler was called with correct arguments
+ handler.assert_called_once()
+ call_args = handler.call_args[0]
+ assert isinstance(call_args[0], DurableExecutionInvocationInputWithClient)
+ assert isinstance(call_args[1], LambdaContext)
+
+
+def test_lambda_invoker_init():
+ """Test LambdaInvoker initialization."""
+ lambda_client = Mock()
+
+ invoker = LambdaInvoker(lambda_client)
+
+ assert invoker.lambda_client is lambda_client
+
+
+def test_lambda_invoker_create():
+ """Test creating LambdaInvoker with boto3 client."""
+ with patch("aws_durable_execution_sdk_python_testing.invoker.boto3") as mock_boto3:
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+
+ invoker = LambdaInvoker.create("http://localhost:3001", "us-west-2")
+
+ assert isinstance(invoker, LambdaInvoker)
+ assert invoker.lambda_client is mock_client
+ mock_boto3.client.assert_called_once_with(
+ "lambda",
+ endpoint_url="http://localhost:3001",
+ region_name="us-west-2",
+ config=_LAMBDA_CLIENT_CONFIG,
+ )
+
+
+def test_lambda_invoker_create_invocation_input():
+ """Test creating invocation input for lambda invoker."""
+ lambda_client = Mock()
+ invoker = LambdaInvoker(lambda_client)
+
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation",
+ )
+ execution = Execution.new(input_data)
+
+ invocation_input = invoker.create_invocation_input(execution)
+
+ assert isinstance(invocation_input, DurableExecutionInvocationInput)
+ assert invocation_input.durable_execution_arn == execution.durable_execution_arn
+ assert invocation_input.checkpoint_token is not None
+ assert isinstance(invocation_input.initial_execution_state, InitialExecutionState)
+
+
+def test_lambda_invoker_invoke_success():
+ """Test successful lambda invocation."""
+ lambda_client = Mock()
+
+ # Mock successful response
+ mock_payload = Mock()
+ mock_payload.read.return_value = json.dumps(
+ {"Status": "SUCCEEDED", "Result": "lambda-result"}
+ ).encode("utf-8")
+
+ lambda_client.invoke.return_value = {
+ "StatusCode": 200,
+ "Payload": mock_payload,
+ "ResponseMetadata": {"HTTPHeaders": {"x-amzn-RequestId": "test-request-id"}},
+ }
+
+ invoker = LambdaInvoker(lambda_client)
+
+ mock_operation = Operation(
+ operation_id="op-1",
+ parent_id=None,
+ name="test-execution",
+ start_timestamp=datetime.now(UTC),
+ end_timestamp=datetime.now(UTC),
+ operation_type=OperationType.EXECUTION,
+ status=OperationStatus.SUCCEEDED,
+ execution_details=ExecutionDetails(input_payload='{"test": "data"}'),
+ )
+
+ input_data = DurableExecutionInvocationInput(
+ durable_execution_arn="test-arn",
+ checkpoint_token="test-token", # noqa: S106
+ initial_execution_state=InitialExecutionState(
+ operations=[mock_operation], next_marker=""
+ ),
+ )
+
+ response = invoker.invoke("test-function", input_data)
+
+ assert isinstance(response.invocation_output, DurableExecutionInvocationOutput)
+ assert response.invocation_output.status == InvocationStatus.SUCCEEDED
+ assert response.invocation_output.result == "lambda-result"
+ assert response.request_id == "test-request-id"
+
+ # Verify lambda client was called correctly
+ lambda_client.invoke.assert_called_once_with(
+ FunctionName="test-function",
+ InvocationType="RequestResponse",
+ Payload=json.dumps(input_data.to_json_dict()),
+ )
+
+
+def test_lambda_invoker_invoke_failure():
+ """Test lambda invocation failure."""
+ from aws_durable_execution_sdk_python_testing.exceptions import (
+ DurableFunctionsTestError,
+ )
+
+ lambda_client, _ = _create_mock_lambda_client_with_exceptions()
+
+ # Mock failed response
+ mock_payload = Mock()
+ lambda_client.invoke.return_value = {
+ "StatusCode": 500,
+ "Payload": mock_payload,
+ }
+
+ invoker = LambdaInvoker(lambda_client)
+
+ input_data = DurableExecutionInvocationInput(
+ durable_execution_arn="test-arn",
+ checkpoint_token="test-token", # noqa: S106
+ initial_execution_state=InitialExecutionState(operations=[], next_marker=""),
+ )
+
+ with pytest.raises(
+ DurableFunctionsTestError,
+ match="Lambda invocation failed with status code: 500",
+ ):
+ invoker.invoke("test-function", input_data)
+
+
+def test_in_process_invoker_invoke_with_execution_operations():
+ """Test in-process invoker with execution that has operations."""
+ handler = Mock()
+ handler.return_value = {"Status": "SUCCEEDED", "Result": None}
+
+ service_client = Mock()
+ invoker = InProcessInvoker(handler, service_client)
+
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation",
+ )
+ execution = Execution.new(input_data)
+ execution.start() # This adds operations
+
+ invocation_input = invoker.create_invocation_input(execution)
+ response = invoker.invoke("test-function", invocation_input)
+
+ assert isinstance(response.invocation_output, DurableExecutionInvocationOutput)
+ assert isinstance(response.request_id, str)
+ assert response.invocation_output.status == InvocationStatus.SUCCEEDED
+ assert len(invocation_input.initial_execution_state.operations) > 0
+
+
+def test_lambda_invoker_create_invocation_input_with_operations():
+ """Test lambda invoker creating input with execution operations."""
+ lambda_client = Mock()
+ invoker = LambdaInvoker(lambda_client)
+
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation",
+ )
+ execution = Execution.new(input_data)
+ execution.start() # This adds operations
+
+ invocation_input = invoker.create_invocation_input(execution)
+
+ assert isinstance(invocation_input, DurableExecutionInvocationInput)
+ assert len(invocation_input.initial_execution_state.operations) > 0
+ assert invocation_input.initial_execution_state.next_marker == ""
+
+
+def test_lambda_invoker_invoke_empty_function_name():
+ """Test lambda invocation with empty function name."""
+ from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+ )
+
+ lambda_client = Mock()
+ invoker = LambdaInvoker(lambda_client)
+
+ input_data = DurableExecutionInvocationInput(
+ durable_execution_arn="test-arn",
+ checkpoint_token="test-token",
+ initial_execution_state=InitialExecutionState(operations=[], next_marker=""),
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Function name is required"
+ ):
+ invoker.invoke("", input_data)
+
+
+def test_lambda_invoker_invoke_whitespace_function_name():
+ """Test lambda invocation with whitespace-only function name."""
+ from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+ )
+
+ lambda_client = Mock()
+ invoker = LambdaInvoker(lambda_client)
+
+ input_data = DurableExecutionInvocationInput(
+ durable_execution_arn="test-arn",
+ checkpoint_token="test-token",
+ initial_execution_state=InitialExecutionState(operations=[], next_marker=""),
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Function name is required"
+ ):
+ invoker.invoke(" ", input_data)
+
+
+def test_lambda_invoker_invoke_status_202():
+ """Test lambda invocation with status code 202."""
+ lambda_client = Mock()
+
+ mock_payload = Mock()
+ mock_payload.read.return_value = json.dumps(
+ {"Status": "SUCCEEDED", "Result": "async-result"}
+ ).encode("utf-8")
+
+ lambda_client.invoke.return_value = {
+ "StatusCode": 202,
+ "Payload": mock_payload,
+ "ResponseMetadata": {
+ "HTTPHeaders": {"x-amzn-RequestId": "test-request-id-202"}
+ },
+ }
+
+ invoker = LambdaInvoker(lambda_client)
+
+ input_data = DurableExecutionInvocationInput(
+ durable_execution_arn="test-arn",
+ checkpoint_token="test-token",
+ initial_execution_state=InitialExecutionState(operations=[], next_marker=""),
+ )
+
+ response = invoker.invoke("test-function", input_data)
+ assert isinstance(response.invocation_output, DurableExecutionInvocationOutput)
+ assert response.request_id == "test-request-id-202"
+
+
+def test_lambda_invoker_invoke_function_error():
+ """Test lambda invocation with function error."""
+ from aws_durable_execution_sdk_python_testing.exceptions import (
+ DurableFunctionsTestError,
+ )
+
+ lambda_client, _ = _create_mock_lambda_client_with_exceptions()
+
+ mock_payload = Mock()
+ mock_payload.read.return_value = b'{"errorMessage": "Function failed"}'
+
+ lambda_client.invoke.return_value = {
+ "StatusCode": 200,
+ "FunctionError": "Unhandled",
+ "Payload": mock_payload,
+ }
+
+ invoker = LambdaInvoker(lambda_client)
+
+ input_data = DurableExecutionInvocationInput(
+ durable_execution_arn="test-arn",
+ checkpoint_token="test-token",
+ initial_execution_state=InitialExecutionState(operations=[], next_marker=""),
+ )
+
+ with pytest.raises(
+ DurableFunctionsTestError, match="Lambda invocation failed with status 200"
+ ):
+ invoker.invoke("test-function", input_data)
+
+
+def _create_mock_lambda_client_with_exceptions():
+ """Helper to create mock lambda client with all exception types."""
+ lambda_client = Mock()
+
+ class MockException(Exception):
+ pass
+
+ exceptions_mock = Mock()
+ for exc_name in [
+ "ResourceNotFoundException",
+ "InvalidParameterValueException",
+ "TooManyRequestsException",
+ "ServiceException",
+ "ResourceConflictException",
+ "InvalidRequestContentException",
+ "RequestTooLargeException",
+ "UnsupportedMediaTypeException",
+ "InvalidRuntimeException",
+ "InvalidZipFileException",
+ "ResourceNotReadyException",
+ "SnapStartTimeoutException",
+ "SnapStartNotReadyException",
+ "SnapStartException",
+ "RecursiveInvocationException",
+ "InvalidSecurityGroupIDException",
+ "EC2ThrottledException",
+ "EFSMountConnectivityException",
+ "SubnetIPAddressLimitReachedException",
+ "EC2UnexpectedException",
+ "InvalidSubnetIDException",
+ "EC2AccessDeniedException",
+ "EFSIOException",
+ "ENILimitReachedException",
+ "EFSMountTimeoutException",
+ "EFSMountFailureException",
+ "KMSAccessDeniedException",
+ "KMSDisabledException",
+ "KMSNotFoundException",
+ "KMSInvalidStateException",
+ ]:
+ setattr(exceptions_mock, exc_name, MockException)
+
+ lambda_client.exceptions = exceptions_mock
+ return lambda_client, MockException
+
+
+def test_lambda_invoker_invoke_resource_not_found():
+ """Test lambda invocation with ResourceNotFoundException."""
+ from aws_durable_execution_sdk_python_testing.exceptions import (
+ ResourceNotFoundException,
+ )
+
+ lambda_client, _ = _create_mock_lambda_client_with_exceptions()
+
+ # Create specific exception for ResourceNotFoundException
+ class MockResourceNotFoundException(Exception):
+ pass
+
+ lambda_client.exceptions.ResourceNotFoundException = MockResourceNotFoundException
+
+ lambda_client.invoke.side_effect = MockResourceNotFoundException(
+ "Function not found"
+ )
+
+ invoker = LambdaInvoker(lambda_client)
+
+ input_data = DurableExecutionInvocationInput(
+ durable_execution_arn="test-arn",
+ checkpoint_token="test-token",
+ initial_execution_state=InitialExecutionState(operations=[], next_marker=""),
+ )
+
+ with pytest.raises(
+ ResourceNotFoundException, match="Function not found: test-function"
+ ):
+ invoker.invoke("test-function", input_data)
+
+
+def test_lambda_invoker_invoke_invalid_parameter():
+ """Test lambda invocation with InvalidParameterValueException."""
+ from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+ )
+
+ lambda_client, MockException = _create_mock_lambda_client_with_exceptions()
+
+ # Override specific exception for this test
+ class MockInvalidParameterValueException(Exception):
+ pass
+
+ lambda_client.exceptions.InvalidParameterValueException = (
+ MockInvalidParameterValueException
+ )
+
+ lambda_client.invoke.side_effect = MockInvalidParameterValueException(
+ "Invalid param"
+ )
+
+ invoker = LambdaInvoker(lambda_client)
+
+ input_data = DurableExecutionInvocationInput(
+ durable_execution_arn="test-arn",
+ checkpoint_token="test-token",
+ initial_execution_state=InitialExecutionState(operations=[], next_marker=""),
+ )
+
+ with pytest.raises(InvalidParameterValueException, match="Invalid parameter"):
+ invoker.invoke("test-function", input_data)
+
+
+def test_lambda_invoker_invoke_service_exception():
+ """Test lambda invocation with ServiceException."""
+ from aws_durable_execution_sdk_python_testing.exceptions import (
+ DurableFunctionsTestError,
+ )
+
+ lambda_client, _ = _create_mock_lambda_client_with_exceptions()
+
+ # Create specific exception for ServiceException
+ class MockServiceException(Exception):
+ pass
+
+ lambda_client.exceptions.ServiceException = MockServiceException
+
+ lambda_client.invoke.side_effect = MockServiceException("Service error")
+
+ invoker = LambdaInvoker(lambda_client)
+
+ input_data = DurableExecutionInvocationInput(
+ durable_execution_arn="test-arn",
+ checkpoint_token="test-token",
+ initial_execution_state=InitialExecutionState(operations=[], next_marker=""),
+ )
+
+ with pytest.raises(DurableFunctionsTestError, match="Lambda invocation failed"):
+ invoker.invoke("test-function", input_data)
+
+
+def test_lambda_invoker_invoke_ec2_exception():
+ """Test lambda invocation with EC2 exception."""
+ from aws_durable_execution_sdk_python_testing.exceptions import (
+ DurableFunctionsTestError,
+ )
+
+ lambda_client, _ = _create_mock_lambda_client_with_exceptions()
+
+ # Create specific exception for EC2AccessDeniedException
+ class MockEC2Exception(Exception):
+ pass
+
+ lambda_client.exceptions.EC2AccessDeniedException = MockEC2Exception
+
+ lambda_client.invoke.side_effect = MockEC2Exception("Access denied")
+
+ invoker = LambdaInvoker(lambda_client)
+
+ input_data = DurableExecutionInvocationInput(
+ durable_execution_arn="test-arn",
+ checkpoint_token="test-token",
+ initial_execution_state=InitialExecutionState(operations=[], next_marker=""),
+ )
+
+ with pytest.raises(DurableFunctionsTestError, match="Lambda infrastructure error"):
+ invoker.invoke("test-function", input_data)
+
+
+def test_lambda_invoker_invoke_kms_exception():
+ """Test lambda invocation with KMS exception."""
+ from aws_durable_execution_sdk_python_testing.exceptions import (
+ DurableFunctionsTestError,
+ )
+
+ lambda_client, _ = _create_mock_lambda_client_with_exceptions()
+
+ # Create specific exception for KMSAccessDeniedException
+ class MockKMSException(Exception):
+ pass
+
+ lambda_client.exceptions.KMSAccessDeniedException = MockKMSException
+
+ lambda_client.invoke.side_effect = MockKMSException("KMS access denied")
+
+ invoker = LambdaInvoker(lambda_client)
+
+ input_data = DurableExecutionInvocationInput(
+ durable_execution_arn="test-arn",
+ checkpoint_token="test-token",
+ initial_execution_state=InitialExecutionState(operations=[], next_marker=""),
+ )
+
+ with pytest.raises(DurableFunctionsTestError, match="Lambda KMS error"):
+ invoker.invoke("test-function", input_data)
+
+
+def test_lambda_invoker_invoke_durable_execution_already_started():
+ """Test lambda invocation with DurableExecutionAlreadyStartedException."""
+ from aws_durable_execution_sdk_python_testing.exceptions import (
+ DurableFunctionsTestError,
+ )
+
+ lambda_client, _ = _create_mock_lambda_client_with_exceptions()
+
+ class MockDurableExecutionAlreadyStartedException(Exception):
+ pass
+
+ MockDurableExecutionAlreadyStartedException.__name__ = (
+ "DurableExecutionAlreadyStartedException"
+ )
+
+ lambda_client.invoke.side_effect = MockDurableExecutionAlreadyStartedException(
+ "Already started"
+ )
+
+ invoker = LambdaInvoker(lambda_client)
+
+ input_data = DurableExecutionInvocationInput(
+ durable_execution_arn="test-arn",
+ checkpoint_token="test-token",
+ initial_execution_state=InitialExecutionState(operations=[], next_marker=""),
+ )
+
+ with pytest.raises(
+ DurableFunctionsTestError, match="Durable execution already started"
+ ):
+ invoker.invoke("test-function", input_data)
+
+
+def test_lambda_invoker_invoke_unexpected_exception():
+ """Test lambda invocation with unexpected exception."""
+ from aws_durable_execution_sdk_python_testing.exceptions import (
+ DurableFunctionsTestError,
+ )
+
+ lambda_client, _ = _create_mock_lambda_client_with_exceptions()
+ lambda_client.invoke.side_effect = RuntimeError("Unexpected error")
+
+ invoker = LambdaInvoker(lambda_client)
+
+ input_data = DurableExecutionInvocationInput(
+ durable_execution_arn="test-arn",
+ checkpoint_token="test-token",
+ initial_execution_state=InitialExecutionState(operations=[], next_marker=""),
+ )
+
+ with pytest.raises(
+ DurableFunctionsTestError, match="Unexpected error during Lambda invocation"
+ ):
+ invoker.invoke("test-function", input_data)
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/model_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/model_test.py
new file mode 100644
index 0000000..10076c0
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/model_test.py
@@ -0,0 +1,3667 @@
+"""Tests for model serialization dataclasses."""
+
+from __future__ import annotations
+
+import datetime
+import json
+
+import pytest
+from aws_durable_execution_sdk_python.lambda_service import (
+ OperationStatus,
+ OperationType,
+)
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+from aws_durable_execution_sdk_python_testing.model import (
+ CallbackFailedDetails,
+ CallbackStartedDetails,
+ CallbackSucceededDetails,
+ CallbackTimedOutDetails,
+ ChainedInvokeFailedDetails,
+ ChainedInvokeStartedDetails,
+ ChainedInvokeStoppedDetails,
+ ChainedInvokeSucceededDetails,
+ ChainedInvokeTimedOutDetails,
+ CheckpointDurableExecutionRequest,
+ CheckpointDurableExecutionResponse,
+ CheckpointUpdatedExecutionState,
+ ContextFailedDetails,
+ ContextStartedDetails,
+ ContextSucceededDetails,
+ ErrorResponse,
+ Event,
+ EventError,
+ EventInput,
+ EventResult,
+ Execution,
+ ExecutionFailedDetails,
+ ExecutionStartedDetails,
+ ExecutionStoppedDetails,
+ ExecutionSucceededDetails,
+ ExecutionTimedOutDetails,
+ GetDurableExecutionHistoryRequest,
+ GetDurableExecutionHistoryResponse,
+ GetDurableExecutionRequest,
+ GetDurableExecutionResponse,
+ GetDurableExecutionStateRequest,
+ GetDurableExecutionStateResponse,
+ InvocationCompletedDetails,
+ ListDurableExecutionsByFunctionRequest,
+ ListDurableExecutionsByFunctionResponse,
+ ListDurableExecutionsRequest,
+ ListDurableExecutionsResponse,
+ RetryDetails,
+ SendDurableExecutionCallbackFailureRequest,
+ SendDurableExecutionCallbackFailureResponse,
+ SendDurableExecutionCallbackHeartbeatRequest,
+ SendDurableExecutionCallbackHeartbeatResponse,
+ SendDurableExecutionCallbackSuccessRequest,
+ SendDurableExecutionCallbackSuccessResponse,
+ StartDurableExecutionInput,
+ StartDurableExecutionOutput,
+ StepFailedDetails,
+ StepStartedDetails,
+ StepSucceededDetails,
+ StopDurableExecutionRequest,
+ StopDurableExecutionResponse,
+ WaitCancelledDetails,
+ WaitStartedDetails,
+ WaitSucceededDetails,
+ events_to_operations,
+)
+
+
+# Test timestamp constants
+TIMESTAMP_2023_01_01_00_00 = datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC)
+TIMESTAMP_2023_01_01_00_01 = datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC)
+TIMESTAMP_2023_01_01_00_02 = datetime.datetime(2023, 1, 1, 0, 2, 0, tzinfo=datetime.UTC)
+TIMESTAMP_2023_01_02_00_00 = datetime.datetime(2023, 1, 2, 0, 0, 0, tzinfo=datetime.UTC)
+
+DEFAULT_START_DURABLE_EXECUTION_INPUT_DATA = {
+ "AccountId": "123456789012",
+ "FunctionName": "my-function",
+ "FunctionQualifier": "$LATEST",
+ "ExecutionName": "test-execution",
+ "ExecutionTimeoutSeconds": 300,
+ "ExecutionRetentionPeriodDays": 7,
+ "InvocationId": "invocation-123",
+ "TraceFields": {"key": "value"},
+ "TenantId": "tenant-123",
+ "Input": "test-input",
+}
+
+
+def test_start_durable_execution_input_serialization():
+ """Test StartDurableExecutionInput from_dict/to_dict round-trip."""
+ data = DEFAULT_START_DURABLE_EXECUTION_INPUT_DATA
+
+ # Test from_dict
+ input_obj = StartDurableExecutionInput.from_dict(data)
+ assert input_obj.account_id == "123456789012"
+ assert input_obj.function_name == "my-function"
+ assert input_obj.function_qualifier == "$LATEST"
+ assert input_obj.execution_name == "test-execution"
+ assert input_obj.execution_timeout_seconds == 300
+ assert input_obj.execution_retention_period_days == 7
+ assert input_obj.invocation_id == "invocation-123"
+ assert input_obj.trace_fields == {"key": "value"}
+ assert input_obj.tenant_id == "tenant-123"
+ assert input_obj.input == "test-input"
+
+ # Test to_dict
+ result_data = input_obj.to_dict()
+ assert result_data == data
+
+ # Test round-trip
+ round_trip = StartDurableExecutionInput.from_dict(result_data)
+ assert round_trip == input_obj
+
+
+def test_start_durable_execution_input_get_input_json_input():
+ """Test StartDurableExecutionInput from_dict/to_dict round-trip."""
+ data = DEFAULT_START_DURABLE_EXECUTION_INPUT_DATA
+ data["Input"] = '{"message": "hello"}'
+
+ input_obj = StartDurableExecutionInput.from_dict(data)
+ assert '{"message": "hello"}' == input_obj.get_normalized_input()
+
+
+def test_start_durable_execution_input_get_input_str_non_json_input():
+ """Test StartDurableExecutionInput from_dict/to_dict round-trip."""
+ data = DEFAULT_START_DURABLE_EXECUTION_INPUT_DATA
+ data["Input"] = "hello"
+
+ input_obj = StartDurableExecutionInput.from_dict(data)
+ assert '"hello"' == input_obj.get_normalized_input()
+
+
+def test_start_durable_execution_input_get_input_str_json_input():
+ """Test StartDurableExecutionInput from_dict/to_dict round-trip."""
+ data = DEFAULT_START_DURABLE_EXECUTION_INPUT_DATA
+ data["Input"] = '"hello"'
+
+ input_obj = StartDurableExecutionInput.from_dict(data)
+ assert '"hello"' == input_obj.get_normalized_input()
+
+
+def test_start_durable_execution_input_get_input_list_json_input():
+ """Test StartDurableExecutionInput from_dict/to_dict round-trip."""
+ data = DEFAULT_START_DURABLE_EXECUTION_INPUT_DATA
+ data["Input"] = "[1,2,3]"
+
+ input_obj = StartDurableExecutionInput.from_dict(data)
+ assert "[1,2,3]" == input_obj.get_normalized_input()
+
+
+def test_start_durable_execution_input_minimal():
+ """Test StartDurableExecutionInput with only required fields."""
+ data = {
+ "AccountId": "123456789012",
+ "FunctionName": "my-function",
+ "FunctionQualifier": "$LATEST",
+ "ExecutionName": "test-execution",
+ "ExecutionTimeoutSeconds": 300,
+ "ExecutionRetentionPeriodDays": 7,
+ }
+
+ input_obj = StartDurableExecutionInput.from_dict(data)
+ assert input_obj.invocation_id is None
+ assert input_obj.trace_fields is None
+ assert input_obj.tenant_id is None
+ assert input_obj.input is None
+
+ result_data = input_obj.to_dict()
+ assert result_data == data
+
+
+def test_start_durable_execution_output_serialization():
+ """Test StartDurableExecutionOutput from_dict/to_dict round-trip."""
+ data = {
+ "ExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test"
+ }
+
+ output_obj = StartDurableExecutionOutput.from_dict(data)
+ assert (
+ output_obj.execution_arn
+ == "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test"
+ )
+
+ result_data = output_obj.to_dict()
+ assert result_data == data
+
+ # Test round-trip
+ round_trip = StartDurableExecutionOutput.from_dict(result_data)
+ assert round_trip == output_obj
+
+
+def test_start_durable_execution_output_empty():
+ """Test StartDurableExecutionOutput with empty data."""
+ data = {}
+
+ output_obj = StartDurableExecutionOutput.from_dict(data)
+ assert output_obj.execution_arn is None
+
+ result_data = output_obj.to_dict()
+ assert result_data == {}
+
+
+def test_get_durable_execution_request_serialization():
+ """Test GetDurableExecutionRequest from_dict/to_dict round-trip."""
+ data = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test"
+ }
+
+ request_obj = GetDurableExecutionRequest.from_dict(data)
+ assert (
+ request_obj.durable_execution_arn
+ == "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test"
+ )
+
+ result_data = request_obj.to_dict()
+ assert result_data == data
+
+ # Test round-trip
+ round_trip = GetDurableExecutionRequest.from_dict(result_data)
+ assert round_trip == request_obj
+
+
+def test_get_durable_execution_response_serialization():
+ """Test GetDurableExecutionResponse from_dict/to_dict round-trip."""
+ data = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test",
+ "DurableExecutionName": "test-execution",
+ "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function",
+ "Status": "SUCCEEDED",
+ "StartTimestamp": TIMESTAMP_2023_01_01_00_00,
+ "InputPayload": "test-input",
+ "Result": "test-result",
+ "Error": {"ErrorMessage": "test error"},
+ "EndTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "Version": "1.0",
+ }
+
+ response_obj = GetDurableExecutionResponse.from_dict(data)
+ assert (
+ response_obj.durable_execution_arn
+ == "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test"
+ )
+ assert response_obj.durable_execution_name == "test-execution"
+ assert (
+ response_obj.function_arn
+ == "arn:aws:lambda:us-east-1:123456789012:function:my-function"
+ )
+ assert response_obj.status == "SUCCEEDED"
+ assert response_obj.start_timestamp == TIMESTAMP_2023_01_01_00_00
+ assert response_obj.input_payload == "test-input"
+ assert response_obj.result == "test-result"
+ assert response_obj.error.message == "test error"
+ assert response_obj.end_timestamp == TIMESTAMP_2023_01_01_00_01
+ assert response_obj.version == "1.0"
+
+ result_data = response_obj.to_dict()
+ assert result_data == data
+
+ # Test round-trip
+ round_trip = GetDurableExecutionResponse.from_dict(result_data)
+ assert round_trip == response_obj
+
+
+def test_get_durable_execution_response_minimal():
+ """Test GetDurableExecutionResponse with only required fields."""
+ data = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test",
+ "DurableExecutionName": "test-execution",
+ "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function",
+ "Status": "RUNNING",
+ "StartTimestamp": TIMESTAMP_2023_01_01_00_00,
+ }
+
+ response_obj = GetDurableExecutionResponse.from_dict(data)
+ assert response_obj.input_payload is None
+ assert response_obj.result is None
+ assert response_obj.error is None
+ assert response_obj.end_timestamp is None
+ assert response_obj.version is None
+
+ result_data = response_obj.to_dict()
+ assert result_data == data
+
+
+def test_list_durable_executions_request_serialization():
+ """Test ListDurableExecutionsRequest from_dict/to_dict round-trip."""
+ data = {
+ "FunctionName": "my-function",
+ "FunctionVersion": "$LATEST",
+ "DurableExecutionName": "test-execution",
+ "StatusFilter": ["RUNNING", "SUCCEEDED"],
+ "StartedAfter": TIMESTAMP_2023_01_01_00_00,
+ "StartedBefore": TIMESTAMP_2023_01_02_00_00,
+ "Marker": "marker-123",
+ "MaxItems": 10,
+ "ReverseOrder": True,
+ }
+
+ request_obj = ListDurableExecutionsRequest.from_dict(data)
+ assert request_obj.function_name == "my-function"
+ assert request_obj.function_version == "$LATEST"
+ assert request_obj.durable_execution_name == "test-execution"
+ assert request_obj.status_filter == ["RUNNING", "SUCCEEDED"]
+ assert request_obj.started_after == TIMESTAMP_2023_01_01_00_00
+ assert request_obj.started_before == TIMESTAMP_2023_01_02_00_00
+ assert request_obj.marker == "marker-123"
+ assert request_obj.max_items == 10
+ assert request_obj.reverse_order is True
+
+ result_data = request_obj.to_dict()
+ assert result_data == data
+
+ # Test round-trip
+ round_trip = ListDurableExecutionsRequest.from_dict(result_data)
+ assert round_trip == request_obj
+
+
+def test_list_durable_executions_request_empty():
+ """Test ListDurableExecutionsRequest with empty data."""
+ data = {}
+
+ request_obj = ListDurableExecutionsRequest.from_dict(data)
+ assert request_obj.function_name is None
+ assert request_obj.function_version is None
+ assert request_obj.durable_execution_name is None
+ assert request_obj.status_filter is None
+ assert request_obj.started_after is None
+ assert request_obj.started_before is None
+ assert request_obj.marker is None
+ assert request_obj.max_items == 0 # Default value from Smithy
+ assert request_obj.reverse_order is None
+
+ result_data = request_obj.to_dict()
+ # The result should include the default MaxItems
+ expected_data = {"MaxItems": 0}
+ assert result_data == expected_data
+
+
+def test_durable_execution_summary_serialization():
+ """Test Execution from_dict/to_dict round-trip."""
+ data = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test",
+ "DurableExecutionName": "test-execution",
+ "Status": "SUCCEEDED",
+ "StartTimestamp": TIMESTAMP_2023_01_01_00_00,
+ "EndTimestamp": TIMESTAMP_2023_01_01_00_01,
+ }
+
+ summary_obj = Execution.from_dict(data)
+ assert (
+ summary_obj.durable_execution_arn
+ == "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test"
+ )
+ assert summary_obj.durable_execution_name == "test-execution"
+ assert summary_obj.status == "SUCCEEDED"
+ assert summary_obj.start_timestamp == TIMESTAMP_2023_01_01_00_00
+ assert summary_obj.end_timestamp == TIMESTAMP_2023_01_01_00_01
+
+ result_data = summary_obj.to_dict()
+ assert result_data == data
+
+ # Test round-trip
+ round_trip = Execution.from_dict(result_data)
+ assert round_trip == summary_obj
+
+
+def test_durable_execution_summary_no_end_timestamp():
+ """Test Execution without end timestamp."""
+ data = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test",
+ "DurableExecutionName": "test-execution",
+ "Status": "RUNNING",
+ "StartTimestamp": TIMESTAMP_2023_01_01_00_00,
+ }
+
+ summary_obj = Execution.from_dict(data)
+ assert summary_obj.end_timestamp is None
+
+ result_data = summary_obj.to_dict()
+ assert result_data == data
+
+
+def test_list_durable_executions_response_serialization():
+ """Test ListDurableExecutionsResponse from_dict/to_dict round-trip."""
+ data = {
+ "DurableExecutions": [
+ {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test1",
+ "DurableExecutionName": "test-execution-1",
+ "Status": "SUCCEEDED",
+ "StartTimestamp": TIMESTAMP_2023_01_01_00_00,
+ "EndTimestamp": TIMESTAMP_2023_01_01_00_01,
+ },
+ {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test2",
+ "DurableExecutionName": "test-execution-2",
+ "Status": "RUNNING",
+ "StartTimestamp": TIMESTAMP_2023_01_01_00_02,
+ },
+ ],
+ "NextMarker": "next-marker-123",
+ }
+
+ response_obj = ListDurableExecutionsResponse.from_dict(data)
+ assert len(response_obj.durable_executions) == 2
+ assert (
+ response_obj.durable_executions[0].durable_execution_name == "test-execution-1"
+ )
+ assert (
+ response_obj.durable_executions[1].durable_execution_name == "test-execution-2"
+ )
+ assert response_obj.next_marker == "next-marker-123"
+
+ result_data = response_obj.to_dict()
+ assert result_data == data
+
+ # Test round-trip
+ round_trip = ListDurableExecutionsResponse.from_dict(result_data)
+ assert round_trip == response_obj
+
+
+def test_list_durable_executions_response_empty():
+ """Test ListDurableExecutionsResponse with empty executions."""
+ data = {"DurableExecutions": []}
+
+ response_obj = ListDurableExecutionsResponse.from_dict(data)
+ assert len(response_obj.durable_executions) == 0
+ assert response_obj.next_marker is None
+
+ result_data = response_obj.to_dict()
+ assert result_data == {"DurableExecutions": []}
+
+
+def test_stop_durable_execution_request_serialization():
+ """Test StopDurableExecutionRequest from_dict/to_dict round-trip."""
+ data = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test",
+ "Error": {"ErrorMessage": "Stopped by user"},
+ }
+
+ request_obj = StopDurableExecutionRequest.from_dict(data)
+ assert (
+ request_obj.durable_execution_arn
+ == "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test"
+ )
+ assert request_obj.error.message == "Stopped by user"
+
+ result_data = request_obj.to_dict()
+ assert result_data == data
+
+ # Test round-trip
+ round_trip = StopDurableExecutionRequest.from_dict(result_data)
+ assert round_trip == request_obj
+
+
+def test_stop_durable_execution_request_minimal():
+ """Test StopDurableExecutionRequest with only required fields."""
+ data = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test"
+ }
+
+ request_obj = StopDurableExecutionRequest.from_dict(data)
+ assert request_obj.error is None
+
+ result_data = request_obj.to_dict()
+ assert result_data == data
+
+
+def test_stop_durable_execution_response_serialization():
+ """Test StopDurableExecutionResponse from_dict/to_dict round-trip."""
+ data = {"StopTimestamp": "2023-01-01T00:01:00Z"}
+
+ response_obj = StopDurableExecutionResponse.from_dict(data)
+ assert response_obj.stop_timestamp == "2023-01-01T00:01:00Z"
+
+ result_data = response_obj.to_dict()
+ assert result_data == data
+
+ # Test round-trip
+ round_trip = StopDurableExecutionResponse.from_dict(result_data)
+ assert round_trip == response_obj
+
+
+def test_get_durable_execution_state_request_serialization():
+ """Test GetDurableExecutionStateRequest from_dict/to_dict round-trip."""
+ data = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test",
+ "CheckpointToken": "checkpoint-123",
+ "Marker": "marker-123",
+ "MaxItems": 10,
+ }
+
+ request_obj = GetDurableExecutionStateRequest.from_dict(data)
+ assert (
+ request_obj.durable_execution_arn
+ == "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test"
+ )
+ assert request_obj.checkpoint_token == "checkpoint-123" # noqa: S105
+ assert request_obj.marker == "marker-123"
+ assert request_obj.max_items == 10
+
+ result_data = request_obj.to_dict()
+ assert result_data == data
+
+ # Test round-trip
+ round_trip = GetDurableExecutionStateRequest.from_dict(result_data)
+ assert round_trip == request_obj
+
+
+def test_get_durable_execution_state_request_minimal():
+ """Test GetDurableExecutionStateRequest with only required fields."""
+ data = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test",
+ "CheckpointToken": "checkpoint-123",
+ }
+
+ request_obj = GetDurableExecutionStateRequest.from_dict(data)
+ assert request_obj.marker is None
+ assert request_obj.max_items == 0 # Default value from Smithy
+
+ result_data = request_obj.to_dict()
+ # The result should include the default MaxItems
+ expected_data = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test",
+ "CheckpointToken": "checkpoint-123",
+ "MaxItems": 0,
+ }
+ assert result_data == expected_data
+
+
+def test_get_durable_execution_state_response_serialization():
+ """Test GetDurableExecutionStateResponse from_dict/to_dict round-trip."""
+ data = {
+ "Operations": [
+ {"Id": "op-1", "Type": "STEP", "Status": "SUCCEEDED"},
+ {"Id": "op-2", "Type": "CONTEXT", "Status": "STARTED"},
+ ],
+ "NextMarker": "next-marker-123",
+ }
+
+ response_obj = GetDurableExecutionStateResponse.from_dict(data)
+ assert len(response_obj.operations) == 2
+ assert response_obj.operations[0].operation_id == "op-1"
+ assert response_obj.operations[0].operation_type.value == "STEP"
+ assert response_obj.operations[1].operation_id == "op-2"
+ assert response_obj.operations[1].operation_type.value == "CONTEXT"
+ assert response_obj.next_marker == "next-marker-123"
+
+ result_data = response_obj.to_dict()
+ assert result_data == data
+
+ # Test round-trip
+ round_trip = GetDurableExecutionStateResponse.from_dict(result_data)
+ assert round_trip == response_obj
+
+
+def test_get_durable_execution_state_response_empty():
+ """Test GetDurableExecutionStateResponse with empty operations."""
+ data = {"Operations": []}
+
+ response_obj = GetDurableExecutionStateResponse.from_dict(data)
+ assert len(response_obj.operations) == 0
+ assert response_obj.next_marker is None
+
+ result_data = response_obj.to_dict()
+ assert result_data == {"Operations": []}
+
+
+def test_get_durable_execution_history_request_serialization():
+ """Test GetDurableExecutionHistoryRequest from_dict/to_dict round-trip."""
+ data = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test",
+ "IncludeExecutionData": True,
+ "ReverseOrder": False,
+ "Marker": "marker-123",
+ "MaxItems": 20,
+ }
+
+ request_obj = GetDurableExecutionHistoryRequest.from_dict(data)
+ assert (
+ request_obj.durable_execution_arn
+ == "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test"
+ )
+ assert request_obj.include_execution_data is True
+ assert request_obj.reverse_order is False
+ assert request_obj.marker == "marker-123"
+ assert request_obj.max_items == 20
+
+ result_data = request_obj.to_dict()
+ assert result_data == data
+
+ # Test round-trip
+ round_trip = GetDurableExecutionHistoryRequest.from_dict(result_data)
+ assert round_trip == request_obj
+
+
+def test_get_durable_execution_history_request_minimal():
+ """Test GetDurableExecutionHistoryRequest with only required fields."""
+ data = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test"
+ }
+
+ request_obj = GetDurableExecutionHistoryRequest.from_dict(data)
+ assert request_obj.include_execution_data is None
+ assert request_obj.reverse_order is None
+ assert request_obj.marker is None
+ assert request_obj.max_items == 0 # Default value from Smithy
+
+ result_data = request_obj.to_dict()
+ # The result should include the default MaxItems
+ expected_data = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test",
+ "MaxItems": 0,
+ }
+ assert result_data == expected_data
+
+
+def test_execution_event_serialization():
+ """Test Event from_dict/to_dict round-trip."""
+ data = {
+ "EventType": "ExecutionStarted",
+ "EventId": 123,
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_00,
+ "SubType": "UserInitiated",
+ "Id": "op-123",
+ "Name": "test-operation",
+ "ParentId": "parent-op-123",
+ "ExecutionStartedDetails": {
+ "Input": {"Payload": "test-input", "Truncated": False},
+ "ExecutionTimeout": 300,
+ },
+ }
+
+ event_obj = Event.from_dict(data)
+ assert event_obj.event_type == "ExecutionStarted"
+ assert event_obj.event_id == 123
+ assert event_obj.event_timestamp == TIMESTAMP_2023_01_01_00_00
+ assert event_obj.sub_type == "UserInitiated"
+ assert event_obj.operation_id == "op-123"
+ assert event_obj.name == "test-operation"
+ assert event_obj.parent_id == "parent-op-123"
+ assert event_obj.execution_started_details is not None
+ assert event_obj.execution_started_details.input.payload == "test-input"
+ assert event_obj.execution_started_details.execution_timeout == 300
+
+ result_data = event_obj.to_dict()
+ assert result_data == data
+
+ # Test round-trip
+ round_trip = Event.from_dict(result_data)
+ assert round_trip == event_obj
+
+
+def test_execution_event_minimal():
+ """Test Event with only required fields."""
+ data = {
+ "EventType": "ExecutionStarted",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_00,
+ }
+
+ event_obj = Event.from_dict(data)
+ assert event_obj.event_id == 1 # Default value from Smithy
+ assert event_obj.sub_type is None
+ assert event_obj.operation_id is None
+ assert event_obj.name is None
+ assert event_obj.parent_id is None
+ assert event_obj.execution_started_details is None
+
+ result_data = event_obj.to_dict()
+ # The result should include the default EventId
+ expected_data = {
+ "EventType": "ExecutionStarted",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_00,
+ "EventId": 1,
+ }
+ assert result_data == expected_data
+
+
+def test_get_durable_execution_history_response_serialization():
+ """Test GetDurableExecutionHistoryResponse from_dict/to_dict round-trip."""
+ data = {
+ "Events": [
+ {
+ "EventType": "ExecutionStarted",
+ "EventId": 1,
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_00,
+ },
+ {
+ "EventType": "ExecutionSucceeded",
+ "EventId": 2,
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "ExecutionSucceededDetails": {
+ "Result": {"Payload": "success", "Truncated": False}
+ },
+ },
+ ],
+ "NextMarker": "next-marker-123",
+ }
+
+ response_obj = GetDurableExecutionHistoryResponse.from_dict(data)
+ assert len(response_obj.events) == 2
+ assert response_obj.events[0].event_type == "ExecutionStarted"
+ assert response_obj.events[1].event_type == "ExecutionSucceeded"
+ assert response_obj.events[1].execution_succeeded_details is not None
+ assert (
+ response_obj.events[1].execution_succeeded_details.result.payload == "success"
+ )
+ assert response_obj.next_marker == "next-marker-123"
+
+ result_data = response_obj.to_dict()
+ assert result_data == data
+
+ # Test round-trip
+ round_trip = GetDurableExecutionHistoryResponse.from_dict(result_data)
+ assert round_trip == response_obj
+
+
+def test_get_durable_execution_history_response_empty():
+ """Test GetDurableExecutionHistoryResponse with empty events."""
+ data = {"Events": []}
+
+ response_obj = GetDurableExecutionHistoryResponse.from_dict(data)
+ assert len(response_obj.events) == 0
+ assert response_obj.next_marker is None
+
+ result_data = response_obj.to_dict()
+ assert result_data == {"Events": []}
+
+
+def test_list_durable_executions_by_function_request_serialization():
+ """Test ListDurableExecutionsByFunctionRequest from_dict/to_dict round-trip."""
+ data = {
+ "FunctionName": "my-function",
+ "Qualifier": "$LATEST",
+ "StatusFilter": ["RUNNING", "SUCCEEDED"],
+ "StartedAfter": TIMESTAMP_2023_01_01_00_00,
+ "StartedBefore": TIMESTAMP_2023_01_02_00_00,
+ "Marker": "marker-123",
+ "MaxItems": 10,
+ "ReverseOrder": True,
+ }
+
+ request_obj = ListDurableExecutionsByFunctionRequest.from_dict(data)
+ assert request_obj.function_name == "my-function"
+ assert request_obj.qualifier == "$LATEST"
+ assert request_obj.status_filter == ["RUNNING", "SUCCEEDED"]
+ assert request_obj.started_after == TIMESTAMP_2023_01_01_00_00
+ assert request_obj.started_before == TIMESTAMP_2023_01_02_00_00
+ assert request_obj.marker == "marker-123"
+ assert request_obj.max_items == 10
+ assert request_obj.reverse_order is True
+
+ result_data = request_obj.to_dict()
+ assert result_data == data
+
+ # Test round-trip
+ round_trip = ListDurableExecutionsByFunctionRequest.from_dict(result_data)
+ assert round_trip == request_obj
+
+
+def test_list_durable_executions_by_function_request_minimal():
+ """Test ListDurableExecutionsByFunctionRequest with only required fields."""
+ data = {"FunctionName": "my-function"}
+
+ request_obj = ListDurableExecutionsByFunctionRequest.from_dict(data)
+ assert request_obj.qualifier is None
+ assert request_obj.status_filter is None
+ assert request_obj.started_after is None
+ assert request_obj.started_before is None
+ assert request_obj.marker is None
+ assert request_obj.max_items == 0 # Default value from Smithy
+ assert request_obj.reverse_order is None
+
+ result_data = request_obj.to_dict()
+ # The result should include the default MaxItems
+ expected_data = {"FunctionName": "my-function", "MaxItems": 0}
+ assert result_data == expected_data
+
+
+def test_list_durable_executions_by_function_response_serialization():
+ """Test ListDurableExecutionsByFunctionResponse from_dict/to_dict round-trip."""
+ data = {
+ "DurableExecutions": [
+ {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test1",
+ "DurableExecutionName": "test-execution-1",
+ "Status": "SUCCEEDED",
+ "StartTimestamp": TIMESTAMP_2023_01_01_00_00,
+ "EndTimestamp": TIMESTAMP_2023_01_01_00_01,
+ }
+ ],
+ "NextMarker": "next-marker-123",
+ }
+
+ response_obj = ListDurableExecutionsByFunctionResponse.from_dict(data)
+ assert len(response_obj.durable_executions) == 1
+ assert (
+ response_obj.durable_executions[0].durable_execution_name == "test-execution-1"
+ )
+ assert response_obj.next_marker == "next-marker-123"
+
+ result_data = response_obj.to_dict()
+ assert result_data == data
+
+ # Test round-trip
+ round_trip = ListDurableExecutionsByFunctionResponse.from_dict(result_data)
+ assert round_trip == response_obj
+
+
+def test_send_durable_execution_callback_success_request_serialization():
+ """Test SendDurableExecutionCallbackSuccessRequest from_dict/to_dict round-trip."""
+ data = {
+ "CallbackId": "callback-123",
+ "Result": "success-result",
+ }
+
+ request_obj = SendDurableExecutionCallbackSuccessRequest.from_dict(data)
+ assert request_obj.callback_id == "callback-123"
+ assert request_obj.result == "success-result"
+
+ result_data = request_obj.to_dict()
+ assert result_data == data
+
+ # Test round-trip
+ round_trip = SendDurableExecutionCallbackSuccessRequest.from_dict(result_data)
+ assert round_trip == request_obj
+
+
+def test_send_durable_execution_callback_success_request_minimal():
+ """Test SendDurableExecutionCallbackSuccessRequest with only required fields."""
+ data = {"CallbackId": "callback-123"}
+
+ request_obj = SendDurableExecutionCallbackSuccessRequest.from_dict(data)
+ assert request_obj.result is None
+
+ result_data = request_obj.to_dict()
+ assert result_data == data
+
+
+def test_send_durable_execution_callback_success_response_creation():
+ """Test SendDurableExecutionCallbackSuccessResponse creation."""
+ response_obj = SendDurableExecutionCallbackSuccessResponse()
+ assert isinstance(response_obj, SendDurableExecutionCallbackSuccessResponse)
+
+
+def test_send_durable_execution_callback_failure_request_serialization():
+ """Test SendDurableExecutionCallbackFailureRequest from_dict/to_dict round-trip."""
+ data = {"ErrorMessage": "callback failed"}
+
+ request_obj = SendDurableExecutionCallbackFailureRequest.from_dict(
+ data, "callback-123"
+ )
+ assert request_obj.callback_id == "callback-123"
+ assert request_obj.error.message == "callback failed"
+
+ result_data = request_obj.to_dict()
+ expected_data = {
+ "CallbackId": "callback-123",
+ "Error": {"ErrorMessage": "callback failed"},
+ }
+ assert result_data == expected_data
+
+ # Test round-trip
+ round_trip = SendDurableExecutionCallbackFailureRequest.from_dict(
+ result_data.get("Error", {}), result_data["CallbackId"]
+ )
+ assert round_trip == request_obj
+
+
+def test_send_durable_execution_callback_failure_request_minimal():
+ """Test SendDurableExecutionCallbackFailureRequest with only required fields."""
+
+ request_obj = SendDurableExecutionCallbackFailureRequest.from_dict(
+ {}, "callback-123"
+ )
+ assert request_obj.error is None
+
+ result_data = request_obj.to_dict()
+ assert result_data == {"CallbackId": "callback-123"}
+
+
+def test_send_durable_execution_callback_failure_response_creation():
+ """Test SendDurableExecutionCallbackFailureResponse creation."""
+ response_obj = SendDurableExecutionCallbackFailureResponse()
+ assert isinstance(response_obj, SendDurableExecutionCallbackFailureResponse)
+
+
+def test_send_durable_execution_callback_heartbeat_request_serialization():
+ """Test SendDurableExecutionCallbackHeartbeatRequest from_dict/to_dict round-trip."""
+ data = {"CallbackId": "callback-123"}
+
+ request_obj = SendDurableExecutionCallbackHeartbeatRequest.from_dict(data)
+ assert request_obj.callback_id == "callback-123"
+
+ result_data = request_obj.to_dict()
+ assert result_data == data
+
+ # Test round-trip
+ round_trip = SendDurableExecutionCallbackHeartbeatRequest.from_dict(result_data)
+ assert round_trip == request_obj
+
+
+def test_send_durable_execution_callback_heartbeat_response_creation():
+ """Test SendDurableExecutionCallbackHeartbeatResponse creation."""
+ response_obj = SendDurableExecutionCallbackHeartbeatResponse()
+ assert isinstance(response_obj, SendDurableExecutionCallbackHeartbeatResponse)
+
+
+def test_checkpoint_durable_execution_request_serialization():
+ """Test CheckpointDurableExecutionRequest from_dict/to_dict round-trip."""
+ execution_arn = (
+ "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test"
+ )
+ data = {
+ "CheckpointToken": "checkpoint-123",
+ "Updates": [
+ {"Id": "op-1", "Type": "STEP", "Action": "SUCCEED"},
+ {"Id": "op-2", "Type": "CONTEXT", "Action": "START"},
+ ],
+ "ClientToken": "client-token-123",
+ }
+
+ request_obj = CheckpointDurableExecutionRequest.from_dict(data, execution_arn)
+ assert request_obj.durable_execution_arn == execution_arn
+ assert request_obj.checkpoint_token == "checkpoint-123" # noqa: S105
+ assert len(request_obj.updates) == 2
+ assert request_obj.updates[0].operation_id == "op-1"
+ assert request_obj.updates[0].operation_type.value == "STEP"
+ assert request_obj.updates[0].action.value == "SUCCEED"
+ assert request_obj.updates[1].operation_id == "op-2"
+ assert request_obj.updates[1].operation_type.value == "CONTEXT"
+ assert request_obj.updates[1].action.value == "START"
+ assert request_obj.client_token == "client-token-123" # noqa: S105
+
+ result_data = request_obj.to_dict()
+ expected_data = {"DurableExecutionArn": execution_arn, **data}
+ assert result_data == expected_data
+
+ # Test round-trip
+ round_trip = CheckpointDurableExecutionRequest.from_dict(result_data, execution_arn)
+ assert round_trip == request_obj
+
+
+def test_checkpoint_durable_execution_request_minimal():
+ """Test CheckpointDurableExecutionRequest with only required fields."""
+ execution_arn = (
+ "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test"
+ )
+ data = {
+ "CheckpointToken": "checkpoint-123",
+ }
+
+ request_obj = CheckpointDurableExecutionRequest.from_dict(data, execution_arn)
+ assert request_obj.updates is None
+ assert request_obj.client_token is None
+
+ result_data = request_obj.to_dict()
+ expected_data = {"DurableExecutionArn": execution_arn, **data}
+ assert result_data == expected_data
+
+
+def test_checkpoint_durable_execution_response_serialization():
+ """Test CheckpointDurableExecutionResponse from_dict/to_dict round-trip."""
+ data = {
+ "CheckpointToken": "new-checkpoint-123",
+ "NewExecutionState": {
+ "Operations": [{"Id": "op-1", "Type": "STEP", "Status": "SUCCEEDED"}],
+ "NextMarker": "marker-123",
+ },
+ }
+
+ response_obj = CheckpointDurableExecutionResponse.from_dict(data)
+ assert response_obj.checkpoint_token == "new-checkpoint-123" # noqa: S105
+ assert response_obj.new_execution_state is not None
+ assert len(response_obj.new_execution_state.operations) == 1
+ assert response_obj.new_execution_state.operations[0].operation_id == "op-1"
+ assert response_obj.new_execution_state.next_marker == "marker-123"
+
+ result_data = response_obj.to_dict()
+ assert result_data == data
+
+ # Test round-trip
+ round_trip = CheckpointDurableExecutionResponse.from_dict(result_data)
+ assert round_trip == response_obj
+
+
+def test_checkpoint_durable_execution_response_minimal():
+ """Test CheckpointDurableExecutionResponse with only required fields."""
+ data = {"CheckpointToken": "new-checkpoint-123"}
+
+ response_obj = CheckpointDurableExecutionResponse.from_dict(data)
+ assert response_obj.new_execution_state is None
+
+ result_data = response_obj.to_dict()
+ assert result_data == data
+
+
+def test_error_response_creation():
+ """Test ErrorResponse creation with all fields."""
+ error_response = ErrorResponse(
+ error_type="InvalidParameterValueException",
+ error_message="Invalid parameter value",
+ error_code="INVALID_PARAMETER",
+ request_id="req-123",
+ )
+
+ assert error_response.error_type == "InvalidParameterValueException"
+ assert error_response.error_message == "Invalid parameter value"
+ assert error_response.error_code == "INVALID_PARAMETER"
+ assert error_response.request_id == "req-123"
+
+
+def test_error_response_creation_minimal():
+ """Test ErrorResponse creation with minimal fields."""
+ error_response = ErrorResponse(
+ error_type="ServiceException",
+ error_message="Internal server error",
+ )
+
+ assert error_response.error_type == "ServiceException"
+ assert error_response.error_message == "Internal server error"
+ assert error_response.error_code is None
+ assert error_response.request_id is None
+
+
+def test_error_response_to_dict_complete():
+ """Test ErrorResponse.to_dict() with all fields."""
+ error_response = ErrorResponse(
+ error_type="ResourceNotFoundException",
+ error_message="Resource not found",
+ error_code="RESOURCE_NOT_FOUND",
+ request_id="req-456",
+ )
+
+ result = error_response.to_dict()
+
+ expected = {
+ "error": {
+ "type": "ResourceNotFoundException",
+ "message": "Resource not found",
+ "code": "RESOURCE_NOT_FOUND",
+ "requestId": "req-456",
+ }
+ }
+
+ assert result == expected
+
+
+def test_error_response_to_dict_minimal():
+ """Test ErrorResponse.to_dict() with minimal fields."""
+ error_response = ErrorResponse(
+ error_type="ConflictException",
+ error_message="Resource conflict",
+ )
+
+ result = error_response.to_dict()
+
+ expected = {
+ "error": {
+ "type": "ConflictException",
+ "message": "Resource conflict",
+ }
+ }
+
+ assert result == expected
+
+
+def test_error_response_from_dict_nested():
+ """Test ErrorResponse.from_dict() with nested error structure."""
+ data = {
+ "error": {
+ "type": "InvalidParameterValueException",
+ "message": "Invalid input",
+ "code": "INVALID_INPUT",
+ "requestId": "req-789",
+ }
+ }
+
+ error_response = ErrorResponse.from_dict(data)
+
+ assert error_response.error_type == "InvalidParameterValueException"
+ assert error_response.error_message == "Invalid input"
+ assert error_response.error_code == "INVALID_INPUT"
+ assert error_response.request_id == "req-789"
+
+
+def test_error_response_from_dict_flat():
+ """Test ErrorResponse.from_dict() with flat error structure."""
+ data = {
+ "type": "ServiceException",
+ "message": "Internal error",
+ "code": "INTERNAL_ERROR",
+ }
+
+ error_response = ErrorResponse.from_dict(data)
+
+ assert error_response.error_type == "ServiceException"
+ assert error_response.error_message == "Internal error"
+ assert error_response.error_code == "INTERNAL_ERROR"
+ assert error_response.request_id is None
+
+
+def test_error_response_from_dict_minimal():
+ """Test ErrorResponse.from_dict() with minimal fields."""
+ data = {
+ "error": {
+ "type": "TooManyRequestsException",
+ "message": "Rate limit exceeded",
+ }
+ }
+
+ error_response = ErrorResponse.from_dict(data)
+
+ assert error_response.error_type == "TooManyRequestsException"
+ assert error_response.error_message == "Rate limit exceeded"
+ assert error_response.error_code is None
+ assert error_response.request_id is None
+
+
+def test_error_response_round_trip():
+ """Test ErrorResponse round-trip serialization."""
+ original = ErrorResponse(
+ error_type="ExecutionAlreadyStartedException",
+ error_message="Execution already exists",
+ error_code="EXECUTION_ALREADY_STARTED",
+ request_id="req-round-trip",
+ )
+
+ # Convert to dict and back
+ data = original.to_dict()
+ restored = ErrorResponse.from_dict(data)
+
+ assert restored.error_type == original.error_type
+ assert restored.error_message == original.error_message
+ assert restored.error_code == original.error_code
+ assert restored.request_id == original.request_id
+
+
+def test_error_response_immutable():
+ """Test that ErrorResponse is immutable (frozen dataclass)."""
+ error_response = ErrorResponse(
+ error_type="TestException",
+ error_message="Test message",
+ )
+
+ with pytest.raises(AttributeError):
+ error_response.error_type = "ModifiedException" # type: ignore
+
+
+# Tests for missing coverage in StartDurableExecutionInput
+def test_start_durable_execution_input_missing_required_fields():
+ """Test StartDurableExecutionInput validation with missing required fields."""
+ # Test missing AccountId
+ data = {
+ "FunctionName": "my-function",
+ "FunctionQualifier": "$LATEST",
+ "ExecutionName": "test-execution",
+ "ExecutionTimeoutSeconds": 300,
+ "ExecutionRetentionPeriodDays": 7,
+ }
+
+ with pytest.raises(InvalidParameterValueException) as exc_info:
+ StartDurableExecutionInput.from_dict(data)
+ assert "Missing required field: AccountId" in str(exc_info.value)
+
+ # Test missing FunctionName
+ data = {
+ "AccountId": "123456789012",
+ "FunctionQualifier": "$LATEST",
+ "ExecutionName": "test-execution",
+ "ExecutionTimeoutSeconds": 300,
+ "ExecutionRetentionPeriodDays": 7,
+ }
+
+ with pytest.raises(InvalidParameterValueException) as exc_info:
+ StartDurableExecutionInput.from_dict(data)
+ assert "Missing required field: FunctionName" in str(exc_info.value)
+
+ # Test missing FunctionQualifier
+ data = {
+ "AccountId": "123456789012",
+ "FunctionName": "my-function",
+ "ExecutionName": "test-execution",
+ "ExecutionTimeoutSeconds": 300,
+ "ExecutionRetentionPeriodDays": 7,
+ }
+
+ with pytest.raises(InvalidParameterValueException) as exc_info:
+ StartDurableExecutionInput.from_dict(data)
+ assert "Missing required field: FunctionQualifier" in str(exc_info.value)
+
+ # Test missing ExecutionName
+ data = {
+ "AccountId": "123456789012",
+ "FunctionName": "my-function",
+ "FunctionQualifier": "$LATEST",
+ "ExecutionTimeoutSeconds": 300,
+ "ExecutionRetentionPeriodDays": 7,
+ }
+
+ with pytest.raises(InvalidParameterValueException) as exc_info:
+ StartDurableExecutionInput.from_dict(data)
+ assert "Missing required field: ExecutionName" in str(exc_info.value)
+
+ # Test missing ExecutionTimeoutSeconds
+ data = {
+ "AccountId": "123456789012",
+ "FunctionName": "my-function",
+ "FunctionQualifier": "$LATEST",
+ "ExecutionName": "test-execution",
+ "ExecutionRetentionPeriodDays": 7,
+ }
+
+ with pytest.raises(InvalidParameterValueException) as exc_info:
+ StartDurableExecutionInput.from_dict(data)
+ assert "Missing required field: ExecutionTimeoutSeconds" in str(exc_info.value)
+
+ # Test missing ExecutionRetentionPeriodDays
+ data = {
+ "AccountId": "123456789012",
+ "FunctionName": "my-function",
+ "FunctionQualifier": "$LATEST",
+ "ExecutionName": "test-execution",
+ "ExecutionTimeoutSeconds": 300,
+ }
+
+ with pytest.raises(InvalidParameterValueException) as exc_info:
+ StartDurableExecutionInput.from_dict(data)
+ assert "Missing required field: ExecutionRetentionPeriodDays" in str(exc_info.value)
+
+
+# Tests for Execution backward compatibility
+def test_execution_backward_compatibility_empty_function_arn():
+ """Test Execution with empty FunctionArn for backward compatibility."""
+ data = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test",
+ "DurableExecutionName": "test-execution",
+ "Status": "SUCCEEDED",
+ "StartTimestamp": TIMESTAMP_2023_01_01_00_00,
+ "EndTimestamp": TIMESTAMP_2023_01_01_00_01,
+ }
+
+ execution_obj = Execution.from_dict(data)
+ assert (
+ execution_obj.function_arn == ""
+ ) # Default empty string for backward compatibility
+
+ result_data = execution_obj.to_dict()
+ # Empty function_arn should not be included in output
+ expected_data = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test",
+ "DurableExecutionName": "test-execution",
+ "Status": "SUCCEEDED",
+ "StartTimestamp": TIMESTAMP_2023_01_01_00_00,
+ "EndTimestamp": TIMESTAMP_2023_01_01_00_01,
+ }
+ assert result_data == expected_data
+
+
+def test_execution_with_function_arn():
+ """Test Execution with non-empty FunctionArn."""
+ data = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test",
+ "DurableExecutionName": "test-execution",
+ "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function",
+ "Status": "SUCCEEDED",
+ "StartTimestamp": TIMESTAMP_2023_01_01_00_00,
+ "EndTimestamp": TIMESTAMP_2023_01_01_00_01,
+ }
+
+ execution_obj = Execution.from_dict(data)
+ assert (
+ execution_obj.function_arn
+ == "arn:aws:lambda:us-east-1:123456789012:function:my-function"
+ )
+
+ result_data = execution_obj.to_dict()
+ assert result_data == data
+
+
+# Tests for ListDurableExecutionsRequest with all optional fields
+def test_list_durable_executions_request_all_optional_fields():
+ """Test ListDurableExecutionsRequest to_dict with all optional fields as None."""
+ request_obj = ListDurableExecutionsRequest(
+ function_name=None,
+ function_version=None,
+ durable_execution_name=None,
+ status_filter=None,
+ started_after=None,
+ started_before=None,
+ marker=None,
+ max_items=None,
+ reverse_order=None,
+ )
+
+ result_data = request_obj.to_dict()
+ # Only non-None fields should be included
+ expected_data = {}
+ assert result_data == expected_data
+
+
+def test_list_durable_executions_request_partial_fields():
+ """Test ListDurableExecutionsRequest to_dict with some optional fields."""
+ request_obj = ListDurableExecutionsRequest(
+ function_name="my-function",
+ function_version=None,
+ durable_execution_name="test-execution",
+ status_filter=None,
+ started_after=TIMESTAMP_2023_01_01_00_00,
+ started_before=None,
+ marker="marker-123",
+ max_items=10,
+ reverse_order=None,
+ )
+
+ result_data = request_obj.to_dict()
+ expected_data = {
+ "FunctionName": "my-function",
+ "DurableExecutionName": "test-execution",
+ "StartedAfter": TIMESTAMP_2023_01_01_00_00,
+ "Marker": "marker-123",
+ "MaxItems": 10,
+ }
+ assert result_data == expected_data
+
+
+# Tests for GetDurableExecutionStateRequest with all optional fields
+def test_get_durable_execution_state_request_all_optional_fields():
+ """Test GetDurableExecutionStateRequest to_dict with all optional fields as None."""
+ request_obj = GetDurableExecutionStateRequest(
+ durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test",
+ checkpoint_token="checkpoint-123", # noqa: S106
+ marker=None,
+ max_items=None,
+ )
+
+ result_data = request_obj.to_dict()
+ expected_data = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test",
+ "CheckpointToken": "checkpoint-123",
+ }
+ assert result_data == expected_data
+
+
+# Tests for EventInput
+def test_event_input_serialization():
+ """Test EventInput from_dict/to_dict round-trip."""
+ data = {
+ "Payload": "test-payload",
+ "Truncated": True,
+ }
+
+ event_input = EventInput.from_dict(data)
+ assert event_input.payload == "test-payload"
+ assert event_input.truncated is True
+
+ result_data = event_input.to_dict()
+ assert result_data == data
+
+
+def test_event_input_minimal():
+ """Test EventInput with minimal data."""
+ data = {}
+
+ event_input = EventInput.from_dict(data)
+ assert event_input.payload is None
+ assert event_input.truncated is False
+
+ result_data = event_input.to_dict()
+ assert result_data == {"Truncated": False}
+
+
+def test_event_input_with_payload_only():
+ """Test EventInput with payload but default truncated."""
+ data = {"Payload": "test-payload"}
+
+ event_input = EventInput.from_dict(data)
+ assert event_input.payload == "test-payload"
+ assert event_input.truncated is False
+
+ result_data = event_input.to_dict()
+ assert result_data == {"Payload": "test-payload", "Truncated": False}
+
+
+# Tests for EventResult
+def test_event_result_serialization():
+ """Test EventResult from_dict/to_dict round-trip."""
+ data = {
+ "Payload": "test-result",
+ "Truncated": True,
+ }
+
+ event_result = EventResult.from_dict(data)
+ assert event_result.payload == "test-result"
+ assert event_result.truncated is True
+
+ result_data = event_result.to_dict()
+ assert result_data == data
+
+
+def test_event_result_minimal():
+ """Test EventResult with minimal data."""
+ data = {}
+
+ event_result = EventResult.from_dict(data)
+ assert event_result.payload is None
+ assert event_result.truncated is False
+
+ result_data = event_result.to_dict()
+ assert result_data == {"Truncated": False}
+
+
+# Tests for EventError
+def test_event_error_serialization():
+ """Test EventError from_dict/to_dict round-trip."""
+ data = {
+ "Payload": {"ErrorMessage": "test error"},
+ "Truncated": True,
+ }
+
+ event_error = EventError.from_dict(data)
+ assert event_error.payload.message == "test error"
+ assert event_error.truncated is True
+
+ result_data = event_error.to_dict()
+ assert result_data == data
+
+
+def test_event_error_minimal():
+ """Test EventError with minimal data."""
+ data = {}
+
+ event_error = EventError.from_dict(data)
+ assert event_error.payload is None
+ assert event_error.truncated is False
+
+ result_data = event_error.to_dict()
+ assert result_data == {"Truncated": False}
+
+
+def test_event_error_with_payload_only():
+ """Test EventError with payload but default truncated."""
+ data = {"Payload": {"ErrorMessage": "test error"}}
+
+ event_error = EventError.from_dict(data)
+ assert event_error.payload.message == "test error"
+ assert event_error.truncated is False
+
+ result_data = event_error.to_dict()
+ assert result_data == {
+ "Payload": {"ErrorMessage": "test error"},
+ "Truncated": False,
+ }
+
+
+# Tests for RetryDetails
+def test_retry_details_serialization():
+ """Test RetryDetails from_dict/to_dict round-trip."""
+ data = {
+ "CurrentAttempt": 3,
+ "NextAttemptDelaySeconds": 60,
+ }
+
+ retry_details = RetryDetails.from_dict(data)
+ assert retry_details.current_attempt == 3
+ assert retry_details.next_attempt_delay_seconds == 60
+
+ result_data = retry_details.to_dict()
+ assert result_data == data
+
+
+def test_retry_details_minimal():
+ """Test RetryDetails with minimal data."""
+ data = {}
+
+ retry_details = RetryDetails.from_dict(data)
+ assert retry_details.current_attempt == 0
+ assert retry_details.next_attempt_delay_seconds is None
+
+ result_data = retry_details.to_dict()
+ assert result_data == {"CurrentAttempt": 0}
+
+
+def test_retry_details_with_current_attempt_only():
+ """Test RetryDetails with current attempt but no delay."""
+ data = {"CurrentAttempt": 2}
+
+ retry_details = RetryDetails.from_dict(data)
+ assert retry_details.current_attempt == 2
+ assert retry_details.next_attempt_delay_seconds is None
+
+ result_data = retry_details.to_dict()
+ assert result_data == {"CurrentAttempt": 2}
+
+
+# Tests for ExecutionStartedDetails
+def test_execution_started_details_serialization():
+ """Test ExecutionStartedDetails from_dict/to_dict round-trip."""
+ data = {
+ "Input": {"Payload": "test-input", "Truncated": False},
+ "ExecutionTimeout": 300,
+ }
+
+ details = ExecutionStartedDetails.from_dict(data)
+ assert details.input.payload == "test-input"
+ assert details.execution_timeout == 300
+
+ result_data = details.to_dict()
+ assert result_data == data
+
+
+def test_execution_started_details_minimal():
+ """Test ExecutionStartedDetails with minimal data."""
+ data = {}
+
+ details = ExecutionStartedDetails.from_dict(data)
+ assert details.input is None
+ assert details.execution_timeout is None
+
+ result_data = details.to_dict()
+ assert result_data == {}
+
+
+def test_execution_started_details_with_input_only():
+ """Test ExecutionStartedDetails with input but no timeout."""
+ data = {"Input": {"Payload": "test-input", "Truncated": False}}
+
+ details = ExecutionStartedDetails.from_dict(data)
+ assert details.input.payload == "test-input"
+ assert details.execution_timeout is None
+
+ result_data = details.to_dict()
+ assert result_data == {"Input": {"Payload": "test-input", "Truncated": False}}
+
+
+# Tests for ExecutionSucceededDetails
+def test_execution_succeeded_details_serialization():
+ """Test ExecutionSucceededDetails from_dict/to_dict round-trip."""
+ data = {
+ "Result": {"Payload": "success-result", "Truncated": False},
+ }
+
+ details = ExecutionSucceededDetails.from_dict(data)
+ assert details.result.payload == "success-result"
+
+ result_data = details.to_dict()
+ assert result_data == data
+
+
+def test_execution_succeeded_details_minimal():
+ """Test ExecutionSucceededDetails with minimal data."""
+ data = {}
+
+ details = ExecutionSucceededDetails.from_dict(data)
+ assert details.result is None
+
+ result_data = details.to_dict()
+ assert result_data == {}
+
+
+# Tests for ExecutionFailedDetails
+def test_execution_failed_details_serialization():
+ """Test ExecutionFailedDetails from_dict/to_dict round-trip."""
+ data = {
+ "Error": {"Payload": {"ErrorMessage": "execution failed"}, "Truncated": False},
+ }
+
+ details = ExecutionFailedDetails.from_dict(data)
+ assert details.error.payload.message == "execution failed"
+
+ result_data = details.to_dict()
+ assert result_data == data
+
+
+def test_execution_failed_details_minimal():
+ """Test ExecutionFailedDetails with minimal data."""
+ data = {}
+
+ details = ExecutionFailedDetails.from_dict(data)
+ assert details.error is None
+
+ result_data = details.to_dict()
+ assert result_data == {}
+
+
+# Tests for ExecutionTimedOutDetails
+def test_execution_timed_out_details_serialization():
+ """Test ExecutionTimedOutDetails from_dict/to_dict round-trip."""
+ data = {
+ "Error": {
+ "Payload": {"ErrorMessage": "execution timed out"},
+ "Truncated": False,
+ },
+ }
+
+ details = ExecutionTimedOutDetails.from_dict(data)
+ assert details.error.payload.message == "execution timed out"
+
+ result_data = details.to_dict()
+ assert result_data == data
+
+
+def test_execution_timed_out_details_minimal():
+ """Test ExecutionTimedOutDetails with minimal data."""
+ data = {}
+
+ details = ExecutionTimedOutDetails.from_dict(data)
+ assert details.error is None
+
+ result_data = details.to_dict()
+ assert result_data == {}
+
+
+# Tests for ExecutionStoppedDetails
+def test_execution_stopped_details_serialization():
+ """Test ExecutionStoppedDetails from_dict/to_dict round-trip."""
+ data = {
+ "Error": {"Payload": {"ErrorMessage": "execution stopped"}, "Truncated": False},
+ }
+
+ details = ExecutionStoppedDetails.from_dict(data)
+ assert details.error.payload.message == "execution stopped"
+
+ result_data = details.to_dict()
+ assert result_data == data
+
+
+def test_execution_stopped_details_minimal():
+ """Test ExecutionStoppedDetails with minimal data."""
+ data = {}
+
+ details = ExecutionStoppedDetails.from_dict(data)
+ assert details.error is None
+
+ result_data = details.to_dict()
+ assert result_data == {}
+
+
+# Tests for ContextStartedDetails
+def test_context_started_details_serialization():
+ """Test ContextStartedDetails from_dict/to_dict round-trip."""
+ # ContextStartedDetails ignores input data and always returns empty dict
+ data = {"dummy": "value"} # Can provide any data
+
+ details = ContextStartedDetails.from_dict(data)
+ assert isinstance(details, ContextStartedDetails)
+
+ result_data = details.to_dict()
+ assert result_data == {}
+
+
+# Tests for ContextSucceededDetails
+def test_context_succeeded_details_serialization():
+ """Test ContextSucceededDetails from_dict/to_dict round-trip."""
+ data = {
+ "Result": {"Payload": "context-result", "Truncated": False},
+ }
+
+ details = ContextSucceededDetails.from_dict(data)
+ assert details.result.payload == "context-result"
+
+ result_data = details.to_dict()
+ assert result_data == data
+
+
+def test_context_succeeded_details_minimal():
+ """Test ContextSucceededDetails with minimal data."""
+ data = {}
+
+ details = ContextSucceededDetails.from_dict(data)
+ assert details.result is None
+
+ result_data = details.to_dict()
+ assert result_data == {}
+
+
+# Tests for ContextFailedDetails
+def test_context_failed_details_serialization():
+ """Test ContextFailedDetails from_dict/to_dict round-trip."""
+ data = {
+ "Error": {"Payload": {"ErrorMessage": "context failed"}, "Truncated": False},
+ }
+
+ details = ContextFailedDetails.from_dict(data)
+ assert details.error.payload.message == "context failed"
+
+ result_data = details.to_dict()
+ assert result_data == data
+
+
+def test_context_failed_details_minimal():
+ """Test ContextFailedDetails with minimal data."""
+ data = {}
+
+ details = ContextFailedDetails.from_dict(data)
+ assert details.error is None
+
+ result_data = details.to_dict()
+ assert result_data == {}
+
+
+# Tests for WaitStartedDetails
+def test_wait_started_details_serialization():
+ """Test WaitStartedDetails from_dict/to_dict round-trip."""
+ data = {
+ "Duration": 60,
+ "ScheduledEndTimestamp": TIMESTAMP_2023_01_01_00_01,
+ }
+
+ details = WaitStartedDetails.from_dict(data)
+ assert details.duration == 60
+ assert details.scheduled_end_timestamp == TIMESTAMP_2023_01_01_00_01
+
+ result_data = details.to_dict()
+ assert result_data == data
+
+
+def test_wait_started_details_minimal():
+ """Test WaitStartedDetails with minimal data."""
+ data = {}
+
+ details = WaitStartedDetails.from_dict(data)
+ assert details.duration is None
+ assert details.scheduled_end_timestamp is None
+
+ result_data = details.to_dict()
+ assert result_data == {}
+
+
+def test_wait_started_details_with_duration_only():
+ """Test WaitStartedDetails with duration but no timestamp."""
+ data = {"Duration": 30}
+
+ details = WaitStartedDetails.from_dict(data)
+ assert details.duration == 30
+ assert details.scheduled_end_timestamp is None
+
+ result_data = details.to_dict()
+ assert result_data == {"Duration": 30}
+
+
+# Tests for WaitSucceededDetails
+def test_wait_succeeded_details_serialization():
+ """Test WaitSucceededDetails from_dict/to_dict round-trip."""
+ data = {"Duration": 60}
+
+ details = WaitSucceededDetails.from_dict(data)
+ assert details.duration == 60
+
+ result_data = details.to_dict()
+ assert result_data == data
+
+
+def test_wait_succeeded_details_minimal():
+ """Test WaitSucceededDetails with minimal data."""
+ data = {}
+
+ details = WaitSucceededDetails.from_dict(data)
+ assert details.duration is None
+
+ result_data = details.to_dict()
+ assert result_data == {}
+
+
+# Tests for WaitCancelledDetails
+def test_wait_cancelled_details_serialization():
+ """Test WaitCancelledDetails from_dict/to_dict round-trip."""
+ data = {
+ "Error": {"Payload": {"ErrorMessage": "wait cancelled"}, "Truncated": False},
+ }
+
+ details = WaitCancelledDetails.from_dict(data)
+ assert details.error.payload.message == "wait cancelled"
+
+ result_data = details.to_dict()
+ assert result_data == data
+
+
+def test_wait_cancelled_details_minimal():
+ """Test WaitCancelledDetails with minimal data."""
+ data = {}
+
+ details = WaitCancelledDetails.from_dict(data)
+ assert details.error is None
+
+ result_data = details.to_dict()
+ assert result_data == {}
+
+
+# Tests for StepStartedDetails
+def test_step_started_details_serialization():
+ """Test StepStartedDetails from_dict/to_dict round-trip."""
+ # StepStartedDetails ignores input data and always returns empty dict
+ data = {"dummy": "value"} # Can provide any data
+
+ details = StepStartedDetails.from_dict(data)
+ assert isinstance(details, StepStartedDetails)
+
+ result_data = details.to_dict()
+ assert result_data == {}
+
+
+# Tests for StepSucceededDetails
+def test_step_succeeded_details_serialization():
+ """Test StepSucceededDetails from_dict/to_dict round-trip."""
+ data = {
+ "Result": {"Payload": "step-result", "Truncated": False},
+ "RetryDetails": {"CurrentAttempt": 2, "NextAttemptDelaySeconds": 30},
+ }
+
+ details = StepSucceededDetails.from_dict(data)
+ assert details.result.payload == "step-result"
+ assert details.retry_details.current_attempt == 2
+
+ result_data = details.to_dict()
+ assert result_data == data
+
+
+def test_step_succeeded_details_minimal():
+ """Test StepSucceededDetails with minimal data."""
+ data = {}
+
+ details = StepSucceededDetails.from_dict(data)
+ assert details.result is None
+ assert details.retry_details is None
+
+ result_data = details.to_dict()
+ assert result_data == {}
+
+
+def test_step_succeeded_details_with_result_only():
+ """Test StepSucceededDetails with result but no retry details."""
+ data = {"Result": {"Payload": "step-result", "Truncated": False}}
+
+ details = StepSucceededDetails.from_dict(data)
+ assert details.result.payload == "step-result"
+ assert details.retry_details is None
+
+ result_data = details.to_dict()
+ assert result_data == {"Result": {"Payload": "step-result", "Truncated": False}}
+
+
+# Tests for StepFailedDetails
+def test_step_failed_details_serialization():
+ """Test StepFailedDetails from_dict/to_dict round-trip."""
+ data = {
+ "Error": {"Payload": {"ErrorMessage": "step failed"}, "Truncated": False},
+ "RetryDetails": {"CurrentAttempt": 1, "NextAttemptDelaySeconds": 15},
+ }
+
+ details = StepFailedDetails.from_dict(data)
+ assert details.error.payload.message == "step failed"
+ assert details.retry_details.current_attempt == 1
+
+ result_data = details.to_dict()
+ assert result_data == data
+
+
+def test_step_failed_details_minimal():
+ """Test StepFailedDetails with minimal data."""
+ data = {}
+
+ details = StepFailedDetails.from_dict(data)
+ assert details.error is None
+ assert details.retry_details is None
+
+ result_data = details.to_dict()
+ assert result_data == {}
+
+
+def test_step_failed_details_with_error_only():
+ """Test StepFailedDetails with error but no retry details."""
+ data = {"Error": {"Payload": {"ErrorMessage": "step failed"}, "Truncated": False}}
+
+ details = StepFailedDetails.from_dict(data)
+ assert details.error.payload.message == "step failed"
+ assert details.retry_details is None
+
+ result_data = details.to_dict()
+ assert result_data == {
+ "Error": {"Payload": {"ErrorMessage": "step failed"}, "Truncated": False}
+ }
+
+
+# Tests for ChainedInvokeStartedDetails
+def test_invoke_started_details_serialization():
+ """Test ChainedInvokeStartedDetails from_dict/to_dict round-trip."""
+ data = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test",
+ }
+
+ details = ChainedInvokeStartedDetails.from_dict(data)
+ assert (
+ details.durable_execution_arn
+ == "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test"
+ )
+
+ result_data = details.to_dict()
+ assert result_data == data
+
+
+def test_invoke_started_details_minimal():
+ """Test ChainedInvokeStartedDetails with minimal data."""
+ data = {}
+
+ details = ChainedInvokeStartedDetails.from_dict(data)
+ assert details.durable_execution_arn is None
+
+ result_data = details.to_dict()
+ assert result_data == {}
+
+
+def test_invoke_started_details_partial():
+ """Test ChainedInvokeStartedDetails with partial data."""
+ data = {}
+ details = ChainedInvokeStartedDetails.from_dict(data)
+ assert details.durable_execution_arn is None
+
+
+# Tests for ChainedInvokeSucceededDetails
+def test_invoke_succeeded_details_serialization():
+ """Test ChainedInvokeSucceededDetails from_dict/to_dict round-trip."""
+ data = {
+ "Result": {"Payload": "invoke-result", "Truncated": False},
+ }
+
+ details = ChainedInvokeSucceededDetails.from_dict(data)
+ assert details.result.payload == "invoke-result"
+
+ result_data = details.to_dict()
+ assert result_data == data
+
+
+def test_invoke_succeeded_details_minimal():
+ """Test ChainedInvokeSucceededDetails with minimal data."""
+ data = {}
+
+ details = ChainedInvokeSucceededDetails.from_dict(data)
+ assert details.result is None
+
+ result_data = details.to_dict()
+ assert result_data == {}
+
+
+# Tests for ChainedInvokeFailedDetails
+def test_invoke_failed_details_serialization():
+ """Test ChainedInvokeFailedDetails from_dict/to_dict round-trip."""
+ data = {
+ "Error": {"Payload": {"ErrorMessage": "invoke failed"}, "Truncated": False},
+ }
+
+ details = ChainedInvokeFailedDetails.from_dict(data)
+ assert details.error.payload.message == "invoke failed"
+
+ result_data = details.to_dict()
+ assert result_data == data
+
+
+def test_invoke_failed_details_minimal():
+ """Test ChainedInvokeFailedDetails with minimal data."""
+ data = {}
+
+ details = ChainedInvokeFailedDetails.from_dict(data)
+ assert details.error is None
+
+ result_data = details.to_dict()
+ assert result_data == {}
+
+
+# Tests for ChainedInvokeTimedOutDetails
+def test_invoke_timed_out_details_serialization():
+ """Test ChainedInvokeTimedOutDetails from_dict/to_dict round-trip."""
+ data = {
+ "Error": {"Payload": {"ErrorMessage": "invoke timed out"}, "Truncated": False},
+ }
+
+ details = ChainedInvokeTimedOutDetails.from_dict(data)
+ assert details.error.payload.message == "invoke timed out"
+
+ result_data = details.to_dict()
+ assert result_data == data
+
+
+def test_invoke_timed_out_details_minimal():
+ """Test ChainedInvokeTimedOutDetails with minimal data."""
+ data = {}
+
+ details = ChainedInvokeTimedOutDetails.from_dict(data)
+ assert details.error is None
+
+ result_data = details.to_dict()
+ assert result_data == {}
+
+
+# Tests for ChainedInvokeStoppedDetails
+def test_invoke_stopped_details_serialization():
+ """Test ChainedInvokeStoppedDetails from_dict/to_dict round-trip."""
+ data = {
+ "Error": {"Payload": {"ErrorMessage": "invoke stopped"}, "Truncated": False},
+ }
+
+ details = ChainedInvokeStoppedDetails.from_dict(data)
+ assert details.error.payload.message == "invoke stopped"
+
+ result_data = details.to_dict()
+ assert result_data == data
+
+
+def test_invoke_stopped_details_minimal():
+ """Test ChainedInvokeStoppedDetails with minimal data."""
+ data = {}
+
+ details = ChainedInvokeStoppedDetails.from_dict(data)
+ assert details.error is None
+
+ result_data = details.to_dict()
+ assert result_data == {}
+
+
+# Tests for CallbackStartedDetails
+def test_callback_started_details_serialization():
+ """Test CallbackStartedDetails from_dict/to_dict round-trip."""
+ data = {
+ "CallbackId": "callback-123",
+ "HeartbeatTimeout": 60,
+ "Timeout": 300,
+ }
+
+ details = CallbackStartedDetails.from_dict(data)
+ assert details.callback_id == "callback-123"
+ assert details.heartbeat_timeout == 60
+ assert details.timeout == 300
+
+ result_data = details.to_dict()
+ assert result_data == data
+
+
+def test_callback_started_details_minimal():
+ """Test CallbackStartedDetails with minimal data."""
+ data = {}
+
+ details = CallbackStartedDetails.from_dict(data)
+ assert details.callback_id is None
+ assert details.heartbeat_timeout is None
+ assert details.timeout is None
+
+ result_data = details.to_dict()
+ assert result_data == {}
+
+
+def test_callback_started_details_partial():
+ """Test CallbackStartedDetails with partial data."""
+ data = {
+ "CallbackId": "callback-123",
+ "Timeout": 300,
+ }
+
+ details = CallbackStartedDetails.from_dict(data)
+ assert details.callback_id == "callback-123"
+ assert details.heartbeat_timeout is None
+ assert details.timeout == 300
+
+ result_data = details.to_dict()
+ assert result_data == data
+
+
+# Tests for CallbackSucceededDetails
+def test_callback_succeeded_details_serialization():
+ """Test CallbackSucceededDetails from_dict/to_dict round-trip."""
+ data = {
+ "Result": {"Payload": "callback-result", "Truncated": False},
+ }
+
+ details = CallbackSucceededDetails.from_dict(data)
+ assert details.result.payload == "callback-result"
+
+ result_data = details.to_dict()
+ assert result_data == data
+
+
+def test_callback_succeeded_details_minimal():
+ """Test CallbackSucceededDetails with minimal data."""
+ data = {}
+
+ details = CallbackSucceededDetails.from_dict(data)
+ assert details.result is None
+
+ result_data = details.to_dict()
+ assert result_data == {}
+
+
+# Tests for CallbackFailedDetails
+def test_callback_failed_details_serialization():
+ """Test CallbackFailedDetails from_dict/to_dict round-trip."""
+ data = {
+ "Error": {"Payload": {"ErrorMessage": "callback failed"}, "Truncated": False},
+ }
+
+ details = CallbackFailedDetails.from_dict(data)
+ assert details.error.payload.message == "callback failed"
+
+ result_data = details.to_dict()
+ assert result_data == data
+
+
+def test_callback_failed_details_minimal():
+ """Test CallbackFailedDetails with minimal data."""
+ data = {}
+
+ details = CallbackFailedDetails.from_dict(data)
+ assert details.error is None
+
+ result_data = details.to_dict()
+ assert result_data == {}
+
+
+# Tests for CallbackTimedOutDetails
+def test_callback_timed_out_details_serialization():
+ """Test CallbackTimedOutDetails from_dict/to_dict round-trip."""
+ data = {
+ "Error": {
+ "Payload": {"ErrorMessage": "callback timed out"},
+ "Truncated": False,
+ },
+ }
+
+ details = CallbackTimedOutDetails.from_dict(data)
+ assert details.error.payload.message == "callback timed out"
+
+ result_data = details.to_dict()
+ assert result_data == data
+
+
+def test_callback_timed_out_details_minimal():
+ """Test CallbackTimedOutDetails with minimal data."""
+ data = {}
+
+ details = CallbackTimedOutDetails.from_dict(data)
+ assert details.error is None
+
+ result_data = details.to_dict()
+ assert result_data == {}
+
+
+# Tests for Event class with all detail types
+def test_event_with_execution_succeeded_details():
+ """Test Event with ExecutionSucceededDetails."""
+ data = {
+ "EventType": "ExecutionSucceeded",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "ExecutionSucceededDetails": {
+ "Result": {"Payload": "success", "Truncated": False}
+ },
+ }
+
+ event_obj = Event.from_dict(data)
+ assert event_obj.event_type == "ExecutionSucceeded"
+ assert event_obj.execution_succeeded_details is not None
+ assert event_obj.execution_succeeded_details.result.payload == "success"
+
+ result_data = event_obj.to_dict()
+ expected_data = {
+ "EventType": "ExecutionSucceeded",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "EventId": 1, # Default value
+ "ExecutionSucceededDetails": {
+ "Result": {"Payload": "success", "Truncated": False}
+ },
+ }
+ assert result_data == expected_data
+
+
+def test_event_with_execution_failed_details():
+ """Test Event with ExecutionFailedDetails."""
+ data = {
+ "EventType": "ExecutionFailed",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "ExecutionFailedDetails": {
+ "Error": {
+ "Payload": {"ErrorMessage": "execution failed"},
+ "Truncated": False,
+ }
+ },
+ }
+
+ event_obj = Event.from_dict(data)
+ assert event_obj.event_type == "ExecutionFailed"
+ assert event_obj.execution_failed_details is not None
+ assert (
+ event_obj.execution_failed_details.error.payload.message == "execution failed"
+ )
+
+ result_data = event_obj.to_dict()
+ expected_data = {
+ "EventType": "ExecutionFailed",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "EventId": 1,
+ "ExecutionFailedDetails": {
+ "Error": {
+ "Payload": {"ErrorMessage": "execution failed"},
+ "Truncated": False,
+ }
+ },
+ }
+ assert result_data == expected_data
+
+
+def test_event_with_execution_timed_out_details():
+ """Test Event with ExecutionTimedOutDetails."""
+ data = {
+ "EventType": "ExecutionTimedOut",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "ExecutionTimedOutDetails": {
+ "Error": {
+ "Payload": {"ErrorMessage": "execution timed out"},
+ "Truncated": False,
+ }
+ },
+ }
+
+ event_obj = Event.from_dict(data)
+ assert event_obj.event_type == "ExecutionTimedOut"
+ assert event_obj.execution_timed_out_details is not None
+ assert (
+ event_obj.execution_timed_out_details.error.payload.message
+ == "execution timed out"
+ )
+
+ result_data = event_obj.to_dict()
+ expected_data = {
+ "EventType": "ExecutionTimedOut",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "EventId": 1,
+ "ExecutionTimedOutDetails": {
+ "Error": {
+ "Payload": {"ErrorMessage": "execution timed out"},
+ "Truncated": False,
+ }
+ },
+ }
+ assert result_data == expected_data
+
+
+def test_event_with_execution_stopped_details():
+ """Test Event with ExecutionStoppedDetails."""
+ data = {
+ "EventType": "ExecutionStopped",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "ExecutionStoppedDetails": {
+ "Error": {
+ "Payload": {"ErrorMessage": "execution stopped"},
+ "Truncated": False,
+ }
+ },
+ }
+
+ event_obj = Event.from_dict(data)
+ assert event_obj.event_type == "ExecutionStopped"
+ assert event_obj.execution_stopped_details is not None
+ assert (
+ event_obj.execution_stopped_details.error.payload.message == "execution stopped"
+ )
+
+ result_data = event_obj.to_dict()
+ expected_data = {
+ "EventType": "ExecutionStopped",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "EventId": 1,
+ "ExecutionStoppedDetails": {
+ "Error": {
+ "Payload": {"ErrorMessage": "execution stopped"},
+ "Truncated": False,
+ }
+ },
+ }
+ assert result_data == expected_data
+
+
+def test_event_with_context_started_details():
+ """Test Event with ContextStartedDetails."""
+ # Since ContextStartedDetails has no fields and empty dict is falsy,
+ # we need to provide a non-empty dict or test without the key
+ data = {
+ "EventType": "ContextStarted",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "ContextStartedDetails": {"dummy": "value"}, # Non-empty to be truthy
+ }
+
+ event_obj = Event.from_dict(data)
+ assert event_obj.event_type == "ContextStarted"
+ assert event_obj.context_started_details is not None
+
+ result_data = event_obj.to_dict()
+ expected_data = {
+ "EventType": "ContextStarted",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "EventId": 1,
+ "ContextStartedDetails": {}, # to_dict() returns empty dict
+ }
+ assert result_data == expected_data
+
+
+def test_event_with_context_succeeded_details():
+ """Test Event with ContextSucceededDetails."""
+ data = {
+ "EventType": "ContextSucceeded",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "ContextSucceededDetails": {
+ "Result": {"Payload": "context result", "Truncated": False}
+ },
+ }
+
+ event_obj = Event.from_dict(data)
+ assert event_obj.event_type == "ContextSucceeded"
+ assert event_obj.context_succeeded_details is not None
+ assert event_obj.context_succeeded_details.result.payload == "context result"
+
+ result_data = event_obj.to_dict()
+ expected_data = {
+ "EventType": "ContextSucceeded",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "EventId": 1,
+ "ContextSucceededDetails": {
+ "Result": {"Payload": "context result", "Truncated": False}
+ },
+ }
+ assert result_data == expected_data
+
+
+def test_event_with_context_failed_details():
+ """Test Event with ContextFailedDetails."""
+ data = {
+ "EventType": "ContextFailed",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "ContextFailedDetails": {
+ "Error": {"Payload": {"ErrorMessage": "context failed"}, "Truncated": False}
+ },
+ }
+
+ event_obj = Event.from_dict(data)
+ assert event_obj.event_type == "ContextFailed"
+ assert event_obj.context_failed_details is not None
+ assert event_obj.context_failed_details.error.payload.message == "context failed"
+
+ result_data = event_obj.to_dict()
+ expected_data = {
+ "EventType": "ContextFailed",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "EventId": 1,
+ "ContextFailedDetails": {
+ "Error": {"Payload": {"ErrorMessage": "context failed"}, "Truncated": False}
+ },
+ }
+ assert result_data == expected_data
+
+
+def test_event_with_wait_started_details():
+ """Test Event with WaitStartedDetails."""
+ data = {
+ "EventType": "WaitStarted",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "WaitStartedDetails": {
+ "Duration": 60,
+ "ScheduledEndTimestamp": TIMESTAMP_2023_01_01_00_02,
+ },
+ }
+
+ event_obj = Event.from_dict(data)
+ assert event_obj.event_type == "WaitStarted"
+ assert event_obj.wait_started_details is not None
+ assert event_obj.wait_started_details.duration == 60
+
+ result_data = event_obj.to_dict()
+ expected_data = {
+ "EventType": "WaitStarted",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "EventId": 1,
+ "WaitStartedDetails": {
+ "Duration": 60,
+ "ScheduledEndTimestamp": TIMESTAMP_2023_01_01_00_02,
+ },
+ }
+ assert result_data == expected_data
+
+
+def test_event_with_wait_succeeded_details():
+ """Test Event with WaitSucceededDetails."""
+ data = {
+ "EventType": "WaitSucceeded",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "WaitSucceededDetails": {"Duration": 60},
+ }
+
+ event_obj = Event.from_dict(data)
+ assert event_obj.event_type == "WaitSucceeded"
+ assert event_obj.wait_succeeded_details is not None
+ assert event_obj.wait_succeeded_details.duration == 60
+
+ result_data = event_obj.to_dict()
+ expected_data = {
+ "EventType": "WaitSucceeded",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "EventId": 1,
+ "WaitSucceededDetails": {"Duration": 60},
+ }
+ assert result_data == expected_data
+
+
+def test_event_with_wait_cancelled_details():
+ """Test Event with WaitCancelledDetails."""
+ data = {
+ "EventType": "WaitCancelled",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "WaitCancelledDetails": {
+ "Error": {"Payload": {"ErrorMessage": "wait cancelled"}, "Truncated": False}
+ },
+ }
+
+ event_obj = Event.from_dict(data)
+ assert event_obj.event_type == "WaitCancelled"
+ assert event_obj.wait_cancelled_details is not None
+ assert event_obj.wait_cancelled_details.error.payload.message == "wait cancelled"
+
+ result_data = event_obj.to_dict()
+ expected_data = {
+ "EventType": "WaitCancelled",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "EventId": 1,
+ "WaitCancelledDetails": {
+ "Error": {"Payload": {"ErrorMessage": "wait cancelled"}, "Truncated": False}
+ },
+ }
+ assert result_data == expected_data
+
+
+def test_event_with_step_started_details():
+ """Test Event with StepStartedDetails."""
+ # Since StepStartedDetails has no fields and empty dict is falsy,
+ # we need to provide a non-empty dict or test without the key
+ data = {
+ "EventType": "StepStarted",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "StepStartedDetails": {"dummy": "value"}, # Non-empty to be truthy
+ }
+
+ event_obj = Event.from_dict(data)
+ assert event_obj.event_type == "StepStarted"
+ assert event_obj.step_started_details is not None
+
+ result_data = event_obj.to_dict()
+ expected_data = {
+ "EventType": "StepStarted",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "EventId": 1,
+ "StepStartedDetails": {}, # to_dict() returns empty dict
+ }
+ assert result_data == expected_data
+
+
+def test_event_with_step_succeeded_details():
+ """Test Event with StepSucceededDetails."""
+ data = {
+ "EventType": "StepSucceeded",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "StepSucceededDetails": {
+ "Result": {"Payload": "step result", "Truncated": False},
+ "RetryDetails": {"CurrentAttempt": 1},
+ },
+ }
+
+ event_obj = Event.from_dict(data)
+ assert event_obj.event_type == "StepSucceeded"
+ assert event_obj.step_succeeded_details is not None
+ assert event_obj.step_succeeded_details.result.payload == "step result"
+
+ result_data = event_obj.to_dict()
+ expected_data = {
+ "EventType": "StepSucceeded",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "EventId": 1,
+ "StepSucceededDetails": {
+ "Result": {"Payload": "step result", "Truncated": False},
+ "RetryDetails": {"CurrentAttempt": 1},
+ },
+ }
+ assert result_data == expected_data
+
+
+def test_event_with_step_failed_details():
+ """Test Event with StepFailedDetails."""
+ data = {
+ "EventType": "StepFailed",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "StepFailedDetails": {
+ "Error": {"Payload": {"ErrorMessage": "step failed"}, "Truncated": False},
+ "RetryDetails": {"CurrentAttempt": 2, "NextAttemptDelaySeconds": 30},
+ },
+ }
+
+ event_obj = Event.from_dict(data)
+ assert event_obj.event_type == "StepFailed"
+ assert event_obj.step_failed_details is not None
+ assert event_obj.step_failed_details.error.payload.message == "step failed"
+
+ result_data = event_obj.to_dict()
+ expected_data = {
+ "EventType": "StepFailed",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "EventId": 1,
+ "StepFailedDetails": {
+ "Error": {"Payload": {"ErrorMessage": "step failed"}, "Truncated": False},
+ "RetryDetails": {"CurrentAttempt": 2, "NextAttemptDelaySeconds": 30},
+ },
+ }
+ assert result_data == expected_data
+
+
+def test_event_with_invoke_started_details():
+ """Test Event with ChainedInvokeStartedDetails."""
+ data = {
+ "EventType": "ChainedInvokeStarted",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "ChainedInvokeStartedDetails": {
+ "DurableExecutionArn": "arn:aws:durable-execution:us-east-1:123456789012:execution:my-execution:1234567890",
+ },
+ }
+
+ event_obj = Event.from_dict(data)
+ assert event_obj.event_type == "ChainedInvokeStarted"
+ assert event_obj.chained_invoke_started_details is not None
+
+ result_data = event_obj.to_dict()
+ expected_data = {
+ "EventType": "ChainedInvokeStarted",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "EventId": 1,
+ "ChainedInvokeStartedDetails": {
+ "DurableExecutionArn": "arn:aws:durable-execution:us-east-1:123456789012:execution:my-execution:1234567890",
+ },
+ }
+ assert result_data == expected_data
+
+
+def test_event_with_invoke_succeeded_details():
+ """Test Event with ChainedInvokeSucceededDetails."""
+ data = {
+ "EventType": "ChainedInvokeSucceeded",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "ChainedInvokeSucceededDetails": {
+ "Result": {"Payload": "invoke result", "Truncated": False}
+ },
+ }
+
+ event_obj = Event.from_dict(data)
+ assert event_obj.event_type == "ChainedInvokeSucceeded"
+ assert event_obj.chained_invoke_succeeded_details is not None
+ assert event_obj.chained_invoke_succeeded_details.result.payload == "invoke result"
+
+ result_data = event_obj.to_dict()
+ expected_data = {
+ "EventType": "ChainedInvokeSucceeded",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "EventId": 1,
+ "ChainedInvokeSucceededDetails": {
+ "Result": {"Payload": "invoke result", "Truncated": False}
+ },
+ }
+ assert result_data == expected_data
+
+
+def test_event_with_invoke_failed_details():
+ """Test Event with ChainedInvokeFailedDetails."""
+ data = {
+ "EventType": "ChainedInvokeFailed",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "ChainedInvokeFailedDetails": {
+ "Error": {"Payload": {"ErrorMessage": "invoke failed"}, "Truncated": False}
+ },
+ }
+
+ event_obj = Event.from_dict(data)
+ assert event_obj.event_type == "ChainedInvokeFailed"
+ assert event_obj.chained_invoke_failed_details is not None
+ assert (
+ event_obj.chained_invoke_failed_details.error.payload.message == "invoke failed"
+ )
+
+ result_data = event_obj.to_dict()
+ expected_data = {
+ "EventType": "ChainedInvokeFailed",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "EventId": 1,
+ "ChainedInvokeFailedDetails": {
+ "Error": {"Payload": {"ErrorMessage": "invoke failed"}, "Truncated": False}
+ },
+ }
+ assert result_data == expected_data
+
+
+def test_event_with_invoke_timed_out_details():
+ """Test Event with ChainedInvokeTimedOutDetails."""
+ data = {
+ "EventType": "ChainedInvokeTimedOut",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "ChainedInvokeTimedOutDetails": {
+ "Error": {
+ "Payload": {"ErrorMessage": "invoke timed out"},
+ "Truncated": False,
+ }
+ },
+ }
+
+ event_obj = Event.from_dict(data)
+ assert event_obj.event_type == "ChainedInvokeTimedOut"
+ assert event_obj.chained_invoke_timed_out_details is not None
+ assert (
+ event_obj.chained_invoke_timed_out_details.error.payload.message
+ == "invoke timed out"
+ )
+
+ result_data = event_obj.to_dict()
+ expected_data = {
+ "EventType": "ChainedInvokeTimedOut",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "EventId": 1,
+ "ChainedInvokeTimedOutDetails": {
+ "Error": {
+ "Payload": {"ErrorMessage": "invoke timed out"},
+ "Truncated": False,
+ }
+ },
+ }
+ assert result_data == expected_data
+
+
+def test_event_with_invoke_stopped_details():
+ """Test Event with ChainedInvokeStoppedDetails."""
+ data = {
+ "EventType": "ChainedInvokeStopped",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "ChainedInvokeStoppedDetails": {
+ "Error": {"Payload": {"ErrorMessage": "invoke stopped"}, "Truncated": False}
+ },
+ }
+
+ event_obj = Event.from_dict(data)
+ assert event_obj.event_type == "ChainedInvokeStopped"
+ assert event_obj.chained_invoke_stopped_details is not None
+ assert (
+ event_obj.chained_invoke_stopped_details.error.payload.message
+ == "invoke stopped"
+ )
+
+ result_data = event_obj.to_dict()
+ expected_data = {
+ "EventType": "ChainedInvokeStopped",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "EventId": 1,
+ "ChainedInvokeStoppedDetails": {
+ "Error": {"Payload": {"ErrorMessage": "invoke stopped"}, "Truncated": False}
+ },
+ }
+ assert result_data == expected_data
+
+
+def test_event_with_callback_started_details():
+ """Test Event with CallbackStartedDetails."""
+ data = {
+ "EventType": "CallbackStarted",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "CallbackStartedDetails": {
+ "CallbackId": "callback-123",
+ "HeartbeatTimeout": 60,
+ "Timeout": 300,
+ },
+ }
+
+ event_obj = Event.from_dict(data)
+ assert event_obj.event_type == "CallbackStarted"
+ assert event_obj.callback_started_details is not None
+ assert event_obj.callback_started_details.callback_id == "callback-123"
+
+ result_data = event_obj.to_dict()
+ expected_data = {
+ "EventType": "CallbackStarted",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "EventId": 1,
+ "CallbackStartedDetails": {
+ "CallbackId": "callback-123",
+ "HeartbeatTimeout": 60,
+ "Timeout": 300,
+ },
+ }
+ assert result_data == expected_data
+
+
+def test_event_with_callback_succeeded_details():
+ """Test Event with CallbackSucceededDetails."""
+ data = {
+ "EventType": "CallbackSucceeded",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "CallbackSucceededDetails": {
+ "Result": {"Payload": "callback result", "Truncated": False}
+ },
+ }
+
+ event_obj = Event.from_dict(data)
+ assert event_obj.event_type == "CallbackSucceeded"
+ assert event_obj.callback_succeeded_details is not None
+ assert event_obj.callback_succeeded_details.result.payload == "callback result"
+
+ result_data = event_obj.to_dict()
+ expected_data = {
+ "EventType": "CallbackSucceeded",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "EventId": 1,
+ "CallbackSucceededDetails": {
+ "Result": {"Payload": "callback result", "Truncated": False}
+ },
+ }
+ assert result_data == expected_data
+
+
+def test_event_with_callback_failed_details():
+ """Test Event with CallbackFailedDetails."""
+ data = {
+ "EventType": "CallbackFailed",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "CallbackFailedDetails": {
+ "Error": {
+ "Payload": {"ErrorMessage": "callback failed"},
+ "Truncated": False,
+ }
+ },
+ }
+
+ event_obj = Event.from_dict(data)
+ assert event_obj.event_type == "CallbackFailed"
+ assert event_obj.callback_failed_details is not None
+ assert event_obj.callback_failed_details.error.payload.message == "callback failed"
+
+ result_data = event_obj.to_dict()
+ expected_data = {
+ "EventType": "CallbackFailed",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "EventId": 1,
+ "CallbackFailedDetails": {
+ "Error": {
+ "Payload": {"ErrorMessage": "callback failed"},
+ "Truncated": False,
+ }
+ },
+ }
+ assert result_data == expected_data
+
+
+def test_event_with_callback_timed_out_details():
+ """Test Event with CallbackTimedOutDetails."""
+ data = {
+ "EventType": "CallbackTimedOut",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "CallbackTimedOutDetails": {
+ "Error": {
+ "Payload": {"ErrorMessage": "callback timed out"},
+ "Truncated": False,
+ }
+ },
+ }
+
+ event_obj = Event.from_dict(data)
+ assert event_obj.event_type == "CallbackTimedOut"
+ assert event_obj.callback_timed_out_details is not None
+ assert (
+ event_obj.callback_timed_out_details.error.payload.message
+ == "callback timed out"
+ )
+
+ result_data = event_obj.to_dict()
+ expected_data = {
+ "EventType": "CallbackTimedOut",
+ "EventTimestamp": TIMESTAMP_2023_01_01_00_01,
+ "EventId": 1,
+ "CallbackTimedOutDetails": {
+ "Error": {
+ "Payload": {"ErrorMessage": "callback timed out"},
+ "Truncated": False,
+ }
+ },
+ }
+ assert result_data == expected_data
+
+
+# Tests for GetDurableExecutionHistoryRequest with all optional fields
+def test_get_durable_execution_history_request_all_optional_fields():
+ """Test GetDurableExecutionHistoryRequest to_dict with all optional fields as None."""
+ request_obj = GetDurableExecutionHistoryRequest(
+ durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test",
+ include_execution_data=None,
+ reverse_order=None,
+ marker=None,
+ max_items=None,
+ )
+
+ result_data = request_obj.to_dict()
+ expected_data = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test",
+ }
+ assert result_data == expected_data
+
+
+def test_get_durable_execution_history_request_partial_fields():
+ """Test GetDurableExecutionHistoryRequest to_dict with some optional fields."""
+ request_obj = GetDurableExecutionHistoryRequest(
+ durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test",
+ include_execution_data=True,
+ reverse_order=None,
+ marker="marker-123",
+ max_items=20,
+ )
+
+ result_data = request_obj.to_dict()
+ expected_data = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test",
+ "IncludeExecutionData": True,
+ "Marker": "marker-123",
+ "MaxItems": 20,
+ }
+ assert result_data == expected_data
+
+
+# Tests for ListDurableExecutionsByFunctionRequest with all optional fields
+def test_list_durable_executions_by_function_request_all_optional_fields():
+ """Test ListDurableExecutionsByFunctionRequest to_dict with all optional fields as None."""
+ request_obj = ListDurableExecutionsByFunctionRequest(
+ function_name="my-function",
+ qualifier=None,
+ status_filter=None,
+ started_after=None,
+ started_before=None,
+ marker=None,
+ max_items=None,
+ reverse_order=None,
+ )
+
+ result_data = request_obj.to_dict()
+ expected_data = {
+ "FunctionName": "my-function",
+ }
+ assert result_data == expected_data
+
+
+def test_list_durable_executions_by_function_request_partial_fields():
+ """Test ListDurableExecutionsByFunctionRequest to_dict with some optional fields."""
+ request_obj = ListDurableExecutionsByFunctionRequest(
+ function_name="my-function",
+ qualifier="$LATEST",
+ status_filter=["RUNNING"],
+ started_after=None,
+ started_before=TIMESTAMP_2023_01_02_00_00,
+ marker=None,
+ max_items=15,
+ reverse_order=True,
+ )
+
+ result_data = request_obj.to_dict()
+ expected_data = {
+ "FunctionName": "my-function",
+ "Qualifier": "$LATEST",
+ "StatusFilter": ["RUNNING"],
+ "StartedBefore": TIMESTAMP_2023_01_02_00_00,
+ "MaxItems": 15,
+ "ReverseOrder": True,
+ }
+ assert result_data == expected_data
+
+
+# Tests for SendDurableExecutionCallbackSuccessRequest with optional result
+def test_send_durable_execution_callback_success_request_with_result():
+ """Test SendDurableExecutionCallbackSuccessRequest to_dict with result."""
+ request_obj = SendDurableExecutionCallbackSuccessRequest(
+ callback_id="callback-123",
+ result="success-result",
+ )
+
+ result_data = request_obj.to_dict()
+ expected_data = {
+ "CallbackId": "callback-123",
+ "Result": "success-result",
+ }
+ assert result_data == expected_data
+
+
+# Tests for SendDurableExecutionCallbackFailureRequest with optional error
+def test_send_durable_execution_callback_failure_request_with_error():
+ """Test SendDurableExecutionCallbackFailureRequest to_dict with error."""
+ request_obj = SendDurableExecutionCallbackFailureRequest(
+ callback_id="callback-123",
+ error=None,
+ )
+
+ result_data = request_obj.to_dict()
+ expected_data = {
+ "CallbackId": "callback-123",
+ }
+ assert result_data == expected_data
+
+
+# Test for missing coverage in ListDurableExecutionsByFunctionRequest
+def test_list_durable_executions_by_function_request_with_durable_execution_name():
+ """Test ListDurableExecutionsByFunctionRequest to_dict with durable_execution_name."""
+ request_obj = ListDurableExecutionsByFunctionRequest(
+ function_name="my-function",
+ qualifier=None,
+ durable_execution_name="specific-execution",
+ status_filter=None,
+ started_after=None,
+ started_before=None,
+ marker=None,
+ max_items=None,
+ reverse_order=None,
+ )
+
+ result_data = request_obj.to_dict()
+ expected_data = {
+ "FunctionName": "my-function",
+ "DurableExecutionName": "specific-execution",
+ }
+ assert result_data == expected_data
+
+
+# Test for missing branch coverage in CheckpointDurableExecutionResponse
+def test_checkpoint_updated_execution_state_with_next_marker():
+ """Test CheckpointUpdatedExecutionState to_dict with next_marker."""
+ from aws_durable_execution_sdk_python.lambda_service import (
+ Operation,
+ OperationStatus,
+ OperationType,
+ )
+
+ operation = Operation(
+ operation_id="op-1",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ )
+
+ state_obj = CheckpointUpdatedExecutionState(
+ operations=[operation],
+ next_marker="next-marker-123",
+ )
+
+ result_data = state_obj.to_dict()
+ expected_data = {
+ "Operations": [{"Id": "op-1", "Type": "STEP", "Status": "SUCCEEDED"}],
+ "NextMarker": "next-marker-123",
+ }
+ assert result_data == expected_data
+
+
+# Tests for events_to_operations function
+
+
+def test_events_to_operations_empty_list():
+ """Test events_to_operations with empty event list."""
+ from aws_durable_execution_sdk_python_testing.model import events_to_operations
+
+ operations = events_to_operations([])
+ assert operations == []
+
+
+def test_events_to_operations_execution_started():
+ """Test events_to_operations with ExecutionStarted event."""
+ event = Event(
+ event_type="ExecutionStarted",
+ event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC),
+ operation_id="exec-1",
+ execution_started_details=ExecutionStartedDetails(
+ input=EventInput(payload="test-input", truncated=False),
+ execution_timeout=300,
+ ),
+ )
+
+ operations = events_to_operations([event])
+
+ assert len(operations) == 1
+ assert operations[0].operation_id == "exec-1"
+ assert operations[0].operation_type == OperationType.EXECUTION
+ assert operations[0].status == OperationStatus.STARTED
+ assert operations[0].execution_details.input_payload == "test-input"
+
+
+def test_events_to_operations_callback_lifecycle():
+ """Test events_to_operations with complete callback lifecycle."""
+
+ started_event = Event(
+ event_type="CallbackStarted",
+ event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC),
+ operation_id="cb-1",
+ name="test-callback",
+ callback_started_details=CallbackStartedDetails(callback_id="callback-123"),
+ )
+
+ succeeded_event = Event(
+ event_type="CallbackSucceeded",
+ event_timestamp=datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC),
+ operation_id="cb-1",
+ callback_succeeded_details=CallbackSucceededDetails(
+ result=EventResult(payload="callback-result", truncated=False)
+ ),
+ )
+
+ operations = events_to_operations([started_event, succeeded_event])
+
+ assert len(operations) == 1
+ assert operations[0].operation_id == "cb-1"
+ assert operations[0].operation_type == OperationType.CALLBACK
+ assert operations[0].status == OperationStatus.SUCCEEDED
+ assert operations[0].name == "test-callback"
+ assert operations[0].callback_details.callback_id == "callback-123"
+ assert operations[0].callback_details.result == "callback-result"
+ assert operations[0].callback_details.error is None
+
+
+def test_events_to_operations_missing_event_type():
+ """Test events_to_operations raises error for missing event_type."""
+ event = Event(
+ event_type=None,
+ event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC),
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Missing required 'event_type' field"
+ ):
+ events_to_operations([event])
+
+
+def test_events_to_operations_unknown_event_type():
+ """Test events_to_operations raises error for unknown event type."""
+ event = Event(
+ event_type="UnknownEventType",
+ event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC),
+ operation_id="op-1",
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Unknown event type: UnknownEventType"
+ ):
+ events_to_operations([event])
+
+
+def test_events_to_operations_missing_operation_id():
+ """Test events_to_operations raises error for missing operation_id."""
+ event = Event(
+ event_type="StepStarted",
+ event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC),
+ operation_id=None,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Missing required 'operation_id' field"
+ ):
+ events_to_operations([event])
+
+
+def test_events_to_operations_step_with_retry():
+ """Test events_to_operations with step retry details."""
+ import datetime
+
+ from aws_durable_execution_sdk_python.lambda_service import (
+ OperationStatus,
+ OperationType,
+ )
+
+ from aws_durable_execution_sdk_python_testing.model import (
+ Event,
+ EventResult,
+ RetryDetails,
+ StepSucceededDetails,
+ events_to_operations,
+ )
+
+ succeeded_event = Event(
+ event_type="StepSucceeded",
+ event_timestamp=datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC),
+ operation_id="step-1",
+ name="test-step",
+ step_succeeded_details=StepSucceededDetails(
+ result=EventResult(payload="step-result", truncated=False),
+ retry_details=RetryDetails(current_attempt=2),
+ ),
+ )
+
+ operations = events_to_operations([succeeded_event])
+
+ assert len(operations) == 1
+ assert operations[0].operation_type == OperationType.STEP
+ assert operations[0].status == OperationStatus.SUCCEEDED
+ assert operations[0].step_details.result == "step-result"
+ assert operations[0].step_details.attempt == 2
+
+
+def test_events_to_operations_step_failed_with_next_attempt():
+ """Test events_to_operations with failed step and next attempt timestamp."""
+ import datetime
+
+ from aws_durable_execution_sdk_python.lambda_service import (
+ ErrorObject,
+ OperationStatus,
+ )
+
+ from aws_durable_execution_sdk_python_testing.model import (
+ Event,
+ EventError,
+ RetryDetails,
+ StepFailedDetails,
+ events_to_operations,
+ )
+
+ event_time = datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC)
+ failed_event = Event(
+ event_type="StepFailed",
+ event_timestamp=event_time,
+ operation_id="step-1",
+ step_failed_details=StepFailedDetails(
+ error=EventError(
+ payload=ErrorObject(
+ message="step failed", type=None, data=None, stack_trace=None
+ )
+ ),
+ retry_details=RetryDetails(
+ current_attempt=1, next_attempt_delay_seconds=10
+ ),
+ ),
+ )
+
+ operations = events_to_operations([failed_event])
+
+ assert len(operations) == 1
+ assert operations[0].status == OperationStatus.FAILED
+ assert operations[0].step_details.error.message == "step failed"
+ assert operations[0].step_details.attempt == 1
+ expected_next_attempt = event_time + datetime.timedelta(seconds=10)
+ assert operations[0].step_details.next_attempt_timestamp == expected_next_attempt
+
+
+def test_events_to_operations_context_succeeded():
+ """Test events_to_operations with successful context."""
+ import datetime
+
+ from aws_durable_execution_sdk_python.lambda_service import (
+ OperationStatus,
+ OperationType,
+ )
+
+ from aws_durable_execution_sdk_python_testing.model import (
+ ContextSucceededDetails,
+ Event,
+ EventResult,
+ events_to_operations,
+ )
+
+ succeeded_event = Event(
+ event_type="ContextSucceeded",
+ event_timestamp=datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC),
+ operation_id="ctx-1",
+ name="test-context",
+ context_succeeded_details=ContextSucceededDetails(
+ result=EventResult(payload="context-result", truncated=False)
+ ),
+ )
+
+ operations = events_to_operations([succeeded_event])
+
+ assert len(operations) == 1
+ assert operations[0].operation_type == OperationType.CONTEXT
+ assert operations[0].status == OperationStatus.SUCCEEDED
+ assert operations[0].context_details.result == "context-result"
+ assert operations[0].context_details.error is None
+
+
+def test_events_to_operations_chained_invoke_succeeded():
+ """Test events_to_operations with successful chained invoke."""
+ import datetime
+
+ from aws_durable_execution_sdk_python.lambda_service import (
+ OperationStatus,
+ OperationType,
+ )
+
+ from aws_durable_execution_sdk_python_testing.model import (
+ ChainedInvokeSucceededDetails,
+ Event,
+ EventResult,
+ events_to_operations,
+ )
+
+ succeeded_event = Event(
+ event_type="ChainedInvokeSucceeded",
+ event_timestamp=datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC),
+ operation_id="invoke-1",
+ name="test-invoke",
+ chained_invoke_succeeded_details=ChainedInvokeSucceededDetails(
+ result=EventResult(payload="invoke-result", truncated=False)
+ ),
+ )
+
+ operations = events_to_operations([succeeded_event])
+
+ assert len(operations) == 1
+ assert operations[0].operation_type == OperationType.CHAINED_INVOKE
+ assert operations[0].status == OperationStatus.SUCCEEDED
+ assert operations[0].chained_invoke_details.result == "invoke-result"
+ assert operations[0].chained_invoke_details.error is None
+
+
+def test_events_to_operations_skips_invocation_completed():
+ """Test events_to_operations skips InvocationCompleted events."""
+ invocation_event = Event(
+ event_type="InvocationCompleted",
+ event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC),
+ operation_id="invocation-1",
+ )
+
+ operations = events_to_operations([invocation_event])
+ assert len(operations) == 0
+
+
+def test_events_to_operations_callback_failed():
+ """Test events_to_operations with failed callback."""
+ import datetime
+
+ from aws_durable_execution_sdk_python.lambda_service import (
+ ErrorObject,
+ OperationStatus,
+ )
+
+ from aws_durable_execution_sdk_python_testing.model import (
+ CallbackFailedDetails,
+ CallbackStartedDetails,
+ Event,
+ EventError,
+ events_to_operations,
+ )
+
+ started_event = Event(
+ event_type="CallbackStarted",
+ event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC),
+ operation_id="cb-1",
+ callback_started_details=CallbackStartedDetails(callback_id="callback-123"),
+ )
+
+ failed_event = Event(
+ event_type="CallbackFailed",
+ event_timestamp=datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC),
+ operation_id="cb-1",
+ callback_failed_details=CallbackFailedDetails(
+ error=EventError(
+ payload=ErrorObject(
+ message="callback failed", type=None, data=None, stack_trace=None
+ )
+ )
+ ),
+ )
+
+ operations = events_to_operations([started_event, failed_event])
+
+ assert len(operations) == 1
+ assert operations[0].status == OperationStatus.FAILED
+ assert operations[0].callback_details.error.message == "callback failed"
+ assert operations[0].callback_details.result is None
+
+
+def test_events_to_operations_callback_timed_out():
+ """Test events_to_operations with timed out callback."""
+ import datetime
+
+ from aws_durable_execution_sdk_python.lambda_service import (
+ ErrorObject,
+ OperationStatus,
+ )
+
+ from aws_durable_execution_sdk_python_testing.model import (
+ CallbackStartedDetails,
+ CallbackTimedOutDetails,
+ Event,
+ EventError,
+ events_to_operations,
+ )
+
+ started_event = Event(
+ event_type="CallbackStarted",
+ event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC),
+ operation_id="cb-1",
+ callback_started_details=CallbackStartedDetails(callback_id="callback-123"),
+ )
+
+ timed_out_event = Event(
+ event_type="CallbackTimedOut",
+ event_timestamp=datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC),
+ operation_id="cb-1",
+ callback_timed_out_details=CallbackTimedOutDetails(
+ error=EventError(
+ payload=ErrorObject(
+ message="callback timed out", type=None, data=None, stack_trace=None
+ )
+ )
+ ),
+ )
+
+ operations = events_to_operations([started_event, timed_out_event])
+
+ assert len(operations) == 1
+ assert operations[0].status == OperationStatus.TIMED_OUT
+ assert operations[0].callback_details.error.message == "callback timed out"
+
+
+def test_events_to_operations_wait_started():
+ """Test events_to_operations with wait operation."""
+ import datetime
+
+ from aws_durable_execution_sdk_python.lambda_service import (
+ OperationStatus,
+ OperationType,
+ )
+
+ from aws_durable_execution_sdk_python_testing.model import (
+ Event,
+ WaitStartedDetails,
+ events_to_operations,
+ )
+
+ scheduled_time = datetime.datetime(2023, 1, 1, 1, 0, 0, tzinfo=datetime.UTC)
+ wait_event = Event(
+ event_type="WaitStarted",
+ event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC),
+ operation_id="wait-1",
+ name="test-wait",
+ wait_started_details=WaitStartedDetails(
+ duration=3600, scheduled_end_timestamp=scheduled_time
+ ),
+ )
+
+ operations = events_to_operations([wait_event])
+
+ assert len(operations) == 1
+ assert operations[0].operation_type == OperationType.WAIT
+ assert operations[0].status == OperationStatus.STARTED
+ assert operations[0].wait_details.scheduled_end_timestamp == scheduled_time
+
+
+def test_events_to_operations_context_failed():
+ """Test events_to_operations with failed context."""
+ import datetime
+
+ from aws_durable_execution_sdk_python.lambda_service import (
+ ErrorObject,
+ OperationStatus,
+ )
+
+ from aws_durable_execution_sdk_python_testing.model import (
+ ContextFailedDetails,
+ Event,
+ EventError,
+ events_to_operations,
+ )
+
+ failed_event = Event(
+ event_type="ContextFailed",
+ event_timestamp=datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC),
+ operation_id="ctx-1",
+ context_failed_details=ContextFailedDetails(
+ error=EventError(
+ payload=ErrorObject(
+ message="context failed", type=None, data=None, stack_trace=None
+ )
+ )
+ ),
+ )
+
+ operations = events_to_operations([failed_event])
+
+ assert len(operations) == 1
+ assert operations[0].status == OperationStatus.FAILED
+ assert operations[0].context_details.error.message == "context failed"
+ assert operations[0].context_details.result is None
+
+
+def test_events_to_operations_chained_invoke_failed():
+ """Test events_to_operations with failed chained invoke."""
+ import datetime
+
+ from aws_durable_execution_sdk_python.lambda_service import (
+ ErrorObject,
+ OperationStatus,
+ )
+
+ from aws_durable_execution_sdk_python_testing.model import (
+ ChainedInvokeFailedDetails,
+ Event,
+ EventError,
+ events_to_operations,
+ )
+
+ failed_event = Event(
+ event_type="ChainedInvokeFailed",
+ event_timestamp=datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC),
+ operation_id="invoke-1",
+ chained_invoke_failed_details=ChainedInvokeFailedDetails(
+ error=EventError(
+ payload=ErrorObject(
+ message="invoke failed", type=None, data=None, stack_trace=None
+ )
+ )
+ ),
+ )
+
+ operations = events_to_operations([failed_event])
+
+ assert len(operations) == 1
+ assert operations[0].status == OperationStatus.FAILED
+ assert operations[0].chained_invoke_details.error.message == "invoke failed"
+ assert operations[0].chained_invoke_details.result is None
+
+
+def test_events_to_operations_multiple_operations():
+ """Test events_to_operations with multiple different operations."""
+ import datetime
+
+ from aws_durable_execution_sdk_python.lambda_service import (
+ OperationStatus,
+ OperationType,
+ )
+
+ from aws_durable_execution_sdk_python_testing.model import (
+ Event,
+ EventResult,
+ StepSucceededDetails,
+ events_to_operations,
+ )
+
+ events = [
+ Event(
+ event_type="StepStarted",
+ event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC),
+ operation_id="step-1",
+ name="step-one",
+ ),
+ Event(
+ event_type="StepSucceeded",
+ event_timestamp=datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC),
+ operation_id="step-1",
+ step_succeeded_details=StepSucceededDetails(
+ result=EventResult(payload="result-1", truncated=False)
+ ),
+ ),
+ Event(
+ event_type="WaitStarted",
+ event_timestamp=datetime.datetime(2023, 1, 1, 0, 2, 0, tzinfo=datetime.UTC),
+ operation_id="wait-1",
+ name="wait-one",
+ ),
+ ]
+
+ operations = events_to_operations(events)
+
+ assert len(operations) == 2
+ step_op = next(op for op in operations if op.operation_id == "step-1")
+ wait_op = next(op for op in operations if op.operation_id == "wait-1")
+
+ assert step_op.operation_type == OperationType.STEP
+ assert step_op.status == OperationStatus.SUCCEEDED
+ assert step_op.name == "step-one"
+
+ assert wait_op.operation_type == OperationType.WAIT
+ assert wait_op.status == OperationStatus.STARTED
+ assert wait_op.name == "wait-one"
+
+
+def test_events_to_operations_merges_timestamps():
+ """Test events_to_operations correctly merges start and end timestamps."""
+ import datetime
+
+ from aws_durable_execution_sdk_python_testing.model import (
+ Event,
+ EventResult,
+ StepSucceededDetails,
+ events_to_operations,
+ )
+
+ start_time = datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC)
+ end_time = datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC)
+
+ events = [
+ Event(
+ event_type="StepStarted",
+ event_timestamp=start_time,
+ operation_id="step-1",
+ ),
+ Event(
+ event_type="StepSucceeded",
+ event_timestamp=end_time,
+ operation_id="step-1",
+ step_succeeded_details=StepSucceededDetails(
+ result=EventResult(payload="result", truncated=False)
+ ),
+ ),
+ ]
+
+ operations = events_to_operations(events)
+
+ assert len(operations) == 1
+ assert operations[0].start_timestamp == start_time
+ assert operations[0].end_timestamp == end_time
+
+
+def test_events_to_operations_preserves_parent_id():
+ """Test events_to_operations preserves parent_id from events."""
+ event = Event(
+ event_type="StepStarted",
+ event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC),
+ operation_id="step-1",
+ parent_id="parent-ctx",
+ name="child-step",
+ )
+
+ operations = events_to_operations([event])
+
+ assert len(operations) == 1
+ assert operations[0].parent_id == "parent-ctx"
+
+
+def test_events_to_operations_preserves_sub_type():
+ """Test events_to_operations preserves sub_type from events."""
+ event = Event(
+ event_type="StepStarted",
+ event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC),
+ operation_id="step-1",
+ sub_type="Step",
+ )
+
+ operations = events_to_operations([event])
+
+ assert len(operations) == 1
+ assert operations[0].sub_type is not None
+ assert operations[0].sub_type.value == "Step"
+
+
+def test_events_to_operations_invalid_sub_type():
+ """Test events_to_operations raises InvalidParameterValueException when sub_type is invalid."""
+ invalid_sub_type: str = "INVALID_SUB_TYPE"
+ event = Event(
+ event_type="StepStarted",
+ event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC),
+ operation_id="step-1",
+ sub_type=invalid_sub_type,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match=f"'{invalid_sub_type}' is not a valid OperationSubType",
+ ):
+ events_to_operations([event])
+
+
+def test_invocation_completed_details_to_json_dict():
+ """Test InvocationCompletedDetails.to_json_dict() converts datetime to Unix milliseconds."""
+ start_time = datetime.datetime(2023, 1, 1, 0, 0, 0, 123456, tzinfo=datetime.UTC)
+ end_time = datetime.datetime(2023, 1, 1, 0, 1, 0, 456789, tzinfo=datetime.UTC)
+
+ details = InvocationCompletedDetails(
+ start_timestamp=start_time, end_timestamp=end_time, request_id="req-123"
+ )
+
+ json_dict = details.to_json_dict()
+
+ # Verify timestamps are converted to Unix milliseconds (integers)
+ assert json_dict["StartTimestamp"] == 1672531200123
+ assert json_dict["EndTimestamp"] == 1672531260456
+ assert json_dict["RequestId"] == "req-123"
+
+ # Verify all values are JSON-serializable
+ json_str = json.dumps(json_dict)
+ assert json_str is not None
+
+
+def test_invocation_completed_details_from_json_dict():
+ """Test InvocationCompletedDetails.from_json_dict() converts Unix milliseconds to datetime."""
+ json_dict = {
+ "StartTimestamp": 1672531200123,
+ "EndTimestamp": 1672531260456,
+ "RequestId": "req-456",
+ }
+
+ details = InvocationCompletedDetails.from_json_dict(json_dict)
+
+ # Verify timestamps are converted to datetime objects
+ assert details.start_timestamp == datetime.datetime(
+ 2023, 1, 1, 0, 0, 0, 123000, tzinfo=datetime.UTC
+ )
+ assert details.end_timestamp == datetime.datetime(
+ 2023, 1, 1, 0, 1, 0, 456000, tzinfo=datetime.UTC
+ )
+ assert details.request_id == "req-456"
+
+
+def test_invocation_completed_details_json_round_trip():
+ """Test InvocationCompletedDetails to_json_dict/from_json_dict round-trip."""
+ original = InvocationCompletedDetails(
+ start_timestamp=datetime.datetime(
+ 2023, 6, 15, 12, 30, 45, 678000, tzinfo=datetime.UTC
+ ),
+ end_timestamp=datetime.datetime(
+ 2023, 6, 15, 12, 31, 50, 123000, tzinfo=datetime.UTC
+ ),
+ request_id="round-trip-test",
+ )
+
+ # Serialize to JSON dict
+ json_dict = original.to_json_dict()
+
+ # Deserialize back
+ restored = InvocationCompletedDetails.from_json_dict(json_dict)
+
+ # Verify round-trip preserves data
+ assert restored.start_timestamp == original.start_timestamp
+ assert restored.end_timestamp == original.end_timestamp
+ assert restored.request_id == original.request_id
+
+
+def test_invocation_completed_details_to_dict_preserves_datetime():
+ """Test InvocationCompletedDetails.to_dict() preserves datetime objects (not converted)."""
+ start_time = datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC)
+ end_time = datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC)
+
+ details = InvocationCompletedDetails(
+ start_timestamp=start_time, end_timestamp=end_time, request_id="req-789"
+ )
+
+ regular_dict = details.to_dict()
+
+ # Verify to_dict() preserves datetime objects (not converted to Unix milliseconds)
+ assert regular_dict["StartTimestamp"] == start_time
+ assert regular_dict["EndTimestamp"] == end_time
+ assert isinstance(regular_dict["StartTimestamp"], datetime.datetime)
+ assert isinstance(regular_dict["EndTimestamp"], datetime.datetime)
+
+
+def test_invocation_completed_details_from_json_dict_invalid_timestamp():
+ """Test InvocationCompletedDetails.from_json_dict() raises error for invalid timestamps."""
+ # Test with invalid timestamp that would return None
+ json_dict = {
+ "StartTimestamp": None,
+ "EndTimestamp": 1672531260456,
+ "RequestId": "req-error",
+ }
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="StartTimestamp and EndTimestamp cannot be null",
+ ):
+ InvocationCompletedDetails.from_json_dict(json_dict)
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/observer_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/observer_test.py
new file mode 100644
index 0000000..193f395
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/observer_test.py
@@ -0,0 +1,355 @@
+"""Tests for observer module."""
+
+import inspect
+import threading
+from unittest.mock import Mock
+
+import pytest
+from aws_durable_execution_sdk_python.lambda_service import ErrorObject, CallbackOptions
+
+from aws_durable_execution_sdk_python_testing.observer import (
+ ExecutionNotifier,
+ ExecutionObserver,
+)
+from aws_durable_execution_sdk_python_testing.token import CallbackToken
+
+
+class MockExecutionObserver(ExecutionObserver):
+ """Mock implementation of ExecutionObserver for testing."""
+
+ def __init__(self):
+ self.on_completed_calls = []
+ self.on_failed_calls = []
+ self.on_timed_out_calls = []
+ self.on_stopped_calls = []
+ self.on_wait_timer_scheduled_calls = []
+ self.on_step_retry_scheduled_calls = []
+ self.on_callback_created_calls = []
+
+ def on_completed(self, execution_arn: str, result: str | None = None) -> None:
+ self.on_completed_calls.append((execution_arn, result))
+
+ def on_failed(self, execution_arn: str, error: ErrorObject) -> None:
+ self.on_failed_calls.append((execution_arn, error))
+
+ def on_timed_out(self, execution_arn: str, error: ErrorObject) -> None:
+ self.on_timed_out_calls.append((execution_arn, error))
+
+ def on_stopped(self, execution_arn: str, error: ErrorObject) -> None:
+ self.on_stopped_calls.append((execution_arn, error))
+
+ def on_wait_timer_scheduled(
+ self, execution_arn: str, operation_id: str, delay: float
+ ) -> None:
+ self.on_wait_timer_scheduled_calls.append((execution_arn, operation_id, delay))
+
+ def on_step_retry_scheduled(
+ self, execution_arn: str, operation_id: str, delay: float
+ ) -> None:
+ self.on_step_retry_scheduled_calls.append((execution_arn, operation_id, delay))
+
+ def on_callback_created(
+ self,
+ execution_arn: str,
+ operation_id: str,
+ callback_options: CallbackOptions | None,
+ callback_token: CallbackToken,
+ ) -> None:
+ self.on_callback_created_calls.append(
+ (execution_arn, operation_id, callback_options, callback_token)
+ )
+
+
+def test_execution_notifier_init():
+ """Test ExecutionNotifier initialization."""
+ notifier = ExecutionNotifier()
+
+ assert notifier._observers == [] # noqa: SLF001
+ assert notifier._lock is not None # noqa: SLF001
+
+
+def test_execution_notifier_add_observer():
+ """Test adding an observer to ExecutionNotifier."""
+ notifier = ExecutionNotifier()
+ observer = MockExecutionObserver()
+
+ notifier.add_observer(observer)
+
+ assert len(notifier._observers) == 1 # noqa: SLF001
+ assert notifier._observers[0] is observer # noqa: SLF001
+
+
+def test_execution_notifier_add_multiple_observers():
+ """Test adding multiple observers to ExecutionNotifier."""
+ notifier = ExecutionNotifier()
+ observer1 = MockExecutionObserver()
+ observer2 = MockExecutionObserver()
+
+ notifier.add_observer(observer1)
+ notifier.add_observer(observer2)
+
+ assert len(notifier._observers) == 2 # noqa: SLF001
+ assert observer1 in notifier._observers # noqa: SLF001
+ assert observer2 in notifier._observers # noqa: SLF001
+
+
+def test_execution_notifier_notify_completed():
+ """Test notifying observers about execution completion."""
+ notifier = ExecutionNotifier()
+ observer = MockExecutionObserver()
+ notifier.add_observer(observer)
+
+ execution_arn = "test-arn"
+ result = "test-result"
+
+ notifier.notify_completed(execution_arn, result)
+
+ assert len(observer.on_completed_calls) == 1
+ assert observer.on_completed_calls[0] == (execution_arn, result)
+
+
+def test_execution_notifier_notify_completed_no_result():
+ """Test notifying observers about execution completion with no result."""
+ notifier = ExecutionNotifier()
+ observer = MockExecutionObserver()
+ notifier.add_observer(observer)
+
+ execution_arn = "test-arn"
+
+ notifier.notify_completed(execution_arn)
+
+ assert len(observer.on_completed_calls) == 1
+ assert observer.on_completed_calls[0] == (execution_arn, None)
+
+
+def test_execution_notifier_notify_failed():
+ """Test notifying observers about execution failure."""
+ notifier = ExecutionNotifier()
+ observer = MockExecutionObserver()
+ notifier.add_observer(observer)
+
+ execution_arn = "test-arn"
+ error = ErrorObject(
+ "TestError", "Test error message", "test-data", ["stack", "trace"]
+ )
+
+ notifier.notify_failed(execution_arn, error)
+
+ assert len(observer.on_failed_calls) == 1
+ assert observer.on_failed_calls[0] == (execution_arn, error)
+
+
+def test_execution_notifier_notify_wait_timer_scheduled():
+ """Test notifying observers about wait timer scheduling."""
+ notifier = ExecutionNotifier()
+ observer = MockExecutionObserver()
+ notifier.add_observer(observer)
+
+ execution_arn = "test-arn"
+ operation_id = "test-operation"
+ delay = 5.0
+
+ notifier.notify_wait_timer_scheduled(execution_arn, operation_id, delay)
+
+ assert len(observer.on_wait_timer_scheduled_calls) == 1
+ assert observer.on_wait_timer_scheduled_calls[0] == (
+ execution_arn,
+ operation_id,
+ delay,
+ )
+
+
+def test_execution_notifier_notify_step_retry_scheduled():
+ """Test notifying observers about step retry scheduling."""
+ notifier = ExecutionNotifier()
+ observer = MockExecutionObserver()
+ notifier.add_observer(observer)
+
+ execution_arn = "test-arn"
+ operation_id = "test-operation"
+ delay = 10.0
+
+ notifier.notify_step_retry_scheduled(execution_arn, operation_id, delay)
+
+ assert len(observer.on_step_retry_scheduled_calls) == 1
+ assert observer.on_step_retry_scheduled_calls[0] == (
+ execution_arn,
+ operation_id,
+ delay,
+ )
+
+
+def test_execution_notifier_multiple_observers_all_notified():
+ """Test that all observers are notified when multiple are registered."""
+ notifier = ExecutionNotifier()
+ observer1 = MockExecutionObserver()
+ observer2 = MockExecutionObserver()
+
+ notifier.add_observer(observer1)
+ notifier.add_observer(observer2)
+
+ execution_arn = "test-arn"
+ result = "test-result"
+
+ notifier.notify_completed(execution_arn, result)
+
+ # Both observers should be notified
+ assert len(observer1.on_completed_calls) == 1
+ assert observer1.on_completed_calls[0] == (execution_arn, result)
+ assert len(observer2.on_completed_calls) == 1
+ assert observer2.on_completed_calls[0] == (execution_arn, result)
+
+
+def test_execution_notifier_no_observers():
+ """Test that notifications work even with no observers."""
+ notifier = ExecutionNotifier()
+
+ # Should not raise any exceptions
+ notifier.notify_completed("test-arn", "result")
+ notifier.notify_failed(
+ "test-arn", ErrorObject("Error", "Message", "data", ["trace"])
+ )
+ notifier.notify_wait_timer_scheduled("test-arn", "op-id", 1.0)
+ notifier.notify_step_retry_scheduled("test-arn", "op-id", 2.0)
+
+
+def test_execution_notifier_thread_safety():
+ """Test that ExecutionNotifier is thread-safe."""
+ notifier = ExecutionNotifier()
+ observer = MockExecutionObserver()
+ notifier.add_observer(observer)
+
+ # Test concurrent access
+ def add_observer_thread():
+ new_observer = MockExecutionObserver()
+ notifier.add_observer(new_observer)
+
+ def notify_thread():
+ notifier.notify_completed("test-arn", "result")
+
+ threads = []
+ for _ in range(5):
+ threads.append(threading.Thread(target=add_observer_thread))
+ threads.append(threading.Thread(target=notify_thread))
+
+ for thread in threads:
+ thread.start()
+
+ for thread in threads:
+ thread.join()
+
+ # Should have original observer plus 5 more
+ assert len(notifier._observers) == 6 # noqa: SLF001
+ # Original observer should have been notified multiple times
+ assert len(observer.on_completed_calls) >= 1
+
+
+def test_execution_observer_abstract_methods():
+ """Test that ExecutionObserver is abstract and cannot be instantiated."""
+ with pytest.raises(TypeError):
+ ExecutionObserver()
+
+
+def test_mock_execution_observer_implementation():
+ """Test that MockExecutionObserver properly implements all abstract methods."""
+ observer = MockExecutionObserver()
+
+ # Test all methods can be called
+ error = ErrorObject("Error", "Message", "data", ["trace"])
+ observer.on_completed("arn", "result")
+ observer.on_failed("arn", error)
+ observer.on_timed_out("arn", error)
+ observer.on_stopped("arn", error)
+ observer.on_wait_timer_scheduled("arn", "op", 1.0)
+ observer.on_step_retry_scheduled("arn", "op", 2.0)
+
+ # Verify calls were recorded
+ assert len(observer.on_completed_calls) == 1
+ assert len(observer.on_failed_calls) == 1
+ assert len(observer.on_timed_out_calls) == 1
+ assert len(observer.on_stopped_calls) == 1
+ assert len(observer.on_wait_timer_scheduled_calls) == 1
+ assert len(observer.on_step_retry_scheduled_calls) == 1
+
+
+def test_execution_notifier_notify_observers_with_exception():
+ """Test that exceptions in one observer don't affect others."""
+ notifier = ExecutionNotifier()
+
+ # Create a mock observer that raises an exception
+ failing_observer = Mock(spec=ExecutionObserver)
+ failing_observer.on_completed.side_effect = ValueError("Test exception")
+
+ # Create a normal observer
+ normal_observer = MockExecutionObserver()
+
+ notifier.add_observer(failing_observer)
+ notifier.add_observer(normal_observer)
+
+ # This should raise an exception from the failing observer
+ with pytest.raises(ValueError, match="Test exception"):
+ notifier.notify_completed("test-arn", "result")
+
+ # The normal observer should still have been called before the exception
+ failing_observer.on_completed.assert_called_once_with(
+ execution_arn="test-arn", result="result"
+ )
+
+
+def test_execution_observer_abstract_method_coverage():
+ """Test coverage of abstract methods in ExecutionObserver."""
+ # This test ensures we cover the abstract method definitions
+ # by checking they exist and have the correct signatures
+
+ methods = inspect.getmembers(ExecutionObserver, predicate=inspect.isfunction)
+ method_names = [name for name, _ in methods]
+
+ assert "on_completed" in method_names
+ assert "on_failed" in method_names
+ assert "on_timed_out" in method_names
+ assert "on_stopped" in method_names
+ assert "on_wait_timer_scheduled" in method_names
+ assert "on_step_retry_scheduled" in method_names
+
+
+def test_execution_notifier_notify_observers_internal():
+ """Test the internal _notify_observers method behavior."""
+ notifier = ExecutionNotifier()
+ observer = MockExecutionObserver()
+ notifier.add_observer(observer)
+
+ # Test that _notify_observers correctly calls the method on observers
+ notifier._notify_observers( # noqa: SLF001
+ ExecutionObserver.on_completed, execution_arn="test", result="success"
+ )
+
+ assert len(observer.on_completed_calls) == 1
+ assert observer.on_completed_calls[0] == ("test", "success")
+
+
+def test_execution_notifier_all_notification_methods():
+ """Test all notification methods with various parameter combinations."""
+ notifier = ExecutionNotifier()
+ observer = MockExecutionObserver()
+ notifier.add_observer(observer)
+
+ # Test notify_completed with positional args
+ notifier.notify_completed("arn1", "result1")
+ assert observer.on_completed_calls[-1] == ("arn1", "result1")
+
+ # Test notify_completed with keyword args
+ notifier.notify_completed(execution_arn="arn2", result="result2")
+ assert observer.on_completed_calls[-1] == ("arn2", "result2")
+
+ # Test notify_failed
+ error = ErrorObject("TestError", "Message", "data", ["trace"])
+ notifier.notify_failed("arn3", error)
+ assert observer.on_failed_calls[-1] == ("arn3", error)
+
+ # Test notify_wait_timer_scheduled
+ notifier.notify_wait_timer_scheduled("arn4", "op1", 5.5)
+ assert observer.on_wait_timer_scheduled_calls[-1] == ("arn4", "op1", 5.5)
+
+ # Test notify_step_retry_scheduled
+ notifier.notify_step_retry_scheduled("arn5", "op2", 10.5)
+ assert observer.on_step_retry_scheduled_calls[-1] == ("arn5", "op2", 10.5)
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/pending_operation_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/pending_operation_test.py
new file mode 100644
index 0000000..d508fbc
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/pending_operation_test.py
@@ -0,0 +1,129 @@
+# """Test for pending operation handling in get_execution_history."""
+#
+# from datetime import UTC, datetime
+# from unittest.mock import Mock
+#
+# from aws_durable_execution_sdk_python.lambda_service import (
+# OperationStatus,
+# OperationType,
+# )
+#
+# from aws_durable_execution_sdk_python_testing.executor import Executor
+# from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput
+#
+#
+# def test_get_execution_history_with_pending_chained_invoke():
+# """Test get_execution_history handles pending CHAINED_INVOKE operations correctly."""
+# # Create mocks
+# mock_store = Mock()
+# mock_scheduler = Mock()
+# mock_invoker = Mock()
+# mock_checkpoint_processor = Mock()
+#
+# executor = Executor(mock_store, mock_scheduler, mock_invoker, mock_checkpoint_processor)
+#
+# # Create mock execution
+# mock_execution = Mock()
+# mock_execution.durable_execution_arn = "test-arn"
+# mock_execution.start_input = StartDurableExecutionInput(
+# account_id="123",
+# function_name="test",
+# function_qualifier="$LATEST",
+# execution_name="test",
+# execution_timeout_seconds=300,
+# execution_retention_period_days=7,
+# )
+# mock_execution.result = None
+# mock_execution.updates = []
+#
+# # Create a pending CHAINED_INVOKE operation with start_timestamp
+# pending_op = Mock()
+# pending_op.operation_id = "invoke-1"
+# pending_op.operation_type = OperationType.CHAINED_INVOKE
+# pending_op.status = OperationStatus.PENDING
+# pending_op.start_timestamp = datetime.now(UTC)
+# pending_op.end_timestamp = None
+#
+# # Create a non-CHAINED_INVOKE pending operation (should be skipped)
+# pending_step = Mock()
+# pending_step.operation_id = "step-1"
+# pending_step.operation_type = OperationType.STEP
+# pending_step.status = OperationStatus.PENDING
+# pending_step.start_timestamp = datetime.now(UTC)
+# pending_step.end_timestamp = None
+#
+# # Create a CHAINED_INVOKE pending operation without start_timestamp (should be skipped)
+# pending_invoke_no_timestamp = Mock()
+# pending_invoke_no_timestamp.operation_id = "invoke-2"
+# pending_invoke_no_timestamp.operation_type = OperationType.CHAINED_INVOKE
+# pending_invoke_no_timestamp.status = OperationStatus.PENDING
+# pending_invoke_no_timestamp.start_timestamp = None
+# pending_invoke_no_timestamp.end_timestamp = None
+#
+# mock_execution.operations = [pending_op, pending_step, pending_invoke_no_timestamp]
+# mock_store.load.return_value = mock_execution
+#
+# # Call get_execution_history
+# result = executor.get_execution_history("test-arn", include_execution_data=True)
+#
+# # Should have 2 events: 1 pending event + 1 started event for the valid pending CHAINED_INVOKE
+# assert len(result.events) == 2
+#
+# # First event should be the pending event
+# assert result.events[0].event_type == "ChainedInvokeStarted"
+# assert result.events[0].operation_id == "invoke-1"
+# assert result.events[0].chained_invoke_pending_details is not None
+#
+# # Second event should be the started event
+# assert result.events[1].event_type == "ChainedInvokeStarted"
+# assert result.events[1].operation_id == "invoke-1"
+# assert result.events[1].chained_invoke_started_details is not None
+#
+#
+# def test_get_execution_history_skips_invalid_pending_operations():
+# """Test that invalid pending operations are skipped."""
+# # Create mocks
+# mock_store = Mock()
+# mock_scheduler = Mock()
+# mock_invoker = Mock()
+# mock_checkpoint_processor = Mock()
+#
+# executor = Executor(mock_store, mock_scheduler, mock_invoker, mock_checkpoint_processor)
+#
+# # Create mock execution
+# mock_execution = Mock()
+# mock_execution.durable_execution_arn = "test-arn"
+# mock_execution.start_input = StartDurableExecutionInput(
+# account_id="123",
+# function_name="test",
+# function_qualifier="$LATEST",
+# execution_name="test",
+# execution_timeout_seconds=300,
+# execution_retention_period_days=7,
+# )
+# mock_execution.result = None
+# mock_execution.updates = []
+#
+# # Create operations that should be skipped
+# # 1. Non-CHAINED_INVOKE pending operation
+# pending_step = Mock()
+# pending_step.operation_id = "step-1"
+# pending_step.operation_type = OperationType.STEP
+# pending_step.status = OperationStatus.PENDING
+# pending_step.start_timestamp = datetime.now(UTC)
+#
+# # 2. CHAINED_INVOKE pending operation without start_timestamp
+# pending_invoke_no_timestamp = Mock()
+# pending_invoke_no_timestamp.operation_id = "invoke-1"
+# pending_invoke_no_timestamp.operation_type = OperationType.CHAINED_INVOKE
+# pending_invoke_no_timestamp.status = OperationStatus.PENDING
+# pending_invoke_no_timestamp.start_timestamp = None
+#
+# mock_execution.operations = [pending_step, pending_invoke_no_timestamp]
+# mock_store.load.return_value = mock_execution
+#
+# # Call get_execution_history
+# result = executor.get_execution_history("test-arn")
+#
+# # Should have no events since all pending operations are invalid
+# assert len(result.events) == 0
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/runner_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/runner_test.py
new file mode 100644
index 0000000..3b81269
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/runner_test.py
@@ -0,0 +1,2180 @@
+"""Unit tests for runner module."""
+
+import datetime
+import json
+from unittest.mock import Mock, patch
+
+import pytest
+from aws_durable_execution_sdk_python.execution import InvocationStatus
+from aws_durable_execution_sdk_python.lambda_service import (
+ CallbackDetails,
+ ChainedInvokeDetails,
+ ContextDetails,
+ ExecutionDetails,
+ OperationStatus,
+ OperationType,
+ StepDetails,
+ WaitDetails,
+)
+from aws_durable_execution_sdk_python.lambda_service import Operation as SvcOperation
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ DurableFunctionsTestError,
+ InvalidParameterValueException,
+ ResourceNotFoundException,
+)
+from aws_durable_execution_sdk_python_testing.execution import Execution
+from aws_durable_execution_sdk_python_testing.model import (
+ StartDurableExecutionInput,
+ StartDurableExecutionOutput,
+ GetDurableExecutionHistoryResponse,
+)
+from aws_durable_execution_sdk_python_testing.runner import (
+ OPERATION_FACTORIES,
+ CallbackOperation,
+ ContextOperation,
+ DurableChildContextTestRunner,
+ DurableFunctionTestResult,
+ DurableFunctionTestRunner,
+ ExecutionOperation,
+ InvokeOperation,
+ Operation,
+ StepOperation,
+ WaitOperation,
+ create_operation,
+)
+
+
+def test_operation_creation():
+ """Test basic Operation creation."""
+ op = Operation(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ parent_id="parent-id",
+ name="test-name",
+ sub_type="test-subtype",
+ start_timestamp=datetime.datetime.now(tz=datetime.UTC),
+ end_timestamp=datetime.datetime.now(tz=datetime.UTC),
+ )
+
+ assert op.operation_id == "test-id"
+ assert op.operation_type is OperationType.STEP
+ assert op.status is OperationStatus.SUCCEEDED
+ assert op.parent_id == "parent-id"
+ assert op.name == "test-name"
+ assert op.sub_type == "test-subtype"
+
+
+def test_execution_operation_from_svc_operation():
+ """Test ExecutionOperation creation from service operation."""
+ execution_details = ExecutionDetails(input_payload="test-input")
+ svc_op = SvcOperation(
+ operation_id="exec-id",
+ operation_type=OperationType.EXECUTION,
+ status=OperationStatus.SUCCEEDED,
+ execution_details=execution_details,
+ )
+
+ exec_op = ExecutionOperation.from_svc_operation(svc_op)
+
+ assert exec_op.operation_id == "exec-id"
+ assert exec_op.operation_type is OperationType.EXECUTION
+ assert exec_op.input_payload == "test-input"
+
+
+def test_execution_operation_wrong_type():
+ """Test ExecutionOperation raises error for wrong operation type."""
+ svc_op = SvcOperation(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Expected EXECUTION operation, got OperationType.STEP",
+ ):
+ ExecutionOperation.from_svc_operation(svc_op)
+
+
+def test_context_operation_from_svc_operation():
+ """Test ContextOperation creation from service operation."""
+ context_details = ContextDetails(result=json.dumps("test-result"), error=None)
+ svc_op = SvcOperation(
+ operation_id="ctx-id",
+ operation_type=OperationType.CONTEXT,
+ status=OperationStatus.SUCCEEDED,
+ context_details=context_details,
+ )
+
+ ctx_op = ContextOperation.from_svc_operation(svc_op)
+
+ assert ctx_op.operation_id == "ctx-id"
+ assert ctx_op.operation_type is OperationType.CONTEXT
+ assert ctx_op.result == json.dumps("test-result")
+ assert ctx_op.child_operations == []
+
+
+def test_context_operation_with_children():
+ """Test ContextOperation with child operations."""
+ parent_op = SvcOperation(
+ operation_id="parent-id",
+ operation_type=OperationType.CONTEXT,
+ status=OperationStatus.SUCCEEDED,
+ context_details=ContextDetails(result=json.dumps("parent-result")),
+ )
+
+ child_op = SvcOperation(
+ operation_id="child-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ parent_id="parent-id",
+ name="child-step",
+ step_details=StepDetails(result=json.dumps("child-result")),
+ )
+
+ all_ops = [parent_op, child_op]
+ ctx_op = ContextOperation.from_svc_operation(parent_op, all_ops)
+
+ assert len(ctx_op.child_operations) == 1
+ assert ctx_op.child_operations[0].name == "child-step"
+
+
+def test_context_operation_get_operation_by_name():
+ """Test ContextOperation get_operation_by_name method."""
+ child_op = Operation(
+ operation_id="child-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ name="test-child",
+ )
+
+ ctx_op = ContextOperation(
+ operation_id="ctx-id",
+ operation_type=OperationType.CONTEXT,
+ status=OperationStatus.SUCCEEDED,
+ child_operations=[child_op],
+ )
+
+ found_op = ctx_op.get_operation_by_name("test-child")
+ assert found_op == child_op
+
+
+def test_context_operation_get_operation_by_name_not_found():
+ """Test ContextOperation get_operation_by_name raises error when not found."""
+ ctx_op = ContextOperation(
+ operation_id="ctx-id",
+ operation_type=OperationType.CONTEXT,
+ status=OperationStatus.SUCCEEDED,
+ child_operations=[],
+ )
+
+ with pytest.raises(
+ DurableFunctionsTestError, match="Child Operation with name 'missing' not found"
+ ):
+ ctx_op.get_operation_by_name("missing")
+
+
+def test_context_operation_get_step():
+ """Test ContextOperation get_step method."""
+ step_op = StepOperation(
+ operation_id="step-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ name="test-step",
+ child_operations=[],
+ )
+
+ ctx_op = ContextOperation(
+ operation_id="ctx-id",
+ operation_type=OperationType.CONTEXT,
+ status=OperationStatus.SUCCEEDED,
+ child_operations=[step_op],
+ )
+
+ found_step = ctx_op.get_step("test-step")
+ assert isinstance(found_step, StepOperation)
+ assert found_step.name == "test-step"
+
+
+def test_context_operation_get_wait():
+ """Test ContextOperation get_wait method."""
+ wait_op = WaitOperation(
+ operation_id="wait-id",
+ operation_type=OperationType.WAIT,
+ status=OperationStatus.SUCCEEDED,
+ name="test-wait",
+ )
+
+ ctx_op = ContextOperation(
+ operation_id="ctx-id",
+ operation_type=OperationType.CONTEXT,
+ status=OperationStatus.SUCCEEDED,
+ child_operations=[wait_op],
+ )
+
+ found_wait = ctx_op.get_wait("test-wait")
+ assert isinstance(found_wait, WaitOperation)
+ assert found_wait.name == "test-wait"
+
+
+def test_context_operation_get_context():
+ """Test ContextOperation get_context method."""
+ nested_ctx_op = ContextOperation(
+ operation_id="nested-ctx-id",
+ operation_type=OperationType.CONTEXT,
+ status=OperationStatus.SUCCEEDED,
+ name="nested-context",
+ child_operations=[],
+ )
+
+ ctx_op = ContextOperation(
+ operation_id="ctx-id",
+ operation_type=OperationType.CONTEXT,
+ status=OperationStatus.SUCCEEDED,
+ child_operations=[nested_ctx_op],
+ )
+
+ found_ctx = ctx_op.get_context("nested-context")
+ assert isinstance(found_ctx, ContextOperation)
+ assert found_ctx.name == "nested-context"
+
+
+def test_context_operation_get_callback():
+ """Test ContextOperation get_callback method."""
+ callback_op = CallbackOperation(
+ operation_id="callback-id",
+ operation_type=OperationType.CALLBACK,
+ status=OperationStatus.SUCCEEDED,
+ name="test-callback",
+ child_operations=[],
+ )
+
+ ctx_op = ContextOperation(
+ operation_id="ctx-id",
+ operation_type=OperationType.CONTEXT,
+ status=OperationStatus.SUCCEEDED,
+ child_operations=[callback_op],
+ )
+
+ found_callback = ctx_op.get_callback("test-callback")
+ assert isinstance(found_callback, CallbackOperation)
+ assert found_callback.name == "test-callback"
+
+
+def test_context_operation_get_invoke():
+ """Test ContextOperation get_invoke method."""
+ invoke_op = InvokeOperation(
+ operation_id="invoke-id",
+ operation_type=OperationType.CHAINED_INVOKE,
+ status=OperationStatus.SUCCEEDED,
+ name="test-invoke",
+ )
+
+ ctx_op = ContextOperation(
+ operation_id="ctx-id",
+ operation_type=OperationType.CONTEXT,
+ status=OperationStatus.SUCCEEDED,
+ child_operations=[invoke_op],
+ )
+
+ found_invoke = ctx_op.get_invoke("test-invoke")
+ assert isinstance(found_invoke, InvokeOperation)
+ assert found_invoke.name == "test-invoke"
+
+
+def test_context_operation_get_execution():
+ """Test ContextOperation get_execution method."""
+ exec_op = ExecutionOperation(
+ operation_id="exec-id",
+ operation_type=OperationType.EXECUTION,
+ status=OperationStatus.SUCCEEDED,
+ name="test-execution",
+ )
+
+ ctx_op = ContextOperation(
+ operation_id="ctx-id",
+ operation_type=OperationType.CONTEXT,
+ status=OperationStatus.SUCCEEDED,
+ child_operations=[exec_op],
+ )
+
+ found_exec = ctx_op.get_execution("test-execution")
+ assert isinstance(found_exec, ExecutionOperation)
+ assert found_exec.name == "test-execution"
+
+
+def test_step_operation_from_svc_operation():
+ """Test StepOperation creation from service operation."""
+ step_details = StepDetails(attempt=2, result=json.dumps("step-result"), error=None)
+ svc_op = SvcOperation(
+ operation_id="step-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ step_details=step_details,
+ )
+
+ step_op = StepOperation.from_svc_operation(svc_op)
+
+ assert step_op.operation_id == "step-id"
+ assert step_op.operation_type is OperationType.STEP
+ assert step_op.attempt == 2
+ assert step_op.result == json.dumps("step-result")
+
+
+def test_step_operation_wrong_type():
+ """Test StepOperation raises error for wrong operation type."""
+ svc_op = SvcOperation(
+ operation_id="test-id",
+ operation_type=OperationType.CONTEXT,
+ status=OperationStatus.SUCCEEDED,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Expected STEP operation, got OperationType.CONTEXT",
+ ):
+ StepOperation.from_svc_operation(svc_op)
+
+
+def test_wait_operation_from_svc_operation():
+ """Test WaitOperation creation from service operation."""
+ scheduled_time = datetime.datetime.now(tz=datetime.UTC)
+ wait_details = WaitDetails(scheduled_end_timestamp=scheduled_time)
+ svc_op = SvcOperation(
+ operation_id="wait-id",
+ operation_type=OperationType.WAIT,
+ status=OperationStatus.SUCCEEDED,
+ wait_details=wait_details,
+ )
+
+ wait_op = WaitOperation.from_svc_operation(svc_op)
+
+ assert wait_op.operation_id == "wait-id"
+ assert wait_op.operation_type is OperationType.WAIT
+ assert wait_op.scheduled_end_timestamp == scheduled_time
+
+
+def test_wait_operation_wrong_type():
+ """Test WaitOperation raises error for wrong operation type."""
+ svc_op = SvcOperation(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Expected WAIT operation, got OperationType.STEP",
+ ):
+ WaitOperation.from_svc_operation(svc_op)
+
+
+def test_callback_operation_from_svc_operation():
+ """Test CallbackOperation creation from service operation."""
+ callback_details = CallbackDetails(
+ callback_id="cb-123", result=json.dumps("callback-result")
+ )
+ svc_op = SvcOperation(
+ operation_id="callback-id",
+ operation_type=OperationType.CALLBACK,
+ status=OperationStatus.SUCCEEDED,
+ callback_details=callback_details,
+ )
+
+ callback_op = CallbackOperation.from_svc_operation(svc_op)
+
+ assert callback_op.operation_id == "callback-id"
+ assert callback_op.operation_type is OperationType.CALLBACK
+ assert callback_op.callback_id == "cb-123"
+ assert callback_op.result == json.dumps("callback-result")
+
+
+def test_callback_operation_wrong_type():
+ """Test CallbackOperation raises error for wrong operation type."""
+ svc_op = SvcOperation(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Expected CALLBACK operation, got OperationType.STEP",
+ ):
+ CallbackOperation.from_svc_operation(svc_op)
+
+
+def test_invoke_operation_from_svc_operation():
+ """Test InvokeOperation creation from service operation."""
+ invoke_details = ChainedInvokeDetails(
+ result=json.dumps("invoke-result"),
+ )
+ svc_op = SvcOperation(
+ operation_id="invoke-id",
+ operation_type=OperationType.CHAINED_INVOKE,
+ status=OperationStatus.SUCCEEDED,
+ chained_invoke_details=invoke_details,
+ )
+
+ invoke_op = InvokeOperation.from_svc_operation(svc_op)
+
+ assert invoke_op.operation_id == "invoke-id"
+ assert invoke_op.operation_type is OperationType.CHAINED_INVOKE
+ assert invoke_op.result == json.dumps("invoke-result")
+
+
+def test_invoke_operation_wrong_type():
+ """Test InvokeOperation raises error for wrong operation type."""
+ svc_op = SvcOperation(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Expected INVOKE operation, got OperationType.STEP",
+ ):
+ InvokeOperation.from_svc_operation(svc_op)
+
+
+def test_operation_factories_mapping():
+ """Test OPERATION_FACTORIES contains all expected mappings."""
+ expected_types = {
+ OperationType.EXECUTION: ExecutionOperation,
+ OperationType.CONTEXT: ContextOperation,
+ OperationType.STEP: StepOperation,
+ OperationType.WAIT: WaitOperation,
+ OperationType.CHAINED_INVOKE: InvokeOperation,
+ OperationType.CALLBACK: CallbackOperation,
+ }
+
+ assert expected_types == OPERATION_FACTORIES
+
+
+def test_create_operation_step():
+ """Test create_operation function with STEP operation."""
+ svc_op = SvcOperation(
+ operation_id="step-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ step_details=StepDetails(result=json.dumps("test-result")),
+ )
+
+ operation = create_operation(svc_op)
+
+ assert isinstance(operation, StepOperation)
+ assert operation.operation_id == "step-id"
+
+
+def test_create_operation_unknown_type():
+ """Test create_operation raises error for unknown operation type."""
+ # Create a mock operation with an invalid type
+ svc_op = Mock()
+ svc_op.operation_type = "UNKNOWN_TYPE"
+
+ with pytest.raises(
+ DurableFunctionsTestError, match="Unknown operation type: UNKNOWN_TYPE"
+ ):
+ create_operation(svc_op)
+
+
+def test_durable_function_test_result_create():
+ """Test DurableFunctionTestResult.create method."""
+ # Create mock execution with operations
+ execution = Mock(spec=Execution)
+
+ # Create mock operations - one EXECUTION (should be filtered) and one STEP
+ exec_op = Mock()
+ exec_op.operation_type = OperationType.EXECUTION
+ exec_op.parent_id = None
+
+ step_op = Mock()
+ step_op.operation_type = OperationType.STEP
+ step_op.parent_id = None
+ step_op.operation_id = "step-id"
+ step_op.status = OperationStatus.SUCCEEDED
+ step_op.name = "test-step"
+ step_op.step_details = StepDetails(result=json.dumps("step-result"))
+
+ execution.operations = [exec_op, step_op]
+
+ # Mock execution result
+ execution.result = Mock()
+ execution.result.status = InvocationStatus.SUCCEEDED
+ execution.result.result = json.dumps("test-result")
+ execution.result.error = None
+
+ result = DurableFunctionTestResult.create(execution)
+
+ assert result.status is InvocationStatus.SUCCEEDED
+ assert result.result == json.dumps("test-result")
+ assert result.error is None
+ assert len(result.operations) == 1 # EXECUTION operation filtered out
+
+
+def test_durable_function_test_result_get_operation_by_name():
+ """Test DurableFunctionTestResult get_operation_by_name method."""
+ step_op = StepOperation(
+ operation_id="step-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ name="test-step",
+ child_operations=[],
+ )
+
+ result = DurableFunctionTestResult(
+ status=InvocationStatus.SUCCEEDED,
+ operations=[step_op],
+ )
+
+ found_op = result.get_operation_by_name("test-step")
+ assert found_op == step_op
+
+
+def test_durable_function_test_result_get_operation_by_name_not_found():
+ """Test DurableFunctionTestResult get_operation_by_name raises error when not found."""
+ result = DurableFunctionTestResult(
+ status=InvocationStatus.SUCCEEDED,
+ operations=[],
+ )
+
+ with pytest.raises(
+ DurableFunctionsTestError, match="Operation with name 'missing' not found"
+ ):
+ result.get_operation_by_name("missing")
+
+
+def test_durable_function_test_result_get_step():
+ """Test DurableFunctionTestResult get_step method."""
+ step_op = StepOperation(
+ operation_id="step-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ name="test-step",
+ child_operations=[],
+ )
+
+ result = DurableFunctionTestResult(
+ status=InvocationStatus.SUCCEEDED,
+ operations=[step_op],
+ )
+
+ found_step = result.get_step("test-step")
+ assert isinstance(found_step, StepOperation)
+ assert found_step.name == "test-step"
+
+
+def test_durable_function_test_result_get_wait():
+ """Test DurableFunctionTestResult get_wait method."""
+ wait_op = WaitOperation(
+ operation_id="wait-id",
+ operation_type=OperationType.WAIT,
+ status=OperationStatus.SUCCEEDED,
+ name="test-wait",
+ )
+
+ result = DurableFunctionTestResult(
+ status=InvocationStatus.SUCCEEDED,
+ operations=[wait_op],
+ )
+
+ found_wait = result.get_wait("test-wait")
+ assert isinstance(found_wait, WaitOperation)
+ assert found_wait.name == "test-wait"
+
+
+def test_durable_function_test_result_get_context():
+ """Test DurableFunctionTestResult get_context method."""
+ ctx_op = ContextOperation(
+ operation_id="ctx-id",
+ operation_type=OperationType.CONTEXT,
+ status=OperationStatus.SUCCEEDED,
+ name="test-context",
+ child_operations=[],
+ )
+
+ result = DurableFunctionTestResult(
+ status=InvocationStatus.SUCCEEDED,
+ operations=[ctx_op],
+ )
+
+ found_ctx = result.get_context("test-context")
+ assert isinstance(found_ctx, ContextOperation)
+ assert found_ctx.name == "test-context"
+
+
+def test_durable_function_test_result_get_callback():
+ """Test DurableFunctionTestResult get_callback method."""
+ callback_op = CallbackOperation(
+ operation_id="callback-id",
+ operation_type=OperationType.CALLBACK,
+ status=OperationStatus.SUCCEEDED,
+ name="test-callback",
+ child_operations=[],
+ )
+
+ result = DurableFunctionTestResult(
+ status=InvocationStatus.SUCCEEDED,
+ operations=[callback_op],
+ )
+
+ found_callback = result.get_callback("test-callback")
+ assert isinstance(found_callback, CallbackOperation)
+ assert found_callback.name == "test-callback"
+
+
+def test_durable_function_test_result_get_invoke():
+ """Test DurableFunctionTestResult get_invoke method."""
+ invoke_op = InvokeOperation(
+ operation_id="invoke-id",
+ operation_type=OperationType.CHAINED_INVOKE,
+ status=OperationStatus.SUCCEEDED,
+ name="test-invoke",
+ )
+
+ result = DurableFunctionTestResult(
+ status=InvocationStatus.SUCCEEDED,
+ operations=[invoke_op],
+ )
+
+ found_invoke = result.get_invoke("test-invoke")
+ assert isinstance(found_invoke, InvokeOperation)
+ assert found_invoke.name == "test-invoke"
+
+
+def test_durable_function_test_result_get_execution():
+ """Test DurableFunctionTestResult get_execution method."""
+ exec_op = ExecutionOperation(
+ operation_id="exec-id",
+ operation_type=OperationType.EXECUTION,
+ status=OperationStatus.SUCCEEDED,
+ name="test-execution",
+ )
+
+ result = DurableFunctionTestResult(
+ status=InvocationStatus.SUCCEEDED,
+ operations=[exec_op],
+ )
+
+ found_exec = result.get_execution("test-execution")
+ assert isinstance(found_exec, ExecutionOperation)
+ assert found_exec.name == "test-execution"
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.Scheduler")
+@patch("aws_durable_execution_sdk_python_testing.runner.InMemoryExecutionStore")
+@patch("aws_durable_execution_sdk_python_testing.runner.CheckpointProcessor")
+@patch("aws_durable_execution_sdk_python_testing.runner.InMemoryServiceClient")
+@patch("aws_durable_execution_sdk_python_testing.runner.InProcessInvoker")
+@patch("aws_durable_execution_sdk_python_testing.runner.Executor")
+def test_durable_function_test_runner_init(
+ mock_executor, mock_invoker, mock_client, mock_processor, mock_store, mock_scheduler
+):
+ """Test DurableFunctionTestRunner initialization."""
+ handler = Mock()
+
+ DurableFunctionTestRunner(handler)
+
+ # Verify all components are initialized
+ mock_scheduler.assert_called_once()
+ mock_scheduler.return_value.start.assert_called_once()
+ mock_store.assert_called_once()
+ mock_processor.assert_called_once()
+ mock_client.assert_called_once()
+ mock_invoker.assert_called_once_with(handler, mock_client.return_value)
+ mock_executor.assert_called_once()
+
+ # Verify observer pattern setup
+ mock_processor.return_value.add_execution_observer.assert_called_once_with(
+ mock_executor.return_value
+ )
+
+
+def test_durable_function_test_runner_context_manager():
+ """Test DurableFunctionTestRunner context manager."""
+ handler = Mock()
+
+ with patch.object(DurableFunctionTestRunner, "__init__", return_value=None):
+ with patch.object(DurableFunctionTestRunner, "close") as mock_close:
+ runner = DurableFunctionTestRunner(handler)
+
+ with runner:
+ pass
+
+ mock_close.assert_called_once()
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.Scheduler")
+def test_durable_function_test_runner_close(mock_scheduler):
+ """Test DurableFunctionTestRunner close method."""
+ handler = Mock()
+
+ # Let the constructor run normally with mocked dependencies
+ mock_scheduler_instance = Mock()
+ mock_scheduler.return_value = mock_scheduler_instance
+
+ runner = DurableFunctionTestRunner(handler)
+ runner.close()
+
+ # Verify scheduler.stop() was called
+ mock_scheduler_instance.stop.assert_called_once()
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.Executor")
+@patch("aws_durable_execution_sdk_python_testing.runner.InMemoryExecutionStore")
+def test_durable_function_test_runner_run(mock_store_class, mock_executor_class):
+ """Test DurableFunctionTestRunner run method."""
+ handler = Mock()
+
+ # Mock the class instances
+ mock_executor = Mock()
+ mock_store = Mock()
+ mock_executor_class.return_value = mock_executor
+ mock_store_class.return_value = mock_store
+
+ # Mock execution output
+ output = StartDurableExecutionOutput(execution_arn="test-arn")
+ mock_executor.start_execution.return_value = output
+ mock_executor.wait_until_complete.return_value = True
+
+ # Mock execution for result creation
+ mock_execution = Mock(spec=Execution)
+ mock_execution.operations = []
+ mock_execution.result = Mock()
+ mock_execution.result.status = InvocationStatus.SUCCEEDED
+ mock_execution.result.result = json.dumps("test-result")
+ mock_execution.result.error = None
+ mock_store.load.return_value = mock_execution
+
+ runner = DurableFunctionTestRunner(handler)
+ result = runner.run("test-input")
+
+ # Verify start_execution was called with correct input
+ mock_executor.start_execution.assert_called_once()
+ start_input = mock_executor.start_execution.call_args[0][0]
+ assert isinstance(start_input, StartDurableExecutionInput)
+ assert start_input.input == "test-input"
+ assert start_input.function_name == "test-function"
+ assert start_input.execution_name == "execution-name"
+ assert start_input.account_id == "123456789012"
+
+ # Verify wait_until_complete was called
+ mock_executor.wait_until_complete.assert_called_once_with("test-arn", 900)
+
+ # Verify store.load was called
+ mock_store.load.assert_called_once_with("test-arn")
+
+ # Verify result
+ assert isinstance(result, DurableFunctionTestResult)
+ assert result.status is InvocationStatus.SUCCEEDED
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.Executor")
+@patch("aws_durable_execution_sdk_python_testing.runner.InMemoryExecutionStore")
+def test_durable_function_test_runner_run_with_custom_params(
+ mock_store_class, mock_executor_class
+):
+ """Test DurableFunctionTestRunner run method with custom parameters."""
+ handler = Mock()
+
+ # Mock the class instances
+ mock_executor = Mock()
+ mock_store = Mock()
+ mock_executor_class.return_value = mock_executor
+ mock_store_class.return_value = mock_store
+
+ # Mock execution output
+ output = StartDurableExecutionOutput(execution_arn="test-arn")
+ mock_executor.start_execution.return_value = output
+ mock_executor.wait_until_complete.return_value = True
+
+ # Mock execution for result creation
+ mock_execution = Mock(spec=Execution)
+ mock_execution.operations = []
+ mock_execution.result = Mock()
+ mock_execution.result.status = InvocationStatus.SUCCEEDED
+ mock_execution.result.result = json.dumps("test-result")
+ mock_execution.result.error = None
+ mock_store.load.return_value = mock_execution
+
+ runner = DurableFunctionTestRunner(handler)
+ result = runner.run(
+ input="custom-input",
+ timeout=1800,
+ function_name="custom-function",
+ execution_name="custom-execution",
+ account_id="987654321098",
+ )
+
+ # Verify start_execution was called with custom parameters
+ start_input = mock_executor.start_execution.call_args[0][0]
+ assert start_input.input == "custom-input"
+ assert start_input.function_name == "custom-function"
+ assert start_input.execution_name == "custom-execution"
+ assert start_input.account_id == "987654321098"
+ assert start_input.execution_timeout_seconds == 1800
+
+ # Verify wait_until_complete was called with custom timeout
+ mock_executor.wait_until_complete.assert_called_once_with("test-arn", 1800)
+
+ assert result.status is InvocationStatus.SUCCEEDED
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.Executor")
+def test_durable_function_test_runner_run_timeout(mock_executor_class):
+ """Test DurableFunctionTestRunner run method with timeout."""
+ handler = Mock()
+
+ # Mock the class instance
+ mock_executor = Mock()
+ mock_executor_class.return_value = mock_executor
+
+ # Mock execution output
+ output = StartDurableExecutionOutput(execution_arn="test-arn")
+ mock_executor.start_execution.return_value = output
+ mock_executor.wait_until_complete.return_value = False # Timeout
+
+ runner = DurableFunctionTestRunner(handler)
+
+ with pytest.raises(TimeoutError, match="Execution did not complete within timeout"):
+ runner.run("test-input")
+
+
+def test_context_operation_wrong_type():
+ """Test ContextOperation raises error for wrong operation type."""
+ svc_op = SvcOperation(
+ operation_id="test-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Expected CONTEXT operation, got OperationType.STEP",
+ ):
+ ContextOperation.from_svc_operation(svc_op)
+
+
+def test_context_operation_with_child_operations_none():
+ """Test ContextOperation with None child operations."""
+ svc_op = SvcOperation(
+ operation_id="ctx-id",
+ operation_type=OperationType.CONTEXT,
+ status=OperationStatus.SUCCEEDED,
+ context_details=ContextDetails(result=json.dumps("test-result")),
+ )
+
+ ctx_op = ContextOperation.from_svc_operation(svc_op, None)
+
+ assert ctx_op.child_operations == []
+
+
+def test_callback_operation_with_child_operations_none():
+ """Test CallbackOperation with None child operations."""
+ svc_op = SvcOperation(
+ operation_id="callback-id",
+ operation_type=OperationType.CALLBACK,
+ status=OperationStatus.SUCCEEDED,
+ callback_details=CallbackDetails(callback_id="cb-123"),
+ )
+
+ callback_op = CallbackOperation.from_svc_operation(svc_op, None)
+
+ assert callback_op.child_operations == []
+
+
+def test_step_operation_with_child_operations_none():
+ """Test StepOperation with None child operations."""
+ svc_op = SvcOperation(
+ operation_id="step-id",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ step_details=StepDetails(result=json.dumps("step-result")),
+ )
+
+ step_op = StepOperation.from_svc_operation(svc_op, None)
+
+ assert step_op.child_operations == []
+
+
+def test_durable_function_test_result_create_with_parent_operations():
+ """Test DurableFunctionTestResult.create with operations that have parent_id."""
+ execution = Mock(spec=Execution)
+
+ # Create operation with parent_id (should be filtered out)
+ child_op = Mock()
+ child_op.operation_type = OperationType.STEP
+ child_op.parent_id = "parent-id"
+
+ # Create operation without parent_id (should be included)
+ root_op = Mock()
+ root_op.operation_type = OperationType.STEP
+ root_op.parent_id = None
+ root_op.operation_id = "root-id"
+ root_op.status = OperationStatus.SUCCEEDED
+ root_op.name = "root-step"
+ root_op.step_details = StepDetails(result=json.dumps("root-result"))
+
+ execution.operations = [child_op, root_op]
+ execution.result = Mock()
+ execution.result.status = InvocationStatus.SUCCEEDED
+ execution.result.result = json.dumps("test-result")
+ execution.result.error = None
+
+ result = DurableFunctionTestResult.create(execution)
+
+ assert len(result.operations) == 1 # Only root operation included
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.Scheduler")
+@patch("aws_durable_execution_sdk_python_testing.runner.InMemoryExecutionStore")
+@patch("aws_durable_execution_sdk_python_testing.runner.CheckpointProcessor")
+@patch("aws_durable_execution_sdk_python_testing.runner.InMemoryServiceClient")
+@patch("aws_durable_execution_sdk_python_testing.runner.InProcessInvoker")
+@patch("aws_durable_execution_sdk_python_testing.runner.Executor")
+@patch("aws_durable_execution_sdk_python_testing.runner.durable_execution")
+def test_durable_context_test_runner_init(
+ mock_durable_execution_handler,
+ mock_executor,
+ mock_invoker,
+ mock_client,
+ mock_processor,
+ mock_store,
+ mock_scheduler,
+):
+ """Test DurableContextTestRunner initialization."""
+ handler = Mock()
+ decorated_handler = Mock()
+ mock_durable_execution_handler.return_value = decorated_handler
+
+ DurableChildContextTestRunner(handler) # type: ignore
+
+ # Verify all components are initialized
+ mock_scheduler.assert_called_once()
+ mock_scheduler.return_value.start.assert_called_once()
+ mock_store.assert_called_once()
+ mock_processor.assert_called_once()
+ mock_client.assert_called_once()
+ mock_invoker.assert_called_once_with(decorated_handler, mock_client.return_value)
+ mock_executor.assert_called_once()
+
+ # Verify observer pattern setup
+ mock_processor.return_value.add_execution_observer.assert_called_once_with(
+ mock_executor.return_value
+ )
+
+ # Verify durable_execution was called (with internal lambda function)
+ mock_durable_execution_handler.assert_called_once()
+
+ # Verify the lambda function calls our handler
+ durable_execution_func = mock_durable_execution_handler.call_args.args[0]
+ assert callable(durable_execution_func)
+
+ # verify handler is called when durable function is invoked
+ durable_execution_func(Mock(), Mock())
+ handler.assert_called_once()
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.Scheduler")
+@patch("aws_durable_execution_sdk_python_testing.runner.InMemoryExecutionStore")
+@patch("aws_durable_execution_sdk_python_testing.runner.CheckpointProcessor")
+@patch("aws_durable_execution_sdk_python_testing.runner.InMemoryServiceClient")
+@patch("aws_durable_execution_sdk_python_testing.runner.InProcessInvoker")
+@patch("aws_durable_execution_sdk_python_testing.runner.Executor")
+@patch("aws_durable_execution_sdk_python_testing.runner.durable_execution")
+def test_durable_child_context_test_runner_init_with_args(
+ mock_durable_execution_handler,
+ mock_executor,
+ mock_invoker,
+ mock_client,
+ mock_processor,
+ mock_store,
+ mock_scheduler,
+):
+ """Test DurableChildContextTestRunner initialization with additional args."""
+ handler = Mock()
+ decorated_handler = Mock()
+ mock_durable_execution_handler.return_value = decorated_handler
+
+ str_input = "a random string input"
+ num_input = 10
+ DurableChildContextTestRunner(handler, str_input, num=num_input) # type: ignore
+
+ # Verify all components are initialized
+ mock_scheduler.assert_called_once()
+ mock_scheduler.return_value.start.assert_called_once()
+ mock_store.assert_called_once()
+ mock_processor.assert_called_once()
+ mock_client.assert_called_once()
+ mock_invoker.assert_called_once_with(decorated_handler, mock_client.return_value)
+ mock_executor.assert_called_once()
+
+ # Verify observer pattern setup
+ mock_processor.return_value.add_execution_observer.assert_called_once_with(
+ mock_executor.return_value
+ )
+
+ # Verify durable_execution was called (with internal lambda function)
+ mock_durable_execution_handler.assert_called_once()
+ # Verify the lambda function calls our handler
+ durable_execution_func = mock_durable_execution_handler.call_args.args[0]
+ assert callable(durable_execution_func)
+
+ # verify that handler is called with expected args when durable function is invoked
+ durable_execution_func(Mock(), Mock())
+ handler.assert_called_once_with(str_input, num=num_input)
+
+
+# Tests for DurableFunctionCloudTestRunner and from_execution_history
+
+
+def test_durable_function_test_result_from_execution_history():
+ """Test DurableFunctionTestResult.from_execution_history factory method."""
+ import datetime
+
+ from aws_durable_execution_sdk_python.execution import InvocationStatus
+
+ from aws_durable_execution_sdk_python_testing.model import (
+ Event,
+ EventResult,
+ GetDurableExecutionHistoryResponse,
+ GetDurableExecutionResponse,
+ StepSucceededDetails,
+ )
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionTestResult,
+ )
+
+ execution_response = GetDurableExecutionResponse(
+ durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1",
+ durable_execution_name="test-execution",
+ function_arn="arn:aws:lambda:us-east-1:123456789012:function:test",
+ status="SUCCEEDED",
+ start_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC),
+ end_timestamp=datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC),
+ result="test-result",
+ error=None,
+ )
+
+ history_response = GetDurableExecutionHistoryResponse(
+ events=[
+ Event(
+ event_type="ExecutionStarted",
+ event_timestamp=datetime.datetime(
+ 2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC
+ ),
+ operation_id="exec-1",
+ ),
+ Event(
+ event_type="StepStarted",
+ event_timestamp=datetime.datetime(
+ 2023, 1, 1, 0, 0, 10, tzinfo=datetime.UTC
+ ),
+ operation_id="step-1",
+ name="test-step",
+ ),
+ Event(
+ event_type="StepSucceeded",
+ event_timestamp=datetime.datetime(
+ 2023, 1, 1, 0, 0, 20, tzinfo=datetime.UTC
+ ),
+ operation_id="step-1",
+ step_succeeded_details=StepSucceededDetails(
+ result=EventResult(payload="step-result", truncated=False)
+ ),
+ ),
+ ]
+ )
+
+ result = DurableFunctionTestResult.from_execution_history(
+ execution_response, history_response
+ )
+
+ assert result.status == InvocationStatus.SUCCEEDED
+ assert result.result == "test-result"
+ assert result.error is None
+ assert len(result.operations) == 1
+ assert result.operations[0].name == "test-step"
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_init(mock_boto3):
+ """Test DurableFunctionCloudTestRunner initialization."""
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+
+ runner = DurableFunctionCloudTestRunner(
+ function_name="test-function",
+ region="us-west-2",
+ poll_interval=0.5,
+ )
+
+ assert runner.function_name == "test-function"
+ assert runner.region == "us-west-2"
+ assert runner.poll_interval == 0.5
+ mock_boto3.client.assert_called_once()
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_run_success(mock_boto3):
+ """Test DurableFunctionCloudTestRunner.run with successful execution."""
+ from aws_durable_execution_sdk_python.execution import InvocationStatus
+
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+
+ mock_client.invoke.return_value = {
+ "StatusCode": 200,
+ "Payload": Mock(read=lambda: b'{"result": "success"}'),
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1",
+ }
+
+ mock_client.get_durable_execution.return_value = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1",
+ "DurableExecutionName": "test-execution",
+ "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test",
+ "Status": "SUCCEEDED",
+ "StartTimestamp": "2023-01-01T00:00:00Z",
+ "EndTimestamp": "2023-01-01T00:01:00Z",
+ "Result": "test-result",
+ }
+
+ mock_client.get_durable_execution_history.return_value = {
+ "Events": [
+ {
+ "EventType": "ExecutionStarted",
+ "EventTimestamp": "2023-01-01T00:00:00Z",
+ "Id": "exec-1",
+ }
+ ]
+ }
+
+ runner = DurableFunctionCloudTestRunner(
+ function_name="test-function", poll_interval=0.01
+ )
+
+ result = runner.run(input="test-input", timeout=10)
+
+ assert result.status == InvocationStatus.SUCCEEDED
+ assert result.result == "test-result"
+ mock_client.invoke.assert_called_once()
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_run_invoke_failure(mock_boto3):
+ """Test DurableFunctionCloudTestRunner.run with invoke failure."""
+ from aws_durable_execution_sdk_python_testing.exceptions import (
+ DurableFunctionsTestError,
+ )
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+ mock_client.invoke.side_effect = Exception("Invoke failed")
+
+ runner = DurableFunctionCloudTestRunner(function_name="test-function")
+
+ with pytest.raises(
+ DurableFunctionsTestError, match="Failed to invoke Lambda function"
+ ):
+ runner.run(input="test-input")
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+@patch("aws_durable_execution_sdk_python_testing.runner.time")
+def test_cloud_runner_wait_for_completion_timeout(mock_time, mock_boto3):
+ """Test DurableFunctionCloudTestRunner._wait_for_completion with timeout."""
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+ mock_time.time.side_effect = [0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0]
+
+ mock_client.get_durable_execution.return_value = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1",
+ "DurableExecutionName": "test-execution",
+ "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test",
+ "Status": "RUNNING",
+ "StartTimestamp": "2023-01-01T00:00:00Z",
+ }
+
+ runner = DurableFunctionCloudTestRunner(
+ function_name="test-function", poll_interval=0.01
+ )
+
+ with pytest.raises(TimeoutError, match="Execution did not complete within"):
+ runner._wait_for_completion("test-arn", timeout=2)
+
+
+def test_durable_function_test_result_from_execution_history_with_exception():
+ """Test from_execution_history handles events_to_operations exception."""
+ import datetime
+
+ from aws_durable_execution_sdk_python.execution import InvocationStatus
+
+ from aws_durable_execution_sdk_python_testing.model import (
+ Event,
+ GetDurableExecutionHistoryResponse,
+ GetDurableExecutionResponse,
+ )
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionTestResult,
+ )
+
+ execution_response = GetDurableExecutionResponse(
+ durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1",
+ durable_execution_name="test-execution",
+ function_arn="arn:aws:lambda:us-east-1:123456789012:function:test",
+ status="SUCCEEDED",
+ start_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC),
+ )
+
+ history_response = GetDurableExecutionHistoryResponse(
+ events=[
+ Event(
+ event_type="StepStarted",
+ event_timestamp=datetime.datetime(
+ 2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC
+ ),
+ operation_id=None,
+ )
+ ]
+ )
+
+ result = DurableFunctionTestResult.from_execution_history(
+ execution_response, history_response
+ )
+
+ assert result.status == InvocationStatus.SUCCEEDED
+ assert len(result.operations) == 0
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_wait_for_completion_failed_status(mock_boto3):
+ """Test DurableFunctionCloudTestRunner._wait_for_completion with FAILED status."""
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+
+ mock_client.get_durable_execution.return_value = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1",
+ "DurableExecutionName": "test-execution",
+ "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test",
+ "Status": "FAILED",
+ "StartTimestamp": "2023-01-01T00:00:00Z",
+ "EndTimestamp": "2023-01-01T00:01:00Z",
+ "Error": {"ErrorMessage": "execution failed"},
+ }
+
+ runner = DurableFunctionCloudTestRunner(function_name="test-function")
+ result = runner._wait_for_completion("test-arn", timeout=10)
+
+ assert result.status == "FAILED"
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_run_bad_status_code(mock_boto3):
+ """Test DurableFunctionCloudTestRunner.run with bad HTTP status code."""
+ from aws_durable_execution_sdk_python_testing.exceptions import (
+ DurableFunctionsTestError,
+ )
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+
+ mock_client.invoke.return_value = {
+ "StatusCode": 500,
+ "Payload": Mock(read=lambda: b"Internal Server Error"),
+ }
+
+ runner = DurableFunctionCloudTestRunner(function_name="test-function")
+
+ with pytest.raises(
+ DurableFunctionsTestError, match="Lambda invocation failed with status 500"
+ ):
+ runner.run(input="test-input")
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_run_function_error(mock_boto3):
+ """Test DurableFunctionCloudTestRunner.run with function error."""
+ from aws_durable_execution_sdk_python_testing.exceptions import (
+ DurableFunctionsTestError,
+ )
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+
+ mock_client.invoke.return_value = {
+ "StatusCode": 200,
+ "FunctionError": "Unhandled",
+ "Payload": Mock(read=lambda: b'{"errorMessage": "Function failed"}'),
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1",
+ }
+
+ mock_client.get_durable_execution.return_value = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1",
+ "DurableExecutionName": "test-execution",
+ "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test",
+ "Status": "FAILED",
+ "StartTimestamp": "2023-01-01T00:00:00Z",
+ "EndTimestamp": "2023-01-01T00:01:00Z",
+ "Error": {"ErrorMessage": "execution failed"},
+ }
+
+ mock_client.get_durable_execution_history.return_value = {
+ "Events": [
+ {
+ "EventType": "ExecutionStarted",
+ "EventTimestamp": "2023-01-01T00:00:00Z",
+ "Id": "exec-1",
+ }
+ ]
+ }
+ runner = DurableFunctionCloudTestRunner(function_name="test-function")
+ result = runner.run(input="test-input")
+ assert result.status is InvocationStatus.FAILED
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_run_missing_execution_arn(mock_boto3):
+ """Test DurableFunctionCloudTestRunner.run with missing execution ARN."""
+ from aws_durable_execution_sdk_python_testing.exceptions import (
+ DurableFunctionsTestError,
+ )
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+
+ mock_client.invoke.return_value = {
+ "StatusCode": 200,
+ "Payload": Mock(read=lambda: b'{"result": "success"}'),
+ }
+
+ runner = DurableFunctionCloudTestRunner(function_name="test-function")
+
+ with pytest.raises(
+ DurableFunctionsTestError, match="No DurableExecutionArn in response"
+ ):
+ runner.run(input="test-input")
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_wait_for_completion_get_execution_failure(mock_boto3):
+ """Test DurableFunctionCloudTestRunner._wait_for_completion with API failure."""
+ from aws_durable_execution_sdk_python_testing.exceptions import (
+ DurableFunctionsTestError,
+ )
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+ mock_client.get_durable_execution.side_effect = Exception("API error")
+
+ runner = DurableFunctionCloudTestRunner(function_name="test-function")
+
+ with pytest.raises(
+ DurableFunctionsTestError, match="Failed to get execution status"
+ ):
+ runner._wait_for_completion("test-arn", timeout=10)
+
+
+def test_durable_function_test_result_from_execution_history_filters_execution_type():
+ """Test from_execution_history filters out EXECUTION type operations."""
+ import datetime
+
+ from aws_durable_execution_sdk_python.execution import InvocationStatus
+
+ from aws_durable_execution_sdk_python_testing.model import (
+ Event,
+ GetDurableExecutionHistoryResponse,
+ GetDurableExecutionResponse,
+ )
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionTestResult,
+ )
+
+ execution_response = GetDurableExecutionResponse(
+ durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1",
+ durable_execution_name="test-execution",
+ function_arn="arn:aws:lambda:us-east-1:123456789012:function:test",
+ status="SUCCEEDED",
+ start_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC),
+ )
+
+ history_response = GetDurableExecutionHistoryResponse(
+ events=[
+ Event(
+ event_type="ExecutionStarted",
+ event_timestamp=datetime.datetime(
+ 2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC
+ ),
+ operation_id="exec-1",
+ ),
+ ]
+ )
+
+ result = DurableFunctionTestResult.from_execution_history(
+ execution_response, history_response
+ )
+
+ assert len(result.operations) == 0
+
+
+def test_durable_function_test_result_from_execution_history_unknown_status():
+ """Test from_execution_history with unknown status defaults to FAILED."""
+ import datetime
+
+ from aws_durable_execution_sdk_python.execution import InvocationStatus
+
+ from aws_durable_execution_sdk_python_testing.model import (
+ GetDurableExecutionHistoryResponse,
+ GetDurableExecutionResponse,
+ )
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionTestResult,
+ )
+
+ execution_response = GetDurableExecutionResponse(
+ durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1",
+ durable_execution_name="test-execution",
+ function_arn="arn:aws:lambda:us-east-1:123456789012:function:test",
+ status="UNKNOWN_STATUS",
+ start_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC),
+ )
+
+ history_response = GetDurableExecutionHistoryResponse(events=[])
+
+ result = DurableFunctionTestResult.from_execution_history(
+ execution_response, history_response
+ )
+
+ assert result.status == InvocationStatus.FAILED
+
+
+def test_durable_function_test_result_from_execution_history_with_parent_operations():
+ """Test from_execution_history filters operations with parent_id."""
+ import datetime
+
+ from aws_durable_execution_sdk_python.execution import InvocationStatus
+
+ from aws_durable_execution_sdk_python_testing.model import (
+ Event,
+ GetDurableExecutionHistoryResponse,
+ GetDurableExecutionResponse,
+ )
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionTestResult,
+ )
+
+ execution_response = GetDurableExecutionResponse(
+ durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1",
+ durable_execution_name="test-execution",
+ function_arn="arn:aws:lambda:us-east-1:123456789012:function:test",
+ status="SUCCEEDED",
+ start_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC),
+ )
+
+ history_response = GetDurableExecutionHistoryResponse(
+ events=[
+ Event(
+ event_type="StepStarted",
+ event_timestamp=datetime.datetime(
+ 2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC
+ ),
+ operation_id="step-1",
+ name="parent-step",
+ ),
+ Event(
+ event_type="StepStarted",
+ event_timestamp=datetime.datetime(
+ 2023, 1, 1, 0, 0, 10, tzinfo=datetime.UTC
+ ),
+ operation_id="step-2",
+ name="child-step",
+ parent_id="step-1",
+ ),
+ ]
+ )
+
+ result = DurableFunctionTestResult.from_execution_history(
+ execution_response, history_response
+ )
+
+ assert len(result.operations) == 1
+ assert result.operations[0].name == "parent-step"
+
+
+def test_durable_function_test_result_from_execution_history_failed():
+ """Test from_execution_history with failed execution."""
+ import datetime
+
+ from aws_durable_execution_sdk_python.execution import InvocationStatus
+ from aws_durable_execution_sdk_python.lambda_service import ErrorObject
+
+ from aws_durable_execution_sdk_python_testing.model import (
+ GetDurableExecutionHistoryResponse,
+ GetDurableExecutionResponse,
+ )
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionTestResult,
+ )
+
+ execution_response = GetDurableExecutionResponse(
+ durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1",
+ durable_execution_name="test-execution",
+ function_arn="arn:aws:lambda:us-east-1:123456789012:function:test",
+ status="FAILED",
+ start_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC),
+ end_timestamp=datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC),
+ error=ErrorObject(
+ message="execution failed", type=None, data=None, stack_trace=None
+ ),
+ )
+
+ history_response = GetDurableExecutionHistoryResponse(events=[])
+
+ result = DurableFunctionTestResult.from_execution_history(
+ execution_response, history_response
+ )
+
+ assert result.status == InvocationStatus.FAILED
+ assert result.error.message == "execution failed"
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_wait_for_completion_timed_out_status(mock_boto3):
+ """Test DurableFunctionCloudTestRunner._wait_for_completion with TIMED_OUT status."""
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+
+ mock_client.get_durable_execution.return_value = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1",
+ "DurableExecutionName": "test-execution",
+ "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test",
+ "Status": "TIMED_OUT",
+ "StartTimestamp": "2023-01-01T00:00:00Z",
+ "EndTimestamp": "2023-01-01T00:01:00Z",
+ }
+
+ runner = DurableFunctionCloudTestRunner(function_name="test-function")
+ result = runner._wait_for_completion("test-arn", timeout=10)
+
+ assert result.status == "TIMED_OUT"
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_wait_for_completion_aborted_status(mock_boto3):
+ """Test DurableFunctionCloudTestRunner._wait_for_completion with ABORTED status."""
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+
+ mock_client.get_durable_execution.return_value = {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1",
+ "DurableExecutionName": "test-execution",
+ "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test",
+ "Status": "ABORTED",
+ "StartTimestamp": "2023-01-01T00:00:00Z",
+ "EndTimestamp": "2023-01-01T00:01:00Z",
+ }
+
+ runner = DurableFunctionCloudTestRunner(function_name="test-function")
+ result = runner._wait_for_completion("test-arn", timeout=10)
+
+ assert result.status == "ABORTED"
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_run_async_success(mock_boto3):
+ """Test DurableFunctionCloudTestRunner.run_async with successful invocation."""
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+
+ mock_client.invoke.return_value = {
+ "StatusCode": 202,
+ "Payload": Mock(read=lambda: b'{"result": "success"}'),
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1",
+ }
+
+ runner = DurableFunctionCloudTestRunner(function_name="test-function")
+ execution_arn = runner.run_async(input="test-input")
+
+ assert (
+ execution_arn
+ == "arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1"
+ )
+ mock_client.invoke.assert_called_once_with(
+ FunctionName="test-function",
+ InvocationType="Event",
+ Payload='"test-input"',
+ )
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_run_async_with_400(mock_boto3):
+ """Test DurableFunctionCloudTestRunner.run_async with successful invocation."""
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+
+ mock_client.invoke.return_value = {
+ "StatusCode": 400,
+ "Payload": Mock(read=lambda: b'{"result": "failed"}'),
+ }
+
+ runner = DurableFunctionCloudTestRunner(function_name="test-function")
+
+ with pytest.raises(
+ DurableFunctionsTestError, match="Lambda invocation failed with status 400"
+ ):
+ runner.run_async(input="test-input")
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_run_async_failure(mock_boto3):
+ """Test DurableFunctionCloudTestRunner.run_async with invocation failure."""
+ from aws_durable_execution_sdk_python_testing.exceptions import (
+ DurableFunctionsTestError,
+ )
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+ mock_client.invoke.side_effect = Exception("Async invoke failed")
+
+ runner = DurableFunctionCloudTestRunner(function_name="test-function")
+
+ with pytest.raises(
+ DurableFunctionsTestError, match="Failed to invoke Lambda function"
+ ):
+ runner.run_async(input="test-input")
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_send_callback_success(mock_boto3):
+ """Test DurableFunctionCloudTestRunner.send_callback_success."""
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+
+ runner = DurableFunctionCloudTestRunner(function_name="test-function")
+ runner.send_callback_success("callback-123")
+
+ mock_client.send_durable_execution_callback_success.assert_called_once_with(
+ CallbackId="callback-123", Result=None
+ )
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_send_callback_failure(mock_boto3):
+ """Test DurableFunctionCloudTestRunner.send_callback_failure."""
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+
+ runner = DurableFunctionCloudTestRunner(function_name="test-function")
+ runner.send_callback_failure("callback-123")
+
+ mock_client.send_durable_execution_callback_failure.assert_called_once_with(
+ CallbackId="callback-123", Error=None
+ )
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_send_callback_heartbeat(mock_boto3):
+ """Test DurableFunctionCloudTestRunner.send_callback_heartbeat."""
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+
+ runner = DurableFunctionCloudTestRunner(function_name="test-function")
+ runner.send_callback_heartbeat("callback-123")
+
+ mock_client.send_durable_execution_callback_heartbeat.assert_called_once_with(
+ CallbackId="callback-123"
+ )
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_send_callback_error(mock_boto3):
+ """Test DurableFunctionCloudTestRunner callback methods with API errors."""
+ from aws_durable_execution_sdk_python_testing.exceptions import (
+ DurableFunctionsTestError,
+ )
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+ mock_client.send_durable_execution_callback_success.side_effect = Exception(
+ "API error"
+ )
+
+ runner = DurableFunctionCloudTestRunner(function_name="test-function")
+
+ with pytest.raises(
+ DurableFunctionsTestError, match="Failed to send callback success"
+ ):
+ runner.send_callback_success("callback-123")
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_wait_for_callback_success(mock_boto3):
+ """Test DurableFunctionCloudTestRunner.wait_for_callback success."""
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+
+ mock_client.get_durable_execution_history.return_value = {
+ "Events": [
+ {
+ "EventType": "CallbackStarted",
+ "EventTimestamp": "2023-01-01T00:00:00Z",
+ "Id": "callback-event-1",
+ "Name": "test-callback",
+ "CallbackStartedDetails": {"CallbackId": "callback-123"},
+ }
+ ]
+ }
+
+ runner = DurableFunctionCloudTestRunner(
+ function_name="test-function", poll_interval=0.01
+ )
+ callback_id = runner.wait_for_callback("test-arn", name="test-callback", timeout=10)
+
+ assert callback_id == "callback-123"
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_wait_for_callback_none(mock_boto3):
+ """Test DurableFunctionCloudTestRunner.wait_for_callback none."""
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+
+ mock_client.get_durable_execution_history.return_value = {
+ "Events": [
+ {
+ "EventType": "CallbackStarted",
+ "EventTimestamp": "2023-01-01T00:00:00Z",
+ "Id": "callback-event-1",
+ "Name": "test-callback",
+ "CallbackStartedDetails": {"CallbackId": "callback-123"},
+ }
+ ]
+ }
+
+ runner = DurableFunctionCloudTestRunner(
+ function_name="test-function", poll_interval=0.01
+ )
+
+ with pytest.raises(TimeoutError, match="Callback did not available within"):
+ runner.wait_for_callback("test-arn", name="test-callback1", timeout=2)
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_wait_for_callback_success_without_name(mock_boto3):
+ """Test DurableFunctionCloudTestRunner.wait_for_callback success."""
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+
+ mock_client.get_durable_execution_history.return_value = {
+ "Events": [
+ {
+ "EventType": "CallbackStarted",
+ "EventTimestamp": "2023-01-01T00:00:00Z",
+ "Id": "callback-event-1",
+ "Name": "test-callback",
+ "CallbackStartedDetails": {"CallbackId": "callback-123"},
+ }
+ ]
+ }
+
+ runner = DurableFunctionCloudTestRunner(
+ function_name="test-function", poll_interval=0.01
+ )
+ callback_id = runner.wait_for_callback("test-arn")
+
+ assert callback_id == "callback-123"
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_wait_for_callback_all_done_without_name(mock_boto3):
+ """Test DurableFunctionCloudTestRunner.wait_for_callback all_done_without_name."""
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+
+ mock_client.get_durable_execution_history.return_value = {
+ "Events": [
+ {
+ "EventType": "CallbackStarted",
+ "EventTimestamp": "2023-01-01T00:00:00Z",
+ "Id": "callback-event-1",
+ "Name": "test-callback",
+ "CallbackStartedDetails": {"CallbackId": "callback-123"},
+ },
+ {
+ "EventType": "CallbackSucceeded",
+ "EventTimestamp": "2023-01-01T00:05:00Z",
+ "Id": "callback-event-1",
+ "Name": "test-callback",
+ },
+ ]
+ }
+
+ runner = DurableFunctionCloudTestRunner(
+ function_name="test-function", poll_interval=0.01
+ )
+ with pytest.raises(TimeoutError, match="Callback did not available within"):
+ runner.wait_for_callback("test-arn", timeout=2)
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.Executor")
+def test_local_runner_wait_for_callback_all_done_without_name(mock_executor_class):
+ """Test DurableFunctionCloudTestRunner.wait_for_callback all_done_without_name."""
+ handler = Mock()
+ mock_executor = Mock()
+ mock_executor_class.return_value = mock_executor
+ mock_executor.get_execution_history.return_value = (
+ GetDurableExecutionHistoryResponse.from_dict(
+ {
+ "Events": [
+ {
+ "EventType": "CallbackStarted",
+ "EventTimestamp": "2023-01-01T00:00:00Z",
+ "Id": "callback-event-1",
+ "Name": "test-callback",
+ "CallbackStartedDetails": {"CallbackId": "callback-123"},
+ },
+ {
+ "EventType": "CallbackSucceeded",
+ "EventTimestamp": "2023-01-01T00:05:00Z",
+ "Id": "callback-event-1",
+ "Name": "test-callback",
+ },
+ ]
+ }
+ )
+ )
+
+ runner = DurableFunctionTestRunner(handler)
+ with pytest.raises(TimeoutError, match="Callback did not available within"):
+ runner.wait_for_callback("test-arn", timeout=2)
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.Executor")
+def test_local_runner_wait_for_callback_with_exception(mock_executor_class):
+ """Test DurableFunctionCloudTestRunner.wait_for_callback with exception"""
+ handler = Mock()
+ mock_executor = Mock()
+ mock_executor_class.return_value = mock_executor
+ mock_executor.get_execution_history.side_effect = Exception("error")
+
+ runner = DurableFunctionTestRunner(handler)
+ with pytest.raises(
+ DurableFunctionsTestError, match="Failed to fetch execution history"
+ ):
+ runner.wait_for_callback("test-arn", timeout=10)
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.Executor")
+def test_local_runner_wait_for_callback_with_resource_not_found_exception(
+ mock_executor_class,
+):
+ """Test DurableFunctionCloudTestRunner.wait_for_callback with resource_not_found exception"""
+ handler = Mock()
+ mock_executor = Mock()
+ mock_executor_class.return_value = mock_executor
+ mock_executor.get_execution_history.side_effect = ResourceNotFoundException("error")
+
+ runner = DurableFunctionTestRunner(handler)
+ with pytest.raises(TimeoutError, match="Callback did not available within"):
+ runner.wait_for_callback("test-arn", timeout=2)
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+@patch("aws_durable_execution_sdk_python_testing.runner.time")
+def test_cloud_runner_wait_for_callback_timeout(mock_time, mock_boto3):
+ """Test DurableFunctionCloudTestRunner.wait_for_callback timeout."""
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+ mock_time.time.side_effect = [0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0]
+
+ mock_client.get_durable_execution_history.return_value = {"Events": []}
+
+ runner = DurableFunctionCloudTestRunner(
+ function_name="test-function", poll_interval=0.01
+ )
+
+ with pytest.raises(TimeoutError, match="Callback did not available within"):
+ runner.wait_for_callback("test-arn", timeout=2)
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_wait_for_callback_already_completed(mock_boto3):
+ """Test DurableFunctionCloudTestRunner.wait_for_callback already completed."""
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+
+ mock_client.get_durable_execution_history.return_value = {
+ "Events": [
+ {
+ "EventType": "CallbackStarted",
+ "EventTimestamp": "2023-01-01T00:00:00Z",
+ "Id": "callback-event-1",
+ "Name": "test-callback",
+ "CallbackStartedDetails": {"CallbackId": "callback-123"},
+ },
+ {
+ "EventType": "CallbackSucceeded",
+ "EventTimestamp": "2023-01-01T00:05:00Z",
+ "Id": "callback-event-1",
+ "Name": "test-callback",
+ },
+ ]
+ }
+
+ runner = DurableFunctionCloudTestRunner(
+ function_name="test-function", poll_interval=0.01
+ )
+
+ with pytest.raises(
+ DurableFunctionsTestError, match="Callback test-callback has already completed"
+ ):
+ runner.wait_for_callback("test-arn", "test-callback", timeout=2)
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_wait_for_callback_client_error_retryable(mock_boto3):
+ """Test wait_for_callback with retryable ClientError."""
+ from botocore.exceptions import ClientError # type: ignore
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+
+ # First call raises ResourceNotFoundException, second succeeds
+ mock_client.get_durable_execution_history.side_effect = [
+ ClientError(
+ error_response={"Error": {"Code": "ResourceNotFoundException"}},
+ operation_name="GetDurableExecutionHistory",
+ ),
+ {
+ "Events": [
+ {
+ "EventType": "CallbackStarted",
+ "EventTimestamp": "2023-01-01T00:00:00Z",
+ "Id": "callback-event-1",
+ "Name": "test-callback",
+ "CallbackStartedDetails": {"CallbackId": "callback-123"},
+ }
+ ]
+ },
+ ]
+
+ runner = DurableFunctionCloudTestRunner(
+ function_name="test-function", poll_interval=0.01
+ )
+ callback_id = runner.wait_for_callback("test-arn", name="test-callback", timeout=10)
+
+ assert callback_id == "callback-123"
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_wait_for_callback_client_error_non_retryable(
+ mock_boto3,
+):
+ """Test wait_for_callback with non-retryable ClientError."""
+ from botocore.exceptions import ClientError
+ from aws_durable_execution_sdk_python_testing.exceptions import (
+ DurableFunctionsTestError,
+ )
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+
+ mock_client.get_durable_execution_history.side_effect = ClientError(
+ error_response={"Error": {"Code": "AccessDeniedException"}},
+ operation_name="GetDurableExecutionHistory",
+ )
+
+ runner = DurableFunctionCloudTestRunner(function_name="test-function")
+
+ with pytest.raises(
+ DurableFunctionsTestError, match="Failed to fetch execution history"
+ ):
+ runner.wait_for_callback("test-arn", timeout=10)
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_wait_for_callback_generic_exception(mock_boto3):
+ """Test wait_for_callback with generic Exception."""
+ from aws_durable_execution_sdk_python_testing.exceptions import (
+ DurableFunctionsTestError,
+ )
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+
+ mock_client.get_durable_execution_history.side_effect = Exception("Network error")
+
+ runner = DurableFunctionCloudTestRunner(function_name="test-function")
+
+ with pytest.raises(
+ DurableFunctionsTestError, match="Failed to fetch execution history"
+ ):
+ runner.wait_for_callback("test-arn", timeout=10)
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_wait_for_result_fetch_history_exception(mock_boto3):
+ """Test wait_for_result with exception in _fetch_execution_history."""
+ from aws_durable_execution_sdk_python_testing.exceptions import (
+ DurableFunctionsTestError,
+ )
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+
+ # Mock successful _wait_for_completion
+ mock_execution_response = Mock()
+ mock_execution_response.status = "SUCCEEDED"
+
+ # Mock _fetch_execution_history to raise exception
+ runner = DurableFunctionCloudTestRunner(function_name="test-function")
+ runner._wait_for_completion = Mock(return_value=mock_execution_response)
+ runner._fetch_execution_history = Mock(
+ side_effect=Exception("History fetch failed")
+ )
+
+ with pytest.raises(
+ DurableFunctionsTestError,
+ match="Failed to fetch execution history: History fetch failed",
+ ):
+ runner.wait_for_result("test-arn", timeout=60)
+
+
+@patch("aws_durable_execution_sdk_python_testing.runner.boto3")
+def test_cloud_runner_wait_for_result_success(mock_boto3):
+ """Test wait_for_result successful execution."""
+ from aws_durable_execution_sdk_python.execution import InvocationStatus
+ from aws_durable_execution_sdk_python_testing.runner import (
+ DurableFunctionCloudTestRunner,
+ )
+
+ mock_client = Mock()
+ mock_boto3.client.return_value = mock_client
+
+ # Mock successful responses
+ mock_execution_response = Mock()
+ mock_execution_response.status = "SUCCEEDED"
+ mock_history_response = Mock()
+ mock_history_response.events = []
+
+ runner = DurableFunctionCloudTestRunner(function_name="test-function")
+ runner._wait_for_completion = Mock(return_value=mock_execution_response)
+ runner._fetch_execution_history = Mock(return_value=mock_history_response)
+
+ # Mock the from_execution_history method
+ with patch(
+ "aws_durable_execution_sdk_python_testing.runner.DurableFunctionTestResult.from_execution_history"
+ ) as mock_from_history:
+ mock_result = Mock()
+ mock_result.status = InvocationStatus.SUCCEEDED
+ mock_from_history.return_value = mock_result
+
+ result = runner.wait_for_result("test-arn", timeout=60)
+
+ assert result.status == InvocationStatus.SUCCEEDED
+ mock_from_history.assert_called_once_with(
+ mock_execution_response, mock_history_response
+ )
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/runner_web_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/runner_web_test.py
new file mode 100644
index 0000000..f810bf3
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/runner_web_test.py
@@ -0,0 +1,1753 @@
+"""Unit tests for web runner components in runner module."""
+
+from __future__ import annotations
+
+import logging
+import os
+from unittest.mock import Mock, patch
+
+import pytest
+
+from aws_durable_execution_sdk_python_testing.cli import CliApp
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ DurableFunctionsLocalRunnerError,
+)
+from aws_durable_execution_sdk_python_testing.invoker import _LAMBDA_CLIENT_CONFIG
+from aws_durable_execution_sdk_python_testing.runner import (
+ WebRunner,
+ WebRunnerConfig,
+)
+from aws_durable_execution_sdk_python_testing.web.server import WebServiceConfig
+
+
+def test_should_create_config_with_web_service_and_defaults():
+ """Test creating WebRunnerConfig with WebServiceConfig and default Lambda settings."""
+ # Arrange
+ web_config = WebServiceConfig(
+ host="localhost",
+ port=8080,
+ log_level=logging.DEBUG,
+ max_request_size=5 * 1024 * 1024,
+ )
+
+ # Act
+ config = WebRunnerConfig(web_service=web_config)
+
+ # Assert
+ assert config.web_service == web_config
+ assert config.lambda_endpoint == "http://127.0.0.1:3001"
+ assert config.local_runner_endpoint == "http://0.0.0.0:5000"
+ assert config.local_runner_region == "us-west-2"
+ assert config.local_runner_mode == "local"
+
+
+def test_should_create_config_with_custom_lambda_settings():
+ """Test creating WebRunnerConfig with custom Lambda configuration."""
+ # Arrange
+ web_config = WebServiceConfig(host="0.0.0.0", port=5000) # noqa: S104
+ custom_lambda_endpoint = "http://custom-lambda:4000"
+ custom_runner_endpoint = "http://custom-runner:6000"
+ custom_region = "us-east-1"
+ custom_mode = "remote"
+
+ # Act
+ config = WebRunnerConfig(
+ web_service=web_config,
+ lambda_endpoint=custom_lambda_endpoint,
+ local_runner_endpoint=custom_runner_endpoint,
+ local_runner_region=custom_region,
+ local_runner_mode=custom_mode,
+ )
+
+ # Assert
+ assert config.web_service == web_config
+ assert config.lambda_endpoint == custom_lambda_endpoint
+ assert config.local_runner_endpoint == custom_runner_endpoint
+ assert config.local_runner_region == custom_region
+ assert config.local_runner_mode == custom_mode
+
+
+def test_should_access_web_service_config_fields():
+ """Test accessing WebServiceConfig fields through composition."""
+ # Arrange
+ web_config = WebServiceConfig(
+ host="test-host",
+ port=9999,
+ log_level=logging.WARNING,
+ max_request_size=1024,
+ )
+ config = WebRunnerConfig(web_service=web_config)
+
+ # Act & Assert
+ assert config.web_service.host == "test-host"
+ assert config.web_service.port == 9999
+ assert config.web_service.log_level == logging.WARNING
+ assert config.web_service.max_request_size == 1024
+
+
+def test_should_be_immutable_frozen_dataclass():
+ """Test that WebRunnerConfig is immutable (frozen=True)."""
+ # Arrange
+ web_config = WebServiceConfig()
+ config = WebRunnerConfig(web_service=web_config)
+
+ # Act & Assert - attempting to modify should raise FrozenInstanceError
+ with pytest.raises(
+ AttributeError
+ ): # dataclass frozen raises AttributeError in Python 3.13+
+ config.lambda_endpoint = "http://new-endpoint:8000"
+
+ with pytest.raises(AttributeError):
+ config.web_service = WebServiceConfig(host="new-host")
+
+
+def test_should_support_equality_comparison():
+ """Test that WebRunnerConfig supports equality comparison."""
+ # Arrange
+ web_config1 = WebServiceConfig(host="host1", port=5000)
+ web_config2 = WebServiceConfig(host="host1", port=5000)
+ web_config3 = WebServiceConfig(host="host2", port=5000)
+
+ config1 = WebRunnerConfig(
+ web_service=web_config1,
+ lambda_endpoint="http://lambda:3001",
+ )
+ config2 = WebRunnerConfig(
+ web_service=web_config2,
+ lambda_endpoint="http://lambda:3001",
+ )
+ config3 = WebRunnerConfig(
+ web_service=web_config3,
+ lambda_endpoint="http://lambda:3001",
+ )
+
+ # Act & Assert
+ assert config1 == config2 # Same values should be equal
+ assert config1 != config3 # Different web_service should not be equal
+ assert config2 != config3 # Different web_service should not be equal
+
+
+def test_should_support_hash_for_use_in_sets_and_dicts():
+ """Test that WebRunnerConfig is hashable for use in sets and dicts."""
+ # Arrange
+ web_config = WebServiceConfig(host="test", port=8080)
+ config1 = WebRunnerConfig(web_service=web_config)
+ config2 = WebRunnerConfig(web_service=web_config)
+
+ # Act - should not raise exception
+ config_set = {config1, config2}
+ config_dict = {config1: "value1", config2: "value2"}
+
+ # Assert
+ assert len(config_set) == 1 # Same configs should deduplicate in set
+ assert len(config_dict) == 1 # Same configs should overwrite in dict
+
+
+def test_should_create_config_with_minimal_web_service():
+ """Test creating config with minimal WebServiceConfig using defaults."""
+ # Arrange
+ web_config = WebServiceConfig() # Uses all defaults
+
+ # Act
+ config = WebRunnerConfig(web_service=web_config)
+
+ # Assert
+ assert config.web_service.host == "localhost"
+ assert config.web_service.port == 5000
+ assert config.web_service.log_level == logging.INFO
+ assert config.web_service.max_request_size == 10 * 1024 * 1024
+
+
+def test_should_have_proper_type_annotations():
+ """Test that all fields have proper type annotations."""
+ # Arrange & Act
+ annotations = WebRunnerConfig.__annotations__
+
+ # Assert
+ assert "web_service" in annotations
+ assert "lambda_endpoint" in annotations
+ assert "local_runner_endpoint" in annotations
+ assert "local_runner_region" in annotations
+ assert "local_runner_mode" in annotations
+
+ # Check that the annotations are the expected string representations
+ assert annotations["web_service"] == "WebServiceConfig"
+ assert annotations["lambda_endpoint"] == "str"
+ assert annotations["local_runner_endpoint"] == "str"
+ assert annotations["local_runner_region"] == "str"
+ assert annotations["local_runner_mode"] == "str"
+
+
+def test_should_create_config_with_keyword_arguments():
+ """Test creating config using keyword arguments for all fields."""
+ # Arrange
+ web_config = WebServiceConfig(host="kw-host", port=7777)
+
+ # Act
+ config = WebRunnerConfig(
+ web_service=web_config,
+ lambda_endpoint="http://kw-lambda:2000",
+ local_runner_endpoint="http://kw-runner:3000",
+ local_runner_region="eu-west-1",
+ local_runner_mode="test",
+ )
+
+ # Assert
+ assert config.web_service == web_config
+ assert config.lambda_endpoint == "http://kw-lambda:2000"
+ assert config.local_runner_endpoint == "http://kw-runner:3000"
+ assert config.local_runner_region == "eu-west-1"
+ assert config.local_runner_mode == "test"
+
+
+def test_should_represent_config_as_string():
+ """Test string representation of WebRunnerConfig."""
+ # Arrange
+ web_config = WebServiceConfig(host="repr-host", port=1234)
+ config = WebRunnerConfig(
+ web_service=web_config,
+ lambda_endpoint="http://repr-lambda:5000",
+ )
+
+ # Act
+ config_str = str(config)
+
+ # Assert
+ assert "WebRunnerConfig" in config_str
+ assert "repr-host" in config_str
+ assert "1234" in config_str
+ assert "http://repr-lambda:5000" in config_str
+
+
+# WebRunner class tests
+
+
+def test_should_create_web_runner_with_config():
+ """Test creating WebRunner with WebRunnerConfig."""
+ # Arrange
+ web_config = WebServiceConfig(host="test-host", port=8080)
+ config = WebRunnerConfig(web_service=web_config)
+
+ # Act
+ runner = WebRunner(config)
+
+ # Assert - Test through public behavior only
+ assert isinstance(runner, WebRunner)
+ # Verify runner can be used as context manager (public API)
+ assert hasattr(runner, "__enter__")
+ assert hasattr(runner, "__exit__")
+ assert callable(runner.start)
+ assert callable(runner.stop)
+ assert callable(runner.serve_forever)
+
+
+def test_should_support_context_manager_protocol():
+ """Test WebRunner context manager protocol."""
+ # Arrange
+ web_config = WebServiceConfig()
+ config = WebRunnerConfig(web_service=web_config)
+
+ # Act & Assert - should not raise exception
+ with WebRunner(config) as runner:
+ assert isinstance(runner, WebRunner)
+ assert runner._config == config # noqa: SLF001
+
+
+def test_should_return_self_from_context_manager_enter():
+ """Test that __enter__ returns self."""
+ # Arrange
+ web_config = WebServiceConfig()
+ config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(config)
+
+ # Act
+ result = runner.__enter__()
+
+ # Assert
+ assert result is runner
+
+
+def test_should_call_start_and_stop_on_context_manager():
+ """Test that context manager calls start on entry and stop on exit."""
+ # Arrange
+ web_config = WebServiceConfig()
+ config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(config)
+
+ # Mock the start and stop methods to verify they're called
+ with (
+ patch.object(runner, "start") as mock_start,
+ patch.object(runner, "stop") as mock_stop,
+ ):
+ # Act
+ with runner as context_runner:
+ assert context_runner is runner
+ mock_start.assert_called_once()
+
+ # Assert
+ mock_stop.assert_called_once()
+
+
+def test_should_handle_context_manager_exit_with_exception():
+ """Test context manager exit with exception parameters."""
+ # Arrange
+ web_config = WebServiceConfig()
+ config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(config)
+
+ # Act & Assert - should not raise exception
+ runner.__exit__(ValueError, ValueError("test"), None)
+
+
+def test_should_have_proper_method_signatures():
+ """Test that WebRunner has all required methods with proper signatures."""
+ # Arrange
+ web_config = WebServiceConfig()
+ config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(config)
+
+ # Assert methods exist and are callable
+ assert callable(runner.start)
+ assert callable(runner.serve_forever)
+ assert callable(runner.stop)
+ assert callable(runner.__enter__)
+ assert callable(runner.__exit__)
+
+
+def test_should_initialize_runner_in_stopped_state():
+ """Test that WebRunner initializes in a stopped state."""
+ # Arrange
+ web_config = WebServiceConfig()
+ config = WebRunnerConfig(web_service=web_config)
+
+ # Act
+ runner = WebRunner(config)
+
+ # Assert - Test through public behavior
+ # Should raise DurableFunctionsLocalRunnerError when trying to serve before starting
+ with pytest.raises(DurableFunctionsLocalRunnerError, match="Server not started"):
+ runner.serve_forever()
+
+ # Should be safe to call stop multiple times (no-op when not started)
+ runner.stop()
+ runner.stop()
+
+
+def test_should_store_config_reference():
+ """Test that WebRunner can be created with config and used properly."""
+ # Arrange
+ web_config = WebServiceConfig(host="config-test", port=9999)
+ config = WebRunnerConfig(
+ web_service=web_config,
+ lambda_endpoint="http://test:1234",
+ )
+
+ # Act
+ runner = WebRunner(config)
+
+ # Assert - Test through public behavior
+ assert isinstance(runner, WebRunner)
+
+ # Verify the runner can be started and stopped (public behavior)
+ with (
+ patch("boto3.client") as mock_boto3_client,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.WebServer"
+ ) as mock_web_server_class,
+ ):
+ mock_client = Mock()
+ mock_server = Mock()
+ mock_boto3_client.return_value = mock_client
+ mock_web_server_class.return_value = mock_server
+
+ runner.start()
+
+ # Verify server was started (public behavior - no exception on serve_forever call)
+ runner.serve_forever()
+ mock_server.serve_forever.assert_called_once()
+
+ runner.stop()
+ mock_server.server_close.assert_called_once()
+
+
+# Integration Tests - Testing Public Behavior
+
+
+def test_should_handle_start_with_boto3_client_creation():
+ """Test that start() properly handles boto3 client creation through public API."""
+ # Arrange
+ web_config = WebServiceConfig(host="localhost", port=5000)
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Mock boto3.client to avoid actual client creation
+ with patch("boto3.client") as mock_boto3_client:
+ mock_client = Mock()
+ mock_boto3_client.return_value = mock_client
+
+ # Act - Test public behavior
+ runner.start()
+
+ # Assert - Verify public behavior
+ # Should be able to call serve_forever after start (public API)
+ with patch.object(runner, "serve_forever") as mock_serve:
+ runner.serve_forever()
+ mock_serve.assert_called_once()
+
+ # Should be able to stop after start (public API)
+ runner.stop()
+
+
+def test_should_handle_boto3_client_creation_with_custom_config():
+ """Test that start() uses custom configuration for boto3 client through public API."""
+ # Arrange
+ web_config = WebServiceConfig(host="localhost", port=5000)
+ runner_config = WebRunnerConfig(
+ web_service=web_config,
+ lambda_endpoint="http://custom-endpoint:8080",
+ local_runner_region="eu-west-1",
+ )
+ runner = WebRunner(runner_config)
+
+ with patch("boto3.client") as mock_boto3_client:
+ mock_client = Mock()
+ mock_boto3_client.return_value = mock_client
+
+ # Act - Test public behavior
+ runner.start()
+
+ # Assert - Verify boto3 client was called with correct parameters
+ mock_boto3_client.assert_called_once_with(
+ "lambda",
+ endpoint_url="http://custom-endpoint:8080",
+ region_name="eu-west-1",
+ config=_LAMBDA_CLIENT_CONFIG,
+ )
+
+ # Verify public behavior works
+ runner.stop()
+
+
+def test_should_handle_boto3_client_creation_with_defaults():
+ """Test that start() uses default configuration values through public API."""
+ # Arrange
+ web_config = WebServiceConfig(host="localhost", port=5000)
+ runner_config = WebRunnerConfig(web_service=web_config) # Use defaults
+ runner = WebRunner(runner_config)
+
+ with patch("boto3.client") as mock_boto3_client:
+ mock_client = Mock()
+ mock_boto3_client.return_value = mock_client
+
+ # Act - Test public behavior
+ runner.start()
+
+ # Assert - Verify boto3 client was called with default parameters
+ mock_boto3_client.assert_called_once_with(
+ "lambda",
+ endpoint_url="http://127.0.0.1:3001", # Default lambda_endpoint value
+ region_name="us-west-2", # Default value
+ config=_LAMBDA_CLIENT_CONFIG,
+ )
+
+ # Verify public behavior works
+ runner.stop()
+
+
+def test_should_propagate_boto3_client_creation_exceptions():
+ """Test that start() propagates boto3 client creation exceptions through public API."""
+ # Arrange
+ web_config = WebServiceConfig(host="localhost", port=5000)
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Mock boto3.client to raise an exception
+ with patch("boto3.client") as mock_boto3_client:
+ mock_boto3_client.side_effect = Exception("Connection failed")
+
+ # Act & Assert - Test public behavior
+ with pytest.raises(Exception, match="Connection failed"):
+ runner.start()
+
+
+def test_should_create_boto3_client_during_start():
+ """Test that start() creates boto3 client correctly through public API."""
+ # Arrange
+ web_config = WebServiceConfig(host="localhost", port=5000)
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ with patch("boto3.client") as mock_boto3_client:
+ mock_client = Mock()
+ mock_boto3_client.return_value = mock_client
+
+ # Act - Test public behavior
+ runner.start()
+
+ # Assert - Verify boto3 client was created
+ mock_boto3_client.assert_called_once()
+
+ # Verify public behavior works
+ runner.stop()
+
+
+# Error Condition Tests
+
+
+def test_should_raise_runtime_error_on_double_start():
+ """Test that calling start() twice raises DurableFunctionsLocalRunnerError."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Mock dependencies to allow first start
+ with (
+ patch("boto3.client") as mock_boto3_client,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.WebServer"
+ ) as mock_web_server_class,
+ ):
+ mock_client = Mock()
+ mock_server = Mock()
+ mock_boto3_client.return_value = mock_client
+ mock_web_server_class.return_value = mock_server
+
+ # First start should succeed
+ runner.start()
+
+ # Act & Assert - Second start should raise DurableFunctionsLocalRunnerError
+ with pytest.raises(
+ DurableFunctionsLocalRunnerError, match="Server is already running"
+ ):
+ runner.start()
+
+ # Cleanup
+ runner.stop()
+
+
+def test_should_raise_runtime_error_when_serve_before_start():
+ """Test that calling serve_forever() before start() raises DurableFunctionsLocalRunnerError."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Act & Assert - serve_forever before start should raise DurableFunctionsLocalRunnerError
+ with pytest.raises(DurableFunctionsLocalRunnerError, match="Server not started"):
+ runner.serve_forever()
+
+
+def test_should_propagate_boto3_client_creation_failures():
+ """Test that boto3 client creation failures are propagated as exceptions."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Mock boto3.client to raise various exceptions
+ test_cases = [
+ Exception("Connection refused"),
+ ConnectionError("Network error"),
+ ValueError("Invalid endpoint URL"),
+ RuntimeError("AWS credentials not found"),
+ ]
+
+ for exception in test_cases:
+ with patch("boto3.client") as mock_boto3_client:
+ mock_boto3_client.side_effect = exception
+
+ # Act & Assert - Exception should propagate
+ with pytest.raises(type(exception), match=str(exception)):
+ runner.start()
+
+
+def test_should_handle_web_server_creation_failures():
+ """Test that WebServer creation failures are propagated as exceptions."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Mock boto3 client to succeed but WebServer to fail
+ with (
+ patch("boto3.client") as mock_boto3_client,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.WebServer"
+ ) as mock_web_server_class,
+ ):
+ mock_client = Mock()
+ mock_boto3_client.return_value = mock_client
+ mock_web_server_class.side_effect = Exception("Failed to bind to port")
+
+ # Act & Assert - WebServer creation failure should propagate
+ with pytest.raises(Exception, match="Failed to bind to port"):
+ runner.start()
+
+
+def test_should_handle_scheduler_creation_failures():
+ """Test that Scheduler creation failures are propagated as exceptions."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Mock boto3 client to succeed but Scheduler to fail
+ with (
+ patch("boto3.client") as mock_boto3_client,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.Scheduler"
+ ) as mock_scheduler_class,
+ ):
+ mock_client = Mock()
+ mock_boto3_client.return_value = mock_client
+ mock_scheduler_class.side_effect = Exception("Scheduler initialization failed")
+
+ # Act & Assert - Scheduler creation failure should propagate
+ with pytest.raises(Exception, match="Scheduler initialization failed"):
+ runner.start()
+
+
+def test_should_handle_executor_creation_failures():
+ """Test that Executor creation failures are propagated as exceptions."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Mock dependencies to succeed but Executor to fail
+ with (
+ patch("boto3.client") as mock_boto3_client,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.Executor"
+ ) as mock_executor_class,
+ ):
+ mock_client = Mock()
+ mock_boto3_client.return_value = mock_client
+ mock_executor_class.side_effect = Exception("Executor initialization failed")
+
+ # Act & Assert - Executor creation failure should propagate
+ with pytest.raises(Exception, match="Executor initialization failed"):
+ runner.start()
+
+
+# Dependency Creation and Wiring Tests
+
+
+def test_should_create_all_required_dependencies_during_start():
+ """Test that start() creates all required dependencies with proper wiring."""
+ # Arrange
+ web_config = WebServiceConfig(host="test-host", port=8080)
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Mock all dependency classes
+ with (
+ patch("boto3.client") as mock_boto3_client,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.InMemoryExecutionStore"
+ ) as mock_store_class,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.Scheduler"
+ ) as mock_scheduler_class,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.LambdaInvoker"
+ ) as mock_invoker_class,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.Executor"
+ ) as mock_executor_class,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.WebServer"
+ ) as mock_web_server_class,
+ ):
+ # Setup mocks
+ mock_client = Mock()
+ mock_store = Mock()
+ mock_scheduler = Mock()
+ mock_invoker = Mock()
+ mock_executor = Mock()
+ mock_server = Mock()
+
+ mock_boto3_client.return_value = mock_client
+ mock_store_class.return_value = mock_store
+ mock_scheduler_class.return_value = mock_scheduler
+ mock_invoker_class.return_value = mock_invoker
+ mock_executor_class.return_value = mock_executor
+ mock_web_server_class.return_value = mock_server
+
+ # Act
+ runner.start()
+
+ # Assert - Verify all dependencies were created
+ mock_store_class.assert_called_once()
+ mock_scheduler_class.assert_called_once()
+ mock_invoker_class.assert_called_once_with(mock_client)
+ # Verify Executor was called with the expected parameters including checkpoint_processor
+ assert mock_executor_class.call_count == 1
+ call_args = mock_executor_class.call_args
+ assert call_args.kwargs["store"] == mock_store
+ assert call_args.kwargs["scheduler"] == mock_scheduler
+ assert call_args.kwargs["invoker"] == mock_invoker
+ assert "checkpoint_processor" in call_args.kwargs
+ mock_web_server_class.assert_called_once_with(
+ config=web_config, executor=mock_executor
+ )
+
+ # Verify scheduler was started
+ mock_scheduler.start.assert_called_once()
+
+ # Cleanup
+ runner.stop()
+
+
+def test_should_pass_correct_configuration_to_web_server():
+ """Test that WebServer receives correct configuration from WebRunnerConfig."""
+ # Arrange
+ web_config = WebServiceConfig(
+ host="custom-host", port=9999, log_level="WARNING", max_request_size=2048
+ )
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Mock dependencies
+ with (
+ patch("boto3.client") as mock_boto3_client,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.WebServer"
+ ) as mock_web_server_class,
+ ):
+ mock_client = Mock()
+ mock_server = Mock()
+ mock_boto3_client.return_value = mock_client
+ mock_web_server_class.return_value = mock_server
+
+ # Act
+ runner.start()
+
+ # Assert - Verify WebServer was created with correct config
+ mock_web_server_class.assert_called_once()
+ call_args = mock_web_server_class.call_args
+
+ # Verify the web service config was passed correctly
+ passed_config = call_args[1]["config"]
+ assert passed_config == web_config
+ assert passed_config.host == "custom-host"
+ assert passed_config.port == 9999
+ assert passed_config.log_level == "WARNING"
+ assert passed_config.max_request_size == 2048
+
+ # Cleanup
+ runner.stop()
+
+
+def test_should_pass_correct_boto3_client_to_lambda_invoker():
+ """Test that LambdaInvoker receives correct boto3 client configuration."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(
+ web_service=web_config,
+ lambda_endpoint="http://test-endpoint:7777",
+ local_runner_region="ap-southeast-2",
+ )
+ runner = WebRunner(runner_config)
+
+ # Mock dependencies
+ with (
+ patch("boto3.client") as mock_boto3_client,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.LambdaInvoker"
+ ) as mock_invoker_class,
+ ):
+ mock_client = Mock()
+ mock_invoker = Mock()
+ mock_boto3_client.return_value = mock_client
+ mock_invoker_class.return_value = mock_invoker
+
+ # Act
+ runner.start()
+
+ # Assert - Verify boto3 client was created with correct parameters
+ mock_boto3_client.assert_called_once_with(
+ "lambda",
+ endpoint_url="http://test-endpoint:7777",
+ region_name="ap-southeast-2",
+ config=_LAMBDA_CLIENT_CONFIG,
+ )
+
+ # Verify LambdaInvoker was created with the client
+ mock_invoker_class.assert_called_once_with(mock_client)
+
+ # Cleanup
+ runner.stop()
+
+
+def test_should_wire_dependencies_correctly_in_executor():
+ """Test that Executor receives correctly wired dependencies."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Mock dependencies
+ with (
+ patch("boto3.client") as mock_boto3_client,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.InMemoryExecutionStore"
+ ) as mock_store_class,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.Scheduler"
+ ) as mock_scheduler_class,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.LambdaInvoker"
+ ) as mock_invoker_class,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.Executor"
+ ) as mock_executor_class,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.WebServer"
+ ) as mock_web_server_class,
+ ):
+ mock_client = Mock()
+ mock_store = Mock()
+ mock_scheduler = Mock()
+ mock_invoker = Mock()
+ mock_executor = Mock()
+ mock_web_server = Mock()
+
+ mock_boto3_client.return_value = mock_client
+ mock_store_class.return_value = mock_store
+ mock_scheduler_class.return_value = mock_scheduler
+ mock_invoker_class.return_value = mock_invoker
+ mock_executor_class.return_value = mock_executor
+ mock_web_server_class.return_value = mock_web_server
+
+ # Act
+ runner.start()
+
+ # Assert - Verify Executor was created with correct dependencies
+ assert mock_executor_class.call_count == 1
+ call_args = mock_executor_class.call_args
+ assert call_args.kwargs["store"] == mock_store
+ assert call_args.kwargs["scheduler"] == mock_scheduler
+ assert call_args.kwargs["invoker"] == mock_invoker
+ assert "checkpoint_processor" in call_args.kwargs
+
+ # Cleanup
+ runner.stop()
+
+
+# WebServer Lifecycle and Configuration Tests
+
+
+def test_should_delegate_serve_forever_to_web_server():
+ """Test that serve_forever() properly delegates to WebServer.serve_forever()."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Mock dependencies
+ with (
+ patch("boto3.client") as mock_boto3_client,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.WebServer"
+ ) as mock_web_server_class,
+ ):
+ mock_client = Mock()
+ mock_server = Mock()
+ mock_boto3_client.return_value = mock_client
+ mock_web_server_class.return_value = mock_server
+
+ # Start the runner
+ runner.start()
+
+ # Act
+ runner.serve_forever()
+
+ # Assert - Verify WebServer.serve_forever was called
+ mock_server.serve_forever.assert_called_once()
+
+ # Cleanup
+ runner.stop()
+
+
+def test_should_call_server_close_during_stop():
+ """Test that stop() calls server_close() on WebServer."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Mock dependencies
+ with (
+ patch("boto3.client") as mock_boto3_client,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.WebServer"
+ ) as mock_web_server_class,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.Scheduler"
+ ) as mock_scheduler_class,
+ ):
+ mock_client = Mock()
+ mock_server = Mock()
+ mock_scheduler = Mock()
+ mock_boto3_client.return_value = mock_client
+ mock_web_server_class.return_value = mock_server
+ mock_scheduler_class.return_value = mock_scheduler
+
+ # Start the runner
+ runner.start()
+
+ # Act
+ runner.stop()
+
+ # Assert - Verify cleanup methods were called
+ mock_server.server_close.assert_called_once()
+ mock_scheduler.stop.assert_called_once()
+
+
+def test_should_handle_web_server_serve_forever_exceptions():
+ """Test that exceptions from WebServer.serve_forever() are propagated."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Mock dependencies
+ with (
+ patch("boto3.client") as mock_boto3_client,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.WebServer"
+ ) as mock_web_server_class,
+ ):
+ mock_client = Mock()
+ mock_server = Mock()
+ mock_boto3_client.return_value = mock_client
+ mock_web_server_class.return_value = mock_server
+
+ # Make serve_forever raise an exception
+ mock_server.serve_forever.side_effect = Exception("Server error")
+
+ # Start the runner
+ runner.start()
+
+ # Act & Assert - Exception should propagate
+ with pytest.raises(Exception, match="Server error"):
+ runner.serve_forever()
+
+ # Cleanup
+ runner.stop()
+
+
+def test_should_handle_web_server_close_exceptions_gracefully():
+ """Test that exceptions from server_close() are handled gracefully."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Mock dependencies
+ with (
+ patch("boto3.client") as mock_boto3_client,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.WebServer"
+ ) as mock_web_server_class,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.Scheduler"
+ ) as mock_scheduler_class,
+ ):
+ mock_client = Mock()
+ mock_server = Mock()
+ mock_scheduler = Mock()
+ mock_boto3_client.return_value = mock_client
+ mock_web_server_class.return_value = mock_server
+ mock_scheduler_class.return_value = mock_scheduler
+
+ # Make server_close raise an exception
+ mock_server.server_close.side_effect = Exception("Close error")
+
+ # Start the runner
+ runner.start()
+
+ # Act - stop() should not raise exception despite server_close error
+ runner.stop()
+
+ # Assert - Verify both cleanup methods were attempted
+ mock_server.server_close.assert_called_once()
+ mock_scheduler.stop.assert_called_once()
+
+
+# Exception Handling Tests
+
+
+def test_should_handle_standard_runtime_errors():
+ """Test that standard RuntimeError exceptions are handled properly."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Test RuntimeError during start
+ with patch("boto3.client") as mock_boto3_client:
+ mock_boto3_client.side_effect = RuntimeError("Runtime error during start")
+
+ with pytest.raises(RuntimeError, match="Runtime error during start"):
+ runner.start()
+
+ # Test DurableFunctionsLocalRunnerError when serve_forever called before start
+ with pytest.raises(DurableFunctionsLocalRunnerError, match="Server not started"):
+ runner.serve_forever()
+
+
+def test_should_handle_value_errors_during_initialization():
+ """Test that ValueError exceptions during initialization are propagated."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Mock boto3 client to raise ValueError
+ with patch("boto3.client") as mock_boto3_client:
+ mock_boto3_client.side_effect = ValueError("Invalid configuration")
+
+ # Act & Assert
+ with pytest.raises(ValueError, match="Invalid configuration"):
+ runner.start()
+
+
+def test_should_handle_connection_errors_during_initialization():
+ """Test that ConnectionError exceptions during initialization are propagated."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Mock boto3 client to raise ConnectionError
+ with patch("boto3.client") as mock_boto3_client:
+ mock_boto3_client.side_effect = ConnectionError("Network connection failed")
+
+ # Act & Assert
+ with pytest.raises(ConnectionError, match="Network connection failed"):
+ runner.start()
+
+
+def test_should_handle_keyboard_interrupt_during_serve_forever():
+ """Test that KeyboardInterrupt during serve_forever is propagated."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Mock dependencies
+ with (
+ patch("boto3.client") as mock_boto3_client,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.WebServer"
+ ) as mock_web_server_class,
+ ):
+ mock_client = Mock()
+ mock_server = Mock()
+ mock_boto3_client.return_value = mock_client
+ mock_web_server_class.return_value = mock_server
+
+ # Make serve_forever raise KeyboardInterrupt
+ mock_server.serve_forever.side_effect = KeyboardInterrupt()
+
+ # Start the runner
+ runner.start()
+
+ # Act & Assert - KeyboardInterrupt should propagate
+ with pytest.raises(KeyboardInterrupt):
+ runner.serve_forever()
+
+ # Cleanup
+ runner.stop()
+
+
+# Lifecycle Management Tests
+
+
+def test_start_creates_dependencies_and_server():
+ """Test that start() creates all dependencies and WebServer through public API."""
+ # Arrange
+ web_config = WebServiceConfig(host="localhost", port=5000)
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Mock dependencies and WebServer
+ with (
+ patch("boto3.client") as mock_boto3_client,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.WebServer"
+ ) as mock_web_server_class,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.Scheduler"
+ ) as mock_scheduler_class,
+ ):
+ mock_client = Mock()
+ mock_server = Mock()
+ mock_scheduler = Mock()
+
+ mock_boto3_client.return_value = mock_client
+ mock_web_server_class.return_value = mock_server
+ mock_scheduler_class.return_value = mock_scheduler
+
+ # Act
+ runner.start()
+
+ # Assert - Test through public behavior
+ # Should be able to call serve_forever after start
+ runner.serve_forever()
+ mock_server.serve_forever.assert_called_once()
+
+ # Should be able to stop after start
+ runner.stop()
+ mock_server.server_close.assert_called_once()
+ mock_scheduler.stop.assert_called_once()
+
+
+def test_start_raises_runtime_error_if_already_started():
+ """Test that start() raises DurableFunctionsLocalRunnerError if server is already running."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Set server to simulate already started state
+ runner._server = Mock() # noqa: SLF001
+
+ # Act & Assert
+ with pytest.raises(
+ DurableFunctionsLocalRunnerError, match="Server is already running"
+ ):
+ runner.start()
+
+
+def test_serve_forever_delegates_to_web_server():
+ """Test that serve_forever() delegates to WebServer.serve_forever() through public API."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Mock dependencies to allow start
+ with (
+ patch("boto3.client") as mock_boto3_client,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.WebServer"
+ ) as mock_web_server_class,
+ ):
+ mock_client = Mock()
+ mock_server = Mock()
+ mock_boto3_client.return_value = mock_client
+ mock_web_server_class.return_value = mock_server
+
+ # Start the runner first (public API)
+ runner.start()
+
+ # Act
+ runner.serve_forever()
+
+ # Assert
+ mock_server.serve_forever.assert_called_once()
+
+ # Cleanup
+ runner.stop()
+
+
+def test_serve_forever_raises_runtime_error_if_not_started():
+ """Test that serve_forever() raises DurableFunctionsLocalRunnerError if server not started."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Ensure server is None (not started)
+ assert runner._server is None # noqa: SLF001
+
+ # Act & Assert
+ with pytest.raises(DurableFunctionsLocalRunnerError, match="Server not started"):
+ runner.serve_forever()
+
+
+def test_stop_cleans_up_server_and_scheduler():
+ """Test that stop() properly cleans up server and scheduler resources through public API."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Mock dependencies to allow start
+ with (
+ patch("boto3.client") as mock_boto3_client,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.WebServer"
+ ) as mock_web_server_class,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.Scheduler"
+ ) as mock_scheduler_class,
+ ):
+ mock_client = Mock()
+ mock_server = Mock()
+ mock_scheduler = Mock()
+
+ mock_boto3_client.return_value = mock_client
+ mock_web_server_class.return_value = mock_server
+ mock_scheduler_class.return_value = mock_scheduler
+
+ # Start the runner first (public API)
+ runner.start()
+
+ # Act
+ runner.stop()
+
+ # Assert - Verify cleanup was called
+ mock_server.server_close.assert_called_once()
+ mock_scheduler.stop.assert_called_once()
+
+ # Verify runner is back to stopped state (public behavior)
+ with pytest.raises(
+ DurableFunctionsLocalRunnerError, match="Server not started"
+ ):
+ runner.serve_forever()
+
+
+def test_stop_is_safe_to_call_multiple_times():
+ """Test that stop() can be called multiple times safely through public API."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Mock dependencies to allow start
+ with (
+ patch("boto3.client") as mock_boto3_client,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.WebServer"
+ ) as mock_web_server_class,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.Scheduler"
+ ) as mock_scheduler_class,
+ ):
+ mock_client = Mock()
+ mock_server = Mock()
+ mock_scheduler = Mock()
+
+ mock_boto3_client.return_value = mock_client
+ mock_web_server_class.return_value = mock_server
+ mock_scheduler_class.return_value = mock_scheduler
+
+ # Start the runner first (public API)
+ runner.start()
+
+ # Act - call stop multiple times
+ runner.stop()
+ runner.stop()
+ runner.stop()
+
+ # Assert - should only be called once (first time)
+ mock_server.server_close.assert_called_once()
+ mock_scheduler.stop.assert_called_once()
+
+ # Verify runner remains in stopped state (public behavior)
+ with pytest.raises(
+ DurableFunctionsLocalRunnerError, match="Server not started"
+ ):
+ runner.serve_forever()
+
+
+# Integration Tests - CLI to WebRunner Flow
+
+
+def test_should_integrate_with_cli_start_server_command():
+ """Test complete integration from CLI start-server command to WebRunner."""
+ # This test verifies the complete flow from CLI argument parsing
+ # through WebRunnerConfig creation to WebRunner execution
+
+ # Arrange
+ app = CliApp()
+
+ # Mock WebRunner to verify it receives correct configuration
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.WebRunner"
+ ) as mock_web_runner_class:
+ # Setup mock runner instance with context manager support
+ mock_runner = Mock()
+ mock_runner.__enter__ = Mock(return_value=mock_runner)
+ mock_runner.__exit__ = Mock(return_value=None)
+ mock_web_runner_class.return_value = mock_runner
+ mock_runner.serve_forever.side_effect = KeyboardInterrupt()
+
+ # Act - Run CLI command with custom arguments
+ exit_code = app.run(
+ [
+ "start-server",
+ "--host",
+ "integration-host",
+ "--port",
+ "7777",
+ "--log-level",
+ "WARNING",
+ "--lambda-endpoint",
+ "http://integration-lambda:4000",
+ "--local-runner-endpoint",
+ "http://integration-runner:8000",
+ "--local-runner-region",
+ "eu-central-1",
+ "--local-runner-mode",
+ "integration",
+ ]
+ )
+
+ # Assert - Verify CLI handled KeyboardInterrupt correctly
+ assert exit_code == 130
+
+ # Verify WebRunner was created with correct configuration
+ mock_web_runner_class.assert_called_once()
+ config = mock_web_runner_class.call_args[0][0]
+
+ # Verify web service configuration
+ assert config.web_service.host == "integration-host"
+ assert config.web_service.port == 7777
+ assert config.web_service.log_level == "WARNING"
+
+ # Verify Lambda service configuration
+ assert config.lambda_endpoint == "http://integration-lambda:4000"
+ assert config.local_runner_endpoint == "http://integration-runner:8000"
+ assert config.local_runner_region == "eu-central-1"
+ assert config.local_runner_mode == "integration"
+
+ # Verify context manager protocol was used
+ mock_runner.__enter__.assert_called_once()
+ mock_runner.__exit__.assert_called_once()
+ mock_runner.serve_forever.assert_called_once()
+
+
+def test_should_handle_cli_to_web_runner_startup_errors():
+ """Test integration error handling from CLI to WebRunner startup failures."""
+ # Arrange
+ app = CliApp()
+
+ # Mock WebRunner to raise exception during creation
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.WebRunner"
+ ) as mock_web_runner_class:
+ mock_web_runner_class.side_effect = Exception("WebRunner startup failed")
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.logger"
+ ) as mock_logger:
+ # Act
+ exit_code = app.run(["start-server"])
+
+ # Assert - Verify CLI handled WebRunner exception correctly
+ assert exit_code == 1
+ mock_logger.exception.assert_called_with("Failed to start server")
+
+
+def test_should_handle_cli_to_web_runner_context_manager_errors():
+ """Test integration error handling for WebRunner context manager failures."""
+ # Arrange
+ app = CliApp()
+
+ # Mock WebRunner context manager to raise exception
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.WebRunner"
+ ) as mock_web_runner_class:
+ mock_runner = Mock()
+ mock_runner.__enter__ = Mock(
+ side_effect=DurableFunctionsLocalRunnerError("Context manager failed")
+ )
+ mock_runner.__exit__ = Mock(return_value=None)
+ mock_web_runner_class.return_value = mock_runner
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.logger"
+ ) as mock_logger:
+ # Act
+ exit_code = app.run(["start-server"])
+
+ # Assert - Verify CLI handled context manager exception correctly
+ assert exit_code == 1
+ mock_logger.exception.assert_called_with("Failed to start server")
+
+
+def test_should_handle_cli_to_web_runner_serve_forever_errors():
+ """Test integration error handling for WebRunner serve_forever failures."""
+ # Arrange
+ app = CliApp()
+
+ # Mock WebRunner serve_forever to raise exception
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.WebRunner"
+ ) as mock_web_runner_class:
+ mock_runner = Mock()
+ mock_runner.__enter__ = Mock(return_value=mock_runner)
+ mock_runner.__exit__ = Mock(return_value=None)
+ mock_web_runner_class.return_value = mock_runner
+ mock_runner.serve_forever.side_effect = Exception("Server runtime error")
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.logger"
+ ) as mock_logger:
+ # Act
+ exit_code = app.run(["start-server"])
+
+ # Assert - Verify CLI handled serve_forever exception correctly
+ assert exit_code == 1
+ mock_logger.exception.assert_called_with("Failed to start server")
+
+
+def test_should_preserve_cli_configuration_through_web_runner():
+ """Test that CLI configuration is preserved through WebRunner creation."""
+ # This test verifies that all CLI arguments are correctly passed through
+ # the WebRunnerConfig to the WebRunner and its dependencies
+
+ # Arrange
+ app = CliApp()
+
+ # Mock all WebRunner dependencies to verify configuration flow
+ with (
+ patch(
+ "aws_durable_execution_sdk_python_testing.cli.WebRunner"
+ ) as mock_web_runner_class,
+ patch("boto3.client"),
+ patch("aws_durable_execution_sdk_python_testing.runner.WebServer"),
+ ):
+ # Setup mocks
+ mock_runner = Mock()
+
+ mock_runner.__enter__ = Mock(return_value=mock_runner)
+ mock_runner.__exit__ = Mock(return_value=None)
+ mock_web_runner_class.return_value = mock_runner
+ mock_runner.serve_forever.return_value = None
+
+ # No need to mock internal behavior, just verify configuration passing
+
+ # Act - Run CLI with comprehensive configuration
+ exit_code = app.run(
+ [
+ "start-server",
+ "--host",
+ "config-test-host",
+ "--port",
+ "9999",
+ "--log-level",
+ "ERROR", # ERROR level
+ "--lambda-endpoint",
+ "http://config-lambda:5000",
+ "--local-runner-endpoint",
+ "http://config-runner:9000",
+ "--local-runner-region",
+ "ap-northeast-1",
+ "--local-runner-mode",
+ "config-test",
+ ]
+ )
+
+ # Assert - Verify successful execution
+ assert exit_code == 0
+
+ # Verify WebRunner was created with correct configuration
+ mock_web_runner_class.assert_called_once()
+ config = mock_web_runner_class.call_args[0][0]
+
+ # Verify web service configuration
+ assert config.web_service.host == "config-test-host"
+ assert config.web_service.port == 9999
+ assert config.web_service.log_level == "ERROR"
+
+ # Verify Lambda service configuration
+ assert config.lambda_endpoint == "http://config-lambda:5000"
+ assert config.local_runner_endpoint == "http://config-runner:9000"
+ assert config.local_runner_region == "ap-northeast-1"
+ assert config.local_runner_mode == "config-test"
+
+ # Verify context manager protocol was used
+ mock_runner.__enter__.assert_called_once()
+ mock_runner.serve_forever.assert_called_once()
+ mock_runner.__exit__.assert_called_once()
+
+
+def test_should_handle_environment_variable_integration():
+ """Test integration with environment variables through CLI to WebRunner."""
+ # Set environment variables
+ env_vars = {
+ "AWS_DEX_HOST": "env-host",
+ "AWS_DEX_PORT": "8888",
+ "AWS_DEX_LOG_LEVEL": "CRITICAL", # CRITICAL level
+ "AWS_DEX_LAMBDA_ENDPOINT": "http://env-lambda:6000",
+ "AWS_DEX_LOCAL_RUNNER_ENDPOINT": "http://env-runner:7000",
+ "AWS_DEX_LOCAL_RUNNER_REGION": "sa-east-1",
+ "AWS_DEX_LOCAL_RUNNER_MODE": "env-test",
+ }
+
+ with patch.dict(os.environ, env_vars, clear=True):
+ app = CliApp()
+
+ # Mock WebRunner to verify environment configuration
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.WebRunner"
+ ) as mock_web_runner_class:
+ mock_runner = Mock()
+ mock_web_runner_class.return_value = mock_runner
+ mock_runner.__enter__ = Mock(return_value=mock_runner)
+ mock_runner.__exit__ = Mock(return_value=None)
+ mock_runner.serve_forever.return_value = None
+
+ # Act - Run CLI without arguments (should use environment)
+ exit_code = app.run(["start-server"])
+
+ # Assert - Verify successful execution
+ assert exit_code == 0
+
+ # Verify WebRunner was created with environment configuration
+ mock_web_runner_class.assert_called_once()
+ config = mock_web_runner_class.call_args[0][0]
+
+ # Verify environment variables were used
+ assert config.web_service.host == "env-host"
+ assert config.web_service.port == 8888
+ assert config.web_service.log_level == "CRITICAL"
+ assert config.lambda_endpoint == "http://env-lambda:6000"
+ assert config.local_runner_endpoint == "http://env-runner:7000"
+ assert config.local_runner_region == "sa-east-1"
+ assert config.local_runner_mode == "env-test"
+
+
+def test_should_handle_cli_argument_override_of_environment():
+ """Test that CLI arguments override environment variables in integration."""
+ # Set environment variables
+ env_vars = {
+ "AWS_DEX_HOST": "env-host",
+ "AWS_DEX_PORT": "8888",
+ }
+
+ with patch.dict(os.environ, env_vars, clear=True):
+ app = CliApp()
+
+ # Mock WebRunner to verify argument override
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.WebRunner"
+ ) as mock_web_runner_class:
+ mock_runner = Mock()
+ mock_web_runner_class.return_value = mock_runner
+ mock_runner.__enter__ = Mock(return_value=mock_runner)
+ mock_runner.__exit__ = Mock(return_value=None)
+ mock_runner.serve_forever.return_value = None
+
+ # Act - Run CLI with arguments that should override environment
+ exit_code = app.run(
+ [
+ "start-server",
+ "--host",
+ "cli-override-host",
+ "--port",
+ "7777",
+ ]
+ )
+
+ # Assert - Verify successful execution
+ assert exit_code == 0
+
+ # Verify CLI arguments overrode environment variables
+ config = mock_web_runner_class.call_args[0][0]
+ assert config.web_service.host == "cli-override-host" # CLI override
+ assert config.web_service.port == 7777 # CLI override
+
+
+def test_should_maintain_backward_compatibility_in_integration():
+ """Test that integration maintains backward compatibility with existing behavior."""
+ # This test ensures that the refactored CLI-to-WebRunner flow
+ # maintains the same external behavior as the original implementation
+ app = CliApp()
+
+ # Mock WebRunner to simulate successful operation
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.WebRunner"
+ ) as mock_web_runner_class:
+ mock_runner = Mock()
+ mock_web_runner_class.return_value = mock_runner
+ mock_runner.__enter__ = Mock(return_value=mock_runner)
+ mock_runner.__exit__ = Mock(return_value=None)
+ mock_runner.serve_forever.side_effect = KeyboardInterrupt()
+
+ # Mock logging to verify backward compatible messages
+ with patch(
+ "aws_durable_execution_sdk_python_testing.cli.logger"
+ ) as mock_logger:
+ # Act
+ exit_code = app.run(
+ ["start-server", "--host", "compat-host", "--port", "5555"]
+ )
+
+ # Assert - Verify backward compatible behavior
+ assert exit_code == 130 # KeyboardInterrupt exit code
+
+ # Verify backward compatible logging messages
+ mock_logger.info.assert_any_call(
+ "Starting Durable Functions Local Runner on %s:%s",
+ "compat-host",
+ 5555,
+ )
+ mock_logger.info.assert_any_call("Configuration:")
+ mock_logger.info.assert_any_call(" Host: %s", "compat-host")
+ mock_logger.info.assert_any_call(" Port: %s", 5555)
+ mock_logger.info.assert_any_call(
+ "Server started successfully. Press Ctrl+C to stop."
+ )
+ mock_logger.info.assert_any_call(
+ "Received shutdown signal, stopping server..."
+ )
+
+
+def test_stop_handles_unstarted_runner_gracefully():
+ """Test that stop() handles unstarted runner gracefully through public API."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Act & Assert - should not raise any exceptions when stopping unstarted runner
+ runner.stop()
+
+ # Verify runner remains in stopped state (public behavior)
+ with pytest.raises(DurableFunctionsLocalRunnerError, match="Server not started"):
+ runner.serve_forever()
+
+
+def test_complete_lifecycle_start_serve_stop():
+ """Test complete lifecycle: start -> serve_forever -> stop through public API."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Mock all dependencies
+ with (
+ patch("boto3.client") as mock_boto3_client,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.WebServer"
+ ) as mock_web_server_class,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.Scheduler"
+ ) as mock_scheduler_class,
+ ):
+ mock_client = Mock()
+ mock_server = Mock()
+ mock_scheduler = Mock()
+
+ mock_boto3_client.return_value = mock_client
+ mock_web_server_class.return_value = mock_server
+ mock_scheduler_class.return_value = mock_scheduler
+
+ # Act - complete lifecycle through public API
+ runner.start()
+ runner.serve_forever()
+ runner.stop()
+
+ # Assert - Verify all methods were called
+ mock_server.serve_forever.assert_called_once()
+ mock_server.server_close.assert_called_once()
+ mock_scheduler.start.assert_called_once()
+ mock_scheduler.stop.assert_called_once()
+
+ # Verify runner is back to stopped state (public behavior)
+ with pytest.raises(
+ DurableFunctionsLocalRunnerError, match="Server not started"
+ ):
+ runner.serve_forever()
+
+
+def test_context_manager_calls_start_and_stop():
+ """Test that context manager properly calls start() and stop()."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Mock start and stop to track calls
+ with (
+ patch.object(runner, "start") as mock_start,
+ patch.object(runner, "stop") as mock_stop,
+ ):
+ # Act
+ with runner as context_runner:
+ # Verify start was called and runner returned
+ mock_start.assert_called_once()
+ assert context_runner is runner
+
+ # Assert stop was called on exit
+ mock_stop.assert_called_once()
+
+
+def test_context_manager_calls_stop_on_exception():
+ """Test that context manager calls stop() even when exception occurs."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Mock start and stop
+ with (
+ patch.object(runner, "start") as mock_start,
+ patch.object(runner, "stop") as mock_stop,
+ ):
+ # Act & Assert
+ with pytest.raises(ValueError, match="Test exception"): # noqa: PT012
+ with runner:
+ mock_start.assert_called_once()
+ raise ValueError("Test exception") # noqa: TRY003, EM101
+
+ # Verify stop was still called despite exception
+ mock_stop.assert_called_once()
+
+
+def test_state_transitions_prevent_invalid_operations():
+ """Test that state checking prevents invalid operation sequences through public API."""
+ # Arrange
+ web_config = WebServiceConfig()
+ runner_config = WebRunnerConfig(web_service=web_config)
+ runner = WebRunner(runner_config)
+
+ # Test serve_forever before start
+ with pytest.raises(DurableFunctionsLocalRunnerError, match="Server not started"):
+ runner.serve_forever()
+
+ # Mock dependencies for start
+ with (
+ patch("boto3.client") as mock_boto3_client,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.WebServer"
+ ) as mock_web_server_class,
+ patch(
+ "aws_durable_execution_sdk_python_testing.runner.Scheduler"
+ ) as mock_scheduler_class,
+ ):
+ mock_client = Mock()
+ mock_server = Mock()
+ mock_scheduler = Mock()
+
+ mock_boto3_client.return_value = mock_client
+ mock_web_server_class.return_value = mock_server
+ mock_scheduler_class.return_value = mock_scheduler
+
+ # Start server
+ runner.start()
+
+ # Test double start
+ with pytest.raises(
+ DurableFunctionsLocalRunnerError, match="Server is already running"
+ ):
+ runner.start()
+
+ # Verify serve_forever works after start
+ runner.serve_forever()
+ mock_server.serve_forever.assert_called_once()
+
+ # Stop and verify serve_forever fails again
+ runner.stop()
+ with pytest.raises(
+ DurableFunctionsLocalRunnerError, match="Server not started"
+ ):
+ runner.serve_forever()
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/scheduler_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/scheduler_test.py
new file mode 100644
index 0000000..65db932
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/scheduler_test.py
@@ -0,0 +1,729 @@
+"""Unit tests for scheduler.py"""
+
+import threading
+import time
+from concurrent.futures import Future
+from unittest.mock import patch
+
+import pytest
+
+from aws_durable_execution_sdk_python_testing.scheduler import Event, Scheduler
+
+
+def wait_for_condition(condition_func, timeout_iterations=100):
+ """Wait for a condition to become true with polling."""
+ for _ in range(timeout_iterations):
+ if condition_func():
+ return True
+ time.sleep(0.001)
+ return False
+
+
+def test_scheduler_init():
+ """Test Scheduler initialization."""
+ scheduler = Scheduler()
+ assert not scheduler.is_started()
+ assert scheduler.event_count() == 0
+
+
+def test_scheduler_context_manager():
+ """Test Scheduler as context manager."""
+ with Scheduler() as scheduler:
+ assert scheduler.is_started()
+ assert not scheduler.is_started()
+
+
+def test_scheduler_start_stop():
+ """Test Scheduler start and stop methods."""
+ scheduler = Scheduler()
+
+ scheduler.start()
+ assert scheduler.is_started()
+
+ # Test start when already running
+ scheduler.start()
+ assert scheduler.is_started()
+
+ scheduler.stop()
+ assert not scheduler.is_started()
+
+ # Test stop when not running
+ scheduler.stop()
+ assert not scheduler.is_started()
+
+
+def test_scheduler_is_started():
+ """Test Scheduler is_started method."""
+ scheduler = Scheduler()
+
+ # Initially not started
+ assert not scheduler.is_started()
+
+ # After start
+ scheduler.start()
+ assert scheduler.is_started()
+
+ # After stop
+ scheduler.stop()
+ assert not scheduler.is_started()
+
+
+def test_scheduler_event_count():
+ """Test Scheduler event_count method."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ # Initially no events
+ assert scheduler.event_count() == 0
+
+ # Create events
+ event1 = scheduler.create_event()
+ assert scheduler.event_count() == 1
+
+ scheduler.create_event()
+ assert scheduler.event_count() == 2
+
+ # Remove event
+ event1.remove()
+ wait_for_condition(lambda: scheduler.event_count() == 1)
+ assert scheduler.event_count() == 1
+
+ scheduler.stop()
+
+
+def test_scheduler_task_count():
+ """Test Scheduler task_count method."""
+ scheduler = Scheduler()
+
+ # When not started, task count is 0
+ assert scheduler.task_count() == 0
+
+ scheduler.start()
+
+ # Create tasks with longer delay to ensure they're counted
+ future1 = scheduler.call_later(lambda: None, delay=0.5)
+ # Give a moment for the task to be created
+ time.sleep(0.01)
+ assert scheduler.task_count() >= 1
+
+ future2 = scheduler.call_later(lambda: None, delay=0.5)
+ time.sleep(0.01)
+ assert scheduler.task_count() >= 2
+
+ # Cancel tasks to clean up
+ future1.cancel()
+ future2.cancel()
+
+ # Wait for tasks to complete or be cancelled
+ wait_for_condition(lambda: scheduler.task_count() == 0, timeout_iterations=200)
+
+ scheduler.stop()
+
+
+def test_scheduler_call_later_sync_function():
+ """Test call_later with sync function."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ result = []
+
+ def sync_func():
+ result.append("executed")
+
+ future = scheduler.call_later(sync_func, delay=0.01)
+ wait_for_condition(lambda: future.done())
+
+ assert isinstance(future, Future)
+ assert result == ["executed"]
+ assert future.done()
+
+ scheduler.stop()
+
+
+def test_scheduler_call_later_async_function():
+ """Test call_later with async function."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ result = []
+
+ async def async_func():
+ result.append("async_executed")
+
+ future = scheduler.call_later(async_func, delay=0.01)
+ wait_for_condition(lambda: future.done())
+
+ assert isinstance(future, Future)
+ assert result == ["async_executed"]
+ assert future.done()
+
+ scheduler.stop()
+
+
+def test_scheduler_call_later_multiple_count():
+ """Test call_later with multiple executions."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ result = []
+
+ def func():
+ result.append("count")
+
+ # Note: Current implementation only executes once due to early return
+ future = scheduler.call_later(func, delay=0.01, count=3)
+ wait_for_condition(lambda: future.done())
+
+ # Current implementation only executes once
+ assert len(result) == 1
+ assert future.done()
+
+ scheduler.stop()
+
+
+def test_scheduler_call_later_infinite_count():
+ """Test call_later with infinite count."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ result = []
+
+ def func():
+ result.append("infinite")
+
+ # Note: Current implementation only executes once due to early return
+ future = scheduler.call_later(func, delay=0.01, count=None)
+ wait_for_condition(lambda: future.done())
+
+ # Current implementation only executes once
+ assert len(result) == 1
+ assert future.done()
+
+ scheduler.stop()
+
+
+def test_scheduler_call_later_function_exception():
+ """Test call_later with function that raises exception."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ def failing_func() -> None:
+ msg: str = "test error"
+
+ raise ValueError(msg)
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.scheduler.logger"
+ ) as mock_logger:
+ future = scheduler.call_later(failing_func, delay=0.01)
+ wait_for_condition(lambda: future.done())
+
+ assert future.done()
+ mock_logger.exception.assert_called()
+
+ scheduler.stop()
+
+
+def test_scheduler_create_event():
+ """Test create_event method."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ event = scheduler.create_event()
+
+ assert isinstance(event, Event)
+ assert scheduler.event_count() == 1
+
+ scheduler.stop()
+
+
+def test_task_cancel():
+ """Test Future cancel method."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ def func():
+ pass
+
+ future = scheduler.call_later(func, delay=0.1, count=None)
+ future.cancel()
+
+ # Wait briefly for cancellation to take effect
+ wait_for_condition(lambda: future.cancelled())
+
+ assert future.cancelled()
+
+ scheduler.stop()
+
+
+def test_task_is_done():
+ """Test Future done property."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ def quick_func():
+ pass
+
+ future = scheduler.call_later(quick_func, delay=0.01)
+ assert not future.done()
+
+ wait_for_condition(lambda: future.done())
+ assert future.done()
+
+ # Small delay to ensure coroutine cleanup completes
+ time.sleep(0.01)
+ scheduler.stop()
+
+
+def test_task_result():
+ """Test Future result method."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ def func():
+ return None
+
+ future = scheduler.call_later(func, delay=0.01)
+ wait_for_condition(lambda: future.done())
+
+ result = future.result()
+ assert result is None
+
+ scheduler.stop()
+
+
+def test_task_cancel_method():
+ """Test Future cancel method."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ # Create a future and cancel it immediately
+ future = scheduler.call_later(lambda: None, delay=0.01)
+ future.cancel()
+
+ # The cancel method should work without hanging
+ # We don't test the result here to avoid timing issues
+
+ scheduler.stop()
+
+
+def test_task_result_completed():
+ """Test Future result method when completed."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ def func():
+ return "test_result"
+
+ future = scheduler.call_later(func, delay=0.01)
+ wait_for_condition(lambda: future.done())
+ assert future.done()
+
+ # Small delay to ensure coroutine cleanup completes
+ time.sleep(0.01)
+ scheduler.stop()
+
+
+def test_event_set_and_wait_timeout():
+ """Test Event set and wait with timeout."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ event = scheduler.create_event()
+
+ # Test wait with timeout (should timeout)
+ result = event.wait(timeout=0.01, clear_on_set=False)
+ assert result is False
+
+ # Set the event
+ event.set()
+
+ # Wait should now succeed
+ result = event.wait(timeout=0.1, clear_on_set=True)
+ assert result is True
+
+ scheduler.stop()
+
+
+def test_event_wait_set_by_thread():
+ """Test Event wait when set by another thread."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ event = scheduler.create_event()
+ result_container = []
+ start_event = threading.Event()
+
+ def set_event():
+ start_event.wait() # Wait for signal to start
+ event.set()
+
+ def wait_for_event():
+ result = event.wait(timeout=1.0)
+ result_container.append(result)
+
+ set_thread = threading.Thread(target=set_event)
+ wait_thread = threading.Thread(target=wait_for_event)
+
+ set_thread.start()
+ wait_thread.start()
+ start_event.set() # Signal to start setting event
+
+ set_thread.join()
+ wait_thread.join()
+
+ assert result_container[0] is True
+
+ scheduler.stop()
+
+
+def test_event_wait_clear_on_set_false():
+ """Test Event wait with clear_on_set=False."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ event = scheduler.create_event()
+ event.set()
+
+ result = event.wait(clear_on_set=False)
+ assert result is True
+ assert scheduler.event_count() == 1
+
+ scheduler.stop()
+
+
+def test_event_remove():
+ """Test Event remove method."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ event = scheduler.create_event()
+ assert scheduler.event_count() == 1
+
+ event.remove()
+ wait_for_condition(lambda: scheduler.event_count() == 0)
+
+ assert scheduler.event_count() == 0
+
+ scheduler.stop()
+
+
+def test_event_wait_removed_event():
+ """Test Event wait on removed event."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ event = scheduler.create_event()
+ event.remove()
+ wait_for_condition(lambda: scheduler.event_count() == 0)
+
+ result = event.wait(timeout=0.01)
+ assert result is False
+
+ scheduler.stop()
+
+
+def test_event_set_removed_event():
+ """Test Event set on removed event."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ event = scheduler.create_event()
+ event.remove()
+ wait_for_condition(lambda: scheduler.event_count() == 0)
+
+ # Should not crash
+ event.set()
+
+ scheduler.stop()
+
+
+def test_scheduler_cleanup_on_stop():
+ """Test scheduler cleanup when stopped."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ # Create a future and event
+ scheduler.call_later(lambda: None, delay=0.1, count=1)
+ scheduler.create_event()
+
+ # Stop scheduler immediately
+ scheduler.stop()
+
+ # Events should be cleared (this is what we can reliably test)
+ assert scheduler.event_count() == 0
+ # Future state may vary due to timing, but scheduler should be stopped
+ assert not scheduler.is_started()
+
+
+def test_scheduler_multiple_events():
+ """Test scheduler with multiple events."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ event1 = scheduler.create_event()
+ event2 = scheduler.create_event()
+
+ assert scheduler.event_count() == 2
+
+ event1.set()
+ result1 = event1.wait(timeout=0.01)
+ assert result1 is True
+
+ result2 = event2.wait(timeout=0.01)
+ assert result2 is False
+
+ scheduler.stop()
+
+
+def test_task_properties_after_scheduler_stop():
+ """Test Future properties after scheduler is stopped."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ def func():
+ pass
+
+ future = scheduler.call_later(func, delay=0.01)
+ wait_for_condition(lambda: future.done())
+
+ scheduler.stop()
+
+ assert future.done()
+ assert not future.cancelled()
+
+
+def test_event_timeout_handling():
+ """Test Event timeout handling."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ event = scheduler.create_event()
+
+ start_time = time.time()
+ result = event.wait(timeout=0.05)
+ end_time = time.time()
+
+ assert result is False
+ assert 0.04 <= (end_time - start_time) <= 0.1
+
+ scheduler.stop()
+
+
+def test_scheduler_call_later_zero_delay():
+ """Test call_later with zero delay."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ result = []
+
+ def func():
+ result.append("zero_delay")
+
+ future = scheduler.call_later(func, delay=0)
+ wait_for_condition(lambda: future.done())
+
+ assert result == ["zero_delay"]
+ assert future.done()
+
+ scheduler.stop()
+
+
+def test_scheduler_call_later_default_parameters():
+ """Test call_later with default parameters."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ result = []
+
+ def func():
+ result.append("default")
+
+ future = scheduler.call_later(func)
+ wait_for_condition(lambda: future.done())
+
+ assert result == ["default"]
+ assert future.done()
+
+ scheduler.stop()
+
+
+def test_task_result_with_exception():
+ """Test Future result method when function raises exception."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ def failing_func() -> None:
+ msg: str = "test exception"
+
+ raise ValueError(msg)
+
+ # Test that user function exceptions are propagated through the Future
+ with patch(
+ "aws_durable_execution_sdk_python_testing.scheduler.logger"
+ ) as mock_logger:
+ future = scheduler.call_later(failing_func, delay=0.01)
+ wait_for_condition(lambda: future.done())
+
+ # Future should be done and exception should be logged
+ assert future.done()
+ mock_logger.exception.assert_called()
+
+ # Exception should be propagated through Future.result()
+ with pytest.raises(ValueError, match="test exception"):
+ future.result()
+
+ scheduler.stop()
+
+
+def test_get_task_result_exception_handling():
+ """Test Future result exception handling."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ def func():
+ pass
+
+ future = scheduler.call_later(func, delay=0.01)
+ wait_for_condition(lambda: future.done())
+
+ # Future result should work normally
+ result = future.result()
+ assert result is None
+
+ scheduler.stop()
+
+
+def test_call_later_with_sync_function():
+ """Test call_later correctly identifies and runs sync functions."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ result = []
+
+ def sync_function():
+ result.append("sync_executed")
+
+ future = scheduler.call_later(sync_function, delay=0.01)
+ wait_for_condition(lambda: future.done())
+
+ assert result == ["sync_executed"]
+ assert future.done()
+
+ scheduler.stop()
+
+
+def test_call_later_with_async_function():
+ """Test call_later correctly identifies and runs async functions."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ result = []
+
+ async def async_function():
+ result.append("async_executed")
+
+ future = scheduler.call_later(async_function, delay=0.01)
+ wait_for_condition(lambda: future.done())
+
+ assert result == ["async_executed"]
+ assert future.done()
+
+ scheduler.stop()
+
+
+def test_event_set_exception():
+ """Test Event set_exception method."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ event = scheduler.create_event()
+ test_exception = ValueError("test exception")
+
+ event.set_exception(test_exception)
+
+ with pytest.raises(ValueError, match="test exception"):
+ event.wait()
+
+ scheduler.stop()
+
+
+def test_call_later_with_completion_event_exception():
+ """Test call_later with completion_event when function raises exception."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ completion_event = scheduler.create_event()
+
+ def failing_func() -> None:
+ msg: str = "completion event test"
+
+ raise RuntimeError(msg)
+
+ scheduler.call_later(failing_func, delay=0.01, completion_event=completion_event)
+
+ # Wait for the completion event to be set with exception
+ with pytest.raises(RuntimeError, match="completion event test"):
+ completion_event.wait(timeout=1.0)
+
+ scheduler.stop()
+
+
+def test_call_later_multiple_iterations():
+ """Test call_later with multiple count iterations."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ result = []
+
+ def func():
+ result.append("iteration")
+ # Return early to test the loop behavior
+ if len(result) >= 2:
+ return "done"
+ return
+
+ # Use a very small delay and count=3 to test the loop
+ future = scheduler.call_later(func, delay=0.001, count=3)
+ wait_for_condition(lambda: future.done(), timeout_iterations=500)
+
+ # Should execute at least once
+ assert len(result) >= 1
+ assert future.done()
+
+ scheduler.stop()
+
+
+def test_wait_for_event_timeout_exception():
+ """Test _wait_for_event with timeout exception handling."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ event = scheduler.create_event()
+
+ # Test timeout behavior
+ result = event.wait(timeout=0.001)
+ assert result is False
+
+ scheduler.stop()
+
+
+def test_call_later_loop_exit_condition():
+ """Test call_later loop exit condition with count=0."""
+ scheduler = Scheduler()
+ scheduler.start()
+
+ result = []
+
+ def func():
+ result.append("should_not_execute")
+
+ # Test with count=0 to hit the loop exit condition
+ future = scheduler.call_later(func, delay=0.01, count=0)
+ wait_for_condition(lambda: future.done())
+
+ # Should not execute the function at all
+ assert len(result) == 0
+ assert future.done()
+
+ scheduler.stop()
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/stores/__init__.py b/packages/aws-durable-execution-sdk-python-testing/tests/stores/__init__.py
new file mode 100644
index 0000000..dbc7145
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/stores/__init__.py
@@ -0,0 +1 @@
+"""Tests for store implementations."""
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/stores/concurrent_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/stores/concurrent_test.py
new file mode 100644
index 0000000..8703d4f
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/stores/concurrent_test.py
@@ -0,0 +1,278 @@
+"""Concurrent access tests for execution stores."""
+
+import tempfile
+import threading
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from pathlib import Path
+
+import pytest
+
+from aws_durable_execution_sdk_python_testing.execution import Execution
+from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput
+from aws_durable_execution_sdk_python_testing.stores.filesystem import (
+ FileSystemExecutionStore,
+)
+from aws_durable_execution_sdk_python_testing.stores.memory import (
+ InMemoryExecutionStore,
+)
+from aws_durable_execution_sdk_python_testing.stores.sqlite import SQLiteExecutionStore
+
+
+def test_concurrent_save_load():
+ """Test concurrent save and load operations."""
+ store = InMemoryExecutionStore()
+ results = []
+ results_lock = threading.Lock()
+
+ def save_execution(i: int):
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name=f"test-{i}",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id=f"inv-{i}",
+ input=f'{{"test": {i}}}',
+ )
+ execution = Execution.new(input_data)
+ execution.durable_execution_arn = f"arn-{i}"
+ store.save(execution)
+ with results_lock:
+ results.append(f"saved-{i}")
+
+ def load_execution(i: int):
+ try:
+ execution = store.load(f"arn-{i}")
+ with results_lock:
+ results.append(f"loaded-{execution.start_input.execution_name}")
+ except KeyError:
+ with results_lock:
+ results.append(f"not-found-{i}")
+
+ with ThreadPoolExecutor(max_workers=10) as executor:
+ # Submit save operations first
+ futures = [executor.submit(save_execution, i) for i in range(5)]
+ # Wait for saves to complete
+ for future in as_completed(futures):
+ future.result()
+
+ # Then submit load operations
+ futures = []
+ for i in range(5):
+ futures.append(executor.submit(load_execution, i))
+ # Wait for loads to complete
+ for future in as_completed(futures):
+ future.result()
+
+ assert len(results) == 10
+
+
+def test_concurrent_update_list():
+ """Test concurrent update and list operations."""
+ store = InMemoryExecutionStore()
+ results = []
+ results_lock = threading.Lock()
+
+ # Pre-populate store
+ for i in range(3):
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name=f"test-{i}",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id=f"inv-{i}",
+ input=f'{{"test": {i}}}',
+ )
+ execution = Execution.new(input_data)
+ execution.durable_execution_arn = f"arn-{i}"
+ store.save(execution)
+
+ def update_execution(i: int):
+ execution = store.load(f"arn-{i}")
+ execution.is_complete = True
+ store.update(execution)
+ with results_lock:
+ results.append(f"updated-{i}")
+
+ def list_executions():
+ executions = store.list_all()
+ with results_lock:
+ results.append(f"listed-{len(executions)}")
+
+ with ThreadPoolExecutor(max_workers=6) as executor:
+ # Submit update operations
+ futures = [executor.submit(update_execution, i) for i in range(3)]
+ # Submit list operations
+ futures.extend([executor.submit(list_executions) for _ in range(3)])
+
+ # Wait for all operations to complete
+ for future in as_completed(futures):
+ future.result()
+
+ assert len(results) == 6
+ final_list = store.list_all()
+ assert len(final_list) == 3
+
+
+@pytest.fixture
+def temp_storage_dir():
+ """Create a temporary directory for testing."""
+ with tempfile.TemporaryDirectory() as temp_dir:
+ yield Path(temp_dir)
+
+
+@pytest.fixture
+def temp_db_path():
+ """Create a temporary database file for testing."""
+ with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as temp_file:
+ temp_path = Path(temp_file.name)
+ yield temp_path
+ if temp_path.exists():
+ temp_path.unlink()
+
+
+def test_concurrent_filesystem_save_load(temp_storage_dir):
+ """Test concurrent save and load operations with filesystem store."""
+ store = FileSystemExecutionStore.create(temp_storage_dir)
+ results = []
+ results_lock = threading.Lock()
+
+ def save_execution(i: int):
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name=f"test-{i}",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id=f"inv-{i}",
+ input=f'{{"test": {i}}}',
+ )
+ execution = Execution.new(input_data)
+ execution.durable_execution_arn = f"arn-{i}"
+ execution.start()
+ store.save(execution)
+ with results_lock:
+ results.append(f"saved-{i}")
+
+ def load_execution(i: int):
+ try:
+ execution = store.load(f"arn-{i}")
+ with results_lock:
+ results.append(f"loaded-{execution.start_input.execution_name}")
+ except KeyError:
+ with results_lock:
+ results.append(f"not-found-{i}")
+
+ with ThreadPoolExecutor(max_workers=8) as executor:
+ # Submit save operations first
+ futures = [executor.submit(save_execution, i) for i in range(4)]
+ for future in as_completed(futures):
+ future.result()
+
+ # Then submit load operations
+ futures = [executor.submit(load_execution, i) for i in range(4)]
+ for future in as_completed(futures):
+ future.result()
+
+ assert len(results) == 8
+
+
+def test_concurrent_sqlite_save_load(temp_db_path):
+ """Test concurrent save and load operations with SQLite store."""
+ store = SQLiteExecutionStore.create_and_initialize(temp_db_path)
+ results = []
+ results_lock = threading.Lock()
+
+ def save_execution(i: int):
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name=f"test-{i}",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id=f"inv-{i}",
+ input=f'{{"test": {i}}}',
+ )
+ execution = Execution.new(input_data)
+ execution.durable_execution_arn = f"arn-{i}"
+ execution.start()
+ store.save(execution)
+ with results_lock:
+ results.append(f"saved-{i}")
+
+ def load_execution(i: int):
+ try:
+ execution = store.load(f"arn-{i}")
+ with results_lock:
+ results.append(f"loaded-{execution.start_input.execution_name}")
+ except KeyError:
+ with results_lock:
+ results.append(f"not-found-{i}")
+
+ with ThreadPoolExecutor(max_workers=8) as executor:
+ # Submit save operations first
+ futures = [executor.submit(save_execution, i) for i in range(4)]
+ for future in as_completed(futures):
+ future.result()
+
+ # Then submit load operations
+ futures = [executor.submit(load_execution, i) for i in range(4)]
+ for future in as_completed(futures):
+ future.result()
+
+ assert len(results) == 8
+
+
+def test_concurrent_query_operations():
+ """Test concurrent query operations on memory store."""
+ store = InMemoryExecutionStore()
+ results = []
+ results_lock = threading.Lock()
+
+ # Pre-populate store with test data
+ for i in range(10):
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name=f"function-{i % 3}", # 3 different functions
+ function_qualifier="$LATEST",
+ execution_name=f"exec-{i}",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id=f"inv-{i}",
+ )
+ execution = Execution.new(input_data)
+ execution.start()
+ # Complete some executions
+ if i % 4 == 0:
+ execution.complete_success("success")
+ store.save(execution)
+
+ def query_store(query_type: str):
+ if query_type == "function":
+ executions, next_marker = store.query(function_name="function-1")
+ elif query_type == "status":
+ executions, next_marker = store.query(status_filter="SUCCEEDED")
+ elif query_type == "pagination":
+ executions, next_marker = store.query(limit=3, offset=2)
+ else:
+ executions, next_marker = store.query()
+
+ with results_lock:
+ results.append(f"{query_type}-{len(executions)}")
+
+ with ThreadPoolExecutor(max_workers=4) as executor:
+ futures = [
+ executor.submit(query_store, "function"),
+ executor.submit(query_store, "status"),
+ executor.submit(query_store, "pagination"),
+ executor.submit(query_store, "all"),
+ ]
+ for future in as_completed(futures):
+ future.result()
+
+ assert len(results) == 4
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/stores/filesystem_store_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/stores/filesystem_store_test.py
new file mode 100644
index 0000000..01da777
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/stores/filesystem_store_test.py
@@ -0,0 +1,420 @@
+"""Tests for FileSystemExecutionStore."""
+
+import tempfile
+from pathlib import Path
+
+import pytest
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ ResourceNotFoundException,
+)
+from aws_durable_execution_sdk_python_testing.execution import Execution
+from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput
+from aws_durable_execution_sdk_python_testing.stores.filesystem import (
+ FileSystemExecutionStore,
+)
+
+from datetime import datetime, timezone
+
+
+@pytest.fixture
+def temp_storage_dir():
+ """Create a temporary directory for testing."""
+ with tempfile.TemporaryDirectory() as temp_dir:
+ yield Path(temp_dir)
+
+
+@pytest.fixture
+def store(temp_storage_dir):
+ """Create a FileSystemExecutionStore with temporary storage."""
+ return FileSystemExecutionStore.create(temp_storage_dir)
+
+
+@pytest.fixture
+def sample_execution():
+ """Create a sample execution for testing."""
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ return Execution.new(input_data)
+
+
+def test_filesystem_execution_store_save_and_load(store, sample_execution):
+ """Test saving and loading an execution."""
+ store.save(sample_execution)
+ loaded_execution = store.load(sample_execution.durable_execution_arn)
+
+ assert (
+ loaded_execution.durable_execution_arn == sample_execution.durable_execution_arn
+ )
+ assert (
+ loaded_execution.start_input.function_name
+ == sample_execution.start_input.function_name
+ )
+ assert (
+ loaded_execution.start_input.execution_name
+ == sample_execution.start_input.execution_name
+ )
+ assert loaded_execution.token_sequence == sample_execution.token_sequence
+ assert loaded_execution.is_complete == sample_execution.is_complete
+
+
+def test_filesystem_execution_store_load_nonexistent(store):
+ """Test loading a nonexistent execution raises ResourceNotFoundException."""
+ with pytest.raises(
+ ResourceNotFoundException, match="Execution nonexistent-arn not found"
+ ):
+ store.load("nonexistent-arn")
+
+
+def test_filesystem_execution_store_update(store, sample_execution):
+ """Test updating an execution."""
+ store.save(sample_execution)
+
+ sample_execution.is_complete = True
+ for _ in range(5):
+ sample_execution.get_new_checkpoint_token()
+ store.update(sample_execution)
+
+ loaded_execution = store.load(sample_execution.durable_execution_arn)
+ assert loaded_execution.is_complete is True
+ assert loaded_execution.token_sequence == 5
+
+
+def test_filesystem_execution_store_update_overwrites(store, temp_storage_dir):
+ """Test that update overwrites existing execution."""
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ )
+ execution1 = Execution.new(input_data)
+ execution2 = Execution.new(input_data)
+ execution2.durable_execution_arn = execution1.durable_execution_arn
+ for _ in range(10):
+ execution2.get_new_checkpoint_token()
+
+ store.save(execution1)
+ store.update(execution2)
+
+ loaded_execution = store.load(execution1.durable_execution_arn)
+ assert loaded_execution.token_sequence == 10
+
+
+def test_filesystem_execution_store_multiple_executions(store):
+ """Test storing multiple executions."""
+ input_data1 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function-1",
+ function_qualifier="$LATEST",
+ execution_name="test-execution-1",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ )
+ input_data2 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function-2",
+ function_qualifier="$LATEST",
+ execution_name="test-execution-2",
+ execution_timeout_seconds=600,
+ execution_retention_period_days=14,
+ )
+
+ execution1 = Execution.new(input_data1)
+ execution2 = Execution.new(input_data2)
+
+ store.save(execution1)
+ store.save(execution2)
+
+ loaded_execution1 = store.load(execution1.durable_execution_arn)
+ loaded_execution2 = store.load(execution2.durable_execution_arn)
+
+ assert loaded_execution1.durable_execution_arn == execution1.durable_execution_arn
+ assert loaded_execution2.durable_execution_arn == execution2.durable_execution_arn
+ assert loaded_execution1.start_input.function_name == "test-function-1"
+ assert loaded_execution2.start_input.function_name == "test-function-2"
+
+
+def test_filesystem_execution_store_list_all_empty(store):
+ """Test list_all method with empty store."""
+ result = store.list_all()
+ assert result == []
+
+
+def test_filesystem_execution_store_list_all_with_executions(store):
+ """Test list_all method with multiple executions."""
+ # Create test executions
+ input_data1 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function-1",
+ function_qualifier="$LATEST",
+ execution_name="test-execution-1",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ )
+ input_data2 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function-2",
+ function_qualifier="$LATEST",
+ execution_name="test-execution-2",
+ execution_timeout_seconds=600,
+ execution_retention_period_days=14,
+ )
+ input_data3 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function-3",
+ function_qualifier="$LATEST",
+ execution_name="test-execution-3",
+ execution_timeout_seconds=900,
+ execution_retention_period_days=21,
+ )
+
+ execution1 = Execution.new(input_data1)
+ execution2 = Execution.new(input_data2)
+ execution3 = Execution.new(input_data3)
+
+ # Save executions
+ store.save(execution1)
+ store.save(execution2)
+ store.save(execution3)
+
+ # Test list_all
+ result = store.list_all()
+
+ assert len(result) == 3
+ arns = {execution.durable_execution_arn for execution in result}
+ assert execution1.durable_execution_arn in arns
+ assert execution2.durable_execution_arn in arns
+ assert execution3.durable_execution_arn in arns
+
+
+def test_filesystem_execution_store_file_path_generation(
+ store, sample_execution, temp_storage_dir
+):
+ """Test that file paths are generated correctly with safe filenames."""
+ arn_with_colons = "arn:aws:lambda:us-east-1:123456789012:durable-execution:test"
+ expected_filename = (
+ "arn_aws_lambda_us-east-1_123456789012_durable-execution_test.json"
+ )
+
+ # Test by saving and checking the file exists with expected name
+ sample_execution.durable_execution_arn = arn_with_colons
+ store.save(sample_execution)
+
+ expected_file = temp_storage_dir / expected_filename
+ assert expected_file.exists()
+
+
+def test_filesystem_execution_store_corrupted_file_handling(store, temp_storage_dir):
+ """Test that corrupted files are skipped during list_all."""
+ # Create a valid execution
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ )
+ execution = Execution.new(input_data)
+ store.save(execution)
+
+ # Create a corrupted file
+ corrupted_file = temp_storage_dir / "corrupted.json"
+ with open(corrupted_file, "w") as f:
+ f.write("invalid json content")
+
+ # list_all should skip the corrupted file and return only valid executions
+ result = store.list_all()
+ assert len(result) == 1
+ assert result[0].durable_execution_arn == execution.durable_execution_arn
+
+
+def test_filesystem_execution_store_custom_storage_dir():
+ """Test creating store with custom storage directory."""
+ with tempfile.TemporaryDirectory() as temp_dir:
+ custom_dir = Path(temp_dir) / "custom_storage"
+ FileSystemExecutionStore.create(custom_dir)
+
+ # Directory should be created
+ assert custom_dir.exists()
+ assert custom_dir.is_dir()
+
+
+def test_filesystem_execution_store_init_no_side_effects():
+ """Test that __init__ doesn't create directories (no side effects)."""
+ with tempfile.TemporaryDirectory() as temp_dir:
+ nonexistent_dir = Path(temp_dir) / "nonexistent"
+
+ # __init__ should not create the directory
+ FileSystemExecutionStore(nonexistent_dir)
+ assert not nonexistent_dir.exists()
+
+
+def test_filesystem_execution_store_thread_safety_basic(store, sample_execution):
+ """Basic test that operations work without locking (atomic file operations)."""
+ # Test that basic operations work - atomic file operations provide thread safety
+ store.save(sample_execution)
+ loaded = store.load(sample_execution.durable_execution_arn)
+ assert loaded.durable_execution_arn == sample_execution.durable_execution_arn
+
+
+def test_filesystem_execution_store_query_empty(store):
+ """Test query method with empty store."""
+ executions, next_marker = store.query()
+
+ assert executions == []
+ assert next_marker is None
+
+
+def test_filesystem_execution_store_query_by_function_name(store):
+ """Test query filtering by function name."""
+ # Create executions with different function names
+ input1 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="function-a",
+ function_qualifier="$LATEST",
+ execution_name="exec-1",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-1",
+ )
+ input2 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="function-b",
+ function_qualifier="$LATEST",
+ execution_name="exec-2",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-2",
+ )
+
+ exec1 = Execution.new(input1)
+ exec1.start()
+ exec2 = Execution.new(input2)
+ exec2.start()
+ store.save(exec1)
+ store.save(exec2)
+
+ # Query for function-a only
+ executions, next_marker = store.query(function_name="function-a")
+
+ assert len(executions) == 1
+ assert executions[0].durable_execution_arn == exec1.durable_execution_arn
+ assert next_marker is None
+
+
+def test_filesystem_execution_store_query_by_status(store):
+ """Test query filtering by status."""
+ # Create running execution
+ input1 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="running-exec",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-1",
+ )
+ exec1 = Execution.new(input1)
+ exec1.start()
+
+ # Create completed execution
+ input2 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="completed-exec",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-2",
+ )
+ exec2 = Execution.new(input2)
+ exec2.start()
+ exec2.complete_success("success result")
+
+ store.save(exec1)
+ store.save(exec2)
+
+ # Query for running executions
+ executions, next_marker = store.query(status_filter="RUNNING")
+
+ assert len(executions) == 1
+ assert executions[0].durable_execution_arn == exec1.durable_execution_arn
+
+ # Query for succeeded executions
+ executions, next_marker = store.query(status_filter="SUCCEEDED")
+
+ assert len(executions) == 1
+ assert executions[0].durable_execution_arn == exec2.durable_execution_arn
+
+
+def test_filesystem_execution_store_query_pagination(store):
+ """Test query pagination."""
+ # Create multiple executions
+ executions = []
+ for i in range(5):
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name=f"exec-{i}",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id=f"invocation-{i}",
+ )
+ exec_obj = Execution.new(input_data)
+ exec_obj.start()
+ executions.append(exec_obj)
+ store.save(exec_obj)
+
+ # Test first page
+ executions, next_marker = store.query(limit=2, offset=0)
+
+ assert len(executions) == 2
+ assert next_marker is not None
+
+ # Test last page
+ executions, next_marker = store.query(limit=2, offset=4)
+
+ assert len(executions) == 1
+ assert next_marker is None
+
+
+def test_filesystem_execution_store_query_corrupted_file_handling(
+ store, temp_storage_dir
+):
+ """Test that corrupted files are skipped during query."""
+ # Create a valid execution
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ execution = Execution.new(input_data)
+ execution.start()
+ store.save(execution)
+
+ # Create a corrupted file
+ corrupted_file = temp_storage_dir / "corrupted.json"
+ with open(corrupted_file, "w") as f:
+ f.write("invalid json content")
+
+ # Query should skip the corrupted file and return only valid executions
+ executions, next_marker = store.query()
+
+ assert len(executions) == 1
+ assert executions[0].durable_execution_arn == execution.durable_execution_arn
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/stores/memory_store_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/stores/memory_store_test.py
new file mode 100644
index 0000000..b4d5b3e
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/stores/memory_store_test.py
@@ -0,0 +1,494 @@
+"""Tests for InMemoryExecutionStore."""
+
+from datetime import UTC
+from unittest.mock import Mock
+
+import pytest
+
+from aws_durable_execution_sdk_python_testing.execution import Execution
+from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput
+from aws_durable_execution_sdk_python_testing.stores.memory import (
+ InMemoryExecutionStore,
+)
+
+
+def test_in_memory_execution_store_save_and_load():
+ """Test saving and loading an execution."""
+ store = InMemoryExecutionStore()
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ execution = Execution.new(input_data)
+
+ store.save(execution)
+ loaded_execution = store.load(execution.durable_execution_arn)
+
+ assert loaded_execution is execution
+
+
+def test_in_memory_execution_store_load_nonexistent():
+ """Test loading a nonexistent execution raises KeyError."""
+ store = InMemoryExecutionStore()
+
+ with pytest.raises(KeyError):
+ store.load("nonexistent-arn")
+
+
+def test_in_memory_execution_store_update():
+ """Test updating an execution."""
+ store = InMemoryExecutionStore()
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ )
+ execution = Execution.new(input_data)
+ store.save(execution)
+
+ execution.is_complete = True
+ store.update(execution)
+
+ loaded_execution = store.load(execution.durable_execution_arn)
+ assert loaded_execution.is_complete is True
+
+
+def test_in_memory_execution_store_update_overwrites():
+ """Test that update overwrites existing execution."""
+ store = InMemoryExecutionStore()
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ )
+ execution1 = Execution.new(input_data)
+ execution2 = Execution.new(input_data)
+ execution2.durable_execution_arn = execution1.durable_execution_arn
+
+ store.save(execution1)
+ store.update(execution2)
+
+ loaded_execution = store.load(execution1.durable_execution_arn)
+ assert loaded_execution is execution2
+
+
+def test_in_memory_execution_store_multiple_executions():
+ """Test storing multiple executions."""
+ store = InMemoryExecutionStore()
+ input_data1 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function-1",
+ function_qualifier="$LATEST",
+ execution_name="test-execution-1",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ )
+ input_data2 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function-2",
+ function_qualifier="$LATEST",
+ execution_name="test-execution-2",
+ execution_timeout_seconds=600,
+ execution_retention_period_days=14,
+ )
+
+ execution1 = Execution.new(input_data1)
+ execution2 = Execution.new(input_data2)
+
+ store.save(execution1)
+ store.save(execution2)
+
+ loaded_execution1 = store.load(execution1.durable_execution_arn)
+ loaded_execution2 = store.load(execution2.durable_execution_arn)
+
+ assert loaded_execution1 is execution1
+ assert loaded_execution2 is execution2
+
+
+def test_in_memory_execution_store_list_all_empty():
+ """Test list_all method with empty store."""
+ store = InMemoryExecutionStore()
+
+ result = store.list_all()
+
+ assert result == []
+
+
+def test_in_memory_execution_store_list_all_with_executions():
+ """Test list_all method with multiple executions."""
+ store = InMemoryExecutionStore()
+
+ # Create test executions
+ execution1 = Mock()
+ execution1.durable_execution_arn = "arn1"
+ execution2 = Mock()
+ execution2.durable_execution_arn = "arn2"
+ execution3 = Mock()
+ execution3.durable_execution_arn = "arn3"
+
+ # Save executions
+ store.save(execution1)
+ store.save(execution2)
+ store.save(execution3)
+
+ # Test list_all
+ result = store.list_all()
+
+ assert len(result) == 3
+ assert execution1 in result
+ assert execution2 in result
+ assert execution3 in result
+
+
+def test_in_memory_execution_store_query_empty():
+ """Test query method with empty store."""
+ store = InMemoryExecutionStore()
+
+ executions, next_marker = store.query()
+
+ assert executions == []
+ assert next_marker is None
+
+
+def test_in_memory_execution_store_query_by_function_name():
+ """Test query filtering by function name."""
+ store = InMemoryExecutionStore()
+
+ # Create executions with different function names
+ input1 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="function-a",
+ function_qualifier="$LATEST",
+ execution_name="exec-1",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-1",
+ )
+ input2 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="function-b",
+ function_qualifier="$LATEST",
+ execution_name="exec-2",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-2",
+ )
+
+ exec1 = Execution.new(input1)
+ exec1.start()
+ exec2 = Execution.new(input2)
+ exec2.start()
+ store.save(exec1)
+ store.save(exec2)
+
+ # Query for function-a only
+ executions, next_marker = store.query(function_name="function-a")
+
+ assert len(executions) == 1
+ assert executions[0] is exec1
+ assert next_marker is None
+
+
+def test_in_memory_execution_store_query_by_execution_name():
+ """Test query filtering by execution name."""
+ store = InMemoryExecutionStore()
+
+ input1 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="exec-alpha",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-1",
+ )
+ input2 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="exec-beta",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-2",
+ )
+
+ exec1 = Execution.new(input1)
+ exec1.start()
+ exec2 = Execution.new(input2)
+ exec2.start()
+ store.save(exec1)
+ store.save(exec2)
+
+ executions, next_marker = store.query(execution_name="exec-beta")
+
+ assert len(executions) == 1
+ assert executions[0] is exec2
+
+
+def test_in_memory_execution_store_query_by_status():
+ """Test query filtering by status."""
+ store = InMemoryExecutionStore()
+
+ # Create running execution
+ input1 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="running-exec",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-1",
+ )
+ exec1 = Execution.new(input1)
+ exec1.start()
+
+ # Create completed execution
+ input2 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="completed-exec",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-2",
+ )
+ exec2 = Execution.new(input2)
+ exec2.start()
+ exec2.complete_success("success result")
+
+ store.save(exec1)
+ store.save(exec2)
+
+ # Query for running executions
+ executions, next_marker = store.query(status_filter="RUNNING")
+
+ assert len(executions) == 1
+ assert executions[0] is exec1
+
+ # Query for succeeded executions
+ executions, next_marker = store.query(status_filter="SUCCEEDED")
+
+ assert len(executions) == 1
+ assert executions[0] is exec2
+
+
+def test_in_memory_execution_store_query_pagination():
+ """Test query pagination."""
+ store = InMemoryExecutionStore()
+
+ # Create multiple executions
+ executions = []
+ for i in range(5):
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name=f"exec-{i}",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id=f"invocation-{i}",
+ )
+ exec_obj = Execution.new(input_data)
+ exec_obj.start()
+ executions.append(exec_obj)
+ store.save(exec_obj)
+
+ # Test first page
+ executions, next_marker = store.query(limit=2, offset=0)
+
+ assert len(executions) == 2
+ assert next_marker is not None
+
+ # Test second page
+ executions, next_marker = store.query(limit=2, offset=2)
+
+ assert len(executions) == 2
+ assert next_marker is not None
+
+ # Test last page
+ executions, next_marker = store.query(limit=2, offset=4)
+
+ assert len(executions) == 1
+ assert next_marker is None
+
+
+def test_in_memory_execution_store_query_sorting():
+ """Test query sorting by timestamp."""
+ store = InMemoryExecutionStore()
+
+ # Create executions - they will be sorted by creation order
+ executions = []
+ for i in range(3):
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name=f"exec-{i}",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id=f"invocation-{i}",
+ )
+ exec_obj = Execution.new(input_data)
+ exec_obj.start()
+ executions.append(exec_obj)
+ store.save(exec_obj)
+
+ # Test ascending order (default)
+ executions, next_marker = store.query(reverse_order=False)
+
+ assert len(executions) == 3
+
+ # Test descending order
+ executions, next_marker = store.query(reverse_order=True)
+
+ assert len(executions) == 3
+
+
+def test_in_memory_execution_store_query_combined_filters():
+ """Test query with multiple filters combined."""
+ store = InMemoryExecutionStore()
+
+ # Create various executions
+ inputs = [
+ StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="function-a",
+ function_qualifier="$LATEST",
+ execution_name="target-exec",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-1",
+ ),
+ StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="function-b",
+ function_qualifier="$LATEST",
+ execution_name="target-exec",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-2",
+ ),
+ StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="function-a",
+ function_qualifier="$LATEST",
+ execution_name="other-exec",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-3",
+ ),
+ ]
+
+ executions = []
+ for input_data in inputs:
+ exec_obj = Execution.new(input_data)
+ exec_obj.start()
+ executions.append(exec_obj)
+ store.save(exec_obj)
+
+ # Query with both function_name and execution_name filters
+ filtered_executions, next_marker = store.query(
+ function_name="function-a", execution_name="target-exec"
+ )
+
+ assert len(filtered_executions) == 1
+ assert filtered_executions[0] is executions[0]
+
+
+def test_time_filtering_logic():
+ """Test time filtering logic in process_query method."""
+ from datetime import datetime
+ from unittest.mock import Mock
+
+ store = InMemoryExecutionStore()
+
+ # Create mock executions with different timestamps
+ exec1 = Mock()
+ exec1.start_input.function_name = "test-function"
+ exec1.start_input.execution_name = "exec1"
+ exec1.status = "RUNNING"
+
+ exec2 = Mock()
+ exec2.start_input.function_name = "test-function"
+ exec2.start_input.execution_name = "exec2"
+ exec2.status = "RUNNING"
+
+ exec3 = Mock()
+ exec3.start_input.function_name = "test-function"
+ exec3.start_input.execution_name = "exec3"
+ exec3.status = "RUNNING"
+
+ # Use real datetime objects for timestamps
+ op1 = Mock()
+ op1.start_timestamp = datetime(2023, 1, 1, 12, 0, 0, tzinfo=UTC)
+
+ op2 = Mock()
+ op2.start_timestamp = datetime(2023, 1, 2, 12, 0, 0, tzinfo=UTC)
+
+ op3 = Mock()
+ op3.start_timestamp = datetime(2023, 1, 3, 12, 0, 0) # noqa: DTZ001
+
+ exec1.get_operation_execution_started.return_value = op1
+ exec2.get_operation_execution_started.return_value = op2
+ exec3.get_operation_execution_started.return_value = op3
+
+ executions = [exec1, exec2, exec3]
+
+ # Test time_after filtering
+ filtered, _ = store.process_query(
+ executions,
+ started_after="1672617600.0", # 2023-01-01 24:00:00 UTC (between exec1 and exec2)
+ )
+ assert len(filtered) == 2
+ assert exec2 in filtered
+ assert exec3 in filtered
+ assert exec1 not in filtered
+
+ # Test time_before filtering
+ filtered, _ = store.process_query(
+ executions,
+ started_before="1672617600.0", # 2023-01-01 24:00:00 UTC
+ )
+ assert len(filtered) == 1
+ assert exec1 in filtered
+ assert exec2 not in filtered
+ assert exec3 not in filtered
+
+ # Test both time_after and time_before
+ filtered, _ = store.process_query(
+ executions,
+ started_after="1672617600.0", # 2023-01-02 00:00:00 UTC (between exec1 and exec2)
+ started_before="1672704000.0", # 2023-01-03 00:00:00 UTC (between exec2 and exec3)
+ )
+ assert len(filtered) == 1
+ assert exec2 in filtered
+
+ # Test exception handling - exec with AttributeError
+ exec_error = Mock()
+ exec_error.start_input.function_name = "test-function"
+ exec_error.start_input.execution_name = "exec_error"
+ exec_error.status = "RUNNING"
+ exec_error.get_operation_execution_started.side_effect = AttributeError(
+ "No operation"
+ )
+
+ executions_with_error = [exec1, exec_error, exec2]
+ filtered, _ = store.process_query(
+ executions_with_error,
+ started_after="1672617600.0", # After exec1, before exec2
+ )
+ # exec_error should be filtered out due to exception, only exec2 should remain
+ assert len(filtered) == 1
+ assert exec2 in filtered
+ assert exec_error not in filtered
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/stores/sqlite_store_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/stores/sqlite_store_test.py
new file mode 100644
index 0000000..7c7feb4
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/stores/sqlite_store_test.py
@@ -0,0 +1,860 @@
+"""Tests for SQLiteExecutionStore."""
+
+import tempfile
+import time
+from datetime import datetime, UTC
+from pathlib import Path
+
+import pytest
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ ResourceNotFoundException,
+ InvalidParameterValueException,
+)
+from aws_durable_execution_sdk_python_testing.execution import (
+ ExecutionStatus,
+ Execution,
+)
+from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput
+from aws_durable_execution_sdk_python_testing.stores.sqlite import SQLiteExecutionStore
+
+
+@pytest.fixture
+def temp_db_path():
+ """Create a temporary database file for testing."""
+ with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as temp_file:
+ temp_path = Path(temp_file.name)
+ yield temp_path
+ # Cleanup
+ if temp_path.exists():
+ temp_path.unlink()
+
+
+@pytest.fixture
+def store(temp_db_path):
+ """Create a SQLiteExecutionStore with temporary database."""
+ return SQLiteExecutionStore.create_and_initialize(temp_db_path)
+
+
+@pytest.fixture
+def sample_execution():
+ """Create a sample execution for testing."""
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ return Execution.new(input_data)
+
+
+def test_sqlite_execution_store_save_and_load(store, sample_execution):
+ """Test saving and loading an execution."""
+ sample_execution.start()
+ store.save(sample_execution)
+ loaded_execution = store.load(sample_execution.durable_execution_arn)
+
+ assert (
+ loaded_execution.durable_execution_arn == sample_execution.durable_execution_arn
+ )
+ assert (
+ loaded_execution.start_input.function_name
+ == sample_execution.start_input.function_name
+ )
+ assert (
+ loaded_execution.start_input.execution_name
+ == sample_execution.start_input.execution_name
+ )
+ assert loaded_execution.token_sequence == sample_execution.token_sequence
+ assert loaded_execution.is_complete == sample_execution.is_complete
+
+
+def test_sqlite_execution_store_load_nonexistent(store):
+ """Test loading a nonexistent execution raises KeyError."""
+ with pytest.raises(
+ ResourceNotFoundException, match="Execution nonexistent-arn not found"
+ ):
+ store.load("nonexistent-arn")
+
+
+def test_sqlite_execution_store_update(store, sample_execution):
+ """Test updating an execution."""
+ sample_execution.start()
+ store.save(sample_execution)
+
+ sample_execution.is_complete = True
+ sample_execution.close_status = ExecutionStatus.SUCCEEDED
+ for _ in range(5):
+ sample_execution.get_new_checkpoint_token()
+ store.update(sample_execution)
+
+ loaded_execution = store.load(sample_execution.durable_execution_arn)
+ assert loaded_execution.is_complete is True
+ assert loaded_execution.token_sequence == 5
+
+
+def test_sqlite_execution_store_update_overwrites(store):
+ """Test that update overwrites existing execution."""
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ execution1 = Execution.new(input_data)
+ execution1.start()
+ execution2 = Execution.new(input_data)
+ execution2.start()
+ execution2.durable_execution_arn = execution1.durable_execution_arn
+ for _ in range(10):
+ execution2.get_new_checkpoint_token()
+
+ store.save(execution1)
+ store.update(execution2)
+
+ loaded_execution = store.load(execution1.durable_execution_arn)
+ assert loaded_execution.token_sequence == 10
+
+
+def test_sqlite_execution_store_multiple_executions(store):
+ """Test storing multiple executions."""
+ input_data1 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function-1",
+ function_qualifier="$LATEST",
+ execution_name="test-execution-1",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id-1",
+ )
+ input_data2 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function-2",
+ function_qualifier="$LATEST",
+ execution_name="test-execution-2",
+ execution_timeout_seconds=600,
+ execution_retention_period_days=14,
+ invocation_id="test-invocation-id-2",
+ )
+
+ execution1 = Execution.new(input_data1)
+ execution1.start()
+ execution2 = Execution.new(input_data2)
+ execution2.start()
+
+ store.save(execution1)
+ store.save(execution2)
+
+ loaded_execution1 = store.load(execution1.durable_execution_arn)
+ loaded_execution2 = store.load(execution2.durable_execution_arn)
+
+ assert loaded_execution1.durable_execution_arn == execution1.durable_execution_arn
+ assert loaded_execution2.durable_execution_arn == execution2.durable_execution_arn
+ assert loaded_execution1.start_input.function_name == "test-function-1"
+ assert loaded_execution2.start_input.function_name == "test-function-2"
+
+
+def test_sqlite_execution_store_list_all_empty(store):
+ """Test list_all method with empty store."""
+ result = store.list_all()
+ assert result == []
+
+
+def test_sqlite_execution_store_list_all_with_executions(store):
+ """Test list_all method with multiple executions."""
+ # Create test executions
+ executions = []
+ for i in range(3):
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name=f"test-function-{i}",
+ function_qualifier="$LATEST",
+ execution_name=f"test-execution-{i}",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id=f"test-invocation-id-{i}",
+ )
+ execution = Execution.new(input_data)
+ execution.start()
+ executions.append(execution)
+ store.save(execution)
+
+ # Test list_all
+ result = store.list_all()
+
+ assert len(result) == 3
+ arns = {execution.durable_execution_arn for execution in result}
+ for execution in executions:
+ assert execution.durable_execution_arn in arns
+
+
+def test_sqlite_execution_store_query_empty(store):
+ """Test query method with empty store."""
+ executions, next_marker = store.query()
+
+ assert executions == []
+ assert next_marker is None
+
+
+def test_sqlite_execution_store_query_by_function_name(store):
+ """Test query filtering by function name."""
+ # Create executions with different function names
+ input1 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="function-a",
+ function_qualifier="$LATEST",
+ execution_name="exec-1",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-1",
+ )
+ input2 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="function-b",
+ function_qualifier="$LATEST",
+ execution_name="exec-2",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-2",
+ )
+
+ exec1 = Execution.new(input1)
+ exec1.start()
+ exec2 = Execution.new(input2)
+ exec2.start()
+ store.save(exec1)
+ store.save(exec2)
+
+ # Query for function-a only
+ executions, next_marker = store.query(function_name="function-a")
+
+ assert len(executions) == 1
+ assert executions[0].durable_execution_arn == exec1.durable_execution_arn
+ assert next_marker is None
+
+
+def test_sqlite_execution_store_query_by_execution_name(store):
+ """Test query filtering by execution name."""
+ input1 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="exec-alpha",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-1",
+ )
+ input2 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="exec-beta",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-2",
+ )
+
+ exec1 = Execution.new(input1)
+ exec1.start()
+ exec2 = Execution.new(input2)
+ exec2.start()
+ store.save(exec1)
+ store.save(exec2)
+
+ executions, next_marker = store.query(execution_name="exec-beta")
+
+ assert len(executions) == 1
+ assert executions[0].durable_execution_arn == exec2.durable_execution_arn
+
+
+def test_sqlite_execution_store_query_by_status(store):
+ """Test query filtering by status."""
+ # Create running execution
+ input1 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="running-exec",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-1",
+ )
+ exec1 = Execution.new(input1)
+ exec1.start()
+
+ # Create completed execution
+ input2 = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="completed-exec",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-2",
+ )
+ exec2 = Execution.new(input2)
+ exec2.start()
+ exec2.complete_success("success result")
+
+ store.save(exec1)
+ store.save(exec2)
+
+ # Query for running executions
+ executions, next_marker = store.query(status_filter="RUNNING")
+
+ assert len(executions) == 1
+ assert executions[0].durable_execution_arn == exec1.durable_execution_arn
+
+ # Query for succeeded executions
+ executions, next_marker = store.query(status_filter="SUCCEEDED")
+
+ assert len(executions) == 1
+ assert executions[0].durable_execution_arn == exec2.durable_execution_arn
+
+
+def test_sqlite_execution_store_query_pagination(store):
+ """Test query pagination."""
+ # Create multiple executions
+ executions = []
+ for i in range(5):
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name=f"exec-{i}",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id=f"invocation-{i}",
+ )
+ exec_obj = Execution.new(input_data)
+ exec_obj.start()
+ executions.append(exec_obj)
+ store.save(exec_obj)
+
+ # Test first page
+ executions, next_marker = store.query(limit=2, offset=0)
+
+ assert len(executions) == 2
+ assert next_marker is not None
+
+ # Test second page
+ executions, next_marker = store.query(limit=2, offset=2)
+
+ assert len(executions) == 2
+ assert next_marker is not None
+
+ # Test last page
+ executions, next_marker = store.query(limit=2, offset=4)
+
+ assert len(executions) == 1
+ assert next_marker is None
+
+
+def test_sqlite_execution_store_query_sorting(store):
+ """Test query sorting by timestamp."""
+ # Create executions - they will be sorted by creation order
+ executions = []
+ for i in range(3):
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name=f"exec-{i}",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id=f"invocation-{i}",
+ )
+ exec_obj = Execution.new(input_data)
+ exec_obj.start()
+ executions.append(exec_obj)
+ store.save(exec_obj)
+
+ # Test ascending order (default)
+ executions, next_marker = store.query(reverse_order=False)
+
+ assert len(executions) == 3
+
+ # Test descending order
+ executions, next_marker = store.query(reverse_order=True)
+
+ assert len(executions) == 3
+
+
+def test_sqlite_execution_store_query_combined_filters(store):
+ """Test query with multiple filters combined."""
+ # Create various executions
+ inputs = [
+ StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="function-a",
+ function_qualifier="$LATEST",
+ execution_name="target-exec",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-1",
+ ),
+ StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="function-b",
+ function_qualifier="$LATEST",
+ execution_name="target-exec",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-2",
+ ),
+ StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="function-a",
+ function_qualifier="$LATEST",
+ execution_name="other-exec",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="invocation-3",
+ ),
+ ]
+
+ executions = []
+ for input_data in inputs:
+ exec_obj = Execution.new(input_data)
+ exec_obj.start()
+ executions.append(exec_obj)
+ store.save(exec_obj)
+
+ # Query with both function_name and execution_name filters
+ filtered_executions, next_marker = store.query(
+ function_name="function-a", execution_name="target-exec"
+ )
+
+ assert len(filtered_executions) == 1
+ assert (
+ filtered_executions[0].durable_execution_arn
+ == executions[0].durable_execution_arn
+ )
+
+
+def test_sqlite_execution_store_database_initialization(temp_db_path):
+ """Test that database is properly initialized with schema."""
+ store = SQLiteExecutionStore.create_and_initialize(temp_db_path)
+
+ # Verify database file exists
+ assert temp_db_path.exists()
+
+ # Verify we can perform basic operations
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ execution = Execution.new(input_data)
+ execution.start()
+
+ store.save(execution)
+ loaded = store.load(execution.durable_execution_arn)
+ assert loaded.durable_execution_arn == execution.durable_execution_arn
+
+
+def test_sqlite_execution_store_custom_db_path():
+ """Test creating store with custom database path."""
+ with tempfile.TemporaryDirectory() as temp_dir:
+ custom_path = Path(temp_dir) / "custom" / "executions.db"
+ store = SQLiteExecutionStore.create_and_initialize(custom_path)
+
+ # Directory should be created
+ assert custom_path.parent.exists()
+ assert custom_path.exists()
+
+ # Verify functionality
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ execution = Execution.new(input_data)
+ execution.start()
+
+ store.save(execution)
+ loaded = store.load(execution.durable_execution_arn)
+ assert loaded.durable_execution_arn == execution.durable_execution_arn
+
+
+def test_sqlite_execution_store_failed_execution_status(store):
+ """Test that failed executions are properly stored and queried."""
+ from aws_durable_execution_sdk_python.lambda_service import ErrorObject
+
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="failed-exec",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ execution = Execution.new(input_data)
+ execution.start()
+
+ # Complete with failure
+ error = ErrorObject(
+ type="TestError", message="Test failure", data=None, stack_trace=None
+ )
+ execution.complete_fail(error)
+
+ store.save(execution)
+
+ # Query for failed executions
+ executions, next_marker = store.query(status_filter="FAILED")
+
+ assert len(executions) == 1
+ assert executions[0].durable_execution_arn == execution.durable_execution_arn
+ assert executions[0].is_complete is True
+
+
+def test_sqlite_execution_store_error_handling(temp_db_path):
+ """Test error handling for database operations."""
+ store = SQLiteExecutionStore.create_and_initialize(temp_db_path)
+
+ # Test with corrupted database by removing the file after creation
+ temp_db_path.unlink()
+
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ execution = Execution.new(input_data)
+ execution.start()
+
+ # Should raise RuntimeError for database operations
+ with pytest.raises(RuntimeError, match="Failed to save execution"):
+ store.save(execution)
+
+
+def test_sqlite_execution_store_invalid_execution_data(store):
+ """Test handling of invalid execution data."""
+ # Create execution and start it
+ execution = Execution.new(
+ StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ )
+ execution.start()
+
+ # Corrupt the execution object to trigger serialization error
+ execution.start_input = None
+
+ with pytest.raises(ValueError, match="Invalid execution data"):
+ store.save(execution)
+
+
+def test_sqlite_execution_store_sql_injection_protection(store):
+ """Test SQL injection protection in query parameters."""
+ # Create test execution
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ execution = Execution.new(input_data)
+ execution.start()
+ store.save(execution)
+
+ # Try SQL injection attempts - should be safely parameterized
+ malicious_inputs = [
+ "'; DROP TABLE executions; --",
+ "test' OR '1'='1",
+ "test'; DELETE FROM executions; --",
+ "test' UNION SELECT * FROM executions --",
+ ]
+
+ for malicious_input in malicious_inputs:
+ # These should return empty results, not cause SQL errors
+ executions, _ = store.query(function_name=malicious_input)
+ assert executions == []
+
+ executions, _ = store.query(execution_name=malicious_input)
+ assert executions == []
+
+ executions, _ = store.query(status_filter=malicious_input)
+ assert executions == []
+
+
+def test_sqlite_execution_store_time_filtering(store):
+ """Test time-based filtering with edge cases."""
+
+ # Create executions at different times
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+
+ execution1 = Execution.new(input_data)
+ execution1.start()
+ store.save(execution1)
+
+ # Small delay to ensure different timestamps
+ time.sleep(0.01)
+
+ execution2 = Execution.new(input_data)
+ execution2.start()
+ store.save(execution2)
+
+ # Get timestamps as ISO strings
+ start_time_iso = (
+ execution1.get_operation_execution_started().start_timestamp.isoformat()
+ )
+ mid_time = (
+ execution1.get_operation_execution_started().start_timestamp.timestamp() + 0.005
+ )
+ mid_time_iso = datetime.fromtimestamp(mid_time, tz=UTC).isoformat()
+ end_time_iso = datetime.fromtimestamp(
+ execution2.get_operation_execution_started().start_timestamp.timestamp() + 1,
+ tz=UTC,
+ ).isoformat()
+
+ # Test started_after filter
+ executions, _ = store.query(started_after=mid_time_iso)
+ assert len(executions) == 1
+
+ # Test started_before filter
+ executions, _ = store.query(started_before=mid_time_iso)
+ assert len(executions) == 1
+
+ # Test both filters
+ executions, _ = store.query(
+ started_after=start_time_iso, started_before=end_time_iso
+ )
+ assert len(executions) == 2
+
+
+def test_sqlite_execution_store_corrupted_data_handling(store, temp_db_path):
+ """Test handling of corrupted JSON data in database."""
+ import sqlite3
+
+ # Insert corrupted JSON data directly
+ with sqlite3.connect(temp_db_path) as conn:
+ conn.execute(
+ """
+ INSERT INTO executions
+ (durable_execution_arn, function_name, execution_name, status, start_timestamp, end_timestamp, data)
+ VALUES (?, ?, ?, ?, ?, ?, ?)
+ """,
+ (
+ "corrupted-arn",
+ "test-function",
+ "test-execution",
+ "RUNNING",
+ 1234567890.0,
+ None,
+ "invalid json data {{{",
+ ),
+ )
+
+ # Loading corrupted data should raise ValueError
+ with pytest.raises(ValueError, match="Corrupted execution data"):
+ store.load("corrupted-arn")
+
+ # Query should skip corrupted records and continue
+ executions, _ = store.query()
+ # Should not include the corrupted record
+ assert all(exec.durable_execution_arn != "corrupted-arn" for exec in executions)
+
+
+def test_sqlite_execution_store_get_execution_metadata(store):
+ """Test get_execution_metadata method."""
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ execution = Execution.new(input_data)
+ execution.start()
+ store.save(execution)
+
+ # Test existing execution
+ metadata = store.get_execution_metadata(execution.durable_execution_arn)
+ assert metadata is not None
+ assert metadata["durable_execution_arn"] == execution.durable_execution_arn
+ assert metadata["function_name"] == "test-function"
+ assert metadata["execution_name"] == "test-execution"
+ assert metadata["status"] == "RUNNING"
+ assert metadata["start_timestamp"] is not None
+
+ # Test nonexistent execution
+ metadata = store.get_execution_metadata("nonexistent-arn")
+ assert metadata is None
+
+
+def test_sqlite_execution_store_database_init_error():
+ """Test database initialization error handling."""
+ # Try to create database in non-existent directory without permission
+ invalid_path = Path("/invalid/path/that/does/not/exist/test.db")
+
+ with pytest.raises(RuntimeError, match="Failed to initialize database"):
+ store = SQLiteExecutionStore(invalid_path)
+ store._init_db()
+
+
+def test_sqlite_execution_store_query_invalid_parameters(store):
+ """Test query with invalid parameters."""
+ # Test with invalid time parameters
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid query parameters"
+ ):
+ store.query(started_after="invalid_timestamp")
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Invalid query parameters"
+ ):
+ store.query(started_before="not_a_number")
+
+
+def test_sqlite_execution_store_query_no_limit_no_offset(store):
+ """Test query without limit and offset parameters."""
+ # Create test execution
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ execution = Execution.new(input_data)
+ execution.start()
+ store.save(execution)
+
+ # Query without limit should use different code path
+ executions, next_marker = store.query()
+ assert len(executions) == 1
+ assert next_marker is None
+
+
+def test_sqlite_execution_store_query_with_end_timestamp(store):
+ """Test execution with end timestamp."""
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ execution = Execution.new(input_data)
+ execution.start()
+ execution.complete_success("result") # This should set end_timestamp
+ store.save(execution)
+
+ loaded = store.load(execution.durable_execution_arn)
+ assert loaded.is_complete is True
+
+
+def test_sqlite_execution_store_metadata_error_handling(temp_db_path):
+ """Test metadata retrieval error handling."""
+ store = SQLiteExecutionStore.create_and_initialize(temp_db_path)
+
+ # Remove database file to trigger error
+ temp_db_path.unlink()
+
+ with pytest.raises(RuntimeError, match="Failed to get metadata"):
+ store.get_execution_metadata("test-arn")
+
+
+def test_sqlite_execution_store_load_error_handling(temp_db_path):
+ """Test load error handling."""
+ store = SQLiteExecutionStore.create_and_initialize(temp_db_path)
+
+ # Remove database file to trigger error
+ temp_db_path.unlink()
+
+ with pytest.raises(RuntimeError, match="Failed to load execution"):
+ store.load("test-arn")
+
+
+def test_sqlite_execution_store_query_with_corrupted_data_warning(
+ store, temp_db_path, capsys
+):
+ """Test that corrupted data in query results prints warning and continues."""
+ import sqlite3
+
+ # Create a valid execution first
+ input_data = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-function",
+ function_qualifier="$LATEST",
+ execution_name="test-execution",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="test-invocation-id",
+ )
+ execution = Execution.new(input_data)
+ execution.start()
+ store.save(execution)
+
+ # Insert corrupted JSON data directly
+ with sqlite3.connect(temp_db_path) as conn:
+ conn.execute(
+ """
+ INSERT INTO executions
+ (durable_execution_arn, function_name, execution_name, status, start_timestamp, end_timestamp, data)
+ VALUES (?, ?, ?, ?, ?, ?, ?)
+ """,
+ (
+ "corrupted-arn-2",
+ "test-function",
+ "test-execution",
+ "RUNNING",
+ 1234567890.0,
+ None,
+ "invalid json data {{{",
+ ),
+ )
+
+ # Query should skip corrupted records and print warning
+ executions, _ = store.query()
+
+ # Should get the valid execution, skip the corrupted one
+ assert len(executions) == 1
+ assert executions[0].durable_execution_arn == execution.durable_execution_arn
+
+ # Check that warning was printed
+ captured = capsys.readouterr()
+ assert "Warning: Skipping corrupted execution corrupted-arn-2" in captured.out
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/token_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/token_test.py
new file mode 100644
index 0000000..714d8c9
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/token_test.py
@@ -0,0 +1,132 @@
+"""Unit tests for token models."""
+
+import base64
+import json
+
+import pytest
+
+from aws_durable_execution_sdk_python_testing.token import (
+ CallbackToken,
+ CheckpointToken,
+)
+
+
+def test_checkpoint_token_init():
+ """Test CheckpointToken initialization."""
+ token = CheckpointToken("arn:aws:states:us-east-1:123456789012:execution:test", 42)
+
+ assert token.execution_arn == "arn:aws:states:us-east-1:123456789012:execution:test"
+ assert token.token_sequence == 42
+
+
+def test_checkpoint_token_to_str():
+ """Test CheckpointToken serialization to string."""
+ token = CheckpointToken("arn:aws:states:us-east-1:123456789012:execution:test", 42)
+
+ result = token.to_str()
+
+ # Decode and verify the structure
+ decoded = base64.b64decode(result).decode()
+ data = json.loads(decoded)
+ assert data["arn"] == "arn:aws:states:us-east-1:123456789012:execution:test"
+ assert data["seq"] == 42
+
+
+def test_checkpoint_token_from_str():
+ """Test CheckpointToken deserialization from string."""
+ data = {"arn": "arn:aws:states:us-east-1:123456789012:execution:test", "seq": 42}
+ json_str = json.dumps(data, separators=(",", ":"))
+ token_str = base64.b64encode(json_str.encode()).decode()
+
+ token = CheckpointToken.from_str(token_str)
+
+ assert token.execution_arn == "arn:aws:states:us-east-1:123456789012:execution:test"
+ assert token.token_sequence == 42
+
+
+def test_checkpoint_token_round_trip():
+ """Test CheckpointToken serialization and deserialization round trip."""
+ original = CheckpointToken(
+ "arn:aws:states:us-east-1:123456789012:execution:test", 123
+ )
+
+ token_str = original.to_str()
+ restored = CheckpointToken.from_str(token_str)
+
+ assert restored == original
+
+
+def test_checkpoint_token_frozen_dataclass():
+ """Test that CheckpointToken is immutable."""
+ token = CheckpointToken("arn:aws:states:us-east-1:123456789012:execution:test", 42)
+
+ with pytest.raises(AttributeError):
+ token.execution_arn = "new-arn"
+
+ with pytest.raises(AttributeError):
+ token.token_sequence = 999
+
+
+def test_callback_token_init():
+ """Test CallbackToken initialization."""
+ token = CallbackToken(
+ "arn:aws:states:us-east-1:123456789012:execution:test", "op-123"
+ )
+
+ assert token.execution_arn == "arn:aws:states:us-east-1:123456789012:execution:test"
+ assert token.operation_id == "op-123"
+
+
+def test_callback_token_to_str():
+ """Test CallbackToken serialization to string."""
+ token = CallbackToken(
+ "arn:aws:states:us-east-1:123456789012:execution:test", "op-123"
+ )
+
+ result = token.to_str()
+
+ # Decode and verify the structure
+ decoded = base64.b64decode(result).decode()
+ data = json.loads(decoded)
+ assert data["arn"] == "arn:aws:states:us-east-1:123456789012:execution:test"
+ assert data["op"] == "op-123"
+
+
+def test_callback_token_from_str():
+ """Test CallbackToken deserialization from string."""
+ data = {
+ "arn": "arn:aws:states:us-east-1:123456789012:execution:test",
+ "op": "op-123",
+ }
+ json_str = json.dumps(data, separators=(",", ":"))
+ token_str = base64.b64encode(json_str.encode()).decode()
+
+ token = CallbackToken.from_str(token_str)
+
+ assert token.execution_arn == "arn:aws:states:us-east-1:123456789012:execution:test"
+ assert token.operation_id == "op-123"
+
+
+def test_callback_token_round_trip():
+ """Test CallbackToken serialization and deserialization round trip."""
+ original = CallbackToken(
+ "arn:aws:states:us-east-1:123456789012:execution:test", "callback-op"
+ )
+
+ token_str = original.to_str()
+ restored = CallbackToken.from_str(token_str)
+
+ assert restored == original
+
+
+def test_callback_token_frozen_dataclass():
+ """Test that CallbackToken is immutable."""
+ token = CallbackToken(
+ "arn:aws:states:us-east-1:123456789012:execution:test", "op-123"
+ )
+
+ with pytest.raises(AttributeError):
+ token.execution_arn = "new-arn"
+
+ with pytest.raises(AttributeError):
+ token.operation_id = "new-op"
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/web/__init__.py b/packages/aws-durable-execution-sdk-python-testing/tests/web/__init__.py
new file mode 100644
index 0000000..5a4e39e
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/web/__init__.py
@@ -0,0 +1 @@
+"""Tests for web server module."""
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/web/e2e/__init__.py b/packages/aws-durable-execution-sdk-python-testing/tests/web/e2e/__init__.py
new file mode 100644
index 0000000..cbd1352
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/web/e2e/__init__.py
@@ -0,0 +1 @@
+"""End-to-end integration tests for web components."""
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/web/e2e/routes_arn_encoding_int_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/web/e2e/routes_arn_encoding_int_test.py
new file mode 100644
index 0000000..4b9c2a5
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/web/e2e/routes_arn_encoding_int_test.py
@@ -0,0 +1,238 @@
+"""Integration test: WebServer route layer URL-decodes DurableExecutionArn.
+
+Drives a real ``boto3`` Lambda client against a live ``WebServer`` and asserts
+that ``DurableExecutionArn`` values containing characters that boto
+percent-encodes in URI labels (e.g. ``/`` -> ``%2F``) round-trip correctly so
+the store lookup hits.
+"""
+
+from __future__ import annotations
+
+import threading
+import time
+from typing import Any
+
+import boto3 # type: ignore
+import pytest
+from botocore.config import Config # type: ignore
+from botocore.exceptions import ClientError # type: ignore
+
+from aws_durable_execution_sdk_python_testing.checkpoint.processor import (
+ CheckpointProcessor,
+)
+from aws_durable_execution_sdk_python_testing.execution import Execution
+from aws_durable_execution_sdk_python_testing.executor import Executor
+from aws_durable_execution_sdk_python_testing.model import (
+ StartDurableExecutionInput,
+)
+from aws_durable_execution_sdk_python_testing.scheduler import Scheduler
+from aws_durable_execution_sdk_python_testing.stores.memory import (
+ InMemoryExecutionStore,
+)
+from aws_durable_execution_sdk_python_testing.web.server import (
+ WebServer,
+ WebServiceConfig,
+)
+
+
+class _NoOpInvoker:
+ """Satisfies the Invoker protocol without invoking anything.
+
+ The route-layer regression doesn't depend on actually executing the
+ function; the executor just needs *some* invoker to construct it.
+ """
+
+ def create_invocation_input(self, execution: Any) -> Any: # noqa: ARG002
+ return None
+
+ def invoke(self, *args: Any, **kwargs: Any) -> Any: # noqa: ARG002
+ return None
+
+ def update_endpoint(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002
+ return None
+
+
+def _assert_no_percent_encoding_in_error(exc: ClientError, arn: str) -> None:
+ """Fail the test if a ResourceNotFoundException carries a %2F-form ARN.
+
+ Other errors (e.g. invalid checkpoint token, wrong state) are fine; this
+ test is narrowly about whether the route layer decoded the path segment.
+ """
+ msg = str(exc)
+ assert "%2F" not in msg, (
+ f"WebServer route layer did not URL-decode DurableExecutionArn. "
+ f"Original ARN: {arn!r}. Error: {msg}"
+ )
+
+
+@pytest.fixture
+def server_with_slash_arn():
+ """Yield ``(boto_client, arn, executor, store)`` for a live WebServer.
+
+ The yielded ARN contains a literal ``/`` matching the v1.2.0+ format
+ produced by ``Execution.new()``. The Execution is pre-started and saved
+ so read paths have something to find.
+ """
+ store = InMemoryExecutionStore()
+ scheduler = Scheduler()
+ checkpoint_processor = CheckpointProcessor(store=store, scheduler=scheduler)
+ executor = Executor(
+ store=store,
+ scheduler=scheduler,
+ invoker=_NoOpInvoker(),
+ checkpoint_processor=checkpoint_processor,
+ )
+ checkpoint_processor.add_execution_observer(executor)
+ scheduler.start()
+
+ # Hand-build a started Execution whose ARN contains '/' so we control
+ # the format under test without going through executor.start_execution
+ # (which schedules a real invoke + timeout).
+ start_input = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name="test-fn",
+ function_qualifier="$LATEST",
+ execution_name="test-exec",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="inv-12345",
+ input='"hi"',
+ )
+ execution = Execution.new(start_input)
+ execution.start()
+ store.save(execution)
+ arn = execution.durable_execution_arn
+ assert "/" in arn, "regression precondition: ARN must contain literal '/'"
+
+ config = WebServiceConfig(host="127.0.0.1", port=0)
+ server = WebServer(config, executor)
+ port = server.server_address[1]
+ server_thread = threading.Thread(target=server.serve_forever, daemon=True)
+ server_thread.start()
+ # Give the listener a beat to come up before the boto client connects.
+ time.sleep(0.05)
+
+ client = boto3.client(
+ "lambda",
+ endpoint_url=f"http://127.0.0.1:{port}",
+ region_name="us-east-1",
+ aws_access_key_id="x", # noqa: S106 - test stub
+ aws_secret_access_key="y", # noqa: S106 - test stub
+ config=Config(parameter_validation=False, retries={"max_attempts": 0}),
+ )
+
+ try:
+ yield client, arn, executor, store
+ finally:
+ server.shutdown()
+ server.server_close()
+ scheduler.stop()
+
+
+def test_get_durable_execution_decodes_slash_in_arn(server_with_slash_arn):
+ """GetDurableExecution: %2F must be decoded so the store lookup hits."""
+ client, arn, _executor, _store = server_with_slash_arn
+
+ response = client.get_durable_execution(DurableExecutionArn=arn)
+
+ assert response["DurableExecutionArn"] == arn
+
+
+def test_get_durable_execution_state_decodes_slash_in_arn(server_with_slash_arn):
+ """GetDurableExecutionState: %2F must be decoded so the store lookup hits."""
+ client, arn, _executor, _store = server_with_slash_arn
+
+ response = client.get_durable_execution_state(
+ DurableExecutionArn=arn,
+ CheckpointToken="ignored-by-route-layer", # noqa: S106 - test stub
+ )
+
+ # Response shape varies; the only assertion this test cares about is
+ # that we got past route resolution.
+ assert response is not None
+
+
+def test_get_durable_execution_history_decodes_slash_in_arn(server_with_slash_arn):
+ """GetDurableExecutionHistory: %2F must be decoded so the store lookup hits."""
+ client, arn, _executor, _store = server_with_slash_arn
+
+ response = client.get_durable_execution_history(DurableExecutionArn=arn)
+
+ assert response is not None
+
+
+def test_checkpoint_durable_execution_decodes_slash_in_arn(server_with_slash_arn):
+ """CheckpointDurableExecution: %2F must be decoded so the store lookup hits.
+
+ A checkpoint with no operation updates may still trip secondary
+ validation; we only assert the failure (if any) is not the
+ %2F-in-message 404 that indicates the route layer dropped the ball.
+ """
+ client, arn, _executor, store = server_with_slash_arn
+ execution = store.load(arn)
+ token = execution.get_new_checkpoint_token()
+
+ try:
+ client.checkpoint_durable_execution(
+ DurableExecutionArn=arn,
+ CheckpointToken=token,
+ Updates=[],
+ )
+ except ClientError as exc:
+ _assert_no_percent_encoding_in_error(exc, arn)
+
+
+def test_stop_durable_execution_decodes_slash_in_arn(server_with_slash_arn):
+ """StopDurableExecution: %2F must be decoded so the store lookup hits."""
+ client, arn, _executor, _store = server_with_slash_arn
+
+ try:
+ client.stop_durable_execution(DurableExecutionArn=arn)
+ except ClientError as exc:
+ _assert_no_percent_encoding_in_error(exc, arn)
+
+
+def test_list_durable_executions_by_function_decodes_colon_in_name(
+ server_with_slash_arn,
+):
+ """ListDurableExecutionsByFunction: %3A/%24 in FunctionName must be decoded.
+
+ boto percent-encodes ``:`` and ``$`` in the non-greedy ``{FunctionName}``
+ URI label, so a realistic value like ``MyFunction:$LATEST`` arrives as
+ ``MyFunction%3A%24LATEST``. The route layer must decode the segment so
+ the store's exact-match filter on ``function_name`` returns the expected
+ execution.
+
+ Pre-fix behavior: handler filters on the encoded string, response has
+ no executions. Post-fix: handler filters on the decoded string, response
+ returns the seeded execution.
+ """
+ client, _arn, _executor, store = server_with_slash_arn
+
+ # Seed an execution whose function_name contains characters boto encodes.
+ realistic_function_name = "MyFunction:$LATEST"
+ seed = StartDurableExecutionInput(
+ account_id="123456789012",
+ function_name=realistic_function_name,
+ function_qualifier="$LATEST",
+ execution_name="encoded-fn-exec",
+ execution_timeout_seconds=300,
+ execution_retention_period_days=7,
+ invocation_id="inv-encoded-fn",
+ input='"hi"',
+ )
+ seeded = Execution.new(seed)
+ seeded.start()
+ store.save(seeded)
+
+ response = client.list_durable_executions_by_function(
+ FunctionName=realistic_function_name,
+ )
+
+ arns = [e["DurableExecutionArn"] for e in response.get("DurableExecutions", [])]
+ assert seeded.durable_execution_arn in arns, (
+ f"WebServer route layer did not URL-decode FunctionName. "
+ f"Seeded function_name {realistic_function_name!r} produced arn "
+ f"{seeded.durable_execution_arn!r}, but list response contained "
+ f"{arns!r}."
+ )
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/web/e2e/server_int_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/web/e2e/server_int_test.py
new file mode 100644
index 0000000..0186606
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/web/e2e/server_int_test.py
@@ -0,0 +1,101 @@
+"""Integration tests for web server routing and handler integration."""
+
+from __future__ import annotations
+
+from unittest.mock import Mock
+
+from aws_durable_execution_sdk_python_testing.web.models import (
+ HTTPRequest,
+ HTTPResponse,
+)
+from aws_durable_execution_sdk_python_testing.web.routes import (
+ HealthRoute,
+ StartExecutionRoute,
+)
+from aws_durable_execution_sdk_python_testing.web.server import (
+ WebServer,
+ WebServiceConfig,
+)
+
+
+def test_web_server_router_integration():
+ """Test that router can find routes and handlers can handle them."""
+ executor = Mock()
+ config = WebServiceConfig(port=0) # Use port 0 to get any available port
+
+ server = WebServer(config, executor)
+
+ try:
+ # Test router can find a route
+ route = server.router.find_route("/health", "GET")
+ assert isinstance(route, HealthRoute)
+
+ # Test handler exists for the route
+ handler = server.endpoint_handlers.get(type(route))
+ assert handler is not None
+
+ # Test handler can handle the route
+ request = HTTPRequest(
+ method="GET", path=route, headers={}, query_params={}, body={}
+ )
+
+ response = handler.handle(route, request)
+ assert isinstance(response, HTTPResponse)
+ assert response.status_code == 200
+ assert response.body == {"status": "healthy"}
+ finally:
+ server.server_close()
+
+
+def test_web_server_start_execution_route_integration():
+ """Test that start execution route is properly integrated."""
+ executor = Mock()
+ config = WebServiceConfig(port=0) # Use port 0 to get any available port
+
+ server = WebServer(config, executor)
+
+ try:
+ # Test router can find start execution route
+ route = server.router.find_route("/start-durable-execution", "POST")
+ assert isinstance(route, StartExecutionRoute)
+
+ # Test handler exists for the route
+ handler = server.endpoint_handlers.get(type(route))
+ assert handler is not None
+
+ # Test handler returns 400 for invalid input (now implemented)
+ request = HTTPRequest(
+ method="POST",
+ path=route,
+ headers={},
+ query_params={},
+ body={"test": "data"}, # Invalid input - missing required fields
+ )
+
+ response = handler.handle(route, request)
+ assert isinstance(response, HTTPResponse)
+ assert response.status_code == 400 # Bad request for invalid input
+ finally:
+ server.server_close()
+
+
+def test_web_server_context_manager_with_integration():
+ """Test that WebServer context manager works with integrated components."""
+ executor = Mock()
+ config = WebServiceConfig(port=0) # Use port 0 to get any available port
+
+ with WebServer(config, executor) as server:
+ # Verify server is properly initialized
+ assert server.router is not None
+ assert server.endpoint_handlers is not None
+
+ # Test a simple route resolution
+ route = server.router.find_route("/health", "GET")
+ handler = server.endpoint_handlers[type(route)]
+
+ request = HTTPRequest(
+ method="GET", path=route, headers={}, query_params={}, body={}
+ )
+
+ response = handler.handle(route, request)
+ assert response.status_code == 200
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/web/handlers_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/web/handlers_test.py
new file mode 100644
index 0000000..3cb84a8
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/web/handlers_test.py
@@ -0,0 +1,2635 @@
+"""Tests for HTTP endpoint handlers."""
+
+from __future__ import annotations
+
+import base64
+import json
+from typing import TYPE_CHECKING, Any
+from unittest.mock import Mock
+
+import pytest
+from aws_durable_execution_sdk_python.lambda_service import (
+ ErrorObject,
+ Operation,
+ OperationStatus,
+ OperationType,
+)
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ AwsApiException,
+ IllegalArgumentException,
+ IllegalStateException,
+ InvalidParameterValueException,
+ ResourceNotFoundException,
+)
+
+
+if TYPE_CHECKING:
+ from aws_durable_execution_sdk_python_testing.executor import Executor
+from aws_durable_execution_sdk_python_testing.model import (
+ CheckpointDurableExecutionResponse,
+ Event,
+ ExecutionStartedDetails,
+ GetDurableExecutionHistoryResponse,
+ GetDurableExecutionResponse,
+ GetDurableExecutionStateResponse,
+ ListDurableExecutionsByFunctionResponse,
+ ListDurableExecutionsResponse,
+ SendDurableExecutionCallbackFailureRequest,
+ SendDurableExecutionCallbackFailureResponse,
+ SendDurableExecutionCallbackHeartbeatRequest,
+ SendDurableExecutionCallbackHeartbeatResponse,
+ SendDurableExecutionCallbackSuccessRequest,
+ SendDurableExecutionCallbackSuccessResponse,
+ StartDurableExecutionInput,
+ StartDurableExecutionOutput,
+ StopDurableExecutionResponse,
+)
+from aws_durable_execution_sdk_python_testing.model import (
+ Execution as ExecutionSummary,
+)
+from aws_durable_execution_sdk_python_testing.web import handlers
+from aws_durable_execution_sdk_python_testing.web.handlers import (
+ CheckpointDurableExecutionHandler,
+ EndpointHandler,
+ GetDurableExecutionHandler,
+ GetDurableExecutionHistoryHandler,
+ GetDurableExecutionStateHandler,
+ HealthHandler,
+ ListDurableExecutionsByFunctionHandler,
+ ListDurableExecutionsHandler,
+ MetricsHandler,
+ SendDurableExecutionCallbackFailureHandler,
+ SendDurableExecutionCallbackHeartbeatHandler,
+ SendDurableExecutionCallbackSuccessHandler,
+ StartExecutionHandler,
+ StopDurableExecutionHandler,
+)
+from aws_durable_execution_sdk_python_testing.web.models import (
+ HTTPRequest,
+ HTTPResponse,
+)
+from aws_durable_execution_sdk_python_testing.web.routes import (
+ CallbackFailureRoute,
+ CallbackHeartbeatRoute,
+ CallbackSuccessRoute,
+ GetDurableExecutionRoute,
+ ListDurableExecutionsRoute,
+ Route,
+ Router,
+ StartExecutionRoute,
+)
+
+
+class MockableEndpointHandler(EndpointHandler):
+ """Test-specific handler that exposes private methods for testing."""
+
+ def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse:
+ """Handle request - test implementation."""
+ return self._success_response({"test": "data"})
+
+ # Public methods that expose private functionality for testing
+ def parse_json_body(self, request: HTTPRequest) -> dict[str, Any]:
+ """Public wrapper for _parse_json_body."""
+ return self._parse_json_body(request)
+
+ def json_response(
+ self,
+ status_code: int,
+ data: dict[str, Any],
+ additional_headers: dict[str, str] | None = None,
+ ) -> HTTPResponse:
+ """Public wrapper for _json_response."""
+ return self._json_response(status_code, data, additional_headers)
+
+ def success_response(
+ self, data: dict[str, Any], additional_headers: dict[str, str] | None = None
+ ) -> HTTPResponse:
+ """Public wrapper for _success_response."""
+ return self._success_response(data, additional_headers)
+
+ def created_response(
+ self, data: dict[str, Any], additional_headers: dict[str, str] | None = None
+ ) -> HTTPResponse:
+ """Public wrapper for _created_response."""
+ return self._created_response(data, additional_headers)
+
+ def no_content_response(
+ self, additional_headers: dict[str, str] | None = None
+ ) -> HTTPResponse:
+ """Public wrapper for _no_content_response."""
+ return self._no_content_response(additional_headers)
+
+ def parse_query_param(self, request: HTTPRequest, param_name: str) -> str | None:
+ """Public wrapper for _parse_query_param."""
+ return self._parse_query_param(request, param_name)
+
+ def parse_query_param_list(
+ self, request: HTTPRequest, param_name: str
+ ) -> list[str]:
+ """Public wrapper for _parse_query_param_list."""
+ return self._parse_query_param_list(request, param_name)
+
+ def validate_required_fields(
+ self, data: dict[str, Any], required_fields: list[str]
+ ) -> None:
+ """Public wrapper for _validate_required_fields."""
+ return self._validate_required_fields(data, required_fields)
+
+
+def test_endpoint_handler_initialization():
+ """Test EndpointHandler initialization."""
+ executor = Mock()
+ handler = MockableEndpointHandler(executor)
+ assert handler.executor == executor
+
+
+def test_endpoint_handler_parse_json_body_valid():
+ """Test parse_json_body with valid JSON."""
+ executor = Mock()
+ handler = MockableEndpointHandler(executor)
+
+ request = HTTPRequest(
+ method="POST",
+ path=Route.from_string("/test"),
+ headers={"Content-Type": "application/json"},
+ query_params={},
+ body={"key": "value"},
+ )
+
+ result = handler.parse_json_body(request)
+ assert result == {"key": "value"}
+
+
+def test_endpoint_handler_parse_json_body_empty():
+ """Test parse_json_body with empty body."""
+ executor = Mock()
+ handler = MockableEndpointHandler(executor)
+
+ request = HTTPRequest(
+ method="POST",
+ path=Route.from_string("/test"),
+ headers={"Content-Type": "application/json"},
+ query_params={},
+ body={},
+ )
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Request body is required"
+ ):
+ handler.parse_json_body(request)
+
+
+def test_endpoint_handler_parse_json_body_invalid():
+ """Test parse_json_body with invalid JSON - now this test is not applicable since body is already a dict."""
+ executor = Mock()
+ handler = MockableEndpointHandler(executor)
+
+ # Since body is now a dict, this test case doesn't apply anymore
+ # The validation happens during HTTPRequest.from_bytes() deserialization
+ request = HTTPRequest(
+ method="POST",
+ path=Route.from_string("/test"),
+ headers={"Content-Type": "application/json"},
+ query_params={},
+ body={"valid": "json"}, # Body is always valid dict now
+ )
+
+ # This should work fine now since body is already parsed
+ result = handler.parse_json_body(request)
+ assert result == {"valid": "json"}
+
+
+def test_endpoint_handler_json_response():
+ """Test json_response method."""
+ executor = Mock()
+ handler = MockableEndpointHandler(executor)
+
+ response = handler.json_response(200, {"test": "data"})
+ assert response.status_code == 200
+ assert response.headers["Content-Type"] == "application/json"
+ assert response.body == {"test": "data"}
+
+ # Verify serialization to bytes works
+ body_bytes = response.body_to_bytes()
+ assert b'"test":"data"' in body_bytes
+
+
+def test_endpoint_handler_success_response():
+ """Test success_response method."""
+ executor = Mock()
+ handler = MockableEndpointHandler(executor)
+
+ response = handler.success_response({"test": "data"})
+ assert response.status_code == 200
+ assert response.headers["Content-Type"] == "application/json"
+
+
+def test_endpoint_handler_created_response():
+ """Test created_response method."""
+ executor = Mock()
+ handler = MockableEndpointHandler(executor)
+
+ response = handler.created_response({"test": "data"})
+ assert response.status_code == 201
+ assert response.headers["Content-Type"] == "application/json"
+
+
+def test_endpoint_handler_no_content_response():
+ """Test no_content_response method."""
+ executor = Mock()
+ handler = MockableEndpointHandler(executor)
+
+ response = handler.no_content_response()
+ assert response.status_code == 204
+ assert response.body == {}
+
+
+def test_endpoint_handler_error_response():
+ """Test error response creation using HTTPResponse.create_error_from_exception."""
+ # Test that we can create error responses using the new method
+ exception = InvalidParameterValueException("Bad request")
+
+ response = HTTPResponse.create_error_from_exception(exception)
+ assert response.status_code == 400
+ assert response.headers["Content-Type"] == "application/json"
+
+ # The new format doesn't wrap in an "error" object
+ # InvalidParameterValueException uses lowercase "message" per Smithy definition
+ expected_body = {
+ "Type": "InvalidParameterValueException",
+ "message": "Bad request",
+ }
+ assert response.body == expected_body
+
+ # Verify serialization to bytes works
+ body_bytes = response.body_to_bytes()
+ assert b'"message":"Bad request"' in body_bytes
+ assert b'"Type":"InvalidParameterValueException"' in body_bytes
+
+
+def test_endpoint_handler_parse_query_param():
+ """Test parse_query_param method."""
+ executor = Mock()
+ handler = MockableEndpointHandler(executor)
+
+ request = HTTPRequest(
+ method="GET",
+ path=Route.from_string("/test"),
+ headers={},
+ query_params={"param1": ["value1"], "param2": ["value2a", "value2b"]},
+ body={},
+ )
+
+ assert handler.parse_query_param(request, "param1") == "value1"
+ assert handler.parse_query_param(request, "param2") == "value2a" # First value
+ assert handler.parse_query_param(request, "nonexistent") is None
+
+
+def test_endpoint_handler_parse_query_param_list():
+ """Test parse_query_param_list method."""
+ executor = Mock()
+ handler = MockableEndpointHandler(executor)
+
+ request = HTTPRequest(
+ method="GET",
+ path=Route.from_string("/test"),
+ headers={},
+ query_params={"param1": ["value1"], "param2": ["value2a", "value2b"]},
+ body={},
+ )
+
+ assert handler.parse_query_param_list(request, "param1") == ["value1"]
+ assert handler.parse_query_param_list(request, "param2") == ["value2a", "value2b"]
+ assert handler.parse_query_param_list(request, "nonexistent") == []
+
+
+def test_endpoint_handler_validate_required_fields_valid():
+ """Test validate_required_fields with valid data."""
+ executor = Mock()
+ handler = MockableEndpointHandler(executor)
+
+ data = {"field1": "value1", "field2": "value2", "field3": "value3"}
+ required_fields = ["field1", "field2"]
+
+ # Should not raise an exception
+ handler.validate_required_fields(data, required_fields)
+
+
+def test_endpoint_handler_validate_required_fields_missing():
+ """Test validate_required_fields with missing fields."""
+ executor = Mock()
+ handler = MockableEndpointHandler(executor)
+
+ data = {"field1": "value1"}
+ required_fields = ["field1", "field2", "field3"]
+
+ with pytest.raises(
+ InvalidParameterValueException, match="Missing required fields: field2, field3"
+ ):
+ handler.validate_required_fields(data, required_fields)
+
+
+def test_start_execution_handler_success():
+ """Test StartExecutionHandler with successful execution start."""
+ executor = Mock()
+ handler = StartExecutionHandler(executor)
+
+ # Mock successful executor response
+ mock_output = StartDurableExecutionOutput(execution_arn="test-execution-arn")
+ executor.start_execution.return_value = mock_output
+
+ # Create request with valid input data
+ request_data = {
+ "AccountId": "123456789012",
+ "FunctionName": "test-function",
+ "FunctionQualifier": "$LATEST",
+ "ExecutionName": "test-execution",
+ "ExecutionTimeoutSeconds": 300,
+ "ExecutionRetentionPeriodDays": 7,
+ "Input": '{"test": "data"}',
+ }
+
+ request = HTTPRequest(
+ method="POST",
+ path=StartExecutionRoute.from_string("/start-durable-execution"),
+ headers={"Content-Type": "application/json"},
+ query_params={},
+ body=request_data,
+ )
+
+ route = StartExecutionRoute.from_string("/start-durable-execution")
+ response = handler.handle(route, request)
+
+ # Verify response
+ assert response.status_code == 201
+ assert response.headers["Content-Type"] == "application/json"
+ assert response.body == {"ExecutionArn": "test-execution-arn"}
+
+ # Verify executor was called with correct input
+ executor.start_execution.assert_called_once()
+ call_args = executor.start_execution.call_args[0][0]
+ assert isinstance(call_args, StartDurableExecutionInput)
+ assert call_args.account_id == "123456789012"
+ assert call_args.function_name == "test-function"
+ assert call_args.execution_name == "test-execution"
+
+
+def test_start_execution_handler_empty_body():
+ """Test StartExecutionHandler with empty request body."""
+ executor = Mock()
+ handler = StartExecutionHandler(executor)
+
+ request = HTTPRequest(
+ method="POST",
+ path=StartExecutionRoute.from_string("/start-durable-execution"),
+ headers={"Content-Type": "application/json"},
+ query_params={},
+ body={},
+ )
+
+ route = StartExecutionRoute.from_string("/start-durable-execution")
+ response = handler.handle(route, request)
+
+ # Should return 400 Bad Request for empty body with AWS-compliant format
+ assert response.status_code == 400
+ assert response.body["Type"] == "InvalidParameterValueException"
+ assert "Request body is required" in response.body["message"]
+
+
+def test_start_execution_handler_missing_required_fields():
+ """Test StartExecutionHandler with missing required fields."""
+ executor = Mock()
+ handler = StartExecutionHandler(executor)
+
+ # Request missing required fields
+ request_data = {
+ "AccountId": "123456789012",
+ "FunctionName": "test-function",
+ # Missing other required fields
+ }
+
+ request = HTTPRequest(
+ method="POST",
+ path=StartExecutionRoute.from_string("/start-durable-execution"),
+ headers={"Content-Type": "application/json"},
+ query_params={},
+ body=request_data,
+ )
+
+ route = StartExecutionRoute.from_string("/start-durable-execution")
+ response = handler.handle(route, request)
+
+ # Should return 400 Bad Request for missing fields with AWS-compliant format
+ assert response.status_code == 400
+ assert response.body["Type"] == "InvalidParameterValueException"
+ assert "FunctionQualifier" in response.body["message"]
+
+
+def test_start_execution_handler_invalid_parameter_error():
+ """Test StartExecutionHandler with IllegalArgumentException from executor."""
+
+ executor = Mock()
+ handler = StartExecutionHandler(executor)
+
+ # Mock executor to raise IllegalArgumentException
+ executor.start_execution.side_effect = IllegalArgumentException(
+ "Invalid timeout value"
+ )
+
+ request_data = {
+ "AccountId": "123456789012",
+ "FunctionName": "test-function",
+ "FunctionQualifier": "$LATEST",
+ "ExecutionName": "test-execution",
+ "ExecutionTimeoutSeconds": -1, # Invalid value
+ "ExecutionRetentionPeriodDays": 7,
+ }
+
+ request = HTTPRequest(
+ method="POST",
+ path=StartExecutionRoute.from_string("/start-durable-execution"),
+ headers={"Content-Type": "application/json"},
+ query_params={},
+ body=request_data,
+ )
+
+ route = StartExecutionRoute.from_string("/start-durable-execution")
+ response = handler.handle(route, request)
+
+ # Should return 400 Bad Request with AWS-compliant format
+ assert response.status_code == 400
+ assert response.body["Type"] == "InvalidParameterValueException"
+ assert response.body["message"] == "Invalid timeout value"
+
+
+def test_start_execution_handler_execution_already_exists():
+ """Test StartExecutionHandler with execution already exists error."""
+
+ executor = Mock()
+ handler = StartExecutionHandler(executor)
+
+ # Mock executor to raise IllegalStateException (execution already exists)
+ executor.start_execution.side_effect = IllegalStateException(
+ "Execution with name 'test-execution' already exists"
+ )
+
+ request_data = {
+ "AccountId": "123456789012",
+ "FunctionName": "test-function",
+ "FunctionQualifier": "$LATEST",
+ "ExecutionName": "test-execution",
+ "ExecutionTimeoutSeconds": 300,
+ "ExecutionRetentionPeriodDays": 7,
+ }
+
+ request = HTTPRequest(
+ method="POST",
+ path=StartExecutionRoute.from_string("/start-durable-execution"),
+ headers={"Content-Type": "application/json"},
+ query_params={},
+ body=request_data,
+ )
+
+ route = StartExecutionRoute.from_string("/start-durable-execution")
+ response = handler.handle(route, request)
+
+ # Should return 409 Conflict with AWS-compliant format (ExecutionAlreadyStartedException has no Type field)
+ assert response.status_code == 409
+ assert "already exists" in response.body["message"]
+ assert (
+ response.body["DurableExecutionArn"]
+ == "arn:aws:lambda:us-east-1:123456789012:function:test"
+ )
+ assert (
+ "Type" not in response.body
+ ) # ExecutionAlreadyStartedException doesn't have Type field
+
+
+def test_start_execution_handler_unexpected_error():
+ """Test StartExecutionHandler with unexpected error from executor."""
+ executor = Mock()
+ handler = StartExecutionHandler(executor)
+
+ # Mock executor to raise unexpected error
+ executor.start_execution.side_effect = RuntimeError("Unexpected database error")
+
+ request_data = {
+ "AccountId": "123456789012",
+ "FunctionName": "test-function",
+ "FunctionQualifier": "$LATEST",
+ "ExecutionName": "test-execution",
+ "ExecutionTimeoutSeconds": 300,
+ "ExecutionRetentionPeriodDays": 7,
+ }
+
+ request = HTTPRequest(
+ method="POST",
+ path=StartExecutionRoute.from_string("/start-durable-execution"),
+ headers={"Content-Type": "application/json"},
+ query_params={},
+ body=request_data,
+ )
+
+ route = StartExecutionRoute.from_string("/start-durable-execution")
+ response = handler.handle(route, request)
+
+ # Should return 500 Internal Server Error with AWS-compliant format
+ assert response.status_code == 500
+ assert response.body["Type"] == "ServiceException"
+ assert response.body["Message"] == "Unexpected database error"
+
+
+def test_start_execution_handler_with_optional_fields():
+ """Test StartExecutionHandler with optional fields included."""
+
+ executor = Mock()
+ handler = StartExecutionHandler(executor)
+
+ # Mock successful executor response
+ mock_output = StartDurableExecutionOutput(execution_arn="test-execution-arn")
+ executor.start_execution.return_value = mock_output
+
+ # Create request with optional fields
+ request_data = {
+ "AccountId": "123456789012",
+ "FunctionName": "test-function",
+ "FunctionQualifier": "$LATEST",
+ "ExecutionName": "test-execution",
+ "ExecutionTimeoutSeconds": 300,
+ "ExecutionRetentionPeriodDays": 7,
+ "InvocationId": "test-invocation-id",
+ "TraceFields": {"traceId": "test-trace"},
+ "TenantId": "test-tenant",
+ "Input": '{"test": "data"}',
+ }
+
+ request = HTTPRequest(
+ method="POST",
+ path=StartExecutionRoute.from_string("/start-durable-execution"),
+ headers={"Content-Type": "application/json"},
+ query_params={},
+ body=request_data,
+ )
+
+ route = StartExecutionRoute.from_string("/start-durable-execution")
+ response = handler.handle(route, request)
+
+ # Verify response
+ assert response.status_code == 201
+ assert response.body == {"ExecutionArn": "test-execution-arn"}
+
+ # Verify executor was called with correct input including optional fields
+ executor.start_execution.assert_called_once()
+ call_args = executor.start_execution.call_args[0][0]
+ assert isinstance(call_args, StartDurableExecutionInput)
+ assert call_args.invocation_id == "test-invocation-id"
+ assert call_args.trace_fields == {"traceId": "test-trace"}
+ assert call_args.tenant_id == "test-tenant"
+ assert call_args.input == '{"test": "data"}'
+
+
+def test_get_durable_execution_handler_success():
+ """Test GetDurableExecutionHandler with successful execution retrieval."""
+
+ executor = Mock()
+ handler = GetDurableExecutionHandler(executor)
+
+ # Mock the executor response
+ mock_response = GetDurableExecutionResponse(
+ durable_execution_arn="test-arn",
+ durable_execution_name="test-execution",
+ function_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function",
+ status="SUCCEEDED",
+ start_timestamp="2023-01-01T00:00:00Z",
+ input_payload="test-input",
+ result="test-result",
+ error=None,
+ end_timestamp="2023-01-01T00:01:00Z",
+ version="1.0",
+ )
+ executor.get_execution_details.return_value = mock_response
+
+ # Create strongly-typed route
+ base_route = Route.from_string("/2025-12-01/durable-executions/test-arn")
+ typed_route = GetDurableExecutionRoute.from_route(base_route)
+
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={},
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify response
+ assert response.status_code == 200
+ expected_body = {
+ "DurableExecutionArn": "test-arn",
+ "DurableExecutionName": "test-execution",
+ "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function",
+ "Status": "SUCCEEDED",
+ "StartTimestamp": "2023-01-01T00:00:00Z",
+ "InputPayload": "test-input",
+ "Result": "test-result",
+ "EndTimestamp": "2023-01-01T00:01:00Z",
+ "Version": "1.0",
+ }
+ assert response.body == expected_body
+
+ # Verify executor was called with correct ARN
+ executor.get_execution_details.assert_called_once_with("test-arn")
+
+
+def test_get_durable_execution_handler_resource_not_found():
+ """Test GetDurableExecutionHandler with ResourceNotFoundException."""
+
+ executor = Mock()
+ handler = GetDurableExecutionHandler(executor)
+
+ # Mock executor to raise ResourceNotFoundException
+ executor.get_execution_details.side_effect = ResourceNotFoundException(
+ "Execution not-found-arn not found"
+ )
+
+ # Create strongly-typed route
+ base_route = Route.from_string("/2025-12-01/durable-executions/not-found-arn")
+ typed_route = GetDurableExecutionRoute.from_route(base_route)
+
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={},
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify error response with AWS-compliant format
+ assert response.status_code == 404
+ assert response.body["Type"] == "ResourceNotFoundException"
+ assert response.body["Message"] == "Execution not-found-arn not found"
+
+ # Verify executor was called
+ executor.get_execution_details.assert_called_once_with("not-found-arn")
+
+
+def test_get_durable_execution_handler_invalid_parameter():
+ """Test GetDurableExecutionHandler with IllegalArgumentException."""
+
+ executor = Mock()
+ handler = GetDurableExecutionHandler(executor)
+
+ # Mock executor to raise IllegalArgumentException
+ executor.get_execution_details.side_effect = IllegalArgumentException(
+ "Invalid execution ARN format"
+ )
+
+ # Create strongly-typed route
+ base_route = Route.from_string("/2025-12-01/durable-executions/invalid-arn")
+ typed_route = GetDurableExecutionRoute.from_route(base_route)
+
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={},
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify error response with AWS-compliant format
+ assert response.status_code == 400
+ assert response.body["Type"] == "InvalidParameterValueException"
+ assert response.body["message"] == "Invalid execution ARN format"
+
+ # Verify executor was called
+ executor.get_execution_details.assert_called_once_with("invalid-arn")
+
+
+def test_get_durable_execution_handler_unexpected_error():
+ """Test GetDurableExecutionHandler with unexpected error."""
+
+ executor = Mock()
+ handler = GetDurableExecutionHandler(executor)
+
+ # Mock executor to raise unexpected error
+ executor.get_execution_details.side_effect = RuntimeError("Unexpected error")
+
+ # Create strongly-typed route
+ base_route = Route.from_string("/2025-12-01/durable-executions/test-arn")
+ typed_route = GetDurableExecutionRoute.from_route(base_route)
+
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={},
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify error response with AWS-compliant format
+ assert response.status_code == 500
+ assert response.body["Type"] == "ServiceException"
+ assert response.body["Message"] == "Unexpected error"
+
+ # Verify executor was called
+ executor.get_execution_details.assert_called_once_with("test-arn")
+
+
+def test_checkpoint_durable_execution_handler_success():
+ """Test CheckpointDurableExecutionHandler with successful checkpoint processing."""
+
+ executor = Mock()
+ handler = CheckpointDurableExecutionHandler(executor)
+
+ # Mock the executor response
+ mock_response = CheckpointDurableExecutionResponse(
+ checkpoint_token="new-token-123", # noqa: S106
+ new_execution_state=None,
+ )
+ executor.checkpoint_execution.return_value = mock_response
+
+ # Create request with proper checkpoint data
+ request_body = {
+ "CheckpointToken": "current-token-123",
+ "Updates": [
+ {"Id": "op-1", "Type": "STEP", "Action": "SUCCEED", "SubType": "Step"}
+ ],
+ "ClientToken": "client-token-123",
+ }
+
+ # Create strongly-typed route using Router
+ router = Router()
+ typed_route = router.find_route(
+ "/2025-12-01/durable-executions/test-arn/checkpoint", "POST"
+ )
+
+ request = HTTPRequest(
+ method="POST",
+ path=typed_route,
+ headers={"Content-Type": "application/json"},
+ query_params={},
+ body=request_body,
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify response
+ assert response.status_code == 200
+ assert response.body == {
+ "CheckpointToken": "new-token-123",
+ }
+
+ # Verify executor was called with correct parameters
+ executor.checkpoint_execution.assert_called_once()
+ call_args = executor.checkpoint_execution.call_args
+ assert call_args[0][0] == "test-arn" # execution_arn
+ assert call_args[0][1] == "current-token-123" # checkpoint_token
+ assert call_args[0][3] == "client-token-123" # client_token
+
+ # Verify the updates parameter
+ updates = call_args[0][2]
+ assert len(updates) == 1
+ assert updates[0].operation_id == "op-1"
+
+
+def test_checkpoint_durable_execution_handler_invalid_request():
+ """Test CheckpointDurableExecutionHandler with invalid request body."""
+
+ executor = Mock()
+ handler = CheckpointDurableExecutionHandler(executor)
+
+ # Create strongly-typed route using Router
+ router = Router()
+ typed_route = router.find_route(
+ "/2025-12-01/durable-executions/test-arn/checkpoint", "POST"
+ )
+
+ request = HTTPRequest(
+ method="POST",
+ path=typed_route,
+ headers={},
+ query_params={},
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify AWS-compliant error format
+ assert response.status_code == 400
+ assert response.body["Type"] == "InvalidParameterValueException"
+ assert "Request body is required" in response.body["message"]
+
+
+def test_checkpoint_durable_execution_handler_invalid_checkpoint_exception():
+ """Test CheckpointDurableExecutionHandler with IllegalStateException mapping to ServiceException."""
+
+ executor = Mock()
+ handler = CheckpointDurableExecutionHandler(executor)
+
+ # Mock executor to raise IllegalStateException
+ executor.checkpoint_execution.side_effect = IllegalStateException(
+ "Invalid checkpoint token"
+ )
+
+ request_body = {
+ "CheckpointToken": "invalid-token",
+ }
+
+ # Create strongly-typed route using Router
+ router = Router()
+ typed_route = router.find_route(
+ "/2025-12-01/durable-executions/test-arn/checkpoint", "POST"
+ )
+
+ request = HTTPRequest(
+ method="POST",
+ path=typed_route,
+ headers={"Content-Type": "application/json"},
+ query_params={},
+ body=request_body,
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify IllegalStateException maps to ServiceException in AWS-compliant format
+ assert response.status_code == 500
+ assert response.body["Type"] == "ServiceException"
+ assert response.body["Message"] == "Invalid checkpoint token"
+
+
+def test_stop_durable_execution_handler_success():
+ """Test StopDurableExecutionHandler with successful execution stop."""
+
+ executor = Mock()
+ handler = StopDurableExecutionHandler(executor)
+
+ # Mock the executor response
+ mock_response = StopDurableExecutionResponse(stop_timestamp="2023-01-01T00:01:00Z")
+ executor.stop_execution.return_value = mock_response
+
+ # Create request with proper stop data
+ request_body = {
+ "DurableExecutionArn": "test-arn",
+ "Error": {
+ "ErrorMessage": "User requested stop",
+ "ErrorType": "UserStop",
+ },
+ }
+
+ # Create strongly-typed route using Router
+ router = Router()
+ typed_route = router.find_route(
+ "/2025-12-01/durable-executions/test-arn/stop", "POST"
+ )
+
+ request = HTTPRequest(
+ method="POST",
+ path=typed_route,
+ headers={"Content-Type": "application/json"},
+ query_params={},
+ body=request_body,
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify response
+ assert response.status_code == 200
+ assert response.body == {"StopTimestamp": "2023-01-01T00:01:00Z"}
+
+ # Verify executor was called with correct parameters
+ executor.stop_execution.assert_called_once()
+ call_args = executor.stop_execution.call_args
+ assert call_args[0][0] == "test-arn" # execution_arn
+
+
+def test_stop_durable_execution_handler_execution_already_stopped():
+ """Test StopDurableExecutionHandler with execution already stopped returns idempotent response."""
+
+ executor = Mock()
+ handler = StopDurableExecutionHandler(executor)
+
+ # Mock executor to return stop response with timestamp
+ stop_timestamp = "2023-01-01T00:01:00Z"
+ executor.stop_execution.return_value = StopDurableExecutionResponse(
+ stop_timestamp=stop_timestamp
+ )
+
+ request_body = {
+ "DurableExecutionArn": "test-arn",
+ }
+
+ # Create strongly-typed route using Router
+ router = Router()
+ typed_route = router.find_route(
+ "/2025-12-01/durable-executions/test-arn/stop", "POST"
+ )
+
+ request = HTTPRequest(
+ method="POST",
+ path=typed_route,
+ headers={"Content-Type": "application/json"},
+ query_params={},
+ body=request_body,
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify idempotent response with stop timestamp
+ assert response.status_code == 200
+ assert response.body["StopTimestamp"] == stop_timestamp
+
+
+def test_stop_durable_execution_handler_resource_not_found():
+ """Test StopDurableExecutionHandler with ResourceNotFoundException."""
+
+ executor = Mock()
+ handler = StopDurableExecutionHandler(executor)
+
+ # Mock executor to raise ResourceNotFoundException
+ executor.stop_execution.side_effect = ResourceNotFoundException(
+ "Execution not-found-arn not found"
+ )
+
+ request_body = {
+ "DurableExecutionArn": "not-found-arn",
+ }
+
+ # Create strongly-typed route using Router
+ router = Router()
+ typed_route = router.find_route(
+ "/2025-12-01/durable-executions/not-found-arn/stop", "POST"
+ )
+
+ request = HTTPRequest(
+ method="POST",
+ path=typed_route,
+ headers={"Content-Type": "application/json"},
+ query_params={},
+ body=request_body,
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify error response with AWS-compliant format
+ assert response.status_code == 404
+ assert response.body["Type"] == "ResourceNotFoundException"
+ assert response.body["Message"] == "Execution not-found-arn not found"
+
+
+def test_get_durable_execution_state_handler_success():
+ """Test GetDurableExecutionStateHandler with successful state retrieval."""
+
+ executor = Mock()
+ handler = GetDurableExecutionStateHandler(executor)
+
+ # Mock the executor response with operations
+
+ mock_operations = [
+ Operation(
+ operation_id="op-1",
+ operation_type=OperationType.STEP,
+ status=OperationStatus.SUCCEEDED,
+ name="test-step",
+ )
+ ]
+ mock_response = GetDurableExecutionStateResponse(
+ operations=mock_operations, next_marker=None
+ )
+ executor.get_execution_state.return_value = mock_response
+
+ # Create strongly-typed route using Router
+ router = Router()
+ typed_route = router.find_route(
+ "/2025-12-01/durable-executions/test-arn/state", "GET"
+ )
+
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={},
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify response
+ assert response.status_code == 200
+ assert "Operations" in response.body
+ assert len(response.body["Operations"]) == 1
+ assert response.body["Operations"][0]["Id"] == "op-1"
+ assert response.body["Operations"][0]["Type"] == "STEP"
+
+ # Verify executor was called with correct ARN
+ executor.get_execution_state.assert_called_once_with("test-arn")
+
+
+def test_get_durable_execution_state_handler_resource_not_found():
+ """Test GetDurableExecutionStateHandler with ResourceNotFoundException."""
+
+ executor = Mock()
+ handler = GetDurableExecutionStateHandler(executor)
+
+ # Mock executor to raise ResourceNotFoundException
+ executor.get_execution_state.side_effect = ResourceNotFoundException(
+ "Execution not-found-arn not found"
+ )
+
+ # Create strongly-typed route using Router
+ router = Router()
+ typed_route = router.find_route(
+ "/2025-12-01/durable-executions/not-found-arn/state", "GET"
+ )
+
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={},
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify error response with AWS-compliant format
+ assert response.status_code == 404
+ assert response.body["Type"] == "ResourceNotFoundException"
+ assert response.body["Message"] == "Execution not-found-arn not found"
+
+
+def test_get_durable_execution_state_handler_invalid_parameter():
+ """Test GetDurableExecutionStateHandler with IllegalArgumentException."""
+
+ executor = Mock()
+ handler = GetDurableExecutionStateHandler(executor)
+
+ # Mock executor to raise IllegalArgumentException
+ executor.get_execution_state.side_effect = IllegalArgumentException(
+ "Invalid checkpoint token"
+ )
+
+ # Create strongly-typed route using Router
+ router = Router()
+ typed_route = router.find_route(
+ "/2025-12-01/durable-executions/test-arn/state", "GET"
+ )
+
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={},
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify error response with AWS-compliant format
+ assert response.status_code == 400
+ assert response.body["Type"] == "InvalidParameterValueException"
+ assert response.body["message"] == "Invalid checkpoint token"
+
+
+def test_get_durable_execution_history_handler_success():
+ """Test GetDurableExecutionHistoryHandler with successful history retrieval."""
+
+ executor = Mock()
+ handler = GetDurableExecutionHistoryHandler(executor)
+
+ # Mock the executor response with events
+ mock_events = [
+ Event(
+ event_type="ExecutionStarted",
+ event_timestamp="2023-01-01T00:00:00Z",
+ event_id=1,
+ operation_id="exec-1",
+ execution_started_details=ExecutionStartedDetails(),
+ )
+ ]
+ mock_response = GetDurableExecutionHistoryResponse(
+ events=mock_events, next_marker=None
+ )
+ executor.get_execution_history.return_value = mock_response
+
+ # Create strongly-typed route using Router
+ router = Router()
+ typed_route = router.find_route(
+ "/2025-12-01/durable-executions/test-arn/history", "GET"
+ )
+
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={"MaxItems": ["10"], "Marker": ["token-123"]},
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify response
+ assert response.status_code == 200
+ assert "Events" in response.body
+ assert len(response.body["Events"]) == 1
+ assert response.body["Events"][0]["EventType"] == "ExecutionStarted"
+ assert response.body["Events"][0]["EventId"] == 1
+
+ # Verify executor was called with correct parameters
+ executor.get_execution_history.assert_called_once_with(
+ "test-arn",
+ include_execution_data=False,
+ reverse_order=False,
+ marker="token-123",
+ max_items=10,
+ )
+
+
+def test_get_durable_execution_history_handler_resource_not_found():
+ """Test GetDurableExecutionHistoryHandler with ResourceNotFoundException."""
+
+ executor = Mock()
+ handler = GetDurableExecutionHistoryHandler(executor)
+
+ # Mock executor to raise ResourceNotFoundException
+ executor.get_execution_history.side_effect = ResourceNotFoundException(
+ "Execution not-found-arn not found"
+ )
+
+ # Create strongly-typed route using Router
+ router = Router()
+ typed_route = router.find_route(
+ "/2025-12-01/durable-executions/not-found-arn/history", "GET"
+ )
+
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={},
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify error response with AWS-compliant format
+ assert response.status_code == 404
+ assert response.body["Type"] == "ResourceNotFoundException"
+ assert response.body["Message"] == "Execution not-found-arn not found"
+
+
+def test_get_durable_execution_history_handler_with_query_params():
+ """Test GetDurableExecutionHistoryHandler with query parameters."""
+
+ executor = Mock()
+ handler = GetDurableExecutionHistoryHandler(executor)
+
+ # Mock the executor response
+ mock_response = GetDurableExecutionHistoryResponse(events=[], next_marker=None)
+ executor.get_execution_history.return_value = mock_response
+
+ # Create strongly-typed route using Router
+ router = Router()
+ typed_route = router.find_route(
+ "/2025-12-01/durable-executions/test-arn/history", "GET"
+ )
+
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={"MaxItems": ["25"]},
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify response
+ assert response.status_code == 200
+ assert response.body == {"Events": []}
+
+ # Verify executor was called with correct parameters
+ executor.get_execution_history.assert_called_once_with(
+ "test-arn",
+ include_execution_data=False,
+ reverse_order=False,
+ marker=None,
+ max_items=25,
+ )
+
+
+def test_get_durable_execution_history_handler_with_include_execution_data():
+ """Test GetDurableExecutionHistoryHandler with IncludeExecutionData parameter."""
+
+ executor = Mock()
+ handler = GetDurableExecutionHistoryHandler(executor)
+
+ # Mock the executor response
+ mock_response = GetDurableExecutionHistoryResponse(events=[], next_marker=None)
+ executor.get_execution_history.return_value = mock_response
+
+ # Create strongly-typed route using Router
+ router = Router()
+ typed_route = router.find_route(
+ "/2025-12-01/durable-executions/test-arn/history", "GET"
+ )
+
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={"IncludeExecutionData": ["true"], "MaxItems": ["1000"]},
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify response
+ assert response.status_code == 200
+ assert response.body == {"Events": []}
+
+ # Verify executor was called with include_execution_data=True
+ executor.get_execution_history.assert_called_once_with(
+ "test-arn",
+ include_execution_data=True,
+ reverse_order=False,
+ marker=None,
+ max_items=1000,
+ )
+
+
+def test_get_durable_execution_history_handler_with_include_execution_data_false():
+ """Test GetDurableExecutionHistoryHandler with IncludeExecutionData=false."""
+
+ executor = Mock()
+ handler = GetDurableExecutionHistoryHandler(executor)
+
+ # Mock the executor response
+ mock_response = GetDurableExecutionHistoryResponse(events=[], next_marker=None)
+ executor.get_execution_history.return_value = mock_response
+
+ # Create strongly-typed route using Router
+ router = Router()
+ typed_route = router.find_route(
+ "/2025-12-01/durable-executions/test-arn/history", "GET"
+ )
+
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={"IncludeExecutionData": ["false"]},
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify response
+ assert response.status_code == 200
+ assert response.body == {"Events": []}
+
+ # Verify executor was called with include_execution_data=False
+ executor.get_execution_history.assert_called_once_with(
+ "test-arn",
+ include_execution_data=False,
+ reverse_order=False,
+ marker=None,
+ max_items=None,
+ )
+
+
+def test_list_durable_executions_handler_success():
+ """Test ListDurableExecutionsHandler with successful execution listing."""
+ executor = Mock()
+ handler = ListDurableExecutionsHandler(executor)
+
+ # Mock the executor response
+ mock_executions = [
+ ExecutionSummary(
+ durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:test-1",
+ durable_execution_name="test-execution-1",
+ function_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function",
+ status="SUCCEEDED",
+ start_timestamp="2023-01-01T00:00:00Z",
+ end_timestamp="2023-01-01T00:01:00Z",
+ ),
+ ExecutionSummary(
+ durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:test-2",
+ durable_execution_name="test-execution-2",
+ function_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function",
+ status="RUNNING",
+ start_timestamp="2023-01-01T00:02:00Z",
+ end_timestamp=None,
+ ),
+ ]
+
+ mock_response = ListDurableExecutionsResponse(
+ durable_executions=mock_executions,
+ next_marker=None,
+ )
+ executor.list_executions.return_value = mock_response
+
+ # Create strongly-typed route
+ base_route = Route.from_string("/2025-12-01/durable-executions")
+ typed_route = ListDurableExecutionsRoute.from_route(base_route)
+
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={},
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify response
+ assert response.status_code == 200
+ expected_body = {
+ "DurableExecutions": [
+ {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:test-1",
+ "DurableExecutionName": "test-execution-1",
+ "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function",
+ "Status": "SUCCEEDED",
+ "StartTimestamp": "2023-01-01T00:00:00Z",
+ "EndTimestamp": "2023-01-01T00:01:00Z",
+ },
+ {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:test-2",
+ "DurableExecutionName": "test-execution-2",
+ "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function",
+ "Status": "RUNNING",
+ "StartTimestamp": "2023-01-01T00:02:00Z",
+ },
+ ]
+ }
+ assert response.body == expected_body
+
+ # Verify executor was called with correct parameters (all None for no filters)
+ executor.list_executions.assert_called_once_with(
+ function_name=None,
+ function_version=None,
+ execution_name=None,
+ status_filter=None,
+ started_after=None,
+ started_before=None,
+ marker=None,
+ max_items=None,
+ reverse_order=False,
+ )
+
+
+def test_list_durable_executions_handler_with_filters():
+ """Test ListDurableExecutionsHandler with query parameter filters."""
+ executor = Mock()
+ handler = ListDurableExecutionsHandler(executor)
+
+ # Mock the executor response
+ mock_executions = [
+ ExecutionSummary(
+ durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:filtered-1",
+ durable_execution_name="filtered-execution",
+ function_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function",
+ status="SUCCEEDED",
+ start_timestamp="2023-01-01T00:00:00Z",
+ end_timestamp="2023-01-01T00:01:00Z",
+ ),
+ ]
+
+ mock_response = ListDurableExecutionsResponse(
+ durable_executions=mock_executions,
+ next_marker="next-page-token",
+ )
+ executor.list_executions.return_value = mock_response
+
+ # Create strongly-typed route
+ base_route = Route.from_string("/2025-12-01/durable-executions")
+ typed_route = ListDurableExecutionsRoute.from_route(base_route)
+
+ # Create request with query parameters
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={
+ "FunctionName": ["test-function"],
+ "FunctionVersion": ["$LATEST"],
+ "DurableExecutionName": ["filtered-execution"],
+ "StatusFilter": ["SUCCEEDED"],
+ "StartedAfter": ["2023-01-01T00:00:00Z"],
+ "StartedBefore": ["2023-01-01T23:59:59Z"],
+ "Marker": ["start-token"],
+ "MaxItems": ["10"],
+ "ReverseOrder": ["true"],
+ },
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify response
+ assert response.status_code == 200
+ expected_body = {
+ "DurableExecutions": [
+ {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:filtered-1",
+ "DurableExecutionName": "filtered-execution",
+ "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function",
+ "Status": "SUCCEEDED",
+ "StartTimestamp": "2023-01-01T00:00:00Z",
+ "EndTimestamp": "2023-01-01T00:01:00Z",
+ },
+ ],
+ "NextMarker": "next-page-token",
+ }
+ assert response.body == expected_body
+
+ # Verify executor was called with correct filtered parameters
+ executor.list_executions.assert_called_once_with(
+ function_name="test-function",
+ function_version="$LATEST",
+ execution_name="filtered-execution",
+ status_filter="SUCCEEDED",
+ started_after="2023-01-01T00:00:00Z",
+ started_before="2023-01-01T23:59:59Z",
+ marker="start-token",
+ max_items=10,
+ reverse_order=True,
+ )
+
+
+def test_list_durable_executions_handler_pagination():
+ """Test ListDurableExecutionsHandler with pagination support."""
+ executor = Mock()
+ handler = ListDurableExecutionsHandler(executor)
+
+ # Mock the executor response with pagination
+ mock_executions = [
+ ExecutionSummary(
+ durable_execution_arn=f"arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:page-{i}",
+ durable_execution_name=f"page-execution-{i}",
+ function_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function",
+ status="SUCCEEDED",
+ start_timestamp=f"2023-01-0{i}T00:00:00Z",
+ end_timestamp=f"2023-01-0{i}T00:01:00Z",
+ )
+ for i in range(1, 4) # 3 executions
+ ]
+
+ mock_response = ListDurableExecutionsResponse(
+ durable_executions=mock_executions,
+ next_marker="next-page-marker",
+ )
+ executor.list_executions.return_value = mock_response
+
+ # Create strongly-typed route
+ base_route = Route.from_string("/2025-12-01/durable-executions")
+ typed_route = ListDurableExecutionsRoute.from_route(base_route)
+
+ # Create request with pagination parameters
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={
+ "MaxItems": ["3"],
+ "Marker": ["current-page-marker"],
+ },
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify response includes pagination
+ assert response.status_code == 200
+ assert len(response.body["DurableExecutions"]) == 3
+ assert response.body["NextMarker"] == "next-page-marker"
+
+ # Verify executor was called with pagination parameters
+ executor.list_executions.assert_called_once_with(
+ function_name=None,
+ function_version=None,
+ execution_name=None,
+ status_filter=None,
+ started_after=None,
+ started_before=None,
+ marker="current-page-marker",
+ max_items=3,
+ reverse_order=False,
+ )
+
+
+def test_list_durable_executions_handler_empty_results():
+ """Test ListDurableExecutionsHandler with no executions found."""
+
+ executor = Mock()
+ handler = ListDurableExecutionsHandler(executor)
+
+ # Mock empty executor response
+ mock_response = ListDurableExecutionsResponse(
+ durable_executions=[],
+ next_marker=None,
+ )
+ executor.list_executions.return_value = mock_response
+
+ # Create strongly-typed route
+ base_route = Route.from_string("/2025-12-01/durable-executions")
+ typed_route = ListDurableExecutionsRoute.from_route(base_route)
+
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={},
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify response
+ assert response.status_code == 200
+ assert response.body == {"DurableExecutions": []}
+
+ # Verify executor was called
+ executor.list_executions.assert_called_once()
+
+
+def test_list_durable_executions_handler_dataclass_serialization():
+ """Test ListDurableExecutionsHandler uses from_dict/to_dict methods for serialization."""
+ executor = Mock()
+ handler = ListDurableExecutionsHandler(executor)
+
+ # Mock the executor response
+ mock_executions = [
+ ExecutionSummary(
+ durable_execution_arn="test-arn",
+ durable_execution_name="test-execution",
+ function_arn="test-function-arn",
+ status="SUCCEEDED",
+ start_timestamp="2023-01-01T00:00:00Z",
+ end_timestamp="2023-01-01T00:01:00Z",
+ ),
+ ]
+
+ mock_response = ListDurableExecutionsResponse(
+ durable_executions=mock_executions,
+ next_marker=None,
+ )
+ executor.list_executions.return_value = mock_response
+
+ # Create strongly-typed route
+ base_route = Route.from_string("/2025-12-01/durable-executions")
+ typed_route = ListDurableExecutionsRoute.from_route(base_route)
+
+ # Create request with query parameters to test from_dict
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={
+ "FunctionName": ["test-function"],
+ "MaxItems": ["5"],
+ },
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify response uses to_dict() serialization
+ assert response.status_code == 200
+ assert "DurableExecutions" in response.body
+ assert isinstance(response.body["DurableExecutions"], list)
+
+ # Verify the response structure matches to_dict() output
+ execution_data = response.body["DurableExecutions"][0]
+ assert execution_data["DurableExecutionArn"] == "test-arn"
+ assert execution_data["DurableExecutionName"] == "test-execution"
+ assert execution_data["Status"] == "SUCCEEDED"
+
+ # Verify executor was called (implicitly tests from_dict was used for request parsing)
+ executor.list_executions.assert_called_once_with(
+ function_name="test-function",
+ function_version=None,
+ execution_name=None,
+ status_filter=None,
+ started_after=None,
+ started_before=None,
+ marker=None,
+ max_items=5,
+ reverse_order=False,
+ )
+
+
+def test_list_durable_executions_handler_invalid_parameter_error():
+ """Test ListDurableExecutionsHandler with IllegalArgumentException from executor."""
+
+ executor = Mock()
+ handler = ListDurableExecutionsHandler(executor)
+
+ # Mock executor to raise IllegalArgumentException
+ executor.list_executions.side_effect = IllegalArgumentException(
+ "Invalid MaxItems value"
+ )
+
+ # Create strongly-typed route
+ base_route = Route.from_string("/2025-12-01/durable-executions")
+ typed_route = ListDurableExecutionsRoute.from_route(base_route)
+
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={
+ "MaxItems": ["-1"], # Invalid value
+ },
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify error response with AWS-compliant format
+ assert response.status_code == 400
+ assert response.body["Type"] == "InvalidParameterValueException"
+ assert response.body["message"] == "Invalid MaxItems value"
+
+
+def test_list_durable_executions_handler_unexpected_error():
+ """Test ListDurableExecutionsHandler with unexpected error from executor."""
+
+ executor = Mock()
+ handler = ListDurableExecutionsHandler(executor)
+
+ # Mock executor to raise unexpected error
+ executor.list_executions.side_effect = RuntimeError("Database connection failed")
+
+ # Create strongly-typed route
+ base_route = Route.from_string("/2025-12-01/durable-executions")
+ typed_route = ListDurableExecutionsRoute.from_route(base_route)
+
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={},
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify error response with AWS-compliant format
+ assert response.status_code == 500
+ assert response.body["Type"] == "ServiceException"
+ assert response.body["Message"] == "Database connection failed"
+
+
+def test_list_durable_executions_handler_common_exception_handling():
+ """Test ListDurableExecutionsHandler uses base class _handle_common_exceptions method."""
+
+ executor = Mock()
+ handler = ListDurableExecutionsHandler(executor)
+
+ # Mock executor to raise ResourceNotFoundException
+ executor.list_executions.side_effect = ResourceNotFoundException(
+ "Function not found"
+ )
+
+ # Create strongly-typed route
+ base_route = Route.from_string("/2025-12-01/durable-executions")
+ typed_route = ListDurableExecutionsRoute.from_route(base_route)
+
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={},
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify error response uses common exception handling with AWS-compliant format
+ assert response.status_code == 404
+ assert response.body["Type"] == "ResourceNotFoundException"
+ assert response.body["Message"] == "Function not found"
+
+
+def test_list_durable_executions_by_function_handler_success():
+ """Test ListDurableExecutionsByFunctionHandler with successful execution listing."""
+
+ executor = Mock()
+ handler = ListDurableExecutionsByFunctionHandler(executor)
+
+ # Mock the executor response
+ mock_executions = [
+ ExecutionSummary(
+ durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:func-1",
+ durable_execution_name="function-execution-1",
+ function_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function",
+ status="SUCCEEDED",
+ start_timestamp="2023-01-01T00:00:00Z",
+ end_timestamp="2023-01-01T00:01:00Z",
+ ),
+ ExecutionSummary(
+ durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:func-2",
+ durable_execution_name="function-execution-2",
+ function_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function",
+ status="RUNNING",
+ start_timestamp="2023-01-01T00:02:00Z",
+ end_timestamp=None,
+ ),
+ ]
+
+ mock_response = ListDurableExecutionsByFunctionResponse(
+ durable_executions=mock_executions,
+ next_marker=None,
+ )
+ executor.list_executions_by_function.return_value = mock_response
+
+ # Create strongly-typed route using Router
+ router = Router()
+ typed_route = router.find_route(
+ "/2025-12-01/functions/test-function/durable-executions", "GET"
+ )
+
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={},
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify response
+ assert response.status_code == 200
+ expected_body = {
+ "DurableExecutions": [
+ {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:func-1",
+ "DurableExecutionName": "function-execution-1",
+ "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function",
+ "Status": "SUCCEEDED",
+ "StartTimestamp": "2023-01-01T00:00:00Z",
+ "EndTimestamp": "2023-01-01T00:01:00Z",
+ },
+ {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:func-2",
+ "DurableExecutionName": "function-execution-2",
+ "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function",
+ "Status": "RUNNING",
+ "StartTimestamp": "2023-01-01T00:02:00Z",
+ },
+ ]
+ }
+ assert response.body == expected_body
+
+ # Verify executor was called with correct function name
+ executor.list_executions_by_function.assert_called_once_with(
+ function_name="test-function",
+ qualifier=None,
+ execution_name=None,
+ status_filter=None,
+ started_after=None,
+ started_before=None,
+ marker=None,
+ max_items=None,
+ reverse_order=False,
+ )
+
+
+def test_list_durable_executions_by_function_handler_with_filters():
+ """Test ListDurableExecutionsByFunctionHandler with query parameter filters."""
+
+ executor = Mock()
+ handler = ListDurableExecutionsByFunctionHandler(executor)
+
+ # Mock the executor response
+ mock_executions = [
+ ExecutionSummary(
+ durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:filtered",
+ durable_execution_name="filtered-execution",
+ function_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function",
+ status="SUCCEEDED",
+ start_timestamp="2023-01-01T00:00:00Z",
+ end_timestamp="2023-01-01T00:01:00Z",
+ ),
+ ]
+
+ mock_response = ListDurableExecutionsByFunctionResponse(
+ durable_executions=mock_executions,
+ next_marker="next-page-token",
+ )
+ executor.list_executions_by_function.return_value = mock_response
+
+ # Create strongly-typed route using Router
+ router = Router()
+ typed_route = router.find_route(
+ "/2025-12-01/functions/test-function/durable-executions", "GET"
+ )
+
+ # Create request with query parameters
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={
+ "functionVersion": ["$LATEST"],
+ "executionName": ["filtered-execution"],
+ "statusFilter": ["SUCCEEDED"],
+ "startedAfter": ["2023-01-01T00:00:00Z"],
+ "startedBefore": ["2023-01-01T23:59:59Z"],
+ "marker": ["start-token"],
+ "maxItems": ["5"],
+ "reverseOrder": ["true"],
+ },
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify response
+ assert response.status_code == 200
+ expected_body = {
+ "DurableExecutions": [
+ {
+ "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:filtered",
+ "DurableExecutionName": "filtered-execution",
+ "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function",
+ "Status": "SUCCEEDED",
+ "StartTimestamp": "2023-01-01T00:00:00Z",
+ "EndTimestamp": "2023-01-01T00:01:00Z",
+ },
+ ],
+ "NextMarker": "next-page-token",
+ }
+ assert response.body == expected_body
+
+ # Verify executor was called with correct filtered parameters
+ executor.list_executions_by_function.assert_called_once_with(
+ function_name="test-function",
+ qualifier="$LATEST",
+ execution_name="filtered-execution",
+ status_filter="SUCCEEDED",
+ started_after="2023-01-01T00:00:00Z",
+ started_before="2023-01-01T23:59:59Z",
+ marker="start-token",
+ max_items=5,
+ reverse_order=True,
+ )
+
+
+def test_list_durable_executions_by_function_handler_dataclass_serialization():
+ """Test ListDurableExecutionsByFunctionHandler uses from_dict/to_dict methods for serialization."""
+
+ executor = Mock()
+ handler = ListDurableExecutionsByFunctionHandler(executor)
+
+ # Mock the executor response
+ mock_executions = [
+ ExecutionSummary(
+ durable_execution_arn="test-arn",
+ durable_execution_name="test-execution",
+ function_arn="test-function-arn",
+ status="SUCCEEDED",
+ start_timestamp="2023-01-01T00:00:00Z",
+ end_timestamp="2023-01-01T00:01:00Z",
+ ),
+ ]
+
+ mock_response = ListDurableExecutionsByFunctionResponse(
+ durable_executions=mock_executions,
+ next_marker=None,
+ )
+ executor.list_executions_by_function.return_value = mock_response
+
+ # Create strongly-typed route using Router
+ router = Router()
+ typed_route = router.find_route(
+ "/2025-12-01/functions/test-function/durable-executions", "GET"
+ )
+
+ # Create request with query parameters to test from_dict
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={
+ "functionVersion": ["$LATEST"],
+ "maxItems": ["10"],
+ },
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify response uses to_dict() serialization
+ assert response.status_code == 200
+ assert "DurableExecutions" in response.body
+ assert isinstance(response.body["DurableExecutions"], list)
+
+ # Verify the response structure matches to_dict() output
+ execution_data = response.body["DurableExecutions"][0]
+ assert execution_data["DurableExecutionArn"] == "test-arn"
+ assert execution_data["DurableExecutionName"] == "test-execution"
+ assert execution_data["Status"] == "SUCCEEDED"
+
+ # Verify executor was called (implicitly tests from_dict was used for request parsing)
+ executor.list_executions_by_function.assert_called_once_with(
+ function_name="test-function",
+ qualifier="$LATEST",
+ execution_name=None,
+ status_filter=None,
+ started_after=None,
+ started_before=None,
+ marker=None,
+ max_items=10,
+ reverse_order=False,
+ )
+
+
+def test_list_durable_executions_by_function_handler_resource_not_found():
+ """Test ListDurableExecutionsByFunctionHandler with ResourceNotFoundException."""
+
+ executor = Mock()
+ handler = ListDurableExecutionsByFunctionHandler(executor)
+
+ # Mock executor to raise ResourceNotFoundException
+ executor.list_executions_by_function.side_effect = ResourceNotFoundException(
+ "Function not-found-function not found"
+ )
+
+ # Create strongly-typed route using Router
+ router = Router()
+ typed_route = router.find_route(
+ "/2025-12-01/functions/not-found-function/durable-executions", "GET"
+ )
+
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={},
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify error response uses common exception handling with AWS-compliant format
+ assert response.status_code == 404
+ assert response.body["Type"] == "ResourceNotFoundException"
+ assert response.body["Message"] == "Function not-found-function not found"
+
+
+def test_list_durable_executions_by_function_handler_common_exception_handling():
+ """Test ListDurableExecutionsByFunctionHandler uses base class _handle_common_exceptions method."""
+
+ executor = Mock()
+ handler = ListDurableExecutionsByFunctionHandler(executor)
+
+ # Mock executor to raise IllegalArgumentException
+ executor.list_executions_by_function.side_effect = IllegalArgumentException(
+ "Invalid function name format"
+ )
+
+ # Create strongly-typed route using Router
+ router = Router()
+ typed_route = router.find_route(
+ "/2025-12-01/functions/invalid-function/durable-executions", "GET"
+ )
+
+ request = HTTPRequest(
+ method="GET",
+ path=typed_route,
+ headers={},
+ query_params={},
+ body={},
+ )
+
+ response = handler.handle(typed_route, request)
+
+ # Verify error response uses common exception handling with AWS-compliant format
+ assert response.status_code == 400
+ assert response.body["Type"] == "InvalidParameterValueException"
+ assert response.body["message"] == "Invalid function name format"
+
+
+def test_send_durable_execution_callback_success_handler():
+ """Test SendDurableExecutionCallbackSuccessHandler with valid request."""
+
+ executor = Mock()
+ executor.send_callback_success.return_value = (
+ SendDurableExecutionCallbackSuccessResponse()
+ )
+ handler = SendDurableExecutionCallbackSuccessHandler(executor)
+
+ # Create route using Router
+ router = Router()
+ route = router.find_route(
+ "/2025-12-01/durable-execution-callbacks/test-callback-id/succeed", "POST"
+ )
+ assert isinstance(route, CallbackSuccessRoute)
+ assert route.callback_id == "test-callback-id"
+
+ # Result is sent as raw binary body
+ request_body = b"success-result"
+
+ request = HTTPRequest(
+ method="POST",
+ path=route,
+ headers={},
+ query_params={},
+ body=request_body,
+ )
+
+ response = handler.handle(route, request)
+
+ # Verify successful response
+ assert response.status_code == 200
+ assert response.body == {}
+
+ # Verify executor was called with correct parameters
+ executor.send_callback_success.assert_called_once_with(
+ callback_id="test-callback-id", result=b"success-result"
+ )
+
+
+def test_send_durable_execution_callback_success_handler_empty_body():
+ """Test SendDurableExecutionCallbackSuccessHandler with empty body."""
+ executor = Mock()
+ executor.send_callback_success.return_value = (
+ SendDurableExecutionCallbackSuccessResponse()
+ )
+ handler = SendDurableExecutionCallbackSuccessHandler(executor)
+
+ base_route = Route.from_string(
+ "/2025-12-01/durable-execution-callbacks/test-id/succeed"
+ )
+ callback_route = CallbackSuccessRoute.from_route(base_route)
+
+ request = HTTPRequest(
+ method="POST",
+ path=callback_route,
+ headers={},
+ query_params={},
+ body=b"",
+ )
+
+ response = handler.handle(callback_route, request)
+ # Handler should accept empty body (Result is optional) and return 200
+ assert response.status_code == 200
+ assert response.body == {}
+
+ # Verify executor was called with empty result
+ executor.send_callback_success.assert_called_once_with(
+ callback_id="test-id", result=b""
+ )
+
+
+def test_send_durable_execution_callback_failure_handler():
+ """Test SendDurableExecutionCallbackFailureHandler with valid request."""
+
+ executor = Mock()
+ executor.send_callback_failure.return_value = (
+ SendDurableExecutionCallbackFailureResponse()
+ )
+ handler = SendDurableExecutionCallbackFailureHandler(executor)
+
+ # Create route using Router
+ router = Router()
+ route = router.find_route(
+ "/2025-12-01/durable-execution-callbacks/test-callback-id/fail", "POST"
+ )
+ assert isinstance(route, CallbackFailureRoute)
+ assert route.callback_id == "test-callback-id"
+
+ # Test with valid request body including error
+ error_data = {
+ "ErrorMessage": "Test error",
+ "ErrorType": "TestException",
+ "ErrorData": None,
+ "StackTrace": None,
+ }
+ request = HTTPRequest(
+ method="POST",
+ path=route,
+ headers={"Content-Type": "application/json"},
+ query_params={},
+ body=error_data, # Pass error data directly as body
+ )
+ response = handler.handle(route, request)
+
+ # Verify successful response
+ assert response.status_code == 200
+ assert response.body == {}
+
+ # Verify executor was called with correct parameters
+ executor.send_callback_failure.assert_called_once()
+ call_args = executor.send_callback_failure.call_args
+ assert call_args[1]["callback_id"] == "test-callback-id"
+ assert isinstance(call_args[1]["error"], ErrorObject)
+ assert call_args[1]["error"].message == "Test error"
+
+
+def test_update_lambda_endpoint_handler_success():
+ """Test UpdateLambdaEndpointHandler with valid request."""
+ from aws_durable_execution_sdk_python_testing.invoker import LambdaInvoker
+ from aws_durable_execution_sdk_python_testing.web.handlers import (
+ UpdateLambdaEndpointHandler,
+ )
+ from aws_durable_execution_sdk_python_testing.web.routes import (
+ UpdateLambdaEndpointRoute,
+ )
+
+ executor = Mock()
+ lambda_invoker = Mock(spec=LambdaInvoker)
+ executor._invoker = lambda_invoker # noqa: SLF001
+ handler = UpdateLambdaEndpointHandler(executor)
+
+ base_route = Route.from_string("/lambda-endpoint")
+ update_route = UpdateLambdaEndpointRoute.from_route(base_route)
+
+ request = HTTPRequest(
+ method="PUT",
+ path=update_route,
+ headers={"Content-Type": "application/json"},
+ query_params={},
+ body={"EndpointUrl": "http://localhost:8080", "RegionName": "us-west-2"},
+ )
+
+ response = handler.handle(update_route, request)
+
+ assert response.status_code == 200
+ assert response.body == {"message": "Lambda endpoint updated successfully"}
+ lambda_invoker.update_endpoint.assert_called_once_with(
+ "http://localhost:8080", "us-west-2"
+ )
+
+
+def test_update_lambda_endpoint_handler_missing_endpoint_url():
+ """Test UpdateLambdaEndpointHandler with missing EndpointUrl."""
+ from aws_durable_execution_sdk_python_testing.web.handlers import (
+ UpdateLambdaEndpointHandler,
+ )
+ from aws_durable_execution_sdk_python_testing.web.routes import (
+ UpdateLambdaEndpointRoute,
+ )
+
+ executor = Mock()
+ handler = UpdateLambdaEndpointHandler(executor)
+
+ base_route = Route.from_string("/lambda-endpoint")
+ update_route = UpdateLambdaEndpointRoute.from_route(base_route)
+
+ request = HTTPRequest(
+ method="PUT",
+ path=update_route,
+ headers={"Content-Type": "application/json"},
+ query_params={},
+ body={"RegionName": "us-west-2"},
+ )
+
+ response = handler.handle(update_route, request)
+
+ assert response.status_code == 400
+ assert response.body["Type"] == "InvalidParameterValueException"
+ assert response.body["message"] == "EndpointUrl is required"
+
+
+def test_update_lambda_endpoint_handler_default_region():
+ """Test UpdateLambdaEndpointHandler uses default region when not specified."""
+ from aws_durable_execution_sdk_python_testing.invoker import LambdaInvoker
+ from aws_durable_execution_sdk_python_testing.web.handlers import (
+ UpdateLambdaEndpointHandler,
+ )
+ from aws_durable_execution_sdk_python_testing.web.routes import (
+ UpdateLambdaEndpointRoute,
+ )
+
+ executor = Mock()
+ lambda_invoker = Mock(spec=LambdaInvoker)
+ executor._invoker = lambda_invoker # noqa: SLF001
+ handler = UpdateLambdaEndpointHandler(executor)
+
+ base_route = Route.from_string("/lambda-endpoint")
+ update_route = UpdateLambdaEndpointRoute.from_route(base_route)
+
+ request = HTTPRequest(
+ method="PUT",
+ path=update_route,
+ headers={"Content-Type": "application/json"},
+ query_params={},
+ body={"EndpointUrl": "http://localhost:8080"},
+ )
+
+ response = handler.handle(update_route, request)
+
+ assert response.status_code == 200
+ lambda_invoker.update_endpoint.assert_called_once_with(
+ "http://localhost:8080", "us-east-1"
+ )
+
+
+def test_send_durable_execution_callback_failure_handler_empty_body():
+ """Test SendDurableExecutionCallbackFailureHandler with empty body."""
+ executor = Mock()
+ handler = SendDurableExecutionCallbackFailureHandler(executor)
+
+ base_route = Route.from_string(
+ "/2025-12-01/durable-execution-callbacks/test-id/fail"
+ )
+ callback_route = CallbackFailureRoute.from_route(base_route)
+
+ request = HTTPRequest(
+ method="POST",
+ path=callback_route,
+ headers={},
+ query_params={},
+ body={},
+ )
+
+ response = handler.handle(callback_route, request)
+ # Handler should accept empty body for failure requests
+ assert response.status_code == 200
+
+
+def test_send_durable_execution_callback_heartbeat_handler():
+ """Test SendDurableExecutionCallbackHeartbeatHandler with valid request."""
+
+ executor = Mock()
+ executor.send_callback_heartbeat.return_value = (
+ SendDurableExecutionCallbackHeartbeatResponse()
+ )
+ handler = SendDurableExecutionCallbackHeartbeatHandler(executor)
+
+ # Create route using Router
+ router = Router()
+ route = router.find_route(
+ "/2025-12-01/durable-execution-callbacks/test-callback-id/heartbeat", "POST"
+ )
+ assert isinstance(route, CallbackHeartbeatRoute)
+ assert route.callback_id == "test-callback-id"
+
+ # Test with valid request body
+ request = HTTPRequest(
+ method="POST",
+ path=route,
+ headers={"Content-Type": "application/json"},
+ query_params={},
+ body={"CallbackId": "test-callback-id"},
+ )
+ response = handler.handle(route, request)
+
+ # Verify successful response
+ assert response.status_code == 200
+ assert response.body == {}
+
+ # Verify executor was called with correct parameters
+ executor.send_callback_heartbeat.assert_called_once_with(
+ callback_id="test-callback-id"
+ )
+
+
+def test_send_durable_execution_callback_heartbeat_handler_empty_body():
+ """Test SendDurableExecutionCallbackHeartbeatHandler with empty body."""
+ executor = Mock()
+ handler = SendDurableExecutionCallbackHeartbeatHandler(executor)
+
+ base_route = Route.from_string(
+ "/2025-12-01/durable-execution-callbacks/test-id/heartbeat"
+ )
+ callback_route = CallbackHeartbeatRoute.from_route(base_route)
+
+ request = HTTPRequest(
+ method="POST",
+ path=callback_route,
+ headers={},
+ query_params={},
+ body={},
+ )
+
+ response = handler.handle(callback_route, request)
+ # Handler should accept empty body for heartbeat requests
+ assert response.status_code == 200
+
+
+def test_health_handler():
+ """Test HealthHandler returns healthy status."""
+ executor = Mock()
+ handler = HealthHandler(executor)
+
+ request = HTTPRequest(
+ method="GET",
+ path=Route.from_string("/health"),
+ headers={},
+ query_params={},
+ body={},
+ )
+
+ response = handler.handle(Route.from_string("/health"), request)
+ assert response.status_code == 200
+ assert response.body == {"status": "healthy"}
+
+
+def test_metrics_handler():
+ """Test MetricsHandler returns empty metrics."""
+ executor = Mock()
+ handler = MetricsHandler(executor)
+
+ request = HTTPRequest(
+ method="GET",
+ path=Route.from_string("/metrics"),
+ headers={},
+ query_params={},
+ body={},
+ )
+
+ response = handler.handle(Route.from_string("/metrics"), request)
+ assert response.status_code == 200
+ assert response.body == {"metrics": {}}
+
+
+def test_handler_naming_matches_smithy_operations():
+ """Test that handler names match the Smithy operation names."""
+ # Verify that all handlers are named after their corresponding Smithy operations
+ handler_names = [
+ "StartExecutionHandler", # Note: This one doesn't have "Durable" prefix in Smithy
+ "GetDurableExecutionHandler",
+ "CheckpointDurableExecutionHandler",
+ "StopDurableExecutionHandler",
+ "GetDurableExecutionStateHandler",
+ "GetDurableExecutionHistoryHandler",
+ "ListDurableExecutionsHandler",
+ "ListDurableExecutionsByFunctionHandler",
+ "SendDurableExecutionCallbackSuccessHandler",
+ "SendDurableExecutionCallbackFailureHandler",
+ "SendDurableExecutionCallbackHeartbeatHandler",
+ "HealthHandler",
+ "MetricsHandler",
+ ]
+
+ # Import the handlers module to check all classes exist
+
+ for handler_name in handler_names:
+ assert hasattr(handlers, handler_name), f"Handler {handler_name} not found"
+ handler_class = getattr(handlers, handler_name)
+ assert issubclass(handler_class, EndpointHandler), (
+ f"{handler_name} should inherit from EndpointHandler"
+ )
+
+
+def test_all_handlers_have_executor():
+ """Test that all handlers store the executor reference."""
+ executor = Mock()
+
+ handlers_to_test = [
+ StartExecutionHandler,
+ GetDurableExecutionHandler,
+ CheckpointDurableExecutionHandler,
+ StopDurableExecutionHandler,
+ GetDurableExecutionStateHandler,
+ GetDurableExecutionHistoryHandler,
+ ListDurableExecutionsHandler,
+ ListDurableExecutionsByFunctionHandler,
+ SendDurableExecutionCallbackSuccessHandler,
+ SendDurableExecutionCallbackFailureHandler,
+ SendDurableExecutionCallbackHeartbeatHandler,
+ HealthHandler,
+ MetricsHandler,
+ ]
+
+ for handler_class in handlers_to_test:
+ handler = handler_class(executor)
+ assert handler.executor == executor, (
+ f"{handler_class.__name__} should store executor reference"
+ )
+
+
+class MockExceptionHandler(EndpointHandler):
+ """Test handler that can trigger specific exception types for testing."""
+
+ def __init__(
+ self, executor: Executor, exception_to_raise: Exception | None = None
+ ) -> None:
+ super().__init__(executor)
+ self.exception_to_raise = exception_to_raise
+
+ def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse:
+ """Handle request by raising the configured exception."""
+ if self.exception_to_raise:
+ if isinstance(self.exception_to_raise, AwsApiException):
+ return self._handle_aws_exception(self.exception_to_raise)
+
+ return self._handle_framework_exception(self.exception_to_raise)
+ return self._success_response({"status": "ok"})
+
+
+def test_framework_exception_handling():
+ """Test the framework exception handling through public API."""
+
+ executor = Mock()
+
+ # Test ValueError handling - maps to InvalidParameterValueException
+ handler = MockExceptionHandler(executor, ValueError("Invalid input"))
+ response = handler.handle(Mock(), Mock())
+ assert response.status_code == 400
+ assert response.body["Type"] == "InvalidParameterValueException"
+ assert response.body["message"] == "Invalid input"
+
+ # Test KeyError handling - maps to InvalidParameterValueException
+ handler = MockExceptionHandler(executor, KeyError("missing_field"))
+ response = handler.handle(Mock(), Mock())
+ assert response.status_code == 400
+ assert response.body["Type"] == "InvalidParameterValueException"
+ assert response.body["message"] == "'missing_field'"
+
+ # Test unexpected exception handling - maps to ServiceException
+ handler = MockExceptionHandler(executor, RuntimeError("Unexpected error"))
+ response = handler.handle(Mock(), Mock())
+ assert response.status_code == 500
+ assert response.body["Type"] == "ServiceException"
+ assert response.body["Message"] == "Unexpected error"
+
+
+def test_aws_exception_handling():
+ """Test the AWS exception handling through public API."""
+
+ executor = Mock()
+
+ # Test ResourceNotFoundException handling
+ handler = MockExceptionHandler(
+ executor, ResourceNotFoundException("Resource not found")
+ )
+ response = handler.handle(Mock(), Mock())
+ assert response.status_code == 404
+ assert response.body["Type"] == "ResourceNotFoundException"
+ assert response.body["Message"] == "Resource not found"
+
+ # Test IllegalArgumentException handling
+ handler = MockExceptionHandler(
+ executor, IllegalArgumentException("Invalid parameter")
+ )
+ response = handler.handle(Mock(), Mock())
+ assert response.status_code == 400
+ assert response.body["Type"] == "InvalidParameterValueException"
+ assert response.body["message"] == "Invalid parameter"
+
+
+def test_send_durable_execution_callback_success_handler_invalid_callback_id():
+ """Test SendDurableExecutionCallbackSuccessHandler with invalid callback ID."""
+
+ executor = Mock()
+ executor.send_callback_success.side_effect = IllegalArgumentException(
+ "callback_id is required"
+ )
+ handler = SendDurableExecutionCallbackSuccessHandler(executor)
+
+ # Create route using Router
+ router = Router()
+ route = router.find_route(
+ "/2025-12-01/durable-execution-callbacks/test-callback-id/succeed", "POST"
+ )
+
+ request = HTTPRequest(
+ method="POST",
+ path=route,
+ headers={"Content-Type": "application/json"},
+ query_params={},
+ body={"CallbackId": "test-callback-id"},
+ )
+
+ response = handler.handle(route, request)
+
+ # Verify error response with AWS-compliant format
+ assert response.status_code == 400
+ assert response.body["Type"] == "InvalidParameterValueException"
+ assert "callback_id is required" in response.body["message"]
+
+
+def test_send_durable_execution_callback_success_handler_callback_state_conflict():
+ """Test SendDurableExecutionCallbackSuccessHandler with callback state conflict."""
+
+ executor = Mock()
+ executor.send_callback_success.side_effect = IllegalStateException(
+ "Callback already completed"
+ )
+ handler = SendDurableExecutionCallbackSuccessHandler(executor)
+
+ # Create route using Router
+ router = Router()
+ route = router.find_route(
+ "/2025-12-01/durable-execution-callbacks/test-callback-id/succeed", "POST"
+ )
+
+ request = HTTPRequest(
+ method="POST",
+ path=route,
+ headers={"Content-Type": "application/json"},
+ query_params={},
+ body={"CallbackId": "test-callback-id"},
+ )
+
+ response = handler.handle(route, request)
+
+ # Verify error response - IllegalStateException in callback context maps to ExecutionConflictException
+ assert response.status_code == 409
+ assert response.body["Type"] == "ExecutionConflictException"
+ assert response.body["message"] == "Callback already completed"
+
+
+def test_send_durable_execution_callback_failure_handler_callback_state_conflict():
+ """Test SendDurableExecutionCallbackFailureHandler with callback state conflict."""
+
+ executor = Mock()
+ executor.send_callback_failure.side_effect = IllegalStateException(
+ "Callback already completed"
+ )
+ handler = SendDurableExecutionCallbackFailureHandler(executor)
+
+ # Create route using Router
+ router = Router()
+ route = router.find_route(
+ "/2025-12-01/durable-execution-callbacks/test-callback-id/fail", "POST"
+ )
+
+ request = HTTPRequest(
+ method="POST",
+ path=route,
+ headers={"Content-Type": "application/json"},
+ query_params={},
+ body={"CallbackId": "test-callback-id"},
+ )
+
+ response = handler.handle(route, request)
+
+ # Verify error response - IllegalStateException in callback context maps to ExecutionConflictException
+ assert response.status_code == 409
+ assert response.body["Type"] == "ExecutionConflictException"
+ assert response.body["message"] == "Callback already completed"
+
+
+def test_send_durable_execution_callback_heartbeat_handler_callback_state_conflict():
+ """Test SendDurableExecutionCallbackHeartbeatHandler with callback state conflict."""
+
+ executor = Mock()
+ executor.send_callback_heartbeat.side_effect = IllegalStateException(
+ "Callback already completed"
+ )
+ handler = SendDurableExecutionCallbackHeartbeatHandler(executor)
+
+ # Create route using Router
+ router = Router()
+ route = router.find_route(
+ "/2025-12-01/durable-execution-callbacks/test-callback-id/heartbeat", "POST"
+ )
+
+ request = HTTPRequest(
+ method="POST",
+ path=route,
+ headers={"Content-Type": "application/json"},
+ query_params={},
+ body={"CallbackId": "test-callback-id"},
+ )
+
+ response = handler.handle(route, request)
+
+ # Verify error response - IllegalStateException in callback context maps to ExecutionConflictException
+ assert response.status_code == 409
+ assert response.body["Type"] == "ExecutionConflictException"
+ assert response.body["message"] == "Callback already completed"
+
+
+def test_callback_handlers_use_dataclass_serialization():
+ """Test that all callback handlers use dataclass from_dict/to_dict methods."""
+
+ # Test that all callback request dataclasses have from_dict/to_dict methods
+ success_request = SendDurableExecutionCallbackSuccessRequest.from_dict(
+ {"CallbackId": "test-id", "Result": "test-result"}
+ )
+ assert success_request.callback_id == "test-id"
+ assert success_request.result == "test-result"
+ assert success_request.to_dict() == {
+ "CallbackId": "test-id",
+ "Result": "test-result",
+ }
+
+ failure_request = SendDurableExecutionCallbackFailureRequest.from_dict(
+ {}, "test-id"
+ )
+ assert failure_request.callback_id == "test-id"
+ assert failure_request.error is None
+ assert failure_request.to_dict() == {"CallbackId": "test-id"}
+
+ heartbeat_request = SendDurableExecutionCallbackHeartbeatRequest.from_dict(
+ {"CallbackId": "test-id"}
+ )
+ assert heartbeat_request.callback_id == "test-id"
+ assert heartbeat_request.to_dict() == {"CallbackId": "test-id"}
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/web/models_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/web/models_test.py
new file mode 100644
index 0000000..8148736
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/web/models_test.py
@@ -0,0 +1,796 @@
+"""Tests for HTTP request/response data models and utilities."""
+
+from __future__ import annotations
+
+import datetime
+import json
+from unittest.mock import Mock, patch
+
+import pytest
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ CallbackTimeoutException,
+ ExecutionAlreadyStartedException,
+ IllegalArgumentException,
+ IllegalStateException,
+ InvalidParameterValueException,
+ ResourceNotFoundException,
+ ServiceException,
+ TooManyRequestsException,
+)
+from aws_durable_execution_sdk_python_testing.web.models import (
+ HTTPRequest,
+ HTTPResponse,
+ OperationHandler,
+)
+from aws_durable_execution_sdk_python_testing.web.routes import Route
+
+
+def test_http_request_creation() -> None:
+ """Test HTTPRequest dataclass creation."""
+ path = Route.from_string("/test/path")
+ request = HTTPRequest(
+ method="GET",
+ path=path,
+ headers={"Content-Type": "application/json"},
+ query_params={"param1": ["value1"], "param2": ["value2a", "value2b"]},
+ body={"test": "data"},
+ )
+
+ assert request.method == "GET"
+ assert request.path == path
+ assert request.headers == {"Content-Type": "application/json"}
+ assert request.query_params == {
+ "param1": ["value1"],
+ "param2": ["value2a", "value2b"],
+ }
+ assert request.body == {"test": "data"}
+
+
+def test_http_request_immutable() -> None:
+ """Test that HTTPRequest is immutable."""
+ path = Route.from_string("/test/path")
+ request = HTTPRequest(method="GET", path=path, headers={}, query_params={}, body={})
+
+ # Should not be able to modify fields
+ with pytest.raises(AttributeError):
+ request.method = "POST" # type: ignore
+
+
+def test_http_response_creation() -> None:
+ """Test HTTPResponse dataclass creation."""
+ response = HTTPResponse(
+ status_code=200,
+ headers={"Content-Type": "application/json"},
+ body={"result": "success"},
+ )
+
+ assert response.status_code == 200
+ assert response.headers == {"Content-Type": "application/json"}
+ assert response.body == {"result": "success"}
+
+
+def test_http_response_immutable() -> None:
+ """Test that HTTPResponse is immutable."""
+ response = HTTPResponse(status_code=200, headers={}, body={})
+
+ # Should not be able to modify fields
+ with pytest.raises(AttributeError):
+ response.status_code = 404 # type: ignore
+
+
+def test_http_response_json_basic() -> None:
+ """Test creating basic JSON response."""
+ data = {"message": "success", "id": 123}
+ response = HTTPResponse.create_json(200, data)
+
+ assert response.status_code == 200
+ assert response.headers["Content-Type"] == "application/json"
+
+ # Verify the body is stored as dict
+ assert response.body == data
+
+ # Verify serialization to bytes works
+ body_bytes = response.body_to_bytes()
+ parsed_body = json.loads(body_bytes.decode("utf-8"))
+ assert parsed_body == data
+
+
+def test_http_response_json_with_additional_headers() -> None:
+ """Test creating JSON response with additional headers."""
+ data = {"result": "ok"}
+ additional_headers = {
+ "X-Custom-Header": "custom-value",
+ "Cache-Control": "no-cache",
+ }
+
+ response = HTTPResponse.create_json(201, data, additional_headers)
+
+ assert response.status_code == 201
+ assert response.headers["Content-Type"] == "application/json"
+ assert response.headers["X-Custom-Header"] == "custom-value"
+ assert response.headers["Cache-Control"] == "no-cache"
+
+ # Verify the body is stored as dict
+ assert response.body == data
+
+ # Verify serialization to bytes works
+ body_bytes = response.body_to_bytes()
+ parsed_body = json.loads(body_bytes.decode("utf-8"))
+ assert parsed_body == data
+
+
+def test_http_response_json_compact_serialization() -> None:
+ """Test that JSON response uses compact serialization."""
+ data = {"key": "value", "nested": {"inner": "data"}}
+ response = HTTPResponse.create_json(200, data)
+
+ # Verify the body is stored as dict
+ assert response.body == data
+
+ # Verify serialization to bytes uses compact format
+ body_bytes = response.body_to_bytes()
+ body_str = body_bytes.decode("utf-8")
+ assert " " not in body_str # No spaces after separators
+ assert "\n" not in body_str # No newlines
+
+
+# Removed deprecated tests for create_error method
+
+
+def test_http_response_empty_basic() -> None:
+ """Test creating basic empty response."""
+ response = HTTPResponse.create_empty(204)
+
+ assert response.status_code == 204
+ assert response.headers == {}
+ assert response.body == {}
+
+
+def test_http_response_empty_with_headers() -> None:
+ """Test creating empty response with additional headers."""
+ additional_headers = {"Location": "/new-resource", "X-Request-ID": "123"}
+ response = HTTPResponse.create_empty(201, additional_headers)
+
+ assert response.status_code == 201
+ assert response.headers == additional_headers
+ assert response.body == {}
+
+
+def test_operation_handler_protocol() -> None:
+ """Test that OperationHandler protocol works correctly."""
+
+ class TestHandler:
+ def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse:
+ return HTTPResponse(
+ status_code=200,
+ headers={"Content-Type": "text/plain"},
+ body={"message": "handled"},
+ )
+
+ # Should be able to use as OperationHandler
+ handler: OperationHandler = TestHandler()
+
+ path = Route.from_string("/test")
+ request = HTTPRequest(method="GET", path=path, headers={}, query_params={}, body={})
+
+ response = handler.handle(path, request)
+ assert response.status_code == 200
+ assert response.body == {"message": "handled"}
+
+
+def test_operation_handler_protocol_type_checking() -> None:
+ """Test that OperationHandler protocol enforces correct signature."""
+
+ class InvalidHandler:
+ def handle(self, wrong_params: str) -> str: # Wrong signature
+ return "invalid"
+
+ # This should work at runtime but would fail type checking
+ # We can't test static type checking in unit tests, but this documents the expected behavior
+ invalid_handler = InvalidHandler()
+
+ # The protocol is structural, so this would work at runtime
+ # but mypy would catch the type mismatch
+ assert hasattr(invalid_handler, "handle")
+
+
+def test_http_response_edge_cases() -> None:
+ """Test edge cases for HTTP response factory methods."""
+
+ # Test with empty data
+ response = HTTPResponse.create_json(200, {})
+ assert response.body == {}
+
+ # Test with complex nested data
+ complex_data = {
+ "list": [1, 2, 3],
+ "nested": {"deep": {"value": True}},
+ "null": None,
+ "unicode": "🚀",
+ }
+ response = HTTPResponse.create_json(200, complex_data)
+ assert response.body == complex_data
+
+ # Verify serialization to bytes works
+ body_bytes = response.body_to_bytes()
+ parsed = json.loads(body_bytes.decode("utf-8"))
+ assert parsed == complex_data
+
+
+def test_http_request_with_empty_collections() -> None:
+ """Test HTTPRequest with empty collections."""
+ path = Route.from_string("/empty")
+ request = HTTPRequest(method="GET", path=path, headers={}, query_params={}, body={})
+
+ assert request.headers == {}
+ assert request.query_params == {}
+ assert request.body == {}
+
+
+def test_http_response_with_empty_collections() -> None:
+ """Test HTTPResponse with empty collections."""
+ response = HTTPResponse(status_code=204, headers={}, body={})
+
+ assert response.headers == {}
+ assert response.body == {}
+
+
+# Tests for HTTPRequest.from_bytes method
+
+
+def test_http_request_from_bytes_standard_json() -> None:
+ """Test HTTPRequest.from_bytes with standard JSON deserialization."""
+ test_data = {"key": "value", "number": 42}
+ body_bytes = json.dumps(test_data).encode("utf-8")
+
+ path = Route.from_string("/test")
+ request = HTTPRequest.from_bytes(
+ body_bytes=body_bytes,
+ method="POST",
+ path=path,
+ headers={"Content-Type": "application/json"},
+ query_params={"param": ["value"]},
+ )
+
+ assert request.method == "POST"
+ assert request.path == path
+ assert request.headers == {"Content-Type": "application/json"}
+ assert request.query_params == {"param": ["value"]}
+ assert request.body == test_data
+
+
+def test_http_get_request_from_bytes_ignore_body() -> None:
+ """Test HTTPRequest.from_bytes with standard JSON deserialization."""
+ test_data = {"key": "value", "number": 42}
+ body_bytes = json.dumps(test_data).encode("utf-8")
+
+ path = Route.from_string("/test")
+ request = HTTPRequest.from_bytes(
+ body_bytes=body_bytes,
+ method="GET",
+ path=path,
+ headers={"Content-Type": "application/json"},
+ query_params={"param": ["value"]},
+ )
+
+ assert request.method == "GET"
+ assert request.path == path
+ assert request.headers == {"Content-Type": "application/json"}
+ assert request.query_params == {"param": ["value"]}
+ assert request.body == {}
+
+
+def test_http_request_from_bytes_minimal_params() -> None:
+ """Test HTTPRequest.from_bytes with minimal parameters."""
+ test_data = {"message": "hello"}
+ body_bytes = json.dumps(test_data).encode("utf-8")
+
+ request = HTTPRequest.from_bytes(body_bytes=body_bytes)
+
+ assert request.method == "POST" # Default
+ assert request.path.raw_path == "" # Default empty route
+ assert request.headers == {} # Default
+ assert request.query_params == {} # Default
+ assert request.body == test_data
+
+
+def test_http_request_from_bytes_aws_operation_fallback() -> None:
+ """Test HTTPRequest.from_bytes with AWS operation that falls back to JSON."""
+ test_data = {"Input": "test-input", "ExecutionName": "test-execution"}
+ body_bytes = json.dumps(test_data).encode("utf-8")
+
+ # Use a non-existent operation name to trigger fallback
+ request = HTTPRequest.from_bytes(
+ body_bytes=body_bytes,
+ operation_name="NonExistentOperation",
+ method="POST",
+ )
+
+ assert request.method == "POST"
+ assert request.body == test_data
+
+
+def test_http_request_from_bytes_invalid_json() -> None:
+ """Test HTTPRequest.from_bytes with invalid JSON raises InvalidParameterValueException."""
+ invalid_json = b'{"invalid": json}'
+
+ with pytest.raises(
+ InvalidParameterValueException, match="JSON deserialization failed"
+ ):
+ HTTPRequest.from_bytes(body_bytes=invalid_json)
+
+
+def test_http_request_from_bytes_invalid_utf8() -> None:
+ """Test HTTPRequest.from_bytes with invalid UTF-8 raises InvalidParameterValueException."""
+ invalid_utf8 = b'\xff\xfe{"test": "data"}' # Invalid UTF-8 BOM
+
+ with pytest.raises(
+ InvalidParameterValueException, match="JSON deserialization failed"
+ ):
+ HTTPRequest.from_bytes(body_bytes=invalid_utf8)
+
+
+def test_http_request_from_bytes_empty_body() -> None:
+ """Test HTTPRequest.from_bytes with empty body."""
+ empty_body = b"{}"
+
+ request = HTTPRequest.from_bytes(body_bytes=empty_body)
+
+ assert request.body == {}
+
+
+def test_http_request_from_bytes_complex_json() -> None:
+ """Test HTTPRequest.from_bytes with complex nested JSON."""
+ complex_data = {
+ "list": [1, 2, 3],
+ "nested": {"deep": {"value": True}},
+ "null": None,
+ "unicode": "🚀",
+ }
+ body_bytes = json.dumps(complex_data).encode("utf-8")
+
+ request = HTTPRequest.from_bytes(body_bytes=body_bytes)
+
+ assert request.body == complex_data
+
+
+def test_http_request_from_bytes_aws_operation_success() -> None:
+ """Test HTTPRequest.from_bytes with valid AWS operation (if available)."""
+ # This test will use AWS deserialization if available, otherwise fall back to JSON
+ test_data = {
+ "Input": "test-input",
+ "ExecutionName": "test-execution",
+ "FunctionName": "test-function",
+ }
+ body_bytes = json.dumps(test_data).encode("utf-8")
+
+ # Try with a real AWS operation name
+ request = HTTPRequest.from_bytes(
+ body_bytes=body_bytes,
+ operation_name="StartDurableExecution",
+ method="POST",
+ )
+
+ assert request.method == "POST"
+ assert request.body is not None
+ # The exact structure may vary depending on AWS deserialization vs JSON fallback
+ # but we should get some valid dict data
+
+
+def test_http_request_from_bytes_preserves_field_names() -> None:
+ """Test that from_bytes preserves field names from the input."""
+ # Test with AWS-style PascalCase field names
+ aws_style_data = {
+ "ExecutionName": "test-execution",
+ "FunctionName": "my-function",
+ "Input": {"Key": "Value"},
+ }
+ body_bytes = json.dumps(aws_style_data).encode("utf-8")
+
+ request = HTTPRequest.from_bytes(body_bytes=body_bytes)
+
+ # Field names should be preserved as-is
+ assert isinstance(request.body, dict)
+ assert "ExecutionName" in request.body
+ assert "FunctionName" in request.body
+ assert request.body["ExecutionName"] == "test-execution"
+ assert request.body["FunctionName"] == "my-function"
+ assert request.body["Input"]["Key"] == "Value"
+
+
+# Tests for HTTPResponse.body_to_bytes method
+
+
+def test_http_response_body_to_bytes_standard_json() -> None:
+ """Test HTTPResponse.body_to_bytes with standard JSON serialization."""
+ test_data = {"message": "success", "id": 123}
+ response = HTTPResponse(
+ status_code=200,
+ headers={"Content-Type": "application/json"},
+ body=test_data,
+ )
+
+ body_bytes = response.body_to_bytes()
+
+ # Verify it's bytes
+ assert isinstance(body_bytes, bytes)
+
+ # Verify content is correct
+ parsed_data = json.loads(body_bytes.decode("utf-8"))
+ assert parsed_data == test_data
+
+
+def test_http_response_body_to_bytes_compact_format() -> None:
+ """Test that body_to_bytes uses compact JSON format."""
+ test_data = {"key": "value", "nested": {"inner": "data"}}
+ response = HTTPResponse(status_code=200, headers={}, body=test_data)
+
+ body_bytes = response.body_to_bytes()
+ body_str = body_bytes.decode("utf-8")
+
+ # Should not contain extra whitespace
+ assert " " not in body_str # No spaces after separators
+ assert "\n" not in body_str # No newlines
+
+
+def test_http_response_body_to_bytes_empty_body() -> None:
+ """Test body_to_bytes with empty body."""
+ response = HTTPResponse(status_code=204, headers={}, body={})
+
+ body_bytes = response.body_to_bytes()
+
+ assert body_bytes == b"{}"
+
+
+def test_http_response_body_to_bytes_complex_data() -> None:
+ """Test body_to_bytes with complex nested data."""
+ complex_data = {
+ "list": [1, 2, 3],
+ "nested": {"deep": {"value": True}},
+ "null": None,
+ "unicode": "🚀",
+ }
+ response = HTTPResponse(status_code=200, headers={}, body=complex_data)
+
+ body_bytes = response.body_to_bytes()
+ parsed_data = json.loads(body_bytes.decode("utf-8"))
+
+ assert parsed_data == complex_data
+
+
+# Tests for HTTPResponse.from_dict method
+
+
+def test_http_response_from_dict_basic() -> None:
+ """Test HTTPResponse.from_dict with basic parameters."""
+ test_data = {"message": "success", "id": 123}
+
+ response = HTTPResponse.from_dict(test_data)
+
+ assert response.status_code == 200 # Default
+ assert response.headers == {} # Default
+ assert response.body == test_data
+
+
+def test_http_response_from_dict_with_status_code() -> None:
+ """Test HTTPResponse.from_dict with custom status code."""
+ test_data = {"error": "not found"}
+
+ response = HTTPResponse.from_dict(test_data, status_code=404)
+
+ assert response.status_code == 404
+ assert response.headers == {}
+ assert response.body == test_data
+
+
+def test_http_response_from_dict_with_headers() -> None:
+ """Test HTTPResponse.from_dict with custom headers."""
+ test_data = {"result": "ok"}
+ headers = {"Content-Type": "application/json", "X-Custom": "value"}
+
+ response = HTTPResponse.from_dict(test_data, headers=headers)
+
+ assert response.status_code == 200
+ assert response.headers == headers
+ assert response.body == test_data
+
+
+def test_http_response_from_dict_with_all_params() -> None:
+ """Test HTTPResponse.from_dict with all parameters."""
+ test_data = {"data": "test"}
+ headers = {"Content-Type": "application/json"}
+
+ response = HTTPResponse.from_dict(test_data, status_code=201, headers=headers)
+
+ assert response.status_code == 201
+ assert response.headers == headers
+ assert response.body == test_data
+
+
+def test_http_response_from_dict_empty_data() -> None:
+ """Test HTTPResponse.from_dict with empty data."""
+ response = HTTPResponse.from_dict({})
+
+ assert response.status_code == 200
+ assert response.headers == {}
+ assert response.body == {}
+
+
+def test_http_response_from_dict_complex_data() -> None:
+ """Test HTTPResponse.from_dict with complex nested data."""
+ complex_data = {
+ "list": [1, 2, 3],
+ "nested": {"deep": {"value": True}},
+ "null": None,
+ "unicode": "🚀",
+ }
+
+ response = HTTPResponse.from_dict(complex_data)
+
+ assert response.body == complex_data
+
+
+def test_http_response_from_dict_immutable() -> None:
+ """Test that HTTPResponse.from_dict creates immutable response."""
+ test_data = {"key": "value"}
+ response = HTTPResponse.from_dict(test_data)
+
+ # Should not be able to modify fields
+ with pytest.raises(AttributeError):
+ response.status_code = 404 # type: ignore
+
+
+def test_http_response_from_dict_integration_with_body_to_bytes() -> None:
+ """Test that from_dict works with body_to_bytes method."""
+ test_data = {"message": "integration test", "success": True}
+
+ response = HTTPResponse.from_dict(test_data, status_code=201)
+ body_bytes = response.body_to_bytes()
+
+ # Verify round-trip serialization
+ parsed_data = json.loads(body_bytes.decode("utf-8"))
+ assert parsed_data == test_data
+ assert response.status_code == 201
+
+
+def test_http_request_from_bytes_aws_deserialization_success() -> None:
+ """Test HTTPRequest.from_bytes with successful AWS deserialization."""
+ test_data = {"ExecutionName": "test-execution", "Input": "test-input"}
+ body_bytes = json.dumps(test_data).encode("utf-8")
+
+ # Mock successful AWS deserialization
+ mock_deserializer = Mock()
+ mock_deserializer.from_bytes.return_value = test_data
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.web.models.AwsRestJsonDeserializer.create",
+ return_value=mock_deserializer,
+ ):
+ request = HTTPRequest.from_bytes(
+ body_bytes=body_bytes, operation_name="StartDurableExecution"
+ )
+
+ assert request.body == test_data
+ mock_deserializer.from_bytes.assert_called_once_with(body_bytes)
+
+
+def test_http_request_from_bytes_aws_deserialization_fallback_error() -> None:
+ """Test HTTPRequest.from_bytes when both AWS and JSON deserialization fail."""
+
+ invalid_bytes = b"invalid json data"
+
+ # Mock AWS deserialization failure
+ mock_deserializer = Mock()
+ mock_deserializer.from_bytes.side_effect = InvalidParameterValueException(
+ "AWS failed"
+ )
+
+ with patch(
+ "aws_durable_execution_sdk_python_testing.web.models.AwsRestJsonDeserializer.create",
+ return_value=mock_deserializer,
+ ):
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Both AWS and JSON deserialization failed",
+ ):
+ HTTPRequest.from_bytes(
+ body_bytes=invalid_bytes, operation_name="StartDurableExecution"
+ )
+
+
+def test_http_response_body_to_bytes_serialization_error() -> None:
+ """Test HTTPResponse.body_to_bytes when JSON serialization fail."""
+
+ # Create data that can't be JSON serialized
+ class CustomObject:
+ pass
+
+ test_data = {"custom": CustomObject()}
+ response = HTTPResponse(status_code=200, headers={}, body=test_data)
+
+ with pytest.raises(
+ InvalidParameterValueException,
+ match="Failed to serialize data to JSON: Object of type CustomObject is not JSON serializable",
+ ):
+ response.body_to_bytes()
+
+
+# Tests for HTTPResponse.create_error_from_exception method
+
+
+def test_create_error_from_exception_invalid_parameter_value() -> None:
+ """Test create_error_from_exception with InvalidParameterValueException."""
+
+ exception = InvalidParameterValueException("Parameter 'name' is required")
+ response = HTTPResponse.create_error_from_exception(exception)
+
+ assert response.status_code == 400
+ assert response.headers["Content-Type"] == "application/json"
+
+ expected_body = {
+ "Type": "InvalidParameterValueException",
+ "message": "Parameter 'name' is required",
+ }
+ assert response.body == expected_body
+
+
+def test_create_error_from_exception_resource_not_found() -> None:
+ """Test create_error_from_exception with ResourceNotFoundException."""
+
+ exception = ResourceNotFoundException("Execution not found")
+ response = HTTPResponse.create_error_from_exception(exception)
+
+ assert response.status_code == 404
+ assert response.headers["Content-Type"] == "application/json"
+
+ expected_body = {
+ "Type": "ResourceNotFoundException",
+ "Message": "Execution not found",
+ }
+ assert response.body == expected_body
+
+
+def test_create_error_from_exception_service_exception() -> None:
+ """Test create_error_from_exception with ServiceException."""
+
+ exception = ServiceException("Internal server error")
+ response = HTTPResponse.create_error_from_exception(exception)
+
+ assert response.status_code == 500
+ assert response.headers["Content-Type"] == "application/json"
+
+ expected_body = {"Type": "ServiceException", "Message": "Internal server error"}
+ assert response.body == expected_body
+
+
+def test_create_error_from_exception_execution_already_started() -> None:
+ """Test create_error_from_exception with ExecutionAlreadyStartedException."""
+
+ arn = "arn:aws:lambda:us-east-1:123456789012:function:test"
+ exception = ExecutionAlreadyStartedException("Execution already exists", arn)
+ response = HTTPResponse.create_error_from_exception(exception)
+
+ assert response.status_code == 409
+ assert response.headers["Content-Type"] == "application/json"
+
+ # ExecutionAlreadyStartedException has no Type field per Smithy definition
+ expected_body = {"message": "Execution already exists", "DurableExecutionArn": arn}
+ assert response.body == expected_body
+
+
+def test_create_error_from_exception_callback_timeout() -> None:
+ """Test create_error_from_exception with CallbackTimeoutException."""
+
+ exception = CallbackTimeoutException("Callback timed out")
+ response = HTTPResponse.create_error_from_exception(exception)
+
+ assert response.status_code == 408
+ assert response.headers["Content-Type"] == "application/json"
+
+ expected_body = {
+ "Type": "CallbackTimeoutException",
+ "message": "Callback timed out",
+ }
+ assert response.body == expected_body
+
+
+def test_create_error_from_exception_too_many_requests() -> None:
+ """Test create_error_from_exception with TooManyRequestsException."""
+
+ exception = TooManyRequestsException("Rate limit exceeded")
+ response = HTTPResponse.create_error_from_exception(exception)
+
+ assert response.status_code == 429
+ assert response.headers["Content-Type"] == "application/json"
+
+ expected_body = {
+ "Type": "TooManyRequestsException",
+ "message": "Rate limit exceeded",
+ }
+ assert response.body == expected_body
+
+
+def test_create_error_from_exception_illegal_state() -> None:
+ """Test create_error_from_exception with IllegalStateException (unmapped)."""
+
+ exception = IllegalStateException("Invalid state transition")
+ response = HTTPResponse.create_error_from_exception(exception)
+
+ assert response.status_code == 500
+ assert response.headers["Content-Type"] == "application/json"
+
+ # IllegalStateException maps to ServiceException when serialized
+ expected_body = {"Type": "ServiceException", "Message": "Invalid state transition"}
+ assert response.body == expected_body
+
+
+def test_create_error_from_exception_runtime_exception() -> None:
+ """Test create_error_from_exception with RuntimeException (unmapped)."""
+
+ exception = IllegalArgumentException("Invalid argument provided")
+ response = HTTPResponse.create_error_from_exception(exception)
+
+ assert response.status_code == 400
+ assert response.headers["Content-Type"] == "application/json"
+
+ # IllegalArgumentException maps to InvalidParameterValueException when serialized
+ expected_body = {
+ "Type": "InvalidParameterValueException",
+ "message": "Invalid argument provided",
+ }
+ assert response.body == expected_body
+
+
+def test_create_error_from_exception_type_validation() -> None:
+ """Test create_error_from_exception with non-AwsApiException raises TypeError."""
+ # Test with regular Exception
+ regular_exception = Exception("Not an AWS exception")
+
+ with pytest.raises(
+ TypeError, match="Expected AwsApiException, got "
+ ):
+ HTTPResponse.create_error_from_exception(regular_exception) # type: ignore
+
+ # Test with AWS API exception (should work fine)
+
+ framework_exception = InvalidParameterValueException("Framework error")
+
+ # This should NOT raise an error since InvalidParameterValueException is an AwsApiException
+ response = HTTPResponse.create_error_from_exception(framework_exception)
+ assert response.status_code == 400
+
+
+def test_create_error_from_exception_no_wrapper_object() -> None:
+ """Test that create_error_from_exception doesn't add wrapper 'error' object."""
+
+ exception = InvalidParameterValueException("Test message")
+ response = HTTPResponse.create_error_from_exception(exception)
+
+ # Should NOT have wrapper "error" object like the old create_error method
+ assert "error" not in response.body
+
+ # Should have direct AWS-compliant structure
+ assert "Type" in response.body
+ assert "message" in response.body
+ assert response.body["Type"] == "InvalidParameterValueException"
+ assert response.body["message"] == "Test message"
+
+
+def test_create_error_from_exception_serialization_round_trip() -> None:
+ """Test that create_error_from_exception produces serializable responses."""
+
+ exception = ResourceNotFoundException("Resource not found")
+ response = HTTPResponse.create_error_from_exception(exception)
+
+ # Should be able to serialize to bytes
+ body_bytes = response.body_to_bytes()
+
+ # Should be valid JSON
+ parsed_body = json.loads(body_bytes.decode("utf-8"))
+
+ expected_body = {
+ "Type": "ResourceNotFoundException",
+ "Message": "Resource not found",
+ }
+ assert parsed_body == expected_body
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/web/routes_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/web/routes_test.py
new file mode 100644
index 0000000..cc1be9b
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/web/routes_test.py
@@ -0,0 +1,1187 @@
+"""Tests for the strongly-typed route parsing system."""
+
+from __future__ import annotations
+
+import threading
+import time
+from urllib.parse import quote
+
+import pytest
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ UnknownRouteError,
+)
+from aws_durable_execution_sdk_python_testing.web.routes import (
+ CallbackFailureRoute,
+ CallbackHeartbeatRoute,
+ CallbackSuccessRoute,
+ CheckpointDurableExecutionRoute,
+ GetDurableExecutionHistoryRoute,
+ GetDurableExecutionRoute,
+ GetDurableExecutionStateRoute,
+ HealthRoute,
+ ListDurableExecutionsByFunctionRoute,
+ ListDurableExecutionsRoute,
+ MetricsRoute,
+ Route,
+ Router,
+ StartExecutionRoute,
+ StopDurableExecutionRoute,
+)
+
+
+def test_route_from_string_basic():
+ """Test basic route creation from string."""
+ route = Route.from_string("/test/path")
+ assert route.raw_path == "/test/path"
+ assert route.segments == ["test", "path"]
+
+
+def test_route_from_string_with_leading_trailing_slashes():
+ """Test route creation handles leading and trailing slashes."""
+ route = Route.from_string("///test/path///")
+ assert route.raw_path == "///test/path///"
+ assert route.segments == ["test", "path"]
+
+
+def test_route_from_string_empty_segments():
+ """Test route creation filters out empty segments."""
+ route = Route.from_string("/test//path/")
+ assert route.raw_path == "/test//path/"
+ assert route.segments == ["test", "path"]
+
+
+def test_route_from_string_root():
+ """Test route creation for root path."""
+ route = Route.from_string("/")
+ assert route.raw_path == "/"
+ assert route.segments == []
+
+
+def test_route_matches_pattern_exact():
+ """Test pattern matching with exact segments."""
+ route = Route.from_string("/test/path")
+ assert route.matches_pattern(["test", "path"]) is True
+ assert route.matches_pattern(["test", "other"]) is False
+
+
+def test_route_matches_pattern_wildcard():
+ """Test pattern matching with wildcards."""
+ route = Route.from_string("/test/123/path")
+ assert route.matches_pattern(["test", "*", "path"]) is True
+ assert route.matches_pattern(["test", "*", "other"]) is False
+
+
+def test_route_matches_pattern_length_mismatch():
+ """Test pattern matching fails with different lengths."""
+ route = Route.from_string("/test/path")
+ assert route.matches_pattern(["test"]) is False
+ assert route.matches_pattern(["test", "path", "extra"]) is False
+
+
+def test_start_execution_route_is_match():
+ """Test StartExecutionRoute pattern matching."""
+ route = Route.from_string("/start-durable-execution")
+ assert StartExecutionRoute.is_match(route, "POST") is True
+ assert StartExecutionRoute.is_match(route, "GET") is False
+
+ route = Route.from_string("/start-execution")
+ assert StartExecutionRoute.is_match(route, "POST") is False
+
+
+def test_start_execution_route_from_route():
+ """Test StartExecutionRoute creation from base route."""
+ base_route = Route.from_string("/start-durable-execution")
+ start_route = StartExecutionRoute.from_route(base_route)
+
+ assert start_route.raw_path == "/start-durable-execution"
+ assert start_route.segments == ["start-durable-execution"]
+
+
+# Removed test_start_execution_route_from_route_invalid - from_route() no longer validates
+# Call is_match() first to ensure route is valid
+
+
+def test_get_durable_execution_route_is_match():
+ """Test GetDurableExecutionRoute pattern matching."""
+ route = Route.from_string(
+ "/2025-12-01/durable-executions/arn:aws:lambda:us-east-1:123456789012:function:my-function"
+ )
+ assert GetDurableExecutionRoute.is_match(route, "GET") is True
+ assert GetDurableExecutionRoute.is_match(route, "POST") is False
+
+ route = Route.from_string("/2025-12-01/executions/some-arn")
+ assert GetDurableExecutionRoute.is_match(route, "GET") is False
+
+ route = Route.from_string("/2025-12-01/durable-executions")
+ assert GetDurableExecutionRoute.is_match(route, "GET") is False
+
+
+def test_get_durable_execution_route_from_route():
+ """Test GetDurableExecutionRoute creation from base route."""
+ arn = "arn:aws:lambda:us-east-1:123456789012:function:my-function"
+ base_route = Route.from_string(f"/2025-12-01/durable-executions/{arn}")
+ get_route = GetDurableExecutionRoute.from_route(base_route)
+
+ assert get_route.raw_path == f"/2025-12-01/durable-executions/{arn}"
+ assert get_route.segments == ["2025-12-01", "durable-executions", arn]
+ assert get_route.arn == arn
+
+
+# Removed test_get_durable_execution_route_from_route_invalid - from_route() no longer validates
+# Call is_match() first to ensure route is valid
+
+
+def test_checkpoint_durable_execution_route_is_match():
+ """Test CheckpointDurableExecutionRoute pattern matching."""
+ route = Route.from_string("/2025-12-01/durable-executions/some-arn/checkpoint")
+ assert CheckpointDurableExecutionRoute.is_match(route, "POST") is True
+ assert CheckpointDurableExecutionRoute.is_match(route, "GET") is False
+
+ route = Route.from_string("/2025-12-01/durable-executions/some-arn/stop")
+ assert CheckpointDurableExecutionRoute.is_match(route, "POST") is False
+
+
+def test_checkpoint_durable_execution_route_from_route():
+ """Test CheckpointDurableExecutionRoute creation from base route."""
+ arn = "test-arn"
+ base_route = Route.from_string(f"/2025-12-01/durable-executions/{arn}/checkpoint")
+ checkpoint_route = CheckpointDurableExecutionRoute.from_route(base_route)
+
+ assert (
+ checkpoint_route.raw_path == f"/2025-12-01/durable-executions/{arn}/checkpoint"
+ )
+ assert checkpoint_route.segments == [
+ "2025-12-01",
+ "durable-executions",
+ arn,
+ "checkpoint",
+ ]
+ assert checkpoint_route.arn == arn
+
+
+# Removed test_checkpoint_durable_execution_route_from_route_invalid - from_route() no longer validates
+# Call is_match() first to ensure route is valid
+
+
+def test_stop_durable_execution_route_is_match():
+ """Test StopDurableExecutionRoute pattern matching."""
+ route = Route.from_string("/2025-12-01/durable-executions/some-arn/stop")
+ assert StopDurableExecutionRoute.is_match(route, "POST") is True
+ assert StopDurableExecutionRoute.is_match(route, "GET") is False
+
+ route = Route.from_string("/2025-12-01/durable-executions/some-arn/checkpoint")
+ assert StopDurableExecutionRoute.is_match(route, "POST") is False
+
+
+def test_stop_durable_execution_route_from_route():
+ """Test StopDurableExecutionRoute creation from base route."""
+ arn = "test-arn"
+ base_route = Route.from_string(f"/2025-12-01/durable-executions/{arn}/stop")
+ stop_route = StopDurableExecutionRoute.from_route(base_route)
+
+ assert stop_route.raw_path == f"/2025-12-01/durable-executions/{arn}/stop"
+ assert stop_route.segments == ["2025-12-01", "durable-executions", arn, "stop"]
+ assert stop_route.arn == arn
+
+
+# Removed test_stop_durable_execution_route_from_route_invalid - from_route() no longer validates
+# Call is_match() first to ensure route is valid
+
+
+def test_get_durable_execution_state_route_is_match():
+ """Test GetDurableExecutionStateRoute pattern matching."""
+ route = Route.from_string("/2025-12-01/durable-executions/some-arn/state")
+ assert GetDurableExecutionStateRoute.is_match(route, "GET") is True
+ assert GetDurableExecutionStateRoute.is_match(route, "POST") is False
+
+ route = Route.from_string("/2025-12-01/durable-executions/some-arn/history")
+ assert GetDurableExecutionStateRoute.is_match(route, "GET") is False
+
+
+def test_get_durable_execution_state_route_from_route():
+ """Test GetDurableExecutionStateRoute creation from base route."""
+ arn = "test-arn"
+ base_route = Route.from_string(f"/2025-12-01/durable-executions/{arn}/state")
+ state_route = GetDurableExecutionStateRoute.from_route(base_route)
+
+ assert state_route.raw_path == f"/2025-12-01/durable-executions/{arn}/state"
+ assert state_route.segments == ["2025-12-01", "durable-executions", arn, "state"]
+ assert state_route.arn == arn
+
+
+# Removed test_get_durable_execution_state_route_from_route_invalid - from_route() no longer validates
+# Call is_match() first to ensure route is valid
+
+
+def test_get_durable_execution_history_route_is_match():
+ """Test GetDurableExecutionHistoryRoute pattern matching."""
+ route = Route.from_string("/2025-12-01/durable-executions/some-arn/history")
+ assert GetDurableExecutionHistoryRoute.is_match(route, "GET") is True
+ assert GetDurableExecutionHistoryRoute.is_match(route, "POST") is False
+
+ route = Route.from_string("/2025-12-01/durable-executions/some-arn/state")
+ assert GetDurableExecutionHistoryRoute.is_match(route, "GET") is False
+
+
+def test_get_durable_execution_history_route_from_route():
+ """Test GetDurableExecutionHistoryRoute creation from base route."""
+ arn = "test-arn"
+ base_route = Route.from_string(f"/2025-12-01/durable-executions/{arn}/history")
+ history_route = GetDurableExecutionHistoryRoute.from_route(base_route)
+
+ assert history_route.raw_path == f"/2025-12-01/durable-executions/{arn}/history"
+ assert history_route.segments == [
+ "2025-12-01",
+ "durable-executions",
+ arn,
+ "history",
+ ]
+ assert history_route.arn == arn
+
+
+# Removed test_get_durable_execution_history_route_from_route_invalid - from_route() no longer validates
+# Call is_match() first to ensure route is valid
+
+
+def test_list_durable_executions_route_is_match():
+ """Test ListDurableExecutionsRoute pattern matching."""
+ route = Route.from_string("/2025-12-01/durable-executions")
+ assert ListDurableExecutionsRoute.is_match(route, "GET") is True
+ assert ListDurableExecutionsRoute.is_match(route, "POST") is False
+
+ route = Route.from_string("/2025-12-01/durable-executions/some-arn")
+ assert ListDurableExecutionsRoute.is_match(route, "GET") is False
+
+
+def test_list_durable_executions_route_from_route():
+ """Test ListDurableExecutionsRoute creation from base route."""
+ base_route = Route.from_string("/2025-12-01/durable-executions")
+ list_route = ListDurableExecutionsRoute.from_route(base_route)
+
+ assert list_route.raw_path == "/2025-12-01/durable-executions"
+ assert list_route.segments == ["2025-12-01", "durable-executions"]
+
+
+# Removed test_list_durable_executions_route_from_route_invalid - from_route() no longer validates
+# Call is_match() first to ensure route is valid
+
+
+def test_list_durable_executions_by_function_route_is_match():
+ """Test ListDurableExecutionsByFunctionRoute pattern matching."""
+ route = Route.from_string("/2025-12-01/functions/my-function/durable-executions")
+ assert ListDurableExecutionsByFunctionRoute.is_match(route, "GET") is True
+ assert ListDurableExecutionsByFunctionRoute.is_match(route, "POST") is False
+
+ route = Route.from_string("/2025-12-01/functions/my-function")
+ assert ListDurableExecutionsByFunctionRoute.is_match(route, "GET") is False
+
+
+def test_list_durable_executions_by_function_route_from_route():
+ """Test ListDurableExecutionsByFunctionRoute creation from base route."""
+ function_name = "my-function"
+ base_route = Route.from_string(
+ f"/2025-12-01/functions/{function_name}/durable-executions"
+ )
+ list_route = ListDurableExecutionsByFunctionRoute.from_route(base_route)
+
+ assert (
+ list_route.raw_path
+ == f"/2025-12-01/functions/{function_name}/durable-executions"
+ )
+ assert list_route.segments == [
+ "2025-12-01",
+ "functions",
+ function_name,
+ "durable-executions",
+ ]
+ assert list_route.function_name == function_name
+
+
+# Removed test_list_durable_executions_by_function_route_from_route_invalid - from_route() no longer validates
+# Call is_match() first to ensure route is valid
+
+
+def test_callback_success_route_is_match():
+ """Test CallbackSuccessRoute pattern matching."""
+ route = Route.from_string(
+ "/2025-12-01/durable-execution-callbacks/callback-123/succeed"
+ )
+ assert CallbackSuccessRoute.is_match(route, "POST") is True
+ assert CallbackSuccessRoute.is_match(route, "GET") is False
+
+ route = Route.from_string(
+ "/2025-12-01/durable-execution-callbacks/callback-123/fail"
+ )
+ assert CallbackSuccessRoute.is_match(route, "POST") is False
+
+
+def test_callback_success_route_from_route():
+ """Test CallbackSuccessRoute creation from base route."""
+ callback_id = "callback-123"
+ base_route = Route.from_string(
+ f"/2025-12-01/durable-execution-callbacks/{callback_id}/succeed"
+ )
+ callback_route = CallbackSuccessRoute.from_route(base_route)
+
+ assert (
+ callback_route.raw_path
+ == f"/2025-12-01/durable-execution-callbacks/{callback_id}/succeed"
+ )
+ assert callback_route.segments == [
+ "2025-12-01",
+ "durable-execution-callbacks",
+ callback_id,
+ "succeed",
+ ]
+ assert callback_route.callback_id == callback_id
+
+
+# Removed test_callback_success_route_from_route_invalid - from_route() no longer validates
+# Call is_match() first to ensure route is valid
+
+
+def test_callback_failure_route_is_match():
+ """Test CallbackFailureRoute pattern matching."""
+ route = Route.from_string(
+ "/2025-12-01/durable-execution-callbacks/callback-123/fail"
+ )
+ assert CallbackFailureRoute.is_match(route, "POST") is True
+ assert CallbackFailureRoute.is_match(route, "GET") is False
+
+ route = Route.from_string(
+ "/2025-12-01/durable-execution-callbacks/callback-123/succeed"
+ )
+ assert CallbackFailureRoute.is_match(route, "POST") is False
+
+
+def test_callback_failure_route_from_route():
+ """Test CallbackFailureRoute creation from base route."""
+ callback_id = "callback-123"
+ base_route = Route.from_string(
+ f"/2025-12-01/durable-execution-callbacks/{callback_id}/fail"
+ )
+ callback_route = CallbackFailureRoute.from_route(base_route)
+
+ assert (
+ callback_route.raw_path
+ == f"/2025-12-01/durable-execution-callbacks/{callback_id}/fail"
+ )
+ assert callback_route.segments == [
+ "2025-12-01",
+ "durable-execution-callbacks",
+ callback_id,
+ "fail",
+ ]
+ assert callback_route.callback_id == callback_id
+
+
+# Removed test_callback_failure_route_from_route_invalid - from_route() no longer validates
+# Call is_match() first to ensure route is valid
+
+
+def test_callback_heartbeat_route_is_match():
+ """Test CallbackHeartbeatRoute pattern matching."""
+ route = Route.from_string(
+ "/2025-12-01/durable-execution-callbacks/callback-123/heartbeat"
+ )
+ assert CallbackHeartbeatRoute.is_match(route, "POST") is True
+ assert CallbackHeartbeatRoute.is_match(route, "GET") is False
+
+ route = Route.from_string(
+ "/2025-12-01/durable-execution-callbacks/callback-123/succeed"
+ )
+ assert CallbackHeartbeatRoute.is_match(route, "POST") is False
+
+
+def test_callback_heartbeat_route_from_route():
+ """Test CallbackHeartbeatRoute creation from base route."""
+ callback_id = "callback-123"
+ base_route = Route.from_string(
+ f"/2025-12-01/durable-execution-callbacks/{callback_id}/heartbeat"
+ )
+ callback_route = CallbackHeartbeatRoute.from_route(base_route)
+
+ assert (
+ callback_route.raw_path
+ == f"/2025-12-01/durable-execution-callbacks/{callback_id}/heartbeat"
+ )
+ assert callback_route.segments == [
+ "2025-12-01",
+ "durable-execution-callbacks",
+ callback_id,
+ "heartbeat",
+ ]
+ assert callback_route.callback_id == callback_id
+
+
+# Removed test_callback_heartbeat_route_from_route_invalid - from_route() no longer validates
+# Call is_match() first to ensure route is valid
+
+
+def test_health_route_is_match():
+ """Test HealthRoute pattern matching."""
+ route = Route.from_string("/health")
+ assert HealthRoute.is_match(route, "GET") is True
+ assert HealthRoute.is_match(route, "POST") is False
+
+ route = Route.from_string("/metrics")
+ assert HealthRoute.is_match(route, "GET") is False
+
+
+def test_health_route_from_route():
+ """Test HealthRoute creation from base route."""
+ base_route = Route.from_string("/health")
+ health_route = HealthRoute.from_route(base_route)
+
+ assert health_route.raw_path == "/health"
+ assert health_route.segments == ["health"]
+
+
+# Removed test_health_route_from_route_invalid - from_route() no longer validates
+# Call is_match() first to ensure route is valid
+
+
+def test_metrics_route_is_match():
+ """Test MetricsRoute pattern matching."""
+ route = Route.from_string("/metrics")
+ assert MetricsRoute.is_match(route, "GET") is True
+ assert MetricsRoute.is_match(route, "POST") is False
+
+ route = Route.from_string("/health")
+ assert MetricsRoute.is_match(route, "GET") is False
+
+
+def test_metrics_route_from_route():
+ """Test MetricsRoute creation from base route."""
+ base_route = Route.from_string("/metrics")
+ metrics_route = MetricsRoute.from_route(base_route)
+
+ assert metrics_route.raw_path == "/metrics"
+ assert metrics_route.segments == ["metrics"]
+
+
+# Removed test_metrics_route_from_route_invalid - from_route() no longer validates
+# Call is_match() first to ensure route is valid
+
+
+def test_route_immutability():
+ """Test that route objects are immutable (frozen dataclasses)."""
+ route = StartExecutionRoute.from_route(
+ Route.from_string("/start-durable-execution")
+ )
+
+ # Should not be able to modify frozen dataclass
+ with pytest.raises(AttributeError):
+ route.raw_path = "/modified" # type: ignore[misc]
+
+ with pytest.raises(AttributeError):
+ route.segments = ["modified"] # type: ignore[misc]
+
+
+def test_route_with_special_characters():
+ """Test route parsing with special characters in ARNs and IDs.
+
+ URL-decoding happens once in ``Route.from_string`` so every captured
+ path segment (``segments[N]`` and any named field that mirrors it,
+ such as ``arn`` or ``callback_id``) carries the literal value the
+ caller passed to boto. ``raw_path`` keeps the original wire string.
+ """
+ # ARN with %20-encoded spaces should round-trip back to a literal space.
+ encoded_arn = (
+ "arn:aws:lambda:us-east-1:123456789012:function:my-function%20with%20spaces"
+ )
+ decoded_arn = (
+ "arn:aws:lambda:us-east-1:123456789012:function:my-function with spaces"
+ )
+ raw_path = f"/2025-12-01/durable-executions/{encoded_arn}"
+ router = Router()
+ route = router.find_route(raw_path, "GET")
+ assert isinstance(route, GetDurableExecutionRoute)
+ assert route.arn == decoded_arn
+ assert route.segments[2] == decoded_arn
+ # raw_path is preserved as the original wire form for logging/debugging.
+ assert route.raw_path == raw_path
+
+ # Test with callback ID containing special characters
+ callback_id = "callback-123-abc_def"
+ route = router.find_route(
+ f"/2025-12-01/durable-execution-callbacks/{callback_id}/succeed", "POST"
+ )
+ assert isinstance(route, CallbackSuccessRoute)
+ assert route.callback_id == callback_id
+
+
+def test_route_edge_cases():
+ """Test route parsing edge cases."""
+ router = Router()
+
+ # Empty path
+ with pytest.raises(UnknownRouteError, match="Unknown path pattern"):
+ router.find_route("", "GET")
+
+ # Root path
+ with pytest.raises(UnknownRouteError, match="Unknown path pattern"):
+ router.find_route("/", "GET")
+
+ # Path with only slashes
+ with pytest.raises(UnknownRouteError, match="Unknown path pattern"):
+ router.find_route("///", "GET")
+
+
+def test_route_case_sensitivity():
+ """Test that route matching is case-sensitive."""
+ router = Router()
+
+ # Should not match due to case difference
+ with pytest.raises(UnknownRouteError, match="Unknown path pattern"):
+ router.find_route("/START-DURABLE-EXECUTION", "POST")
+
+ with pytest.raises(UnknownRouteError, match="Unknown path pattern"):
+ router.find_route("/Health", "GET")
+
+
+def test_router_find_route_method_validation():
+ """Test that Router.find_route validates HTTP methods correctly."""
+ router = Router()
+
+ # Valid method combinations
+ route = router.find_route("/start-durable-execution", "POST")
+ assert isinstance(route, StartExecutionRoute)
+
+ route = router.find_route("/2025-12-01/durable-executions/test-arn", "GET")
+ assert isinstance(route, GetDurableExecutionRoute)
+
+ # Invalid method combinations
+ with pytest.raises(
+ UnknownRouteError,
+ match="Unknown path pattern: GET /start-durable-execution",
+ ):
+ router.find_route("/start-durable-execution", "GET")
+
+ with pytest.raises(
+ UnknownRouteError,
+ match="Unknown path pattern: POST /2025-12-01/durable-executions/test-arn",
+ ):
+ router.find_route("/2025-12-01/durable-executions/test-arn", "POST")
+
+ with pytest.raises(UnknownRouteError, match="Unknown path pattern: DELETE /health"):
+ router.find_route("/health", "DELETE")
+
+
+def test_router_find_route_method_case_sensitivity():
+ """Test that HTTP method matching is case-sensitive."""
+ router = Router()
+
+ # Should work with uppercase methods
+ route = router.find_route("/start-durable-execution", "POST")
+ assert isinstance(route, StartExecutionRoute)
+
+ # Should not work with lowercase methods
+ with pytest.raises(
+ UnknownRouteError,
+ match="Unknown path pattern: post /start-durable-execution",
+ ):
+ router.find_route("/start-durable-execution", "post")
+
+ with pytest.raises(UnknownRouteError, match="Unknown path pattern: get /health"):
+ router.find_route("/health", "get")
+
+
+def test_router_initialization_default():
+ """Test Router initialization with default route types."""
+ router = Router()
+
+ # Should work with default route types
+ route = router.find_route("/start-durable-execution", "POST")
+ assert isinstance(route, StartExecutionRoute)
+
+ route = router.find_route("/health", "GET")
+ assert isinstance(route, HealthRoute)
+
+
+def test_router_initialization_custom_route_types():
+ """Test Router initialization with custom route types."""
+ # Create router with only health and metrics routes
+ custom_route_types = [HealthRoute, MetricsRoute]
+ router = Router(route_types=custom_route_types)
+
+ # Should work with included route types
+ route = router.find_route("/health", "GET")
+ assert isinstance(route, HealthRoute)
+
+ route = router.find_route("/metrics", "GET")
+ assert isinstance(route, MetricsRoute)
+
+ # Should not work with excluded route types
+ with pytest.raises(
+ UnknownRouteError,
+ match="Unknown path pattern: POST /start-durable-execution",
+ ):
+ router.find_route("/start-durable-execution", "POST")
+
+
+def test_router_initialization_empty_route_types():
+ """Test Router initialization with empty route types list."""
+ router = Router(route_types=[])
+
+ # Should not match any routes
+ with pytest.raises(
+ UnknownRouteError,
+ match="Unknown path pattern: GET /health",
+ ):
+ router.find_route("/health", "GET")
+
+
+def test_router_find_route_basic():
+ """Test Router.find_route with basic routes."""
+ router = Router()
+
+ # Test various route types
+ route = router.find_route("/start-durable-execution", "POST")
+ assert isinstance(route, StartExecutionRoute)
+
+ arn = "test-arn"
+ route = router.find_route(f"/2025-12-01/durable-executions/{arn}", "GET")
+ assert isinstance(route, GetDurableExecutionRoute)
+ assert route.arn == arn
+
+ route = router.find_route("/health", "GET")
+ assert isinstance(route, HealthRoute)
+
+
+def test_router_find_route_with_parameters():
+ """Test Router.find_route extracts route parameters correctly."""
+ router = Router()
+
+ # Test ARN extraction
+ arn = "arn:aws:lambda:us-east-1:123456789012:function:my-function"
+ route = router.find_route(
+ f"/2025-12-01/durable-executions/{arn}/checkpoint", "POST"
+ )
+ assert isinstance(route, CheckpointDurableExecutionRoute)
+ assert route.arn == arn
+
+ # Test function name extraction
+ function_name = "my-test-function"
+ route = router.find_route(
+ f"/2025-12-01/functions/{function_name}/durable-executions", "GET"
+ )
+ assert isinstance(route, ListDurableExecutionsByFunctionRoute)
+ assert route.function_name == function_name
+
+ # Test callback ID extraction
+ callback_id = "callback-123-abc"
+ route = router.find_route(
+ f"/2025-12-01/durable-execution-callbacks/{callback_id}/succeed", "POST"
+ )
+ assert isinstance(route, CallbackSuccessRoute)
+ assert route.callback_id == callback_id
+
+
+def test_router_find_route_unknown_route():
+ """Test Router.find_route with unknown route patterns."""
+ router = Router()
+
+ with pytest.raises(
+ UnknownRouteError,
+ match="Unknown path pattern: GET /unknown/path",
+ ):
+ router.find_route("/unknown/path", "GET")
+
+ with pytest.raises(
+ UnknownRouteError,
+ match="Unknown path pattern: DELETE /health",
+ ):
+ router.find_route("/health", "DELETE")
+
+
+def test_unknown_route_error_attributes():
+ """Test UnknownRouteError provides structured access to method and path."""
+ router = Router()
+ with pytest.raises(UnknownRouteError) as exc_info:
+ router.find_route("/unknown/path", "POST")
+
+ e = exc_info.value
+ assert e.method == "POST"
+ assert e.path == "/unknown/path"
+ assert str(e) == "Unknown path pattern: POST /unknown/path"
+
+
+def test_router_find_route_priority_order():
+ """Test Router.find_route respects priority order for overlapping patterns."""
+ router = Router()
+
+ # Test that more specific patterns are matched before general ones
+ # This tests the order in DEFAULT_ROUTE_TYPES registry
+
+ # Should match GetDurableExecutionRoute, not ListDurableExecutionsRoute
+ route = router.find_route("/2025-12-01/durable-executions/some-arn", "GET")
+ assert isinstance(route, GetDurableExecutionRoute)
+ assert route.arn == "some-arn"
+
+ # Should match ListDurableExecutionsRoute
+ route = router.find_route("/2025-12-01/durable-executions", "GET")
+ assert isinstance(route, ListDurableExecutionsRoute)
+
+
+def test_router_multiple_instances():
+ """Test that multiple Router instances work independently."""
+ # Create two routers with different route types
+ health_router = Router(route_types=[HealthRoute])
+ full_router = Router() # Uses default route types
+
+ # Health router should only handle health routes
+ route = health_router.find_route("/health", "GET")
+ assert isinstance(route, HealthRoute)
+
+ with pytest.raises(UnknownRouteError):
+ health_router.find_route("/metrics", "GET")
+
+ # Full router should handle all routes
+ route = full_router.find_route("/health", "GET")
+ assert isinstance(route, HealthRoute)
+
+ route = full_router.find_route("/metrics", "GET")
+ assert isinstance(route, MetricsRoute)
+
+ route = full_router.find_route("/start-durable-execution", "POST")
+ assert isinstance(route, StartExecutionRoute)
+
+
+def test_router_find_route_start_execution():
+ """Test Router.find_route with start execution route."""
+ router = Router()
+ route = router.find_route("/start-durable-execution", "POST")
+ assert isinstance(route, StartExecutionRoute)
+ assert route.raw_path == "/start-durable-execution"
+
+
+def test_router_find_route_get_durable_execution():
+ """Test Router.find_route with get durable execution route."""
+ router = Router()
+ arn = "test-arn"
+ route = router.find_route(f"/2025-12-01/durable-executions/{arn}", "GET")
+ assert isinstance(route, GetDurableExecutionRoute)
+ assert route.arn == arn
+
+
+def test_router_find_route_checkpoint_durable_execution():
+ """Test Router.find_route with checkpoint durable execution route."""
+ router = Router()
+ arn = "test-arn"
+ route = router.find_route(
+ f"/2025-12-01/durable-executions/{arn}/checkpoint", "POST"
+ )
+ assert isinstance(route, CheckpointDurableExecutionRoute)
+ assert route.arn == arn
+
+
+def test_router_find_route_stop_durable_execution():
+ """Test Router.find_route with stop durable execution route."""
+ router = Router()
+ arn = "test-arn"
+ route = router.find_route(f"/2025-12-01/durable-executions/{arn}/stop", "POST")
+ assert isinstance(route, StopDurableExecutionRoute)
+ assert route.arn == arn
+
+
+def test_router_find_route_get_durable_execution_state():
+ """Test Router.find_route with get durable execution state route."""
+ router = Router()
+ arn = "test-arn"
+ route = router.find_route(f"/2025-12-01/durable-executions/{arn}/state", "GET")
+ assert isinstance(route, GetDurableExecutionStateRoute)
+ assert route.arn == arn
+
+
+def test_router_find_route_get_durable_execution_history():
+ """Test Router.find_route with get durable execution history route."""
+ router = Router()
+ arn = "test-arn"
+ route = router.find_route(f"/2025-12-01/durable-executions/{arn}/history", "GET")
+ assert isinstance(route, GetDurableExecutionHistoryRoute)
+ assert route.arn == arn
+
+
+def test_router_find_route_list_durable_executions():
+ """Test Router.find_route with list durable executions route."""
+ router = Router()
+ route = router.find_route("/2025-12-01/durable-executions", "GET")
+ assert isinstance(route, ListDurableExecutionsRoute)
+
+
+def test_router_find_route_list_durable_executions_by_function():
+ """Test Router.find_route with list durable executions by function route."""
+ router = Router()
+ function_name = "my-function"
+ route = router.find_route(
+ f"/2025-12-01/functions/{function_name}/durable-executions", "GET"
+ )
+ assert isinstance(route, ListDurableExecutionsByFunctionRoute)
+ assert route.function_name == function_name
+
+
+def test_router_find_route_callback_success():
+ """Test Router.find_route with callback success route."""
+ router = Router()
+ callback_id = "callback-123"
+ route = router.find_route(
+ f"/2025-12-01/durable-execution-callbacks/{callback_id}/succeed", "POST"
+ )
+ assert isinstance(route, CallbackSuccessRoute)
+ assert route.callback_id == callback_id
+
+
+def test_router_find_route_callback_failure():
+ """Test Router.find_route with callback failure route."""
+ router = Router()
+ callback_id = "callback-123"
+ route = router.find_route(
+ f"/2025-12-01/durable-execution-callbacks/{callback_id}/fail", "POST"
+ )
+ assert isinstance(route, CallbackFailureRoute)
+ assert route.callback_id == callback_id
+
+
+def test_router_find_route_callback_heartbeat():
+ """Test Router.find_route with callback heartbeat route."""
+ router = Router()
+ callback_id = "callback-123"
+ route = router.find_route(
+ f"/2025-12-01/durable-execution-callbacks/{callback_id}/heartbeat", "POST"
+ )
+ assert isinstance(route, CallbackHeartbeatRoute)
+ assert route.callback_id == callback_id
+
+
+def test_router_find_route_health():
+ """Test Router.find_route with health route."""
+ router = Router()
+ route = router.find_route("/health", "GET")
+ assert isinstance(route, HealthRoute)
+
+
+def test_router_find_route_metrics():
+ """Test Router.find_route with metrics route."""
+ router = Router()
+ route = router.find_route("/metrics", "GET")
+ assert isinstance(route, MetricsRoute)
+
+
+def test_router_find_route_unknown():
+ """Test Router.find_route with unknown route pattern."""
+ router = Router()
+ with pytest.raises(
+ UnknownRouteError,
+ match="Unknown path pattern: GET /unknown/path",
+ ):
+ router.find_route("/unknown/path", "GET")
+
+
+def test_router_constructor_with_all_default_route_types():
+ """Test Router constructor includes all expected default route types."""
+ router = Router()
+
+ # Test that all route types are included by trying to match each one
+ test_cases = [
+ ("/start-durable-execution", "POST", StartExecutionRoute),
+ ("/2025-12-01/durable-executions/test-arn", "GET", GetDurableExecutionRoute),
+ (
+ "/2025-12-01/durable-executions/test-arn/checkpoint",
+ "POST",
+ CheckpointDurableExecutionRoute,
+ ),
+ (
+ "/2025-12-01/durable-executions/test-arn/stop",
+ "POST",
+ StopDurableExecutionRoute,
+ ),
+ (
+ "/2025-12-01/durable-executions/test-arn/state",
+ "GET",
+ GetDurableExecutionStateRoute,
+ ),
+ (
+ "/2025-12-01/durable-executions/test-arn/history",
+ "GET",
+ GetDurableExecutionHistoryRoute,
+ ),
+ ("/2025-12-01/durable-executions", "GET", ListDurableExecutionsRoute),
+ (
+ "/2025-12-01/functions/test-func/durable-executions",
+ "GET",
+ ListDurableExecutionsByFunctionRoute,
+ ),
+ (
+ "/2025-12-01/durable-execution-callbacks/test-id/succeed",
+ "POST",
+ CallbackSuccessRoute,
+ ),
+ (
+ "/2025-12-01/durable-execution-callbacks/test-id/fail",
+ "POST",
+ CallbackFailureRoute,
+ ),
+ (
+ "/2025-12-01/durable-execution-callbacks/test-id/heartbeat",
+ "POST",
+ CallbackHeartbeatRoute,
+ ),
+ ("/health", "GET", HealthRoute),
+ ("/metrics", "GET", MetricsRoute),
+ ]
+
+ for path, method, expected_type in test_cases:
+ route = router.find_route(path, method)
+ assert isinstance(route, expected_type), (
+ f"Expected {expected_type.__name__} for {method} {path}"
+ )
+
+
+def test_router_constructor_with_subset_of_route_types():
+ """Test Router constructor with a subset of route types."""
+ # Create router with only callback routes
+ callback_route_types = [
+ CallbackSuccessRoute,
+ CallbackFailureRoute,
+ CallbackHeartbeatRoute,
+ ]
+ router = Router(route_types=callback_route_types)
+
+ # Should work with callback routes
+ route = router.find_route(
+ "/2025-12-01/durable-execution-callbacks/test-id/succeed", "POST"
+ )
+ assert isinstance(route, CallbackSuccessRoute)
+
+ route = router.find_route(
+ "/2025-12-01/durable-execution-callbacks/test-id/fail", "POST"
+ )
+ assert isinstance(route, CallbackFailureRoute)
+
+ route = router.find_route(
+ "/2025-12-01/durable-execution-callbacks/test-id/heartbeat", "POST"
+ )
+ assert isinstance(route, CallbackHeartbeatRoute)
+
+ # Should not work with other route types
+ with pytest.raises(UnknownRouteError):
+ router.find_route("/health", "GET")
+
+ with pytest.raises(UnknownRouteError):
+ router.find_route("/start-durable-execution", "POST")
+
+
+def test_router_constructor_with_single_route_type():
+ """Test Router constructor with a single route type."""
+ router = Router(route_types=[HealthRoute])
+
+ # Should work with the single route type
+ route = router.find_route("/health", "GET")
+ assert isinstance(route, HealthRoute)
+
+ # Should not work with any other route types
+ with pytest.raises(UnknownRouteError):
+ router.find_route("/metrics", "GET")
+
+ with pytest.raises(UnknownRouteError):
+ router.find_route("/start-durable-execution", "POST")
+
+
+def test_router_constructor_with_duplicate_route_types():
+ """Test Router constructor handles duplicate route types gracefully."""
+ # Include HealthRoute twice
+ duplicate_route_types = [HealthRoute, MetricsRoute, HealthRoute]
+ router = Router(route_types=duplicate_route_types)
+
+ # Should still work correctly (first match wins)
+ route = router.find_route("/health", "GET")
+ assert isinstance(route, HealthRoute)
+
+ route = router.find_route("/metrics", "GET")
+ assert isinstance(route, MetricsRoute)
+
+
+def test_router_find_route_error_handling_comprehensive():
+ """Test Router.find_route error handling with various invalid inputs."""
+ router = Router()
+
+ # Test various invalid path/method combinations
+ invalid_cases = [
+ ("", "GET", "Unknown path pattern: GET "),
+ ("/", "GET", "Unknown path pattern: GET /"),
+ ("///", "GET", "Unknown path pattern: GET ///"),
+ ("/unknown", "GET", "Unknown path pattern: GET /unknown"),
+ (
+ "/start-durable-execution",
+ "GET",
+ "Unknown path pattern: GET /start-durable-execution",
+ ),
+ ("/health", "POST", "Unknown path pattern: POST /health"),
+ ("/metrics", "DELETE", "Unknown path pattern: DELETE /metrics"),
+ (
+ "/2025-12-01/durable-executions/test-arn",
+ "POST",
+ "Unknown path pattern: POST /2025-12-01/durable-executions/test-arn",
+ ),
+ (
+ "/2025-12-01/durable-executions/test-arn/checkpoint",
+ "GET",
+ "Unknown path pattern: GET /2025-12-01/durable-executions/test-arn/checkpoint",
+ ),
+ ]
+
+ for path, method, expected_message in invalid_cases:
+ with pytest.raises(UnknownRouteError, match=expected_message):
+ router.find_route(path, method)
+
+
+def test_router_find_route_with_complex_parameters():
+ """Test Router.find_route with complex parameter extraction."""
+ router = Router()
+
+ # Test with complex ARN
+ complex_arn = "arn:aws:lambda:us-west-2:123456789012:function:my-complex-function-name_with_underscores-and-dashes"
+ route = router.find_route(f"/2025-12-01/durable-executions/{complex_arn}", "GET")
+ assert isinstance(route, GetDurableExecutionRoute)
+ assert route.arn == complex_arn
+
+ # Test with complex function name
+ complex_function_name = "my_complex-function.name123"
+ route = router.find_route(
+ f"/2025-12-01/functions/{complex_function_name}/durable-executions", "GET"
+ )
+ assert isinstance(route, ListDurableExecutionsByFunctionRoute)
+ assert route.function_name == complex_function_name
+
+ # Test with complex callback ID
+ complex_callback_id = "callback-123_abc-def.456"
+ route = router.find_route(
+ f"/2025-12-01/durable-execution-callbacks/{complex_callback_id}/succeed", "POST"
+ )
+ assert isinstance(route, CallbackSuccessRoute)
+ assert route.callback_id == complex_callback_id
+
+
+def test_router_find_route_order_dependency():
+ """Test that Router.find_route respects route type ordering for disambiguation."""
+ router = Router()
+
+ # These paths could potentially match multiple patterns if ordering is wrong
+ # The more specific patterns should match first
+
+ # Should match GetDurableExecutionRoute, not ListDurableExecutionsRoute
+ route = router.find_route("/2025-12-01/durable-executions/specific-arn", "GET")
+ assert isinstance(route, GetDurableExecutionRoute)
+ assert route.arn == "specific-arn"
+
+ # Should match ListDurableExecutionsRoute
+ route = router.find_route("/2025-12-01/durable-executions", "GET")
+ assert isinstance(route, ListDurableExecutionsRoute)
+
+ # Should match CheckpointDurableExecutionRoute, not GetDurableExecutionRoute
+ route = router.find_route(
+ "/2025-12-01/durable-executions/test-arn/checkpoint", "POST"
+ )
+ assert isinstance(route, CheckpointDurableExecutionRoute)
+ assert route.arn == "test-arn"
+
+
+def test_router_thread_safety():
+ """Test that Router instances are thread-safe for concurrent access."""
+
+ router = Router()
+ results = []
+ errors = []
+
+ def worker(worker_id: int):
+ try:
+ for i in range(10):
+ # Test different route types to ensure no interference
+ route = router.find_route(
+ f"/2025-12-01/durable-executions/arn-{worker_id}-{i}", "GET"
+ )
+ assert isinstance(route, GetDurableExecutionRoute)
+ assert route.arn == f"arn-{worker_id}-{i}"
+
+ route = router.find_route("/health", "GET")
+ assert isinstance(route, HealthRoute)
+
+ time.sleep(0.001) # Small delay to increase chance of race conditions
+
+ results.append(f"Worker {worker_id} completed successfully")
+ except (UnknownRouteError, AssertionError) as e:
+ errors.append(f"Worker {worker_id} failed: {e}")
+
+ # Create multiple threads
+ threads = []
+ for i in range(5):
+ thread = threading.Thread(target=worker, args=(i,))
+ threads.append(thread)
+ thread.start()
+
+ # Wait for all threads to complete
+ for thread in threads:
+ thread.join()
+
+ # Check results
+ assert len(errors) == 0, f"Thread safety test failed with errors: {errors}"
+ assert len(results) == 5, f"Expected 5 successful workers, got {len(results)}"
+
+
+def test_callback_routes_url_decoding():
+ """Test that callback routes properly URL-decode callback IDs."""
+ # Test callback ID with special characters that need URL encoding
+ callback_id = "eyJhcm4iOiJhcm4iLCJvcCI6ImVhNjZjMDZjMWUxYzA1ZmEifQ=="
+ encoded_callback_id = quote(callback_id, safe="")
+
+ # Test CallbackSuccessRoute
+ base_route = Route.from_string(
+ f"/2025-12-01/durable-execution-callbacks/{encoded_callback_id}/succeed"
+ )
+ success_route = CallbackSuccessRoute.from_route(base_route)
+ assert success_route.callback_id == callback_id # Should be decoded
+
+ # Test CallbackFailureRoute
+ base_route = Route.from_string(
+ f"/2025-12-01/durable-execution-callbacks/{encoded_callback_id}/fail"
+ )
+ failure_route = CallbackFailureRoute.from_route(base_route)
+ assert failure_route.callback_id == callback_id # Should be decoded
+
+ # Test CallbackHeartbeatRoute
+ base_route = Route.from_string(
+ f"/2025-12-01/durable-execution-callbacks/{encoded_callback_id}/heartbeat"
+ )
+ heartbeat_route = CallbackHeartbeatRoute.from_route(base_route)
+ assert heartbeat_route.callback_id == callback_id # Should be decoded
+
+
+def test_router_callback_routes_url_decoding():
+ """Test Router properly handles URL-encoded callback IDs."""
+ router = Router()
+ callback_id = "eyJhcm4iOiJhcm4iLCJvcCI6ImVhNjZjMDZjMWUxYzA1ZmEifQ=="
+ encoded_callback_id = quote(callback_id, safe="")
+
+ # Test success route
+ route = router.find_route(
+ f"/2025-12-01/durable-execution-callbacks/{encoded_callback_id}/succeed", "POST"
+ )
+ assert isinstance(route, CallbackSuccessRoute)
+ assert route.callback_id == callback_id # Should be decoded
+
+ # Test failure route
+ route = router.find_route(
+ f"/2025-12-01/durable-execution-callbacks/{encoded_callback_id}/fail", "POST"
+ )
+ assert isinstance(route, CallbackFailureRoute)
+ assert route.callback_id == callback_id # Should be decoded
+
+ # Test heartbeat route
+ route = router.find_route(
+ f"/2025-12-01/durable-execution-callbacks/{encoded_callback_id}/heartbeat",
+ "POST",
+ )
+ assert isinstance(route, CallbackHeartbeatRoute)
+ assert route.callback_id == callback_id # Should be decoded
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/web/serialization_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/web/serialization_test.py
new file mode 100644
index 0000000..7c981f3
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/web/serialization_test.py
@@ -0,0 +1,576 @@
+"""Tests for serialization interfaces and AWS boto integration."""
+
+from __future__ import annotations
+
+from unittest.mock import Mock, patch
+
+import pytest
+import json
+from datetime import datetime, timezone
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ InvalidParameterValueException,
+)
+from aws_durable_execution_sdk_python_testing.web.serialization import (
+ JSONSerializer,
+ AwsRestJsonDeserializer,
+ AwsRestJsonSerializer,
+)
+
+
+def test_aws_rest_json_serializer_should_initialize_and_serialize_data():
+ """Test that serializer initializes and can serialize data through public API."""
+ # Arrange
+ operation_name = "StartDurableExecution"
+ mock_serializer = Mock()
+ mock_operation_model = Mock()
+ mock_serializer.serialize_to_request.return_value = {"body": '{"test": "data"}'}
+
+ # Act
+ serializer = AwsRestJsonSerializer(
+ operation_name, mock_serializer, mock_operation_model
+ )
+ result = serializer.to_bytes({"test": "data"})
+
+ # Assert - Test public behavior only
+ assert isinstance(result, bytes)
+ assert result == b'{"test": "data"}'
+ mock_serializer.serialize_to_request.assert_called_once_with(
+ {"test": "data"}, mock_operation_model
+ )
+
+
+@patch("aws_durable_execution_sdk_python_testing.web.serialization.create_serializer")
+@patch("aws_durable_execution_sdk_python_testing.web.serialization.ServiceModel")
+@patch(
+ "aws_durable_execution_sdk_python_testing.web.serialization.botocore.loaders.Loader"
+)
+@patch("aws_durable_execution_sdk_python_testing.web.serialization.os.path.dirname")
+def test_aws_rest_json_serializer_should_create_serializer_with_boto_components(
+ mock_dirname,
+ mock_loader_class,
+ mock_service_model_class,
+ mock_create_serializer,
+):
+ """Test that create method sets up boto components correctly."""
+ # Arrange
+ operation_name = "StartDurableExecution"
+ mock_package_path = "/path/to/package"
+ mock_dirname.return_value = mock_package_path
+
+ mock_loader = Mock()
+ mock_loader_class.return_value = mock_loader
+ mock_raw_model = {"operations": {}}
+ mock_loader.load_service_model.return_value = mock_raw_model
+
+ mock_service_model = Mock()
+ mock_service_model_class.return_value = mock_service_model
+ mock_operation_model = Mock()
+ mock_service_model.operation_model.return_value = mock_operation_model
+
+ mock_serializer = Mock()
+ mock_create_serializer.return_value = mock_serializer
+
+ # Act
+ result = AwsRestJsonSerializer.create(operation_name)
+
+ # Assert - Test public behavior only
+ assert isinstance(result, AwsRestJsonSerializer)
+
+ # Test that the created serializer can actually serialize data
+ mock_serializer.serialize_to_request.return_value = {"body": '{"test": "value"}'}
+ serialized_data = result.to_bytes({"test": "value"})
+ assert isinstance(serialized_data, bytes)
+ assert serialized_data == b'{"test": "value"}'
+
+ # Verify boto setup calls
+ mock_loader.load_service_model.assert_called_once_with("lambda", "service-2")
+ mock_service_model_class.assert_called_once_with(mock_raw_model)
+ mock_create_serializer.assert_called_once_with("rest-json", include_validation=True)
+ mock_service_model.operation_model.assert_called_once_with(operation_name)
+
+
+@patch("aws_durable_execution_sdk_python_testing.web.serialization.create_serializer")
+def test_aws_rest_json_serializer_should_raise_serialization_error_when_create_fails(
+ mock_create_serializer,
+):
+ """Test that create method raises InvalidParameterValueException when boto setup fails."""
+ # Arrange
+ operation_name = "StartDurableExecution"
+ mock_create_serializer.side_effect = Exception("Boto error")
+
+ # Act & Assert
+ with pytest.raises(InvalidParameterValueException) as exc_info:
+ AwsRestJsonSerializer.create(operation_name)
+
+ assert "Failed to create serializer for StartDurableExecution" in str(
+ exc_info.value
+ )
+
+
+def test_aws_rest_json_serializer_should_serialize_data_to_bytes():
+ """Test that to_bytes method serializes data using boto serializer."""
+ # Arrange
+ mock_serializer = Mock()
+ mock_operation_model = Mock()
+ serializer = AwsRestJsonSerializer("test", mock_serializer, mock_operation_model)
+
+ test_data = {"key": "value"}
+ serialized_response = {"body": '{"key": "value"}'}
+ mock_serializer.serialize_to_request.return_value = serialized_response
+
+ # Act
+ result = serializer.to_bytes(test_data)
+
+ # Assert
+ assert result == b'{"key": "value"}'
+ mock_serializer.serialize_to_request.assert_called_once_with(
+ test_data, mock_operation_model
+ )
+
+
+def test_aws_rest_json_serializer_should_handle_bytes_body_in_serialization():
+ """Test that to_bytes method handles bytes body from boto serializer."""
+ # Arrange
+ mock_serializer = Mock()
+ mock_operation_model = Mock()
+ serializer = AwsRestJsonSerializer("test", mock_serializer, mock_operation_model)
+
+ test_data = {"key": "value"}
+ serialized_response = {"body": b'{"key": "value"}'}
+ mock_serializer.serialize_to_request.return_value = serialized_response
+
+ # Act
+ result = serializer.to_bytes(test_data)
+
+ # Assert
+ assert result == b'{"key": "value"}'
+
+
+def test_aws_rest_json_serializer_should_handle_empty_body_in_serialization():
+ """Test that to_bytes method handles empty body from boto serializer."""
+ # Arrange
+ mock_serializer = Mock()
+ mock_operation_model = Mock()
+ serializer = AwsRestJsonSerializer("test", mock_serializer, mock_operation_model)
+
+ test_data = {"key": "value"}
+ serialized_response = {}
+ mock_serializer.serialize_to_request.return_value = serialized_response
+
+ # Act
+ result = serializer.to_bytes(test_data)
+
+ # Assert
+ assert result == b""
+
+
+def test_aws_rest_json_serializer_should_raise_error_when_serializer_not_initialized():
+ """Test that to_bytes raises error when serializer is not initialized."""
+ # Arrange
+ serializer = AwsRestJsonSerializer("test", None, None)
+ test_data = {"key": "value"}
+
+ # Act & Assert
+ with pytest.raises(InvalidParameterValueException) as exc_info:
+ serializer.to_bytes(test_data)
+
+ assert "Serializer not initialized for test" in str(exc_info.value)
+
+
+def test_aws_rest_json_serializer_should_raise_error_when_serialization_fails():
+ """Test that to_bytes raises InvalidParameterValueException when boto serialization fails."""
+ # Arrange
+ mock_serializer = Mock()
+ mock_operation_model = Mock()
+ serializer = AwsRestJsonSerializer("test", mock_serializer, mock_operation_model)
+
+ test_data = {"key": "value"}
+ mock_serializer.serialize_to_request.side_effect = Exception("Serialization failed")
+
+ # Act & Assert
+ with pytest.raises(InvalidParameterValueException) as exc_info:
+ serializer.to_bytes(test_data)
+
+ assert "Failed to serialize data for test" in str(exc_info.value)
+
+
+def test_aws_rest_json_deserializer_should_initialize_and_deserialize_data():
+ """Test that deserializer initializes and can deserialize data through public API."""
+ # Arrange
+ operation_name = "StartDurableExecution"
+ mock_parser = Mock()
+ mock_operation_model = Mock()
+ mock_output_shape = Mock()
+ mock_operation_model.output_shape = mock_output_shape
+ mock_parser.parse.return_value = {"test": "data"}
+
+ # Act
+ deserializer = AwsRestJsonDeserializer(
+ operation_name, mock_parser, mock_operation_model
+ )
+ result = deserializer.from_bytes(b'{"test": "data"}')
+
+ # Assert - Test public behavior only
+ assert isinstance(result, dict)
+ assert result == {"test": "data"}
+ expected_response_dict = {
+ "body": b'{"test": "data"}',
+ "headers": {"content-type": "application/json"},
+ "status_code": 200,
+ }
+ mock_parser.parse.assert_called_once_with(expected_response_dict, mock_output_shape)
+
+
+@patch("aws_durable_execution_sdk_python_testing.web.serialization.create_parser")
+@patch("aws_durable_execution_sdk_python_testing.web.serialization.ServiceModel")
+@patch(
+ "aws_durable_execution_sdk_python_testing.web.serialization.botocore.loaders.Loader"
+)
+@patch("aws_durable_execution_sdk_python_testing.web.serialization.os.path.dirname")
+def test_aws_rest_json_deserializer_should_create_deserializer_with_boto_components(
+ mock_dirname,
+ mock_loader_class,
+ mock_service_model_class,
+ mock_create_parser,
+):
+ """Test that create method sets up boto components correctly."""
+ # Arrange
+ operation_name = "StartDurableExecution"
+ mock_package_path = "/path/to/package"
+ mock_dirname.return_value = mock_package_path
+
+ mock_loader = Mock()
+ mock_loader_class.return_value = mock_loader
+ mock_raw_model = {"operations": {}}
+ mock_loader.load_service_model.return_value = mock_raw_model
+
+ mock_service_model = Mock()
+ mock_service_model_class.return_value = mock_service_model
+ mock_operation_model = Mock()
+ mock_service_model.operation_model.return_value = mock_operation_model
+
+ mock_parser = Mock()
+ mock_create_parser.return_value = mock_parser
+
+ # Act
+ result = AwsRestJsonDeserializer.create(operation_name)
+
+ # Assert - Test public behavior only
+ assert isinstance(result, AwsRestJsonDeserializer)
+
+ # Test that the created deserializer can actually deserialize data
+ mock_output_shape = Mock()
+ mock_operation_model.output_shape = mock_output_shape
+ mock_parser.parse.return_value = {"test": "value"}
+ deserialized_data = result.from_bytes(b'{"test": "value"}')
+ assert isinstance(deserialized_data, dict)
+ assert deserialized_data == {"test": "value"}
+
+ # Verify boto setup calls
+ mock_loader.load_service_model.assert_called_once_with("lambda", "service-2")
+ mock_service_model_class.assert_called_once_with(mock_raw_model)
+ mock_create_parser.assert_called_once_with("rest-json")
+ mock_service_model.operation_model.assert_called_once_with(operation_name)
+
+
+@patch("aws_durable_execution_sdk_python_testing.web.serialization.create_parser")
+def test_aws_rest_json_deserializer_should_raise_serialization_error_when_create_fails(
+ mock_create_parser,
+):
+ """Test that create method raises InvalidParameterValueException when boto setup fails."""
+ # Arrange
+ operation_name = "StartDurableExecution"
+ mock_create_parser.side_effect = Exception("Boto error")
+
+ # Act & Assert
+ with pytest.raises(InvalidParameterValueException) as exc_info:
+ AwsRestJsonDeserializer.create(operation_name)
+
+ assert "Failed to create deserializer for StartDurableExecution" in str(
+ exc_info.value
+ )
+
+
+def test_aws_rest_json_deserializer_should_deserialize_bytes_with_output_shape():
+ """Test that from_bytes method deserializes data using boto parser with output shape."""
+ # Arrange
+ mock_parser = Mock()
+ mock_operation_model = Mock()
+ mock_output_shape = Mock()
+ mock_operation_model.output_shape = mock_output_shape
+ deserializer = AwsRestJsonDeserializer("test", mock_parser, mock_operation_model)
+
+ test_bytes = b'{"key": "value"}'
+ parsed_data = {"key": "value"}
+ mock_parser.parse.return_value = parsed_data
+
+ # Act
+ result = deserializer.from_bytes(test_bytes)
+
+ # Assert
+ assert result == parsed_data
+ expected_response_dict = {
+ "body": test_bytes,
+ "headers": {"content-type": "application/json"},
+ "status_code": 200,
+ }
+ mock_parser.parse.assert_called_once_with(expected_response_dict, mock_output_shape)
+
+
+def test_aws_rest_json_deserializer_should_deserialize_bytes_without_output_shape():
+ """Test that from_bytes method falls back to JSON parsing when no output shape."""
+ # Arrange
+ mock_parser = Mock()
+ mock_operation_model = Mock()
+ mock_operation_model.output_shape = None
+ deserializer = AwsRestJsonDeserializer("test", mock_parser, mock_operation_model)
+
+ test_bytes = b'{"key": "value"}'
+
+ # Act
+ result = deserializer.from_bytes(test_bytes)
+
+ # Assert
+ assert result == {"key": "value"}
+ mock_parser.parse.assert_not_called()
+
+
+def test_aws_rest_json_deserializer_should_raise_error_when_parser_not_initialized():
+ """Test that from_bytes raises error when parser is not initialized."""
+ # Arrange
+ deserializer = AwsRestJsonDeserializer("test", None, None)
+ test_bytes = b'{"key": "value"}'
+
+ # Act & Assert
+ with pytest.raises(InvalidParameterValueException) as exc_info:
+ deserializer.from_bytes(test_bytes)
+
+ assert "Parser not initialized for test" in str(exc_info.value)
+
+
+def test_aws_rest_json_deserializer_should_raise_error_when_deserialization_fails():
+ """Test that from_bytes raises InvalidParameterValueException when boto parsing fails."""
+ # Arrange
+ mock_parser = Mock()
+ mock_operation_model = Mock()
+ mock_output_shape = Mock()
+ mock_operation_model.output_shape = mock_output_shape
+ deserializer = AwsRestJsonDeserializer("test", mock_parser, mock_operation_model)
+
+ test_bytes = b'{"key": "value"}'
+ mock_parser.parse.side_effect = Exception("Parsing failed")
+
+ # Act & Assert
+ with pytest.raises(InvalidParameterValueException) as exc_info:
+ deserializer.from_bytes(test_bytes)
+
+ assert "Failed to deserialize data for test" in str(exc_info.value)
+
+
+def test_aws_rest_json_deserializer_should_raise_error_when_json_parsing_fails():
+ """Test that from_bytes raises InvalidParameterValueException when JSON parsing fails."""
+ # Arrange
+ mock_parser = Mock()
+ mock_operation_model = Mock()
+ mock_operation_model.output_shape = None
+ deserializer = AwsRestJsonDeserializer("test", mock_parser, mock_operation_model)
+
+ test_bytes = b"invalid json"
+
+ # Act & Assert
+ with pytest.raises(InvalidParameterValueException) as exc_info:
+ deserializer.from_bytes(test_bytes)
+
+ assert "Failed to deserialize data for test" in str(exc_info.value)
+
+
+def test_serialize_simple_dict():
+ """Test serialization of simple dictionary."""
+ serializer = JSONSerializer()
+ data = {"key": "value", "number": 42}
+ result = serializer.to_bytes(data)
+
+ expected = b'{"key":"value","number":42}'
+ assert result == expected
+ assert isinstance(result, bytes)
+ assert json.loads(result.decode("utf-8")) == data
+
+
+def test_serialize_datetime():
+ """Test serialization of datetime objects."""
+ serializer = JSONSerializer()
+ now = datetime(2025, 11, 5, 16, 30, 9, 895000, tzinfo=timezone.utc)
+ data = {"timestamp": now}
+
+ result = serializer.to_bytes(data)
+ expected = b'{"timestamp":1762360209.895}'
+
+ assert result == expected
+ assert isinstance(result, bytes)
+
+ deserialized = json.loads(result.decode("utf-8"))
+ assert deserialized["timestamp"] == now.timestamp()
+
+
+def test_serialize_nested_datetime():
+ """Test serialization of nested structures with datetime."""
+ serializer = JSONSerializer()
+ now = datetime(2025, 11, 5, 16, 30, 9, tzinfo=timezone.utc)
+ data = {
+ "event": "user_login",
+ "timestamp": now,
+ "metadata": {"created_at": now, "updated_at": now},
+ }
+
+ result = serializer.to_bytes(data)
+ expected = (
+ b'{"event":"user_login",'
+ b'"timestamp":1762360209.0,'
+ b'"metadata":{"created_at":1762360209.0,'
+ b'"updated_at":1762360209.0}}'
+ )
+
+ assert result == expected
+
+ deserialized = json.loads(result.decode("utf-8"))
+ assert deserialized["timestamp"] == now.timestamp()
+ assert deserialized["metadata"]["created_at"] == now.timestamp()
+
+
+def test_serialize_list_with_datetime():
+ """Test serialization of list containing datetime."""
+ serializer = JSONSerializer()
+ now = datetime(2025, 11, 5, 16, 30, 9, tzinfo=timezone.utc)
+ data = {
+ "events": [{"time": now, "action": "login"}, {"time": now, "action": "logout"}]
+ }
+
+ result = serializer.to_bytes(data)
+ expected = (
+ b'{"events":['
+ b'{"time":1762360209.0,"action":"login"},'
+ b'{"time":1762360209.0,"action":"logout"}'
+ b"]}"
+ )
+
+ assert result == expected
+
+ deserialized = json.loads(result.decode("utf-8"))
+ assert deserialized["events"][0]["time"] == now.timestamp()
+ assert deserialized["events"][1]["time"] == now.timestamp()
+
+
+def test_serialize_mixed_types():
+ """Test serialization of mixed data types."""
+ serializer = JSONSerializer()
+ now = datetime(2025, 11, 5, 16, 30, 9, tzinfo=timezone.utc)
+ data = {
+ "string": "test",
+ "number": 42,
+ "float": 3.14,
+ "boolean": True,
+ "null": None,
+ "list": [1, 2, 3],
+ "datetime": now,
+ }
+
+ result = serializer.to_bytes(data)
+ expected = (
+ b'{"string":"test",'
+ b'"number":42,'
+ b'"float":3.14,'
+ b'"boolean":true,'
+ b'"null":null,'
+ b'"list":[1,2,3],'
+ b'"datetime":1762360209.0}'
+ )
+
+ assert result == expected
+
+ deserialized = json.loads(result.decode("utf-8"))
+ assert deserialized["string"] == "test"
+ assert deserialized["number"] == 42
+ assert deserialized["float"] == 3.14
+ assert deserialized["boolean"] is True
+ assert deserialized["null"] is None
+ assert deserialized["list"] == [1, 2, 3]
+ assert deserialized["datetime"] == now.timestamp()
+
+
+def test_serialize_returns_bytes():
+ """Test that serialization returns bytes."""
+ serializer = JSONSerializer()
+ data = {"test": "value"}
+ result = serializer.to_bytes(data)
+ expected = b'{"test":"value"}'
+
+ assert result == expected
+ assert isinstance(result, bytes)
+
+
+def test_serialize_non_serializable_object_raises_exception():
+ """Test that non-serializable objects raise InvalidParameterValueException."""
+ serializer = JSONSerializer()
+
+ class CustomObject:
+ pass
+
+ data = {"custom": CustomObject()}
+
+ with pytest.raises(InvalidParameterValueException) as exc_info:
+ serializer.to_bytes(data)
+
+ assert (
+ "Failed to serialize data to JSON: Object of type CustomObject is not JSON serializable"
+ in str(exc_info.value)
+ )
+
+
+def test_serialize_circular_reference_raises_exception():
+ """Test that circular references raise InvalidParameterValueException."""
+ serializer = JSONSerializer()
+ data = {"key": "value"}
+ data["self"] = data # Create circular reference
+
+ with pytest.raises(InvalidParameterValueException) as exc_info:
+ serializer.to_bytes(data)
+
+ assert "Failed to serialize data to JSON" in str(exc_info.value)
+
+
+def test_serialize_datetime_with_microseconds():
+ """Test serialization of datetime with microseconds."""
+ serializer = JSONSerializer()
+ now = datetime(2025, 11, 5, 16, 30, 9, 123456, tzinfo=timezone.utc)
+ data = {"timestamp": now}
+
+ result = serializer.to_bytes(data)
+ expected = b'{"timestamp":1762360209.123456}'
+
+ assert result == expected
+
+
+def test_serialize_datetime_without_microseconds():
+ """Test serialization of datetime without microseconds."""
+ serializer = JSONSerializer()
+ now = datetime(2025, 11, 5, 16, 30, 9, tzinfo=timezone.utc)
+ data = {"timestamp": now}
+
+ result = serializer.to_bytes(data)
+ expected = b'{"timestamp":1762360209.0}'
+
+ assert result == expected
+
+
+def test_serialize_multiple_datetimes():
+ """Test multiple datetime objects."""
+ serializer = JSONSerializer()
+ dt1 = datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
+ dt2 = datetime(2025, 12, 31, 23, 59, 59, tzinfo=timezone.utc)
+
+ data = {"start": dt1, "end": dt2}
+ result = serializer.to_bytes(data)
+ expected = b'{"start":1735689600.0,"end":1767225599.0}'
+
+ assert result == expected
diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/web/server_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/web/server_test.py
new file mode 100644
index 0000000..8a15990
--- /dev/null
+++ b/packages/aws-durable-execution-sdk-python-testing/tests/web/server_test.py
@@ -0,0 +1,252 @@
+"""Tests for web server implementation."""
+
+from __future__ import annotations
+
+import logging
+import threading
+import time
+from unittest.mock import Mock, patch
+
+import pytest
+
+from aws_durable_execution_sdk_python_testing.exceptions import (
+ IllegalStateException,
+ InvalidParameterValueException,
+ ResourceNotFoundException,
+ SerializationError,
+ UnknownRouteError,
+)
+from aws_durable_execution_sdk_python_testing.web.models import HTTPResponse
+from aws_durable_execution_sdk_python_testing.web.routes import (
+ GetDurableExecutionRoute,
+ HealthRoute,
+ Router,
+ StartExecutionRoute,
+)
+from aws_durable_execution_sdk_python_testing.web.server import (
+ RequestHandler,
+ WebServer,
+ WebServiceConfig,
+)
+
+
+def test_web_service_config_default_values():
+ """Test that default configuration values are correct."""
+
+ config = WebServiceConfig()
+
+ assert config.host == "localhost"
+ assert config.port == 5000
+ assert config.log_level == logging.INFO
+ assert config.max_request_size == 10 * 1024 * 1024
+
+
+def test_web_service_config_custom_values():
+ """Test that custom configuration values are set correctly."""
+
+ config = WebServiceConfig(
+ host="127.0.0.1",
+ port=9000,
+ log_level=logging.DEBUG,
+ max_request_size=5 * 1024 * 1024,
+ )
+
+ assert config.host == "127.0.0.1"
+ assert config.port == 9000
+ assert config.log_level == logging.DEBUG
+ assert config.max_request_size == 5 * 1024 * 1024
+
+
+def test_web_service_config_frozen_dataclass():
+ """Test that WebServiceConfig is immutable."""
+ config = WebServiceConfig()
+
+ with pytest.raises(AttributeError):
+ config.port = 9000
+
+
+def test_web_server_initialization():
+ """Test that WebServer initializes correctly."""
+ config = WebServiceConfig(port=0) # Use port 0 for testing
+ executor = Mock()
+
+ with WebServer(config, executor) as server:
+ assert server.config == config
+ assert server.executor == executor
+
+
+def test_web_server_context_manager():
+ """Test that WebServer works as a context manager."""
+ config = WebServiceConfig(port=0)
+ executor = Mock()
+
+ # Test context manager entry and exit
+ with WebServer(config, executor) as server:
+ assert isinstance(server, WebServer)
+ assert server.config == config
+ assert server.executor == executor
+
+ # Server should be cleaned up after context exit
+
+
+def test_web_server_background_usage():
+ """Test that server can be used in background thread for testing."""
+ config = WebServiceConfig(port=0)
+ executor = Mock()
+
+ with WebServer(config, executor) as server:
+ # Start server in background thread
+ server_thread = threading.Thread(target=server.serve_forever, daemon=True)
+ server_thread.start()
+
+ # Give it a moment to start
+ time.sleep(0.1)
+ assert server_thread.is_alive()
+
+ # Stop the server
+ server.shutdown()
+
+ # Give it a moment to shutdown
+ time.sleep(0.1)
+ server_thread.join(timeout=1)
+ assert not server_thread.is_alive()
+
+
+def test_web_server_has_executor_reference():
+ """Test that WebServer stores executor reference correctly."""
+ config = WebServiceConfig(port=0)
+ executor = Mock()
+
+ with WebServer(config, executor) as server:
+ # Verify server has executor reference
+ assert server.executor == executor
+
+ # Verify RequestHandler class is set correctly
+
+ assert server.RequestHandlerClass == RequestHandler
+
+
+def test_web_server_has_router_and_handlers():
+ """Test that WebServer creates router and handlers correctly."""
+
+ executor = Mock()
+ config = WebServiceConfig(port=0) # Use port 0 to get any available port
+
+ with WebServer(config, executor) as server:
+ # Verify router is created
+ assert server.router is not None
+ assert isinstance(server.router, Router)
+
+ # Verify handlers are created
+ assert server.endpoint_handlers is not None
+ assert len(server.endpoint_handlers) > 0
+
+ # Verify specific handlers exist
+ assert StartExecutionRoute in server.endpoint_handlers
+ assert HealthRoute in server.endpoint_handlers
+
+ # Verify handlers have executor reference
+ start_handler = server.endpoint_handlers[StartExecutionRoute]
+ assert start_handler.executor is executor
+
+
+def test_web_server_all_routes_have_handlers():
+ """Test that all routes in the router have corresponding handlers."""
+ executor = Mock()
+ config = WebServiceConfig(port=0) # Use port 0 to get any available port
+
+ with WebServer(config, executor) as server:
+ # Test that router can find routes for all handler types
+ handler_route_types = set(server.endpoint_handlers.keys())
+
+ # Test a sample of routes to verify router functionality
+ test_routes = [
+ ("/start-durable-execution", "POST", StartExecutionRoute),
+ ("/health", "GET", HealthRoute),
+ (
+ "/2025-12-01/durable-executions/test-arn",
+ "GET",
+ GetDurableExecutionRoute,
+ ),
+ ]
+
+ for path, method, expected_route_type in test_routes:
+ # Verify router can find the route (tests public API)
+ found_route = server.router.find_route(path, method)
+ assert isinstance(found_route, expected_route_type)
+
+ # Verify handler exists for this route type
+ assert expected_route_type in handler_route_types
+
+
+def test_request_handler_exception_mapping():
+ """Test that RequestHandler has proper exception handling capabilities."""
+
+ # Verify that all the required exception types are available for import
+ assert SerializationError is not None
+ assert InvalidParameterValueException is not None
+ assert ResourceNotFoundException is not None
+ assert IllegalStateException is not None
+ assert UnknownRouteError is not None
+
+
+def test_http_response_create_error_from_exception():
+ """Test HTTPResponse.create_error_from_exception method directly."""
+ test_exception = InvalidParameterValueException("Test error message")
+ response = HTTPResponse.create_error_from_exception(test_exception)
+
+ assert response.status_code == 400
+ assert response.headers["Content-Type"] == "application/json"
+
+ # AWS-compliant format without wrapper
+ expected_body = {
+ "Type": "InvalidParameterValueException",
+ "message": "Test error message",
+ }
+ assert response.body == expected_body
+
+
+def test_request_handler_error_response_through_public_api():
+ """Test error response handling through public do_POST method."""
+ import io
+ from unittest.mock import MagicMock
+
+ # Create a mock request handler with minimal setup
+ mock_server = MagicMock()
+ mock_server.executor = Mock()
+ mock_server.router = Mock()
+ mock_server.endpoint_handlers = {}
+
+ # Mock the router to raise an exception
+ mock_server.router.find_route.side_effect = InvalidParameterValueException(
+ "Test error message"
+ )
+
+ # Create handler instance
+ with patch.object(RequestHandler, "__init__", return_value=None):
+ handler = RequestHandler.__new__(RequestHandler)
+ handler.executor = mock_server.executor
+ handler.router = mock_server.router
+ handler.endpoint_handlers = mock_server.endpoint_handlers
+ handler.path = "/test-path"
+ handler.headers = {"Content-Length": "0"}
+ handler.rfile = io.BytesIO(b"")
+
+ # Mock the response sending
+ with patch.object(handler, "_send_response") as mock_send_response:
+ # Call the public method that should trigger error handling
+ handler.do_POST()
+
+ # Verify _send_response was called with correct error response
+ mock_send_response.assert_called_once()
+ response = mock_send_response.call_args[0][0]
+
+ assert response.status_code == 400
+ assert response.headers["Content-Type"] == "application/json"
+
+ # AWS-compliant format without wrapper
+ expected_body = {
+ "Type": "InvalidParameterValueException",
+ "message": "Test error message",
+ }
+ assert response.body == expected_body
diff --git a/pyproject.toml b/pyproject.toml
index b10df01..c744ed5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -13,7 +13,7 @@ dependencies = [
"pytest",
"pytest-cov",
"opentelemetry-sdk>=1.20.0",
- "aws_durable_execution_sdk_python_testing"
+ "aws-durable-execution-sdk-python-testing",
]
[tool.hatch.envs.test.scripts]
@@ -27,6 +27,7 @@ addopts = "-v --strict-markers --import-mode=importlib"
testpaths = [
"packages/aws-durable-execution-sdk-python/tests",
"packages/aws-durable-execution-sdk-python-otel/tests",
+ "packages/aws-durable-execution-sdk-python-testing/tests",
"packages/aws-durable-execution-sdk-python-examples/test",
]
markers = [
@@ -77,11 +78,29 @@ test = "pytest packages/aws-durable-execution-sdk-python-otel/tests {args}"
cov = "pytest --cov-report=term-missing --cov-config=packages/aws-durable-execution-sdk-python-otel/pyproject.toml --cov=packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel packages/aws-durable-execution-sdk-python-otel/tests {args}"
typecheck = "mypy --install-types --non-interactive packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel packages/aws-durable-execution-sdk-python-otel/tests"
+[tool.hatch.envs.dev-testing]
+workspace.members = [
+ "packages/aws-durable-execution-sdk-python",
+ "packages/aws-durable-execution-sdk-python-testing",
+]
+dependencies = [
+ "pytest",
+ "pytest-cov",
+ "coverage[toml]",
+ "mypy>=1.0.0",
+]
+
+[tool.hatch.envs.dev-testing.scripts]
+test = "pytest packages/aws-durable-execution-sdk-python-testing/tests {args}"
+cov = "pytest --cov-report=term-missing --cov-config=packages/aws-durable-execution-sdk-python-testing/pyproject.toml --cov=packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing --cov-fail-under=95 packages/aws-durable-execution-sdk-python-testing/tests {args}"
+typecheck = "mypy --install-types --non-interactive packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing packages/aws-durable-execution-sdk-python-testing/tests"
+
[tool.hatch.envs.dev-examples]
workspace.members = [
"packages/aws-durable-execution-sdk-python",
"packages/aws-durable-execution-sdk-python-examples",
"packages/aws-durable-execution-sdk-python-otel",
+ "packages/aws-durable-execution-sdk-python-testing",
]
dependencies = [
"pytest",
@@ -149,7 +168,11 @@ preview = true
select = ["TID252"] # Enforce absolute imports (ban relative imports)
[tool.ruff.lint.isort]
-known-first-party = ["aws_durable_execution_sdk_python", "aws_durable_execution_sdk_python_otel"]
+known-first-party = [
+ "aws_durable_execution_sdk_python",
+ "aws_durable_execution_sdk_python_otel",
+ "aws_durable_execution_sdk_python_testing",
+]
force-single-line = false
lines-after-imports = 2