diff --git a/.github/workflows/deploy-examples.yml b/.github/workflows/deploy-examples.yml index c90de6a..fc19378 100644 --- a/.github/workflows/deploy-examples.yml +++ b/.github/workflows/deploy-examples.yml @@ -59,7 +59,7 @@ jobs: run: pip install hatch - name: Build examples run: | - hatch run -- examples:pip install -e packages/aws-durable-execution-sdk-python packages/aws-durable-execution-sdk-python-otel + hatch run -- examples:pip install -e packages/aws-durable-execution-sdk-python packages/aws-durable-execution-sdk-python-otel packages/aws-durable-execution-sdk-python-testing hatch run examples:build - name: Deploy Lambda function - ${{ matrix.example.name }} diff --git a/.github/workflows/ecr-release.yml b/.github/workflows/ecr-release.yml new file mode 100644 index 0000000..55ffa02 --- /dev/null +++ b/.github/workflows/ecr-release.yml @@ -0,0 +1,130 @@ +name: Upload Testing SDK Emulator Image + +on: + release: + types: [published] + +permissions: + contents: read + id-token: write + +env: + package_path: packages/aws-durable-execution-sdk-python-testing + aws_region: us-east-1 + ecr_repository_name: durable-functions/aws-durable-execution-emulator + +jobs: + build-and-upload-image-to-ecr: + runs-on: ubuntu-latest + outputs: + full_image_arm64: ${{ steps.build-publish.outputs.full_image_arm64 }} + full_image_x86_64: ${{ steps.build-publish.outputs.full_image_x86_64 }} + ecr_registry_repository: ${{ steps.build-publish.outputs.ecr_registry_repository }} + version: ${{ steps.version.outputs.VERSION }} + strategy: + matrix: + include: + - arch: x86_64 + platform: linux/amd64 + - arch: arm64 + platform: linux/arm64 + + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + ref: ${{ github.event.release.tag_name }} + + - name: Set up Python + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + with: + python-version: "3.13" + + - name: Install Hatch + run: python -m pip install --upgrade hatch==1.16.5 + + - name: Set up QEMU for multi-platform builds + if: matrix.arch == 'arm64' + uses: docker/setup-qemu-action@v3 + with: + platforms: arm64 + + - name: Build distribution + working-directory: ${{ env.package_path }} + run: hatch build + + - name: Get version from __about__.py + id: version + run: | + VERSION=$(grep "^__version__" "${{ env.package_path }}/src/aws_durable_execution_sdk_python_testing/__about__.py" | cut -d'"' -f2) + echo "VERSION=$VERSION" + echo "VERSION=${VERSION}" >> "$GITHUB_OUTPUT" + + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.ECR_UPLOAD_IAM_ROLE_ARN }} + aws-region: ${{ env.aws_region }} + + - name: Login to Amazon ECR + id: login-ecr-public + uses: aws-actions/amazon-ecr-login@v2 + with: + registry-type: public + + - name: Build, tag, and push image to Amazon ECR + id: build-publish + shell: bash + env: + ECR_REGISTRY: ${{ steps.login-ecr-public.outputs.registry }} + ECR_REPOSITORY: ${{ env.ecr_repository_name }} + PER_ARCH_IMAGE_TAG: v${{ steps.version.outputs.VERSION }}-${{ matrix.arch }} + run: | + docker build --platform "${{ matrix.platform }}" --provenance false "${{ env.package_path }}" -f "${{ env.package_path }}/Dockerfile" -t "$ECR_REGISTRY/$ECR_REPOSITORY:$PER_ARCH_IMAGE_TAG" + docker push "$ECR_REGISTRY/$ECR_REPOSITORY:$PER_ARCH_IMAGE_TAG" + echo "ecr_registry_repository=$ECR_REGISTRY/$ECR_REPOSITORY" >> "$GITHUB_OUTPUT" + echo "full_image_${{ matrix.arch }}=$ECR_REGISTRY/$ECR_REPOSITORY:$PER_ARCH_IMAGE_TAG" >> "$GITHUB_OUTPUT" + + create-ecr-manifest-per-arch: + runs-on: ubuntu-latest + needs: [build-and-upload-image-to-ecr] + steps: + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.ECR_UPLOAD_IAM_ROLE_ARN }} + aws-region: ${{ env.aws_region }} + + - name: Login to Amazon ECR + uses: aws-actions/amazon-ecr-login@v2 + with: + registry-type: public + + - name: Create and push explicit version manifest + run: | + docker manifest create "${{ needs.build-and-upload-image-to-ecr.outputs.ecr_registry_repository }}:v${{ needs.build-and-upload-image-to-ecr.outputs.version }}" \ + "${{ needs.build-and-upload-image-to-ecr.outputs.full_image_x86_64 }}" \ + "${{ needs.build-and-upload-image-to-ecr.outputs.full_image_arm64 }}" + docker manifest annotate "${{ needs.build-and-upload-image-to-ecr.outputs.ecr_registry_repository }}:v${{ needs.build-and-upload-image-to-ecr.outputs.version }}" \ + "${{ needs.build-and-upload-image-to-ecr.outputs.full_image_arm64 }}" \ + --arch arm64 \ + --os linux + docker manifest annotate "${{ needs.build-and-upload-image-to-ecr.outputs.ecr_registry_repository }}:v${{ needs.build-and-upload-image-to-ecr.outputs.version }}" \ + "${{ needs.build-and-upload-image-to-ecr.outputs.full_image_x86_64 }}" \ + --arch amd64 \ + --os linux + docker manifest push "${{ needs.build-and-upload-image-to-ecr.outputs.ecr_registry_repository }}:v${{ needs.build-and-upload-image-to-ecr.outputs.version }}" + + - name: Create and push latest manifest + run: | + docker manifest create "${{ needs.build-and-upload-image-to-ecr.outputs.ecr_registry_repository }}" \ + "${{ needs.build-and-upload-image-to-ecr.outputs.full_image_arm64 }}" \ + "${{ needs.build-and-upload-image-to-ecr.outputs.full_image_x86_64 }}" + docker manifest annotate "${{ needs.build-and-upload-image-to-ecr.outputs.ecr_registry_repository }}" \ + "${{ needs.build-and-upload-image-to-ecr.outputs.full_image_arm64 }}" \ + --arch arm64 \ + --os linux + docker manifest annotate "${{ needs.build-and-upload-image-to-ecr.outputs.ecr_registry_repository }}" \ + "${{ needs.build-and-upload-image-to-ecr.outputs.full_image_x86_64 }}" \ + --arch amd64 \ + --os linux + docker manifest push "${{ needs.build-and-upload-image-to-ecr.outputs.ecr_registry_repository }}" diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 1b3176c..e038913 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -22,12 +22,6 @@ jobs: with: path: language-sdk - - name: Checkout the latest Testing SDK - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - repository: aws/aws-durable-execution-sdk-python-testing - path: language-sdk/packages/testing-sdk - - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 with: @@ -40,7 +34,6 @@ jobs: working-directory: language-sdk run: | echo "Running SDK tests..." - hatch run -- test:pip install -e packages/testing-sdk hatch run types:check hatch run test:cov @@ -61,12 +54,6 @@ jobs: with: path: language-sdk - - name: Checkout the latest Testing SDK - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - repository: aws/aws-durable-execution-sdk-python-testing - path: language-sdk/packages/testing-sdk - - name: Set up Python 3.13 uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 with: @@ -105,7 +92,6 @@ jobs: KMS_KEY_ARN: ${{ secrets.KMS_KEY_ARN }} run: | echo "Building examples..." - hatch run -- examples:pip install -e packages/testing-sdk hatch run examples:build # Get first integration example for testing diff --git a/.github/workflows/pypi-publish.yml b/.github/workflows/pypi-publish.yml index 9e0fa22..42c66f8 100644 --- a/.github/workflows/pypi-publish.yml +++ b/.github/workflows/pypi-publish.yml @@ -26,6 +26,8 @@ jobs: path: packages/aws-durable-execution-sdk-python - name: aws-durable-execution-sdk-python-otel path: packages/aws-durable-execution-sdk-python-otel + - name: aws-durable-execution-sdk-python-testing + path: packages/aws-durable-execution-sdk-python-testing steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 @@ -58,6 +60,7 @@ jobs: package: - name: aws-durable-execution-sdk-python - name: aws-durable-execution-sdk-python-otel + - name: aws-durable-execution-sdk-python-testing permissions: id-token: write diff --git a/packages/aws-durable-execution-sdk-python-examples/cli.py b/packages/aws-durable-execution-sdk-python-examples/cli.py index ff2e839..5a5fa71 100755 --- a/packages/aws-durable-execution-sdk-python-examples/cli.py +++ b/packages/aws-durable-execution-sdk-python-examples/cli.py @@ -46,24 +46,12 @@ def build_examples(): shutil.rmtree(build_dir) build_dir.mkdir() - # Copy testing library from current environment - try: - import aws_durable_execution_sdk_python_testing - - sdk_path = Path(aws_durable_execution_sdk_python_testing.__file__).parent - logger.info("Copying SDK from %s", sdk_path) - shutil.copytree( - sdk_path, build_dir / "aws_durable_execution_sdk_python_testing" - ) - except (ImportError, OSError): - logger.exception("Failed to copy testing library") - return False - # Install local packages so their runtime dependencies are included in # the Lambda deployment package. runtime_packages = [ packages_dir / "aws-durable-execution-sdk-python", packages_dir / "aws-durable-execution-sdk-python-otel", + packages_dir / "aws-durable-execution-sdk-python-testing", ] try: subprocess.run( diff --git a/packages/aws-durable-execution-sdk-python-testing/Dockerfile b/packages/aws-durable-execution-sdk-python-testing/Dockerfile new file mode 100644 index 0000000..30d3a7f --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/Dockerfile @@ -0,0 +1,20 @@ +FROM python:3.13-slim + +# Copy and install the wheel +COPY dist/*.whl /tmp/ +RUN pip install --no-cache-dir /tmp/*.whl && rm -rf /tmp/*.whl + +# AWS credentials (required for boto3) +ENV AWS_ACCESS_KEY_ID=foo \ + AWS_SECRET_ACCESS_KEY=bar \ + AWS_DEFAULT_REGION=us-east-1 + +EXPOSE 9014 + +CMD ["dex-local-runner", "start-server", \ + "--host", "0.0.0.0", \ + "--port", "9014", \ + "--log-level", "DEBUG", \ + "--lambda-endpoint", "http://host.docker.internal:3001", \ + "--store-type", "sqlite", \ + "--store-path", "/tmp/.durable-executions-local/durable-executions.db"] diff --git a/packages/aws-durable-execution-sdk-python-testing/LICENSE b/packages/aws-durable-execution-sdk-python-testing/LICENSE new file mode 100644 index 0000000..67db858 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/LICENSE @@ -0,0 +1,175 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. diff --git a/packages/aws-durable-execution-sdk-python-testing/NOTICE b/packages/aws-durable-execution-sdk-python-testing/NOTICE new file mode 100644 index 0000000..616fc58 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/NOTICE @@ -0,0 +1 @@ +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/packages/aws-durable-execution-sdk-python-testing/README.md b/packages/aws-durable-execution-sdk-python-testing/README.md new file mode 100644 index 0000000..223f07c --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/README.md @@ -0,0 +1,199 @@ +# AWS Durable Execution Testing SDK for Python + +[![PyPI - Version](https://img.shields.io/pypi/v/aws-durable-execution-sdk-python-testing.svg)](https://pypi.org/project/aws-durable-execution-sdk-python-testing) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/aws-durable-execution-sdk-python-testing.svg)](https://pypi.org/project/aws-durable-execution-sdk-python-testing) + + +[![OpenSSF Scorecard](https://api.scorecard.dev/projects/github.com/aws/aws-durable-execution-sdk-python-testing/badge)](https://scorecard.dev/viewer/?uri=github.com/aws/aws-durable-execution-sdk-python-testing) + +----- + +## Table of Contents + +- [Installation](#installation) +- [Quick Start](#quick-start) +- [Architecture](#architecture) +- [Documentation](#documentation) +- [Developer Guide](#developers) +- [License](#license) + +## Installation + +```console +pip install aws-durable-execution-sdk-python-testing +``` + +## Overview + +Use the AWS Durable Execution Testing SDK for Python to test your Python durable functions locally. + +The test framework contains a local runner, so you can run and test your durable function locally +before you deploy it. + +## Quick Start + +### A durable function under test + +```python +from aws_durable_execution_sdk_python.context import ( + DurableContext, + durable_step, + durable_with_child_context, +) +from aws_durable_execution_sdk_python.execution import durable_execution +from aws_durable_execution_sdk_python.config import Duration + + +@durable_step +def one(a: int, b: int) -> str: + return f"{a} {b}" + +@durable_step +def two_1(a: int, b: int) -> str: + return f"{a} {b}" + +@durable_step +def two_2(a: int, b: int) -> str: + return f"{b} {a}" + +@durable_with_child_context +def two(ctx: DurableContext, a: int, b: int) -> str: + two_1_result: str = ctx.step(two_1(a, b)) + two_2_result: str = ctx.step(two_2(a, b)) + return f"{two_1_result} {two_2_result}" + +@durable_step +def three(a: int, b: int) -> str: + return f"{a} {b}" + +@durable_execution +def function_under_test(event: Any, context: DurableContext) -> list[str]: + results: list[str] = [] + + result_one: str = context.step(one(1, 2)) + results.append(result_one) + + context.wait(Duration.from_seconds(1)) + + result_two: str = context.run_in_child_context(two(3, 4)) + results.append(result_two) + + result_three: str = context.step(three(5, 6)) + results.append(result_three) + + return results +``` + +### Your test code + +```python +from aws_durable_execution_sdk_python.execution import InvocationStatus +from aws_durable_execution_sdk_python_testing.runner import ( + ContextOperation, + DurableFunctionTestResult, + DurableFunctionTestRunner, + StepOperation, +) + +def test_my_durable_functions(): + with DurableFunctionTestRunner(handler=function_under_test) as runner: + result: DurableFunctionTestResult = runner.run(input="input str", timeout=10) + + assert result.status is InvocationStatus.SUCCEEDED + assert result.result == '["1 2", "3 4 4 3", "5 6"]' + + one_result: StepOperation = result.get_step("one") + assert one_result.result == '"1 2"' + + two_result: ContextOperation = result.get_context("two") + assert two_result.result == '"3 4 4 3"' + + three_result: StepOperation = result.get_step("three") + assert three_result.result == '"5 6"' +``` +## Architecture +![Durable Functions Python Test Framework Architecture](assets/dar-python-test-framework-architecture.svg) + +## Event Flow +![Event Flow Sequence Diagram](assets/dar-python-test-framework-event-flow.svg) + +1. **DurableTestRunner** starts execution via **Executor** +2. **Executor** creates **Execution** and schedules initial invocation +3. During execution, checkpoints are processed by **CheckpointProcessor** +4. **Individual Processors** transform operation updates and may trigger events +5. **ExecutionNotifier** broadcasts events to **Executor** (observer) +6. **Executor** updates **Execution** state based on events +7. **Execution** completion triggers final event notifications +8. **DurableTestRunner** run() blocks until it receives completion event, and then returns `DurableFunctionTestResult`. + +## Major Components + +### Core Execution Flow +- **DurableTestRunner** - Main entry point that orchestrates test execution +- **Executor** - Manages execution lifecycle. Mutates Execution. +- **Execution** - Represents the state and operations of a single durable execution + +### Service Client Integration +- **InMemoryServiceClient** - Replaces AWS Lambda service client for local testing. Injected into SDK via `DurableExecutionInvocationInputWithClient` + +### Checkpoint Processing Pipeline +- **CheckpointProcessor** - Orchestrates operation transformations and validation +- **Individual Validators** - Validate operation updates and state transitions +- **Individual Processors** - Transform operation updates into operations (step, wait, callback, context, execution) + +### Execution status changes (Observer Pattern) +- **ExecutionNotifier** - Notifies observers of execution events +- **ExecutionObserver** - Interface for receiving execution lifecycle events +- **Executor** implements `ExecutionObserver` to handle completion events + +## Component Relationships + +### 1. DurableTestRunner → Executor → Execution +- **DurableTestRunner** serves as the main API entry point and sets up all components +- **Executor** manages the execution lifecycle, handling invocations and state transitions +- **Execution** maintains the state of operations and completion status + +### 2. Service Client Injection +- **DurableTestRunner** creates **InMemoryServiceClient** with **CheckpointProcessor** +- **InProcessInvoker** injects the service client into SDK via `DurableExecutionInvocationInputWithClient` +- When durable functions call checkpoint operations, they're intercepted by **InMemoryServiceClient** +- **InMemoryServiceClient** delegates to **CheckpointProcessor** for local processing + +### 3. CheckpointProcessor → Individual Validators → Individual Processors +- **CheckpointProcessor** orchestrates the checkpoint processing pipeline +- **Individual Validators** (CheckpointValidator, TransitionsValidator, and operation-specific validators) ensure operation updates are valid +- **Individual Processors** (StepProcessor, WaitProcessor, etc.) transform `OperationUpdate` into `Operation` + +### 4. Observer Pattern Flow +The observer pattern enables loose coupling between checkpoint processing and execution management: + +1. **CheckpointProcessor** processes operation updates +2. **Individual Processors** detect state changes (completion, failures, timer scheduling) +3. **ExecutionNotifier** broadcasts events to registered observers +4. **Executor** (as ExecutionObserver) receives notifications and updates **Execution** state +5. **Execution** complete_* methods finalize the execution state + + +## Documentation + +### Error Handling + +The testing framework implements AWS-compliant error responses that match the exact format expected by boto3 and AWS services. For detailed information about error response formats, exception types, and troubleshooting, see: + +- [Error Response Documentation](docs/error-responses.md) + +Key features: +- **AWS-compliant JSON format**: Matches boto3 expectations exactly +- **Smithy model compliance**: Field names follow AWS Smithy definitions +- **HTTP status code mapping**: Standard AWS service status codes +- **Boto3 compatibility**: Seamless integration with boto3 error handling + +## Developers +Please see [CONTRIBUTING.md](CONTRIBUTING.md). It contains the testing guide, sample commands and instructions +for how to contribute to this package. + +tldr; use `hatch` and it will manage virtual envs and dependencies for you, so you don't have to do it manually. + +## License + +This project is licensed under the [Apache-2.0 License](LICENSE). diff --git a/packages/aws-durable-execution-sdk-python-testing/assets/dar-python-test-framework-architecture.svg b/packages/aws-durable-execution-sdk-python-testing/assets/dar-python-test-framework-architecture.svg new file mode 100644 index 0000000..0d8fd6d --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/assets/dar-python-test-framework-architecture.svg @@ -0,0 +1 @@ +Service ClientExecution LifecycleCheckpoint ProcessingOperation Processors (Strategy Pattern)Operation Validators (Strategy Pattern)Observer PatternDurableServiceClientcheckpoint()get_execution_state()stop()checkpoint()get_execution_state()stop()InMemoryServiceClientcheckpoint_processor: CheckpointProcessorcheckpoint_processor: CheckpointProcessorcheckpoint()get_execution_state()stop()checkpoint()get_execution_state()stop()InProcessInvokerhandler: Callableservice_client: InMemoryServiceClienthandler: Callableservice_client: InMemoryServiceClientcreate_invocation_input()invoke()create_invocation_input()invoke()Executorstore: ExecutionStorescheduler: Schedulerinvoker: Invokerstart_execution()complete_execution()fail_execution()on_completed()on_failed()on_wait_timer_scheduled()on_step_retry_scheduled()Executiondurable_execution_arn: stroperations: list[Operation]is_complete: boolstart()complete_success()complete_fail()complete_wait()complete_retry()Schedulercall_later()create_event()CheckpointProcessorstore: ExecutionStorescheduler: Schedulernotifier: ExecutionNotifiertransformer: OperationTransformerprocess_checkpoint()add_execution_observer()Processes operation updatesthrough individual processorsand validators, then notifiesobservers of state changesCheckpointValidatorvalidate_input()TransitionsValidatorvalidate_transitions()OperationProcessor«note: Translates OperationUpdate to Operation»process()StepProcessorWaitProcessorCallbackProcessorContextProcessorExecutionProcessorOperationValidatorvalidate()Strategy Pattern: Each validatorimplements specific validationlogic for different operation typesStepValidatorWaitValidatorCallbackValidatorContextValidatorExecutionValidatorInvokeValidatorExecutionObserveron_completed()on_failed()on_wait_timer_scheduled()on_step_retry_scheduled()ExecutionNotifierobservers: list[ExecutionObserver]add_observer()notify_completed()notify_failed()notify_wait_timer_scheduled()notify_step_retry_scheduled()DurableTestRunnerhandler: Callableservice_client: InMemoryServiceClientexecutor: Executorrun()close()InMemoryServiceClientReplaces AWS Lambda service clientfor local testing. Injected intoSDK via DurableExecutionInvocationInputWithClientto intercept checkpoint callscreatesusesmanagescomplete_success()complete_fail()usesimplementsimplementsdelegates toinjects into SDKusesusesusesusesusesusescall_later/create_eventnotifiesnotifies via ExecutionNotifiernotify_completed()notify_failed()notify_wait_timer_scheduled()notify_step_retry_scheduled() \ No newline at end of file diff --git a/packages/aws-durable-execution-sdk-python-testing/assets/dar-python-test-framework-event-flow.svg b/packages/aws-durable-execution-sdk-python-testing/assets/dar-python-test-framework-event-flow.svg new file mode 100644 index 0000000..fbd55ab --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/assets/dar-python-test-framework-event-flow.svg @@ -0,0 +1 @@ +DurableTestRunnerDurableTestRunnerExecutorExecutorExecutionExecutionCheckpointProcessorCheckpointProcessorIndividual ProcessorsIndividual ProcessorsExecutionNotifierExecutionNotifier1. start execution2. create & schedule invocation3. process checkpoints4. transform operation updates4. trigger events5. broadcast events (observer)6. update state based on events7. completion triggers final notifications7. final event notifications8. DurableFunctionTestResult \ No newline at end of file diff --git a/packages/aws-durable-execution-sdk-python-testing/docs/error-responses.md b/packages/aws-durable-execution-sdk-python-testing/docs/error-responses.md new file mode 100644 index 0000000..44a3567 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/docs/error-responses.md @@ -0,0 +1,311 @@ +# AWS-Compliant Error Response Documentation + +This document describes the AWS-compliant error response format used by the Durable Executions Testing Library. + +## Overview + +The testing library implements AWS-compliant error responses that match the exact format expected by boto3 and AWS services. All error responses follow Smithy model definitions for structure and field naming. + +## Error Response Format + +### HTTP Response Structure + +All error responses use the following HTTP structure: + +``` +HTTP/1.1 +Content-Type: application/json + + +``` + +### JSON Body Format + +The JSON body format varies by exception type based on Smithy model definitions: + +#### Standard Format (Most Exceptions) + +```json +{ + "Type": "ExceptionName", + "message": "Detailed error message" +} +``` + +**Used by:** +- `InvalidParameterValueException` +- `CallbackTimeoutException` + +#### Capital Message Format + +```json +{ + "Type": "ExceptionName", + "Message": "Detailed error message" +} +``` + +**Used by:** +- `ResourceNotFoundException` +- `ServiceException` + +#### Special Format (ExecutionAlreadyStartedException) + +```json +{ + "message": "Detailed error message", + "DurableExecutionArn": "arn:aws:states:region:account:execution:name" +} +``` + +**Note:** This exception has no "Type" field per AWS Smithy definition. + +## Exception Types and Examples + +### InvalidParameterValueException (HTTP 400) + +**When:** Invalid parameter values are provided to API operations. + +**Example Response:** +```json +{ + "Type": "InvalidParameterValueException", + "message": "The parameter 'executionName' cannot be empty" +} +``` + +**Common Causes:** +- Empty or null required parameters +- Invalid parameter formats +- Parameter values outside allowed ranges + +### ResourceNotFoundException (HTTP 404) + +**When:** Requested resource does not exist. + +**Example Response:** +```json +{ + "Type": "ResourceNotFoundException", + "Message": "Execution with ID 'exec-123' not found" +} +``` + +**Common Causes:** +- Non-existent execution IDs +- Deleted or expired resources +- Incorrect resource identifiers + +### ServiceException (HTTP 500) + +**When:** Internal service errors occur. + +**Example Response:** +```json +{ + "Type": "ServiceException", + "Message": "An internal error occurred while processing the request" +} +``` + +**Common Causes:** +- Unexpected internal errors +- System unavailability +- Configuration issues + +### CallbackTimeoutException (HTTP 408) + +**When:** Callback operations timeout. + +**Example Response:** +```json +{ + "Type": "CallbackTimeoutException", + "message": "Callback operation timed out after 30 seconds" +} +``` + +**Common Causes:** +- Callback not received within timeout period +- Network connectivity issues +- Client-side delays + +### ExecutionAlreadyStartedException (HTTP 409) + +**When:** Attempting to start an execution that is already running. + +**Example Response:** +```json +{ + "message": "Execution is already started", + "DurableExecutionArn": "arn:aws:states:us-east-1:123456789012:execution:MyExecution:abc123" +} +``` + +**Common Causes:** +- Duplicate start execution requests +- Race conditions in execution management +- Client retry logic issues + +## HTTP Status Code Mapping + +| Exception | HTTP Status | Description | +|-----------|-------------|-------------| +| InvalidParameterValueException | 400 | Bad Request - Invalid input parameters | +| ResourceNotFoundException | 404 | Not Found - Resource does not exist | +| CallbackTimeoutException | 408 | Request Timeout - Operation timed out | +| ExecutionAlreadyStartedException | 409 | Conflict - Resource already exists | +| ServiceException | 500 | Internal Server Error - System error | + +## Field Name Conventions + +Field names strictly follow Smithy model definitions: + +- **lowercase "message"**: InvalidParameterValueException, CallbackTimeoutException, ExecutionAlreadyStartedException +- **capital "Message"**: ResourceNotFoundException, ServiceException +- **"Type"**: Present in all exceptions except ExecutionAlreadyStartedException +- **"DurableExecutionArn"**: Only in ExecutionAlreadyStartedException + +## Boto3 Compatibility + +All error responses are designed for boto3 compatibility: + +### Client Error Handling + +```python +import boto3 +from botocore.exceptions import ClientError + +try: + # API call that might fail + response = client.some_operation() +except ClientError as e: + error_code = e.response['Error']['Code'] + error_message = e.response['Error']['Message'] + + if error_code == 'InvalidParameterValueException': + # Handle invalid parameter + pass + elif error_code == 'ResourceNotFoundException': + # Handle not found + pass +``` + +### Error Response Structure + +The testing library's error responses match the structure boto3 expects: + +```python +# What boto3 receives +{ + 'Error': { + 'Code': 'InvalidParameterValueException', + 'Message': 'The parameter cannot be empty' + }, + 'ResponseMetadata': { + 'HTTPStatusCode': 400, + 'HTTPHeaders': {...} + } +} +``` + +## Migration from Legacy Format + +### Old Format (Deprecated) +```json +{ + "error": { + "type": "InvalidParameterError", + "message": "Error message", + "code": "INVALID_PARAMETER", + "requestId": "req-123" + } +} +``` + +### New AWS-Compliant Format +```json +{ + "Type": "InvalidParameterValueException", + "message": "Error message" +} +``` + +### Key Changes +1. **No wrapper object**: Direct JSON structure, no "error" wrapper +2. **AWS exception names**: Use official AWS exception names +3. **Smithy field names**: Follow exact Smithy model field naming +4. **Simplified structure**: Only essential fields per AWS standards +5. **Consistent HTTP codes**: Match AWS service status codes + +## Testing Error Responses + +### Unit Testing + +```python +from aws_durable_execution_sdk_python_testing.exceptions import InvalidParameterValueException +from aws_durable_execution_sdk_python_testing.web.models import HTTPResponse + +def test_error_response(): + exception = InvalidParameterValueException("Test error") + response = HTTPResponse.create_error_from_exception(exception) + + assert response.status_code == 400 + assert response.body == { + "Type": "InvalidParameterValueException", + "message": "Test error" + } +``` + +### Integration Testing + +```python +import requests + +def test_api_error_response(): + response = requests.post('http://localhost:8080/invalid-endpoint') + + assert response.status_code == 404 + error_data = response.json() + assert error_data['Type'] == 'ResourceNotFoundException' + assert 'Message' in error_data +``` + +## Best Practices + +### Error Message Guidelines + +1. **Be specific**: Include relevant details about what went wrong +2. **Be actionable**: Suggest how to fix the issue when possible +3. **Be consistent**: Use consistent terminology across similar errors +4. **Avoid sensitive data**: Don't include passwords, tokens, or PII + +### Exception Selection + +1. **InvalidParameterValueException**: For all input validation errors +2. **ResourceNotFoundException**: When requested resources don't exist +3. **ServiceException**: For unexpected internal errors only +4. **CallbackTimeoutException**: Specifically for callback timeouts +5. **ExecutionAlreadyStartedException**: Only for duplicate execution starts + +### HTTP Status Codes + +1. **Use standard codes**: Follow HTTP and AWS conventions +2. **Be consistent**: Same error types should use same status codes +3. **Client vs Server**: 4xx for client errors, 5xx for server errors + +## Troubleshooting + +### Common Issues + +1. **Wrong field names**: Ensure "message" vs "Message" matches exception type +2. **Missing Type field**: All exceptions except ExecutionAlreadyStartedException need "Type" +3. **Wrong status codes**: Verify HTTP status matches exception type +4. **JSON serialization**: Ensure all fields are JSON-serializable + +### Debugging Tips + +1. **Check exception type**: Verify you're using the correct AWS exception +2. **Validate JSON structure**: Use `to_dict()` to see exact output +3. **Test with boto3**: Verify compatibility with actual boto3 client +4. **Compare with AWS**: Match format with real AWS service responses \ No newline at end of file diff --git a/packages/aws-durable-execution-sdk-python-testing/pyproject.toml b/packages/aws-durable-execution-sdk-python-testing/pyproject.toml new file mode 100644 index 0000000..5153fea --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/pyproject.toml @@ -0,0 +1,132 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "aws-durable-execution-sdk-python-testing" +dynamic = ["version"] +description = 'AWS Durable Execution Testing SDK for Python' +readme = "README.md" +requires-python = ">=3.11" +license = "Apache-2.0" +keywords = [] +authors = [{ name = "AWS durable-execution-dev", email = "durable-execution-dev@amazon.com" }] +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = [ + "boto3>=1.42.1", + "aws-durable-execution-sdk-python>=1.0.0", +] + +[project.urls] +Documentation = "https://github.com/aws/aws-durable-execution-sdk-python-testing#readme" +Issues = "https://github.com/aws/aws-durable-execution-sdk-python-testing/issues" +Source = "https://github.com/aws/aws-durable-execution-sdk-python-testing" + +[project.scripts] +dex-local-runner = "aws_durable_execution_sdk_python_testing.cli:main" + +[tool.hatch.build.targets.sdist] +packages = ["src/aws_durable_execution_sdk_python_testing"] + +[tool.hatch.build.targets.wheel] +packages = ["src/aws_durable_execution_sdk_python_testing"] + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.version] +path = "src/aws_durable_execution_sdk_python_testing/__about__.py" + +# [tool.hatch.envs.default] +# dependencies=["pytest"] + +# [tool.hatch.envs.default.scripts] +# test="pytest" + +[tool.hatch.envs.test] +dependencies = [ + "coverage[toml]", + "pytest", + "pytest-cov", + "ruff", + "aws-durable-execution-sdk-python>=1.0.0", +] + +[tool.hatch.envs.test.scripts] +test = "pytest tests/ -v" +cov = "pytest --cov-report=term-missing --cov-config=pyproject.toml --cov=src/aws_durable_execution_sdk_python_testing --cov-fail-under=95" + +[tool.hatch.envs.types] +extra-dependencies = ["mypy>=1.0.0", "pytest"] +[tool.hatch.envs.types.scripts] +check = "mypy --install-types --non-interactive {args:src/aws_durable_execution_sdk_python_testing tests}" + +[tool.coverage.run] +source_pkgs = ["aws_durable_execution_sdk_python_testing"] +branch = true +parallel = true +omit = ["src/aws_durable_execution_sdk_python_testing/__about__.py", "tests/*"] + +[tool.coverage.paths] +aws_durable_execution_sdk_python_testing = [ + "src/aws_durable_execution_sdk_python_testing", + "*/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing", +] +tests = ["tests", "*/aws-durable-execution-sdk-python-testing/tests"] + +[tool.coverage.report] +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", + "@abstractmethod", +] + +[tool.ruff] +line-length = 88 +target-version = "py311" + +[tool.ruff.lint] +preview = false +select = ["TID252"] # Enforce absolute imports (ban relative imports) + +[tool.ruff.lint.isort] +known-first-party = ["aws_durable_execution_sdk_python_testing"] +force-single-line = false +lines-after-imports = 2 + +[tool.ruff.lint.per-file-ignores] +"tests/**" = [ + "ARG001", + "ARG002", + "ARG005", + "S101", + "PLR2004", + "SIM117", + "TRY301", +] +"src/aws_durable_execution_sdk_python_testing/invoker.py" = [ + "A002", # Argument `input` is shadowing a Python builtin +] + +[tool.pytest.ini_options] +# Declare custom markers to avoid warnings with --strict-markers +markers = [ + # Used for test selection with -m example + "example: marks tests as example tests (deselect with '-m \"not example\"')", + # Used for configuration - passes handler and lambda_function_name to durable_runner fixture + "durable_execution: marks tests that use the durable_runner fixture (not used for test selection)", +] +# Default test discovery paths +testpaths = ["tests"] +# Default options for all test runs +addopts = "-v --strict-markers" diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/__about__.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/__about__.py new file mode 100644 index 0000000..698e248 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/__about__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: 2025-present Amazon.com, Inc. or its affiliates. +# +# SPDX-License-Identifier: Apache-2.0 +__version__ = "1.2.1" diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/__init__.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/__init__.py new file mode 100644 index 0000000..c25db1c --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/__init__.py @@ -0,0 +1,23 @@ +"""DurableExecutionsPythonTestingLibrary module.""" + +from aws_durable_execution_sdk_python_testing.runner import ( + DurableChildContextTestRunner, + DurableFunctionCloudTestRunner, + DurableFunctionTestResult, + DurableFunctionTestRunner, + WebRunner, + WebRunnerConfig, +) + +from aws_durable_execution_sdk_python_testing.__about__ import __version__ + + +__all__ = [ + "DurableChildContextTestRunner", + "DurableFunctionCloudTestRunner", + "DurableFunctionTestResult", + "DurableFunctionTestRunner", + "WebRunner", + "WebRunnerConfig", + "__version__", +] diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/__init__.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/__init__.py new file mode 100644 index 0000000..8128bfb --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/__init__.py @@ -0,0 +1 @@ +"""Checkpoint processing module for handling OperationUpdate transformations.""" diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processor.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processor.py new file mode 100644 index 0000000..04b991c --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processor.py @@ -0,0 +1,101 @@ +"""Main checkpoint processor that orchestrates operation transformations.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from aws_durable_execution_sdk_python.lambda_service import ( + CheckpointOutput, + CheckpointUpdatedExecutionState, + OperationUpdate, + StateOutput, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.transformer import ( + OperationTransformer, +) +from aws_durable_execution_sdk_python_testing.checkpoint.validators.checkpoint import ( + CheckpointValidator, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) +from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier +from aws_durable_execution_sdk_python_testing.token import CheckpointToken + + +if TYPE_CHECKING: + from aws_durable_execution_sdk_python_testing.execution import Execution + from aws_durable_execution_sdk_python_testing.scheduler import Scheduler + from aws_durable_execution_sdk_python_testing.stores.base import ExecutionStore + + +class CheckpointProcessor: + """Handle OperationUpdate transformations and execution state updates.""" + + def __init__(self, store: ExecutionStore, scheduler: Scheduler): + self._store = store + self._scheduler = scheduler + self._notifier = ExecutionNotifier() + self._transformer = OperationTransformer() + + def add_execution_observer(self, observer) -> None: + """Add observer for execution events.""" + self._notifier.add_observer(observer) + + def process_checkpoint( + self, + checkpoint_token: str, + updates: list[OperationUpdate], + client_token: str | None, # noqa: ARG002 + ) -> CheckpointOutput: + """Process checkpoint updates and return result with updated execution state.""" + # 1. Get current execution state + token: CheckpointToken = CheckpointToken.from_str(checkpoint_token) + execution: Execution = self._store.load(token.execution_arn) + + # 2. Validate checkpoint token + if execution.is_complete or token.token_sequence != execution.token_sequence: + msg: str = "Invalid checkpoint token" + + raise InvalidParameterValueException(msg) + + # 3. Validate all updates, state transitions are valid, sizes etc. + CheckpointValidator.validate_input(updates, execution) + + # 4. Transform OperationUpdate -> Operation and schedule future replays + updated_operations, all_updates = self._transformer.process_updates( + updates=updates, + current_operations=execution.operations, + notifier=self._notifier, + execution_arn=token.execution_arn, + ) + + # 5. Generate a new checkpoint token and save updated operations + new_checkpoint_token = execution.get_new_checkpoint_token() + execution.operations = updated_operations + execution.updates.extend(all_updates) + self._store.update(execution) + + # 6. Return checkpoint result + return CheckpointOutput( + checkpoint_token=new_checkpoint_token, + new_execution_state=CheckpointUpdatedExecutionState( + operations=execution.get_navigable_operations(), next_marker=None + ), + ) + + def get_execution_state( + self, + checkpoint_token: str, + next_marker: str, # noqa: ARG002 + max_items: int = 1000, # noqa: ARG002 + ) -> StateOutput: + """Get current execution state.""" + token: CheckpointToken = CheckpointToken.from_str(checkpoint_token) + execution: Execution = self._store.load(token.execution_arn) + + # TODO: paging when size or max + return StateOutput( + operations=execution.get_navigable_operations(), next_marker=None + ) diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/__init__.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/__init__.py new file mode 100644 index 0000000..0e52f40 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/__init__.py @@ -0,0 +1 @@ +"""Checkpoint processors module.""" diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/base.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/base.py new file mode 100644 index 0000000..56933d5 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/base.py @@ -0,0 +1,199 @@ +"""Base processor class for operation transformations.""" + +from __future__ import annotations + +import datetime +from datetime import timedelta +from typing import TYPE_CHECKING + +from aws_durable_execution_sdk_python.lambda_service import ( + CallbackDetails, + ChainedInvokeDetails, + ContextDetails, + ExecutionDetails, + Operation, + OperationAction, + OperationStatus, + OperationType, + OperationUpdate, + StepDetails, + WaitDetails, +) + + +if TYPE_CHECKING: + from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier + + +class OperationProcessor: + """Base class for processing OperationUpdate to Operation transformations.""" + + def process( + self, + update: OperationUpdate, + current_op: Operation | None, + notifier: ExecutionNotifier, + execution_arn: str, + ) -> Operation | None: + """Process an operation update and return the transformed operation.""" + raise NotImplementedError + + def _get_start_time( + self, current_operation: Operation | None + ) -> datetime.datetime | None: + start_time: datetime.datetime | None = ( + current_operation.start_timestamp + if current_operation + else datetime.datetime.now(tz=datetime.UTC) + ) + return start_time + + def _get_end_time( + self, current_operation: Operation | None, status: OperationStatus + ) -> datetime.datetime | None: + """Get end timestamp for operation based on current state and status.""" + if current_operation and current_operation.end_timestamp: + return current_operation.end_timestamp + if status in { + OperationStatus.SUCCEEDED, + OperationStatus.FAILED, + OperationStatus.CANCELLED, + OperationStatus.TIMED_OUT, + OperationStatus.STOPPED, + }: + return datetime.datetime.now(tz=datetime.UTC) + return None + + def _create_execution_details( + self, update: OperationUpdate + ) -> ExecutionDetails | None: + """Create ExecutionDetails from OperationUpdate.""" + return ( + ExecutionDetails(input_payload=update.payload) + if update.operation_type == OperationType.EXECUTION + else None + ) + + def _create_context_details(self, update: OperationUpdate) -> ContextDetails | None: + """Create ContextDetails from OperationUpdate.""" + return ( + ContextDetails( + result=update.payload, + error=update.error, + replay_children=update.context_options.replay_children + if update.context_options + else False, + ) + if update.operation_type == OperationType.CONTEXT + else None + ) + + def _create_step_details( + self, + update: OperationUpdate, + current_operation: Operation | None = None, + ) -> StepDetails | None: + """Create StepDetails from OperationUpdate. + + Automatically increments attempt count for RETRY, SUCCEED, and FAIL actions. + """ + + attempt: int = 0 + next_attempt_timestamp: datetime.datetime | None = None + + if update.operation_type is OperationType.STEP: + if current_operation and current_operation.step_details: + attempt = current_operation.step_details.attempt + next_attempt_timestamp = ( + current_operation.step_details.next_attempt_timestamp + ) + # Increment attempt for RETRY, SUCCEED, and FAIL actions + if update.action in { + OperationAction.RETRY, + OperationAction.SUCCEED, + OperationAction.FAIL, + }: + attempt += 1 + return StepDetails( + attempt=attempt, + next_attempt_timestamp=next_attempt_timestamp, + result=update.payload, + error=update.error, + ) + + return None + + def _create_callback_details( + self, update: OperationUpdate + ) -> CallbackDetails | None: + """Create CallbackDetails from OperationUpdate.""" + return ( + CallbackDetails( + callback_id="placeholder", result=update.payload, error=update.error + ) + if update.operation_type == OperationType.CALLBACK + else None + ) + + def _create_invoke_details( + self, update: OperationUpdate + ) -> ChainedInvokeDetails | None: + """Create ChainedInvokeDetails from OperationUpdate.""" + if ( + update.operation_type == OperationType.CHAINED_INVOKE + and update.chained_invoke_options + ): + return ChainedInvokeDetails(result=update.payload, error=update.error) + return None + + def _translate_update_to_operation( + self, + update: OperationUpdate, + current_operation: Operation | None, + status: OperationStatus, + ) -> Operation: + """Transform OperationUpdate to Operation, always creating new Operation.""" + start_time: datetime.datetime | None = self._get_start_time(current_operation) + end_time: datetime.datetime | None = self._get_end_time( + current_operation, status + ) + + execution_details = self._create_execution_details(update) + context_details = self._create_context_details(update) + step_details = self._create_step_details(update, current_operation) + callback_details = self._create_callback_details(update) + invoke_details = self._create_invoke_details(update) + wait_details = self._create_wait_details(update, current_operation) + + return Operation( + operation_id=update.operation_id, + parent_id=update.parent_id, + name=update.name, + start_timestamp=start_time, + end_timestamp=end_time, + operation_type=update.operation_type, + status=status, + sub_type=update.sub_type, + execution_details=execution_details, + context_details=context_details, + step_details=step_details, + callback_details=callback_details, + chained_invoke_details=invoke_details, + wait_details=wait_details, + ) + + def _create_wait_details( + self, update: OperationUpdate, current_operation: Operation | None + ) -> WaitDetails | None: + """Create WaitDetails from OperationUpdate.""" + if update.operation_type == OperationType.WAIT and update.wait_options: + if current_operation and current_operation.wait_details: + scheduled_end_timestamp = ( + current_operation.wait_details.scheduled_end_timestamp + ) + else: + scheduled_end_timestamp = datetime.datetime.now( + tz=datetime.UTC + ) + timedelta(seconds=update.wait_options.wait_seconds) + return WaitDetails(scheduled_end_timestamp=scheduled_end_timestamp) + return None diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/callback.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/callback.py new file mode 100644 index 0000000..48c2f01 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/callback.py @@ -0,0 +1,89 @@ +"""Callback operation processor for handling CALLBACK operation updates.""" + +from __future__ import annotations + +import datetime +from typing import TYPE_CHECKING + +from aws_durable_execution_sdk_python.lambda_service import ( + Operation, + OperationAction, + OperationStatus, + OperationUpdate, + CallbackDetails, + OperationType, + CallbackOptions, +) +from aws_durable_execution_sdk_python_testing.checkpoint.processors.base import ( + OperationProcessor, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) +from aws_durable_execution_sdk_python_testing.token import CallbackToken + +if TYPE_CHECKING: + from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier + + +class CallbackProcessor(OperationProcessor): + """Processes CALLBACK operation updates with activity scheduling.""" + + def process( + self, + update: OperationUpdate, + current_op: Operation | None, + notifier: ExecutionNotifier, # noqa: ARG002 + execution_arn: str, # noqa: ARG002 + ) -> Operation: + """Process CALLBACK operation update with scheduler integration for activities.""" + match update.action: + case OperationAction.START: + callback_token: CallbackToken = CallbackToken( + execution_arn=execution_arn, + operation_id=update.operation_id, + ) + + callback_id: str = callback_token.to_str() + + callback_details: CallbackDetails | None = ( + CallbackDetails( + callback_id=callback_id, + result=update.payload, + error=update.error, + ) + if update.operation_type == OperationType.CALLBACK + else None + ) + + status: OperationStatus = OperationStatus.STARTED + + start_time: datetime.datetime | None = self._get_start_time(current_op) + + end_time: datetime.datetime | None = self._get_end_time( + current_op, status + ) + + operation: Operation = Operation( + operation_id=update.operation_id, + parent_id=update.parent_id, + name=update.name, + start_timestamp=start_time, + end_timestamp=end_time, + operation_type=update.operation_type, + status=status, + sub_type=update.sub_type, + callback_details=callback_details, + ) + callback_options: CallbackOptions | None = update.callback_options + + notifier.notify_callback_created( + execution_arn=execution_arn, + operation_id=update.operation_id, + callback_options=callback_options, + callback_token=callback_token, + ) + return operation + case _: + msg: str = "Invalid action for CALLBACK operation." + raise InvalidParameterValueException(msg) diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/context.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/context.py new file mode 100644 index 0000000..182bf91 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/context.py @@ -0,0 +1,59 @@ +"""Context operation processor for handling CONTEXT operation updates.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from aws_durable_execution_sdk_python.lambda_service import ( + Operation, + OperationAction, + OperationStatus, + OperationUpdate, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.processors.base import ( + OperationProcessor, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) + + +if TYPE_CHECKING: + from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier + + +class ContextProcessor(OperationProcessor): + """Processes CONTEXT operation updates for execution context management.""" + + def process( + self, + update: OperationUpdate, + current_op: Operation | None, + notifier: ExecutionNotifier, # noqa: ARG002 + execution_arn: str, # noqa: ARG002 + ) -> Operation: + """Process CONTEXT operation update for context state transitions.""" + match update.action: + case OperationAction.START: + # TODO: check for "Cannot start a CONTEXT operation that already exists." + return self._translate_update_to_operation( + update=update, + current_operation=current_op, + status=OperationStatus.STARTED, + ) + case OperationAction.SUCCEED: + return self._translate_update_to_operation( + update=update, + current_operation=current_op, + status=OperationStatus.SUCCEEDED, + ) + case OperationAction.FAIL: + return self._translate_update_to_operation( + update=update, + current_operation=current_op, + status=OperationStatus.FAILED, + ) + case _: + msg: str = "Invalid action for CONTEXT operation." + raise InvalidParameterValueException(msg) diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/execution.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/execution.py new file mode 100644 index 0000000..e8ad2ef --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/execution.py @@ -0,0 +1,52 @@ +"""Execution operation processor for handling EXECUTION operation updates.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + Operation, + OperationAction, + OperationUpdate, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.processors.base import ( + OperationProcessor, +) + + +if TYPE_CHECKING: + from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier + + +class ExecutionProcessor(OperationProcessor): + """Processes EXECUTION operation updates for workflow completion.""" + + def process( + self, + update: OperationUpdate, + current_op: Operation | None, # noqa: ARG002 + notifier: ExecutionNotifier, + execution_arn: str, + ) -> Operation | None: + """Process EXECUTION operation update for workflow completion/failure.""" + match update.action: + case OperationAction.SUCCEED: + notifier.notify_completed( + execution_arn=execution_arn, result=update.payload + ) + case _: + # intentional. actual service will fail any EXECUTION update that is not SUCCEED. + error = ( + update.error + if update.error + else ErrorObject.from_message( + "There is no error details but EXECUTION checkpoint action is not SUCCEED." + ) + ) + # All EXECUTION failures go through normal fail path + # Timeout/Stop status is set by executor based on the operation that caused it + notifier.notify_failed(execution_arn=execution_arn, error=error) + # TODO: Svc doesn't actually create checkpoint for EXECUTION. might have to for localrunner though. + return None diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/step.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/step.py new file mode 100644 index 0000000..0db5a0b --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/step.py @@ -0,0 +1,124 @@ +"""Step operation processor for handling STEP operation updates.""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from typing import TYPE_CHECKING + +from aws_durable_execution_sdk_python.lambda_service import ( + Operation, + OperationAction, + OperationStatus, + OperationUpdate, + StepDetails, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.processors.base import ( + OperationProcessor, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) + + +if TYPE_CHECKING: + from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier + + +class StepProcessor(OperationProcessor): + """Processes STEP operation updates with retry scheduling.""" + + def process( + self, + update: OperationUpdate, + current_op: Operation | None, + notifier: ExecutionNotifier, + execution_arn: str, + ) -> Operation: + """Process STEP operation update with scheduler integration for retries.""" + match update.action: + case OperationAction.START: + return self._translate_update_to_operation( + update=update, + current_operation=current_op, + status=OperationStatus.STARTED, + ) + case OperationAction.RETRY: + # set Status=PENDING, next attempt time, attempt count + 1 + delay = ( + update.step_options.next_attempt_delay_seconds + if update.step_options + else 0 + ) + next_attempt_time = datetime.now(UTC) + timedelta(seconds=delay) + + # Build new step_details with incremented attempt + current_attempt = ( + current_op.step_details.attempt + if current_op and current_op.step_details + else 0 + ) + new_step_details = StepDetails( + attempt=current_attempt + 1, + next_attempt_timestamp=next_attempt_time, + result=( + current_op.step_details.result + if current_op and current_op.step_details + else None + ), + error=( + current_op.step_details.error + if current_op and current_op.step_details + else None + ), + ) + + # Create new operation with updated step_details + retry_operation = Operation( + operation_id=update.operation_id, + operation_type=update.operation_type, + status=OperationStatus.PENDING, + parent_id=update.parent_id, + name=update.name, + start_timestamp=( + current_op.start_timestamp if current_op else datetime.now(UTC) + ), + end_timestamp=None, + sub_type=update.sub_type, + execution_details=current_op.execution_details + if current_op + else None, + context_details=current_op.context_details if current_op else None, + step_details=new_step_details, + wait_details=current_op.wait_details if current_op else None, + callback_details=current_op.callback_details + if current_op + else None, + chained_invoke_details=current_op.chained_invoke_details + if current_op + else None, + ) + + # Schedule step retry timer to fire after delay + notifier.notify_step_retry_scheduled( + execution_arn=execution_arn, + operation_id=update.operation_id, + delay=delay, + ) + return retry_operation + case OperationAction.SUCCEED: + return self._translate_update_to_operation( + update=update, + current_operation=current_op, + status=OperationStatus.SUCCEEDED, + ) + case OperationAction.FAIL: + return self._translate_update_to_operation( + update=update, + current_operation=current_op, + status=OperationStatus.FAILED, + ) + case _: + msg: str = "Invalid action for STEP operation." + + raise InvalidParameterValueException(msg) diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/wait.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/wait.py new file mode 100644 index 0000000..01cc69b --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/wait.py @@ -0,0 +1,95 @@ +"""Wait operation processor for handling WAIT operation updates.""" + +from __future__ import annotations + +import logging +import os +from datetime import UTC, datetime, timedelta +from typing import TYPE_CHECKING + +from aws_durable_execution_sdk_python.lambda_service import ( + Operation, + OperationAction, + OperationStatus, + OperationUpdate, + WaitDetails, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.processors.base import ( + OperationProcessor, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) + + +if TYPE_CHECKING: + from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier + + +class WaitProcessor(OperationProcessor): + """Processes WAIT operation updates with timer scheduling.""" + + def process( + self, + update: OperationUpdate, + current_op: Operation | None, + notifier: ExecutionNotifier, + execution_arn: str, + ) -> Operation: + """Process WAIT operation update with scheduler integration for timers.""" + match update.action: + case OperationAction.START: + wait_seconds = ( + update.wait_options.wait_seconds if update.wait_options else 0 + ) + time_scale = float(os.getenv("DURABLE_EXECUTION_TIME_SCALE", "1.0")) + logging.info("Using DURABLE_EXECUTION_TIME_SCALE: %f", time_scale) + scaled_wait_seconds = wait_seconds * time_scale + + scheduled_end_timestamp = datetime.now(UTC) + timedelta( + seconds=scaled_wait_seconds + ) + + # Create WaitDetails with scheduled timestamp + wait_details = WaitDetails( + scheduled_end_timestamp=scheduled_end_timestamp + ) + + # Create new operation with wait details + wait_operation = Operation( + operation_id=update.operation_id, + operation_type=update.operation_type, + status=OperationStatus.STARTED, + parent_id=update.parent_id, + name=update.name, + start_timestamp=datetime.now(UTC), + end_timestamp=None, + sub_type=update.sub_type, + execution_details=None, + context_details=None, + step_details=None, + wait_details=wait_details, + callback_details=None, + chained_invoke_details=None, + ) + + # Schedule wait timer to complete after delay + notifier.notify_wait_timer_scheduled( + execution_arn=execution_arn, + operation_id=update.operation_id, + delay=scaled_wait_seconds, + ) + return wait_operation + case OperationAction.CANCEL: + # TODO: need to cancel the WAIT in the executor + # TODO: increase sequence id + return self._translate_update_to_operation( + update=update, + current_operation=current_op, + status=OperationStatus.CANCELLED, + ) + case _: + msg: str = "Invalid action for WAIT operation." + + raise InvalidParameterValueException(msg) diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/transformer.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/transformer.py new file mode 100644 index 0000000..cd37b8a --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/transformer.py @@ -0,0 +1,104 @@ +"""Operation transformer for converting OperationUpdates to Operations.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from aws_durable_execution_sdk_python.lambda_service import ( + Operation, + OperationType, + OperationUpdate, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.processors.callback import ( + CallbackProcessor, +) +from aws_durable_execution_sdk_python_testing.checkpoint.processors.context import ( + ContextProcessor, +) +from aws_durable_execution_sdk_python_testing.checkpoint.processors.execution import ( + ExecutionProcessor, +) +from aws_durable_execution_sdk_python_testing.checkpoint.processors.step import ( + StepProcessor, +) +from aws_durable_execution_sdk_python_testing.checkpoint.processors.wait import ( + WaitProcessor, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) + + +if TYPE_CHECKING: + from collections.abc import MutableMapping + + from aws_durable_execution_sdk_python_testing.checkpoint.processors.base import ( + OperationProcessor, + ) + +from typing import ClassVar + + +class OperationTransformer: + """Transforms OperationUpdates to Operations while maintaining order and triggering scheduler actions.""" + + _DEFAULT_PROCESSORS: ClassVar[dict[OperationType, OperationProcessor]] = { + OperationType.STEP: StepProcessor(), + OperationType.WAIT: WaitProcessor(), + OperationType.CONTEXT: ContextProcessor(), + OperationType.CALLBACK: CallbackProcessor(), + OperationType.EXECUTION: ExecutionProcessor(), + } + + def __init__( + self, + processors: MutableMapping[OperationType, OperationProcessor] | None = None, + ): + self.processors = processors if processors else self._DEFAULT_PROCESSORS + + def process_updates( + self, + updates: list[OperationUpdate], + current_operations: list[Operation], + notifier, + execution_arn: str, + ) -> tuple[list[Operation], list[OperationUpdate]]: + """Transform updates maintaining operation order and return (operations, updates).""" + op_map = {op.operation_id: op for op in current_operations} + + # Start with copy of current operations list + result_operations = current_operations.copy() + + for update in updates: + processor = self.processors.get(update.operation_type) + if processor: + current_op = op_map.get(update.operation_id) + updated_op = processor.process( + update=update, + current_op=current_op, + notifier=notifier, + execution_arn=execution_arn, + ) + + if updated_op is not None: + if update.operation_id in op_map: + # Update existing operation in-place + for i, op in enumerate(result_operations): # pragma: no branch + # no branch coverage because result_operation empty not reachable here + if op.operation_id == update.operation_id: + result_operations[i] = updated_op + break + else: + # Append new operation to end + result_operations.append(updated_op) + + # Update map for future lookups + op_map[update.operation_id] = updated_op + else: + msg: str = ( + f"Checkpoint for {update.operation_type} is not implemented yet." + ) + raise InvalidParameterValueException(msg) + + return result_operations, updates diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/__init__.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/__init__.py new file mode 100644 index 0000000..f97d027 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/__init__.py @@ -0,0 +1 @@ +"""Checkpoint validation module.""" diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/checkpoint.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/checkpoint.py new file mode 100644 index 0000000..86d654d --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/checkpoint.py @@ -0,0 +1,242 @@ +"""Main checkpoint input validator.""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING + +from aws_durable_execution_sdk_python.lambda_service import ( + OperationAction, + OperationType, + OperationUpdate, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.callback import ( + CallbackOperationValidator, +) +from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.context import ( + ContextOperationValidator, +) +from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.execution import ( + ExecutionOperationValidator, +) +from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.invoke import ( + ChainedInvokeOperationValidator, +) +from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.step import ( + StepOperationValidator, +) +from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.wait import ( + WaitOperationValidator, +) +from aws_durable_execution_sdk_python_testing.checkpoint.validators.transitions import ( + ValidActionsByOperationTypeValidator, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) + + +if TYPE_CHECKING: + from collections.abc import MutableMapping + + from aws_durable_execution_sdk_python_testing.execution import Execution + +MAX_ERROR_PAYLOAD_SIZE_BYTES = 32768 + + +class CheckpointValidator: + """Validates checkpoint input based on current state.""" + + @staticmethod + def validate_input(updates: list[OperationUpdate], execution: Execution) -> None: + """Perform validation on the given input based on the current state.""" + if not updates: + return + + CheckpointValidator._validate_conflicting_execution_update(updates) + CheckpointValidator._validate_parent_id_and_duplicate_id(updates, execution) + + for update in updates: + CheckpointValidator._validate_operation_update(update, execution) + + @staticmethod + def _validate_conflicting_execution_update(updates: list[OperationUpdate]) -> None: + """Validate that there are no conflicting execution updates.""" + execution_updates = [ + update + for update in updates + if update.operation_type == OperationType.EXECUTION + ] + + if len(execution_updates) > 1: + msg_multiple_exec: str = "Cannot checkpoint multiple EXECUTION updates." + + raise InvalidParameterValueException(msg_multiple_exec) + + if execution_updates and updates[-1].operation_type != OperationType.EXECUTION: + msg_exec_last: str = "EXECUTION checkpoint must be the last update." + + raise InvalidParameterValueException(msg_exec_last) + + @staticmethod + def _validate_operation_update( + update: OperationUpdate, execution: Execution + ) -> None: + """Validate a single operation update.""" + CheckpointValidator._validate_inconsistent_operation_metadata(update, execution) + CheckpointValidator._validate_payload_sizes(update) + ValidActionsByOperationTypeValidator.validate( + update.operation_type, update.action + ) + CheckpointValidator._validate_operation_status_transition(update, execution) + + @staticmethod + def _validate_payload_sizes(update: OperationUpdate) -> None: + """Validate that operation payload sizes are not too large.""" + if update.error is not None: + payload = json.dumps(update.error.to_dict()) + if len(payload) > MAX_ERROR_PAYLOAD_SIZE_BYTES: + msg: str = f"Error object size must be less than {MAX_ERROR_PAYLOAD_SIZE_BYTES} bytes." + raise InvalidParameterValueException(msg) + + @staticmethod + def _validate_operation_status_transition( + update: OperationUpdate, execution: Execution + ) -> None: + """Validate that the operation status transition is valid.""" + current_state = None + for operation in execution.operations: + if operation.operation_id == update.operation_id: + current_state = operation + break + + match update.operation_type: + case OperationType.STEP: + StepOperationValidator.validate(current_state, update) + case OperationType.CONTEXT: + ContextOperationValidator.validate(current_state, update) + case OperationType.WAIT: + WaitOperationValidator.validate(current_state, update) + case OperationType.CALLBACK: + CallbackOperationValidator.validate(current_state, update) + case OperationType.CHAINED_INVOKE: + ChainedInvokeOperationValidator.validate(current_state, update) + case OperationType.EXECUTION: + ExecutionOperationValidator.validate(update) + case _: # pragma: no cover + msg: str = "Invalid operation type." + + raise InvalidParameterValueException(msg) + + @staticmethod + def _validate_inconsistent_operation_metadata( + update: OperationUpdate, execution: Execution + ) -> None: + """Validate that operation metadata is consistent with existing operation.""" + current_state = None + for operation in execution.operations: + if operation.operation_id == update.operation_id: + current_state = operation + break + + if current_state is not None: + if ( + update.operation_type is not None + and update.operation_type != current_state.operation_type + ): + msg: str = "Inconsistent operation type." + raise InvalidParameterValueException(msg) + + if ( + update.sub_type is not None + and update.sub_type != current_state.sub_type + ): + msg_subtype: str = "Inconsistent operation subtype." + raise InvalidParameterValueException(msg_subtype) + + if update.name is not None and update.name != current_state.name: + msg_name: str = "Inconsistent operation name." + raise InvalidParameterValueException(msg_name) + + if ( + update.parent_id is not None + and update.parent_id != current_state.parent_id + ): + msg_parent: str = "Inconsistent parent operation id." + raise InvalidParameterValueException(msg_parent) + + @staticmethod + def _validate_parent_id_and_duplicate_id( + updates: list[OperationUpdate], execution: Execution + ) -> None: + """Validate parent IDs and check for duplicate operation IDs. + + Validate that any provided parentId is valid, and also validate no duplicate operation is being + updated at the same time (unless it is a STEP/CONTEXT starting + performing one more non-START action). + """ + operations_started: MutableMapping[str, OperationUpdate] = {} + last_updates_seen: MutableMapping[str, OperationUpdate] = {} + + for update in updates: + if CheckpointValidator._is_invalid_duplicate_update( + update, last_updates_seen + ): + msg_duplicate: str = ( + "Cannot checkpoint multiple operations with the same ID." + ) + raise InvalidParameterValueException(msg_duplicate) + + if not CheckpointValidator._is_valid_parent_for_update( + execution, update, operations_started + ): + msg_parent: str = "Invalid parent operation id." + raise InvalidParameterValueException(msg_parent) + + if update.action == OperationAction.START: + operations_started[update.operation_id] = update + + last_updates_seen[update.operation_id] = update + + @staticmethod + def _is_invalid_duplicate_update( + update: OperationUpdate, last_updates_seen: MutableMapping[str, OperationUpdate] + ) -> bool: + """Check if this is an invalid duplicate update.""" + last_update = last_updates_seen.get(update.operation_id) + if last_update is None: + return False + + if last_update.operation_type in (OperationType.STEP, OperationType.CONTEXT): + # Allow duplicate for STEP/CONTEXT if last was START and current is not START + allow_duplicate = ( + last_update.action == OperationAction.START + and update.action != OperationAction.START + ) + return not allow_duplicate + + return True + + @staticmethod + def _is_valid_parent_for_update( + execution: Execution, + update: OperationUpdate, + operations_started: MutableMapping[str, OperationUpdate], + ) -> bool: + """Check if the parent ID is valid for the update.""" + parent_id = update.parent_id + + if parent_id is None: + return True + + # Check if parent is in operations started in this batch + if parent_id in operations_started: + parent_update = operations_started[parent_id] + return parent_update.operation_type == OperationType.CONTEXT + + # Check if parent exists in current execution state + for operation in execution.operations: + if operation.operation_id == parent_id: + return operation.operation_type == OperationType.CONTEXT + + return False diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/__init__.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/__init__.py new file mode 100644 index 0000000..455b119 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/__init__.py @@ -0,0 +1 @@ +"""Operation-specific validators.""" diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/callback.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/callback.py new file mode 100644 index 0000000..575db81 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/callback.py @@ -0,0 +1,45 @@ +"""Callback operation validator.""" + +from __future__ import annotations + +from aws_durable_execution_sdk_python.lambda_service import ( + Operation, + OperationAction, + OperationStatus, + OperationUpdate, +) + +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) + + +VALID_ACTIONS_FOR_CALLBACK = frozenset( + [ + OperationAction.START, + ] +) + + +class CallbackOperationValidator: + """Validates CALLBACK operation transitions.""" + + _ALLOWED_STATUS_TO_CANCEL = frozenset( + [ + OperationStatus.STARTED, + ] + ) + + @staticmethod + def validate(current_state: Operation | None, update: OperationUpdate) -> None: + """Validate CALLBACK operation update.""" + match update.action: + case OperationAction.START: + if current_state is not None: + msg_callback_exists: str = ( + "Cannot start a CALLBACK that already exist." + ) + raise InvalidParameterValueException(msg_callback_exists) + case _: + msg: str = "Invalid action for CALLBACK operation." + raise InvalidParameterValueException(msg) diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/context.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/context.py new file mode 100644 index 0000000..3104044 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/context.py @@ -0,0 +1,73 @@ +"""Context operation validator.""" + +from __future__ import annotations + +from aws_durable_execution_sdk_python.lambda_service import ( + Operation, + OperationAction, + OperationStatus, + OperationUpdate, +) + +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) + + +VALID_ACTIONS_FOR_CONTEXT = frozenset( + [ + OperationAction.START, + OperationAction.FAIL, + OperationAction.SUCCEED, + ] +) + + +class ContextOperationValidator: + """Validates CONTEXT operation transitions.""" + + _ALLOWED_STATUS_TO_CLOSE = frozenset( + [ + OperationStatus.STARTED, + ] + ) + + @staticmethod + def validate(current_state: Operation | None, update: OperationUpdate) -> None: + """Validate CONTEXT operation update.""" + match update.action: + case OperationAction.START: + if current_state is not None: + msg_context_exists: str = ( + "Cannot start a CONTEXT that already exist." + ) + + raise InvalidParameterValueException(msg_context_exists) + case OperationAction.FAIL | OperationAction.SUCCEED: + if ( + current_state is not None + and current_state.status + not in ContextOperationValidator._ALLOWED_STATUS_TO_CLOSE + ): + msg_context_close: str = "Invalid current CONTEXT state to close." + + raise InvalidParameterValueException(msg_context_close) + if update.action == OperationAction.FAIL and update.payload is not None: + msg_context_fail_payload: str = ( + "Cannot provide a Payload for FAIL action." + ) + + raise InvalidParameterValueException(msg_context_fail_payload) + if ( + update.action == OperationAction.SUCCEED + and update.error is not None + ): + msg_context_succeed_error: str = ( + "Cannot provide an Error for SUCCEED action." + ) + + raise InvalidParameterValueException(msg_context_succeed_error) + case _: + msg_context_invalid: str = "Invalid CONTEXT action." + + raise InvalidParameterValueException(msg_context_invalid) diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/execution.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/execution.py new file mode 100644 index 0000000..5e66677 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/execution.py @@ -0,0 +1,47 @@ +"""Execution operation validator.""" + +from __future__ import annotations + +from aws_durable_execution_sdk_python.lambda_service import ( + OperationAction, + OperationUpdate, +) + +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) + + +VALID_ACTIONS_FOR_EXECUTION = frozenset( + [ + OperationAction.SUCCEED, + OperationAction.FAIL, + ] +) + + +class ExecutionOperationValidator: + """Validates EXECUTION operation transitions.""" + + @staticmethod + def validate(update: OperationUpdate) -> None: + """Validate EXECUTION operation update.""" + match update.action: + case OperationAction.SUCCEED: + if update.error is not None: + msg_exec_succeed_error: str = ( + "Cannot provide an Error for SUCCEED action." + ) + + raise InvalidParameterValueException(msg_exec_succeed_error) + case OperationAction.FAIL: + if update.payload is not None: + msg_exec_fail_payload: str = ( + "Cannot provide a Payload for FAIL action." + ) + + raise InvalidParameterValueException(msg_exec_fail_payload) + case _: + msg_exec_invalid: str = "Invalid EXECUTION action." + + raise InvalidParameterValueException(msg_exec_invalid) diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/invoke.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/invoke.py new file mode 100644 index 0000000..1c28712 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/invoke.py @@ -0,0 +1,56 @@ +"""Invoke operation validator.""" + +from __future__ import annotations + +from aws_durable_execution_sdk_python.lambda_service import ( + Operation, + OperationAction, + OperationStatus, + OperationUpdate, +) + +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) + + +VALID_ACTIONS_FOR_INVOKE = frozenset( + [ + OperationAction.START, + OperationAction.CANCEL, + ] +) + + +class ChainedInvokeOperationValidator: + """Validates INVOKE operation transitions.""" + + _ALLOWED_STATUS_TO_CANCEL = frozenset( + [ + OperationStatus.STARTED, + ] + ) + + @staticmethod + def validate(current_state: Operation | None, update: OperationUpdate) -> None: + """Validate INVOKE operation update.""" + match update.action: + case OperationAction.START: + if current_state is not None: + msg_invoke_exists: str = ( + "Cannot start an INVOKE that already exist." + ) + + raise InvalidParameterValueException(msg_invoke_exists) + case OperationAction.CANCEL: + if ( + current_state is None + or current_state.status + not in ChainedInvokeOperationValidator._ALLOWED_STATUS_TO_CANCEL + ): + msg_invoke_cancel: str = "Cannot cancel an INVOKE that does not exist or has already completed." + raise InvalidParameterValueException(msg_invoke_cancel) + case _: + msg_invoke_invalid: str = "Invalid INVOKE action." + + raise InvalidParameterValueException(msg_invoke_invalid) diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/step.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/step.py new file mode 100644 index 0000000..52c5388 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/step.py @@ -0,0 +1,106 @@ +"""Step operation validator.""" + +from __future__ import annotations + +from aws_durable_execution_sdk_python.lambda_service import ( + Operation, + OperationAction, + OperationStatus, + OperationUpdate, +) + +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) + + +VALID_ACTIONS_FOR_STEP = frozenset( + [ + OperationAction.START, + OperationAction.FAIL, + OperationAction.RETRY, + OperationAction.SUCCEED, + ] +) + + +class StepOperationValidator: + """Validates STEP operation transitions.""" + + _ALLOWED_STATUS_TO_CLOSE = frozenset( + [ + OperationStatus.STARTED, + OperationStatus.READY, + ] + ) + + _ALLOWED_STATUS_TO_START = frozenset( + [ + OperationStatus.READY, + ] + ) + + _ALLOWED_STATUS_TO_REATTEMPT = frozenset( + [ + OperationStatus.STARTED, + OperationStatus.READY, + ] + ) + + @staticmethod + def validate(current_state: Operation | None, update: OperationUpdate) -> None: + """Validate STEP operation update.""" + if current_state is None: + return + + match update.action: + case OperationAction.START: + if ( + current_state.status + not in StepOperationValidator._ALLOWED_STATUS_TO_START + ): + msg_step_start: str = "Invalid current STEP state to start." + + raise InvalidParameterValueException(msg_step_start) + case OperationAction.FAIL | OperationAction.SUCCEED: + if ( + current_state.status + not in StepOperationValidator._ALLOWED_STATUS_TO_CLOSE + ): + msg_step_close: str = "Invalid current STEP state to close." + + raise InvalidParameterValueException(msg_step_close) + if update.action == OperationAction.FAIL and update.payload is not None: + msg_fail_payload: str = "Cannot provide a Payload for FAIL action." + + raise InvalidParameterValueException(msg_fail_payload) + if ( + update.action == OperationAction.SUCCEED + and update.error is not None + ): + msg_succeed_error: str = ( + "Cannot provide an Error for SUCCEED action." + ) + + raise InvalidParameterValueException(msg_succeed_error) + case OperationAction.RETRY: + if ( + current_state.status + not in StepOperationValidator._ALLOWED_STATUS_TO_REATTEMPT + ): + msg_step_retry: str = "Invalid current STEP state to re-attempt." + + raise InvalidParameterValueException(msg_step_retry) + if update.step_options is None: + msg_step_options: str = "Invalid StepOptions for the given action." + + raise InvalidParameterValueException(msg_step_options) + if update.error is not None and update.payload is not None: + msg_retry_both: str = ( + "Cannot provide both error and payload to RETRY a STEP." + ) + raise InvalidParameterValueException(msg_retry_both) + case _: + msg_step_invalid: str = "Invalid STEP action." + + raise InvalidParameterValueException(msg_step_invalid) diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/wait.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/wait.py new file mode 100644 index 0000000..1858b3b --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/wait.py @@ -0,0 +1,54 @@ +"""Wait operation validator.""" + +from __future__ import annotations + +from aws_durable_execution_sdk_python.lambda_service import ( + Operation, + OperationAction, + OperationStatus, + OperationUpdate, +) + +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) + + +VALID_ACTIONS_FOR_WAIT = frozenset( + [ + OperationAction.START, + OperationAction.CANCEL, + ] +) + + +class WaitOperationValidator: + """Validates WAIT operation transitions.""" + + _ALLOWED_STATUS_TO_CANCEL = frozenset( + [ + OperationStatus.STARTED, + ] + ) + + @staticmethod + def validate(current_state: Operation | None, update: OperationUpdate) -> None: + """Validate WAIT operation update.""" + match update.action: + case OperationAction.START: + if current_state is not None: + msg_wait_exists: str = "Cannot start a WAIT that already exist." + + raise InvalidParameterValueException(msg_wait_exists) + case OperationAction.CANCEL: + if ( + current_state is None + or current_state.status + not in WaitOperationValidator._ALLOWED_STATUS_TO_CANCEL + ): + msg_wait_cancel: str = "Cannot cancel a WAIT that does not exist or has already completed." + raise InvalidParameterValueException(msg_wait_cancel) + case _: + msg_wait_invalid: str = "Invalid WAIT action." + + raise InvalidParameterValueException(msg_wait_invalid) diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/transitions.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/transitions.py new file mode 100644 index 0000000..fff45a1 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/transitions.py @@ -0,0 +1,66 @@ +"""Validator for valid actions by operation type.""" + +from __future__ import annotations + +from typing import ClassVar + +from aws_durable_execution_sdk_python.lambda_service import ( + OperationAction, + OperationType, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.callback import ( + VALID_ACTIONS_FOR_CALLBACK, +) +from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.context import ( + VALID_ACTIONS_FOR_CONTEXT, +) +from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.execution import ( + VALID_ACTIONS_FOR_EXECUTION, +) +from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.invoke import ( + VALID_ACTIONS_FOR_INVOKE, +) +from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.step import ( + VALID_ACTIONS_FOR_STEP, +) +from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.wait import ( + VALID_ACTIONS_FOR_WAIT, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) + + +class ValidActionsByOperationTypeValidator: + """Validates that the given action is valid for the given operation type.""" + + _VALID_ACTIONS_BY_OPERATION_TYPE: ClassVar[ + dict[OperationType, frozenset[OperationAction]] + ] = { + OperationType.STEP: VALID_ACTIONS_FOR_STEP, + OperationType.CONTEXT: VALID_ACTIONS_FOR_CONTEXT, + OperationType.WAIT: VALID_ACTIONS_FOR_WAIT, + OperationType.CALLBACK: VALID_ACTIONS_FOR_CALLBACK, + OperationType.CHAINED_INVOKE: VALID_ACTIONS_FOR_INVOKE, + OperationType.EXECUTION: VALID_ACTIONS_FOR_EXECUTION, + } + + @staticmethod + def validate(operation_type: OperationType, action: OperationAction) -> None: + """Validate that the action is valid for the operation type.""" + valid_actions = ( + ValidActionsByOperationTypeValidator._VALID_ACTIONS_BY_OPERATION_TYPE.get( + operation_type + ) + ) + + if valid_actions is None: + msg_unknown_op: str = "Unknown operation type." + + raise InvalidParameterValueException(msg_unknown_op) + + if action not in valid_actions: + msg_invalid_action: str = "Invalid action for the given operation type." + + raise InvalidParameterValueException(msg_invalid_action) diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/cli.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/cli.py new file mode 100644 index 0000000..85dcb7d --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/cli.py @@ -0,0 +1,498 @@ +"""Command-line interface for the AWS Durable Functions Local Runner. + +This module provides the dex-local-runner CLI with commands for: +- start-server: Start the local web server +- invoke: Invoke a durable execution +- get-durable-execution: Get execution details +- get-durable-execution-history: Get execution history +""" + +from __future__ import annotations + +import argparse +import json +import logging +import os +import sys +import uuid +from dataclasses import dataclass +from typing import Any +from urllib.parse import urljoin + +import aws_durable_execution_sdk_python +import boto3 # type: ignore +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen + +from botocore.exceptions import ConnectionError # type: ignore + +from aws_durable_execution_sdk_python_testing.exceptions import ( + DurableFunctionsLocalRunnerError, + DurableFunctionsTestError, +) +from aws_durable_execution_sdk_python_testing.model import ( + StartDurableExecutionInput, +) +from aws_durable_execution_sdk_python_testing.runner import WebRunner, WebRunnerConfig +from aws_durable_execution_sdk_python_testing.stores.base import StoreType +from aws_durable_execution_sdk_python_testing.web.server import WebServiceConfig + + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class CliConfig: + """Configuration for the CLI application with environment variable support.""" + + # Server configuration + host: str = "0.0.0.0" # noqa:S104 + port: int = 5000 + log_level: int = logging.INFO + lambda_endpoint: str = "http://127.0.0.1:3001" + local_runner_endpoint: str = "http://0.0.0.0:5000" + local_runner_region: str = "us-west-2" + local_runner_mode: str = "local" + + # Store configuration + store_type: StoreType = StoreType.MEMORY + store_path: str | None = None + + @classmethod + def from_environment(cls) -> CliConfig: + """Create configuration from environment variables with defaults.""" + # Convert log level string to integer if provided + log_level_str = os.getenv("AWS_DEX_LOG_LEVEL", "INFO") + log_level = logging.getLevelNamesMapping().get(log_level_str, logging.INFO) + + return cls( + host=os.getenv("AWS_DEX_HOST", "0.0.0.0"), # noqa:S104 + port=int(os.getenv("AWS_DEX_PORT", "5000")), + log_level=log_level, + lambda_endpoint=os.getenv( + "AWS_DEX_LAMBDA_ENDPOINT", "http://127.0.0.1:3001" + ), + local_runner_endpoint=os.getenv( + "AWS_DEX_LOCAL_RUNNER_ENDPOINT", "http://0.0.0.0:5000" + ), + local_runner_region=os.getenv("AWS_DEX_LOCAL_RUNNER_REGION", "us-west-2"), + local_runner_mode=os.getenv("AWS_DEX_LOCAL_RUNNER_MODE", "local"), + store_type=StoreType(os.getenv("AWS_DEX_STORE_TYPE", "memory")), + store_path=os.getenv("AWS_DEX_STORE_PATH"), + ) + + +class CliApp: + """Main CLI application for dex-local-runner.""" + + def __init__(self) -> None: + """Initialize the CLI application.""" + self.config = CliConfig.from_environment() + + def run(self, args: list[str] | None = None) -> int: + """Run the CLI application with the given arguments. + + Args: + args: Command line arguments. If None, uses sys.argv[1:] + + Returns: + Exit code (0 for success, non-zero for error) + """ + try: + parser = self._create_parsers() + parsed_args = parser.parse_args(args) + + # Configure logging based on log level + if hasattr(parsed_args, "log_level") and isinstance( + parsed_args.log_level, str + ): + level = logging.getLevelNamesMapping().get( + parsed_args.log_level, logging.INFO + ) + else: + # config.log_level is always an integer + level = self.config.log_level + + logging.basicConfig( + level=level, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + logging.getLogger("botocore").setLevel(logging.WARNING) + + # Execute the appropriate command + return parsed_args.func(parsed_args) + + except SystemExit as e: + # argparse calls sys.exit() for help, errors, etc. + return int(e.code) if e.code is not None else 1 + except KeyboardInterrupt: + print("\nOperation cancelled by user", file=sys.stderr) # noqa: T201 + return 130 # Standard exit code for SIGINT + except DurableFunctionsTestError: + logger.exception("Error") + return 1 + except Exception: + logger.exception("Unexpected error.") + return 1 + + def _create_parsers(self) -> argparse.ArgumentParser: + """Create the argument parsers for all commands.""" + parser = argparse.ArgumentParser( + prog="dex-local-runner", + description="AWS Durable Functions Local Runner CLI", + ) + + subparsers = parser.add_subparsers( + dest="command", help="Available commands", required=True + ) + + # Create individual parsers + self._create_start_server_parser(subparsers) + self._create_invoke_parser(subparsers) + self._create_get_durable_execution_parser(subparsers) + self._create_get_durable_execution_history_parser(subparsers) + + return parser + + # region parsers + + def _create_start_server_parser(self, subparsers) -> None: + """Create the start-server command parser.""" + start_server_parser = subparsers.add_parser( + "start-server", help="Start the local Durable Functions Server" + ) + start_server_parser.add_argument( + "--host", + default=self.config.host, + help=f"Server bind address (default: {self.config.host}, env: AWS_DEX_HOST)", + ) + start_server_parser.add_argument( + "--port", + type=int, + default=self.config.port, + help=f"Server port (default: {self.config.port}, env: AWS_DEX_PORT)", + ) + start_server_parser.add_argument( + "--log-level", + type=str, + choices=list(logging.getLevelNamesMapping().keys()), + default=logging.getLevelName(self.config.log_level), + help=f"Logging level (default: {logging.getLevelName(self.config.log_level)}, env: AWS_DEX_LOG_LEVEL)", + ) + start_server_parser.add_argument( + "--lambda-endpoint", + default=self.config.lambda_endpoint, + help=f"Lambda Service endpoint (default: {self.config.lambda_endpoint}, env: AWS_DEX_LAMBDA_ENDPOINT)", + ) + start_server_parser.add_argument( + "--local-runner-endpoint", + default=self.config.local_runner_endpoint, + help=f"Local Runner endpoint (default: {self.config.local_runner_endpoint}, env: AWS_DEX_LOCAL_RUNNER_ENDPOINT)", + ) + start_server_parser.add_argument( + "--local-runner-region", + default=self.config.local_runner_region, + help=f"Local Runner region (default: {self.config.local_runner_region}, env: AWS_DEX_LOCAL_RUNNER_REGION)", + ) + start_server_parser.add_argument( + "--local-runner-mode", + default=self.config.local_runner_mode, + help=f"Local Runner mode (default: {self.config.local_runner_mode}, env: AWS_DEX_LOCAL_RUNNER_MODE)", + ) + start_server_parser.add_argument( + "--store-type", + choices=[store_type.value for store_type in StoreType], + default=self.config.store_type.value, + help=f"Store type for execution persistence (default: {self.config.store_type.value}, env: AWS_DEX_STORE_TYPE)", + ) + start_server_parser.add_argument( + "--store-path", + default=self.config.store_path, + help=f"Path for filesystem store (default: {self.config.store_path or '.durable_executions'}, env: AWS_DEX_STORE_PATH)", + ) + start_server_parser.set_defaults(func=self.start_server_command) + + def _create_invoke_parser(self, subparsers) -> None: + """Create the invoke command parser.""" + invoke_parser = subparsers.add_parser( + "invoke", help="Invoke a Durable Execution" + ) + invoke_parser.add_argument( + "--function-name", required=True, help="Function name (required)" + ) + invoke_parser.add_argument( + "--input", default="{}", help="Input data (default: {})" + ) + invoke_parser.add_argument( + "--durable-execution-name", help="Durable execution name (optional)" + ) + invoke_parser.set_defaults(func=self.invoke_command) + + def _create_get_durable_execution_parser(self, subparsers) -> None: + """Create the get-durable-execution command parser.""" + get_execution_parser = subparsers.add_parser( + "get-durable-execution", help="Get execution details" + ) + get_execution_parser.add_argument( + "--durable-execution-arn", + required=True, + help="Durable execution ARN (required)", + ) + get_execution_parser.set_defaults(func=self.get_durable_execution_command) + + def _create_get_durable_execution_history_parser(self, subparsers) -> None: + """Create the get-durable-execution-history command parser.""" + get_history_parser = subparsers.add_parser( + "get-durable-execution-history", help="Get execution history" + ) + get_history_parser.add_argument( + "--durable-execution-arn", + required=True, + help="Durable execution ARN (required)", + ) + get_history_parser.set_defaults(func=self.get_durable_execution_history_command) + + # endregion parsers + + # region commands + + def start_server_command(self, args: argparse.Namespace) -> int: + """Execute the start-server command. + + Args: + args: Parsed command line arguments + + Returns: + Exit code (0 for success, non-zero for error) + """ + try: + # Create web service configuration from CLI arguments + web_config = WebServiceConfig( + host=args.host, + port=args.port, + log_level=args.log_level, + ) + + # Create web runner configuration with composition + runner_config = WebRunnerConfig( + web_service=web_config, + lambda_endpoint=args.lambda_endpoint, + local_runner_endpoint=args.local_runner_endpoint, + local_runner_region=args.local_runner_region, + local_runner_mode=args.local_runner_mode, + store_type=StoreType(args.store_type), + store_path=args.store_path, + ) + + logger.info( + "Starting Durable Functions Local Runner on %s:%s", + args.host, + args.port, + ) + logger.info("Configuration:") + logger.info(" Host: %s", args.host) + logger.info(" Port: %s", args.port) + logger.info(" Log Level: %s", args.log_level) + logger.info(" Lambda Endpoint: %s", args.lambda_endpoint) + logger.info(" Local Runner Endpoint: %s", args.local_runner_endpoint) + logger.info(" Local Runner Region: %s", args.local_runner_region) + logger.info(" Local Runner Mode: %s", args.local_runner_mode) + logger.info(" Store Type: %s", args.store_type) + if StoreType(args.store_type) == StoreType.FILESYSTEM: + store_path = args.store_path or ".durable_executions" + logger.info(" Store Path: %s", store_path) + + # Use runner as context manager for proper lifecycle + with WebRunner(runner_config) as runner: + logger.info("Server started successfully. Press Ctrl+C to stop.") + runner.serve_forever() + + return 0 # noqa: TRY300 + + except KeyboardInterrupt: + logger.info("Received shutdown signal, stopping server...") + return 130 # Standard exit code for SIGINT + except Exception: + logger.exception("Failed to start server") + return 1 + + def invoke_command(self, args: argparse.Namespace) -> int: + """Execute the invoke command. + + Args: + args: Parsed command line arguments + + Returns: + Exit code (0 for success, non-zero for error) + """ + # Validate input JSON + try: + json.loads(args.input) # Just validate, don't store + except json.JSONDecodeError: + logger.exception("JSON decode error") + return 1 + + try: + # Create StartDurableExecutionInput + start_input = StartDurableExecutionInput( + account_id="123456789012", # Default account ID for local testing + function_name=args.function_name, + function_qualifier="$LATEST", # Default qualifier + execution_name=args.durable_execution_name + or f"{args.function_name}-execution", + execution_timeout_seconds=300, # 5 minutes default + execution_retention_period_days=7, # 1 week default + invocation_id=str(uuid.uuid4()), # Generate unique invocation ID + input=args.input, + ) + + # Make HTTP request to start-durable-execution endpoint + endpoint_url = self.config.local_runner_endpoint + url = urljoin(endpoint_url, "/start-durable-execution") + + payload = start_input.to_dict() + data = json.dumps(payload).encode("utf-8") + req = Request( + url, + data=data, + headers={"Content-Type": "application/json"}, + method="POST", + ) + + try: + with urlopen(req, timeout=10) as response: # noqa: S310 + result = json.loads(response.read().decode("utf-8")) + print(json.dumps(result, indent=2)) # noqa: T201 + return 0 + except HTTPError as e: + try: + error_data = json.loads(e.read().decode("utf-8")) + logger.exception("HTTP error response") + print( # noqa: T201 + f"Error: {error_data.get('ErrorMessage', 'Unknown error')}", + file=sys.stderr, + ) + except json.JSONDecodeError: + logger.exception("Non-JSON error response") + return 1 + + except URLError: + logger.exception( + "Error: Could not connect to the local runner server. Is it running?" + ) + return 1 + except Exception: + logger.exception("Unexpected error in invoke command") + return 1 + + def get_durable_execution_command(self, args: argparse.Namespace) -> int: + """Execute the get-durable-execution command. + + Args: + args: Parsed command line arguments + + Returns: + Exit code (0 for success, non-zero for error) + """ + try: + # Set up boto3 client with local endpoint + client = self._create_boto3_client() + + # Call get_durable_execution + response = client.get_durable_execution( + DurableExecutionArn=args.durable_execution_arn + ) + + # Print formatted response + print(json.dumps(response, indent=2, default=str)) # noqa: T201 + return 0 # noqa: TRY300 + + except client.exceptions.InvalidParameterValueException as e: + print(f"Error: Invalid parameter - {e}", file=sys.stderr) # noqa: T201 + return 1 + except client.exceptions.ResourceNotFoundException as e: + print(f"Error: Execution not found - {e}", file=sys.stderr) # noqa: T201 + return 1 + except client.exceptions.TooManyRequestsException as e: + print(f"Error: Too many requests - {e}", file=sys.stderr) # noqa: T201 + return 1 + except client.exceptions.ServiceException as e: + print(f"Error: Service error - {e}", file=sys.stderr) # noqa: T201 + return 1 + except ConnectionError: + logger.exception( + "Error: Could not connect to the local runner server. Is it running?" + ) + return 1 + except Exception: + logger.exception("Unexpected error in get-durable-execution command") + return 1 + + def get_durable_execution_history_command(self, args: argparse.Namespace) -> int: + """Execute the get-durable-execution-history command. + + TODO: implement - this is incomplete + + Args: + args: Parsed command line arguments + + Returns: + Exit code (0 for success, non-zero for error) + """ + try: + # Set up boto3 client with local endpoint + client = self._create_boto3_client() + + # Call get_durable_execution_history + response = client.get_durable_execution_history( + DurableExecutionArn=args.durable_execution_arn + ) + + print(json.dumps(response, indent=2, default=str)) # noqa: T201 + return 0 # noqa: TRY300 + + except Exception: + logger.exception("General error") + return 1 + + # endregion commands + + def _create_boto3_client( + self, endpoint_url: str | None = None, region_name: str | None = None + ) -> Any: + """Create boto3 client for Lambda service. + + Args: + endpoint_url: Optional endpoint URL override + region_name: Optional region name override + + Returns: + Configured boto3 client for local runner + + Raises: + Exception: If client creation fails + """ + try: + # Use provided values or fall back to config + final_endpoint = endpoint_url or self.config.local_runner_endpoint + final_region = region_name or self.config.local_runner_region + + # Create client with local endpoint - no AWS access keys required + return boto3.client( + "lambda", + endpoint_url=final_endpoint, + region_name=final_region, + ) + except Exception as e: + msg = f"Failed to create boto3 client: {e}" + raise DurableFunctionsLocalRunnerError(msg) from e + + +def main() -> int: + """Main entry point for the dex-local-runner CLI.""" + app = CliApp() + return app.run() + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/client.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/client.py new file mode 100644 index 0000000..a68f0cc --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/client.py @@ -0,0 +1,50 @@ +"""An in-memory service client, that can replace the boto lambda service client.""" + +import datetime + +from aws_durable_execution_sdk_python.lambda_service import ( + CheckpointOutput, + DurableServiceClient, + OperationUpdate, + StateOutput, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.processor import ( + CheckpointProcessor, +) + + +class InMemoryServiceClient(DurableServiceClient): + """An in-memory service client, that can replace the boto lambda service client.""" + + def __init__(self, checkpoint_processor: CheckpointProcessor): + self._checkpoint_processor: CheckpointProcessor = checkpoint_processor + + def checkpoint( + self, + durable_execution_arn: str, # noqa: ARG002 + checkpoint_token: str, + updates: list[OperationUpdate], + client_token: str | None, + ) -> CheckpointOutput: + # durable_execution_arn is not used in in-memory testing + return self._checkpoint_processor.process_checkpoint( + checkpoint_token, updates, client_token + ) + + def get_execution_state( + self, + durable_execution_arn: str, # noqa: ARG002 + checkpoint_token: str, + next_marker: str, + max_items: int = 1000, + ) -> StateOutput: + # durable_execution_arn is not used in in-memory testing + return self._checkpoint_processor.get_execution_state( + checkpoint_token, next_marker, max_items + ) + + def stop(self, execution_arn: str, payload: bytes | None) -> datetime.datetime: # noqa: ARG002 + # TODO: implement + # Return current time for in-memory testing + return datetime.datetime.now(tz=datetime.UTC) diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/exceptions.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/exceptions.py new file mode 100644 index 0000000..8d51e2f --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/exceptions.py @@ -0,0 +1,288 @@ +"""Exceptions for the Durable Executions Testing Library. + +This module provides AWS-compliant exceptions that serialize to the exact JSON format +expected by boto3 and AWS services. All exceptions follow Smithy model definitions +for field names and structure. + +## AWS-Compliant Error Format + +All AWS API exceptions inherit from `AwsApiException` and implement the `to_dict()` method +to serialize to AWS-compliant JSON format. The format varies by exception type based on +their Smithy model definitions: + +### Standard Format (most exceptions): +```json +{ + "Type": "ExceptionName", + "message": "Error message" // or "Message" depending on Smithy definition +} +``` + +### Special Cases: +- `ExecutionAlreadyStartedException`: No "Type" field, includes "DurableExecutionArn" +```json +{ + "message": "Error message", + "DurableExecutionArn": "arn:aws:states:..." +} +``` + +## Field Name Conventions + +Field names follow the exact Smithy model definitions: +- `InvalidParameterValueException`: uses lowercase "message" +- `CallbackTimeoutException`: uses lowercase "message" +- `ResourceNotFoundException`: uses capital "Message" +- `ServiceException`: uses capital "Message" +- `ExecutionAlreadyStartedException`: uses lowercase "message" + "DurableExecutionArn" + +## HTTP Status Codes + +Each exception maps to appropriate HTTP status codes: +- 400: InvalidParameterValueException (Bad Request) +- 404: ResourceNotFoundException (Not Found) +- 408: CallbackTimeoutException (Request Timeout) +- 409: ExecutionAlreadyStartedException (Conflict) +- 500: ServiceException (Internal Server Error) + +## Usage Examples + +```python +# Create and serialize an exception +exception = InvalidParameterValueException("Invalid parameter value") +json_dict = exception.to_dict() +# Result: {"Type": "InvalidParameterValueException", "message": "Invalid parameter value"} + +# HTTP response creation +from aws_durable_execution_sdk_python_testing.web.models import HTTPResponse + +response = HTTPResponse.create_error_from_exception(exception) +# Creates HTTP 400 response with AWS-compliant JSON body +``` + +## Boto3 Compatibility + +All exceptions are designed to be compatible with boto3's error handling: +- JSON structure matches boto3 expectations +- Field names match Smithy model definitions +- Type field values match exception class names +- Can be deserialized by boto3's error factory + +Avoid any non-stdlib references in this module, it is at the bottom of the dependency chain. +""" + +from __future__ import annotations + +from typing import Any + + +# region Local Runner +class DurableFunctionsLocalRunnerError(Exception): + """Base class for Durable Executions exceptions""" + + +class UnknownRouteError(DurableFunctionsLocalRunnerError): + """No route matches the requested path pattern.""" + + def __init__(self, method: str, path: str) -> None: + """Initialize UnknownRouteError with method and path. + + Args: + method: HTTP method (GET, POST, etc.) + path: Request path that couldn't be matched + """ + self.method = method + self.path = path + message = f"Unknown path pattern: {method} {path}" + super().__init__(message) + + +# endregion Local Runner + + +class SerializationError(DurableFunctionsLocalRunnerError): + """Exception for serialization errors.""" + + +# region Testing +class DurableFunctionsTestError(Exception): + """Base class for testing errors.""" + + +# endregion Testing + + +# region AWS API Exceptions +class AwsApiException(DurableFunctionsLocalRunnerError): # noqa: N818 + """Base class for AWS API-style exceptions that can be serialized to AWS format.""" + + http_status_code: int = 500 # Default to server error + + def to_dict(self) -> dict[str, Any]: + """Serialize to AWS-compliant JSON structure.""" + raise NotImplementedError + + +# Smithy-Mapped Exceptions (defined in Smithy models) +class InvalidParameterValueException(AwsApiException): + """Exception for invalid parameter values.""" + + http_status_code = 400 + + def __init__(self, message: str) -> None: + """Initialize with message field (lowercase per Smithy definition).""" + self.message = message + super().__init__(message) + + def to_dict(self) -> dict[str, Any]: + """Serialize to AWS-compliant JSON structure.""" + return {"Type": "InvalidParameterValueException", "message": self.message} + + +class ResourceNotFoundException(AwsApiException): + """Exception for resource not found errors.""" + + http_status_code = 404 + + def __init__( + self, + Message: str, # noqa: N803 + ) -> None: # Capital M per Smithy definition + """Initialize with Message field (capital M per Smithy definition).""" + self.Message = Message + super().__init__(Message) + + def to_dict(self) -> dict[str, Any]: + """Serialize to AWS-compliant JSON structure.""" + return {"Type": "ResourceNotFoundException", "Message": self.Message} + + +class ServiceException(AwsApiException): + """Exception for general service errors.""" + + http_status_code = 500 + + def __init__( + self, + Message: str, # noqa: N803 + ) -> None: # Capital M per Smithy definition + """Initialize with Message field (capital M per Smithy definition).""" + self.Message = Message + super().__init__(Message) + + def to_dict(self) -> dict[str, Any]: + """Serialize to AWS-compliant JSON structure.""" + return {"Type": "ServiceException", "Message": self.Message} + + +class ExecutionAlreadyStartedException(AwsApiException): + """Exception for execution already started errors.""" + + http_status_code = 409 + + def __init__(self, message: str, DurableExecutionArn: str) -> None: # noqa: N803 + """Initialize with message and DurableExecutionArn fields.""" + self.message = message + self.DurableExecutionArn = DurableExecutionArn + super().__init__(message) + + def to_dict(self) -> dict[str, Any]: + """Serialize to AWS-compliant JSON structure (no Type field per Smithy definition).""" + return { + "message": self.message, + "DurableExecutionArn": self.DurableExecutionArn, + } + + +class ExecutionConflictException(AwsApiException): + """Exception for execution conflict errors.""" + + http_status_code = 409 + + def __init__(self, message: str) -> None: + """Initialize with message field.""" + self.message = message + super().__init__(message) + + def to_dict(self) -> dict[str, Any]: + """Serialize to AWS-compliant JSON structure.""" + return {"Type": "ExecutionConflictException", "message": self.message} + + +class CallbackTimeoutException(AwsApiException): + """Exception for callback timeout errors.""" + + http_status_code = 408 + + def __init__(self, message: str) -> None: + """Initialize with message field (lowercase per Smithy definition).""" + self.message = message + super().__init__(message) + + def to_dict(self) -> dict[str, Any]: + """Serialize to AWS-compliant JSON structure.""" + return {"Type": "CallbackTimeoutException", "message": self.message} + + +class TooManyRequestsException(AwsApiException): + """Exception for too many requests errors.""" + + http_status_code = 429 + + def __init__(self, message: str) -> None: + """Initialize with message field (lowercase per Smithy definition).""" + self.message = message + super().__init__(message) + + def to_dict(self) -> dict[str, Any]: + """Serialize to AWS-compliant JSON structure.""" + return {"Type": "TooManyRequestsException", "message": self.message} + + +# Unmapped Exceptions (thrown by services but not in Smithy) +class IllegalStateException(AwsApiException): + """IllegalStateException.""" + + http_status_code = 500 + + def __init__(self, message: str) -> None: + """Initialize with message field.""" + self.message = message + super().__init__(message) + + def to_dict(self) -> dict[str, Any]: + """Serialize to AWS-compliant JSON structure (maps to ServiceException).""" + return {"Type": "ServiceException", "Message": self.message} + + +class RuntimeException(AwsApiException): + """RuntimeException.""" + + http_status_code = 500 + + def __init__(self, message: str) -> None: + """Initialize with message field.""" + self.message = message + super().__init__(message) + + def to_dict(self) -> dict[str, Any]: + """Serialize to AWS-compliant JSON structure (maps to ServiceException).""" + return {"Type": "ServiceException", "Message": self.message} + + +class IllegalArgumentException(AwsApiException): + """IllegalArgumentException.""" + + http_status_code = 400 + + def __init__(self, message: str) -> None: + """Initialize with message field.""" + self.message = message + super().__init__(message) + + def to_dict(self) -> dict[str, Any]: + """Serialize to AWS-compliant JSON structure (maps to InvalidParameterValueException).""" + return {"Type": "InvalidParameterValueException", "message": self.message} + + +# endregion AWS API Exceptions diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/execution.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/execution.py new file mode 100644 index 0000000..a1096f1 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/execution.py @@ -0,0 +1,444 @@ +from __future__ import annotations + +from dataclasses import replace +from datetime import UTC, datetime +from enum import Enum +from threading import Lock +from typing import Any +from uuid import uuid4 + +from aws_durable_execution_sdk_python.execution import ( + DurableExecutionInvocationOutput, + InvocationStatus, +) +from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + ExecutionDetails, + Operation, + OperationStatus, + OperationType, + OperationUpdate, +) + +from aws_durable_execution_sdk_python_testing.exceptions import ( + IllegalStateException, + InvalidParameterValueException, +) + +# Import AWS exceptions +from aws_durable_execution_sdk_python_testing.model import ( + InvocationCompletedDetails, + StartDurableExecutionInput, +) +from aws_durable_execution_sdk_python_testing.token import ( + CheckpointToken, +) + + +class ExecutionStatus(Enum): + """Execution status for API responses.""" + + RUNNING = "RUNNING" + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + STOPPED = "STOPPED" + TIMED_OUT = "TIMED_OUT" + + +class Execution: + """Execution state.""" + + def __init__( + self, + durable_execution_arn: str, + start_input: StartDurableExecutionInput, + operations: list[Operation], + ): + self.durable_execution_arn: str = durable_execution_arn + # operation is frozen, it won't mutate - no need to clone/deep-copy + self.start_input: StartDurableExecutionInput = start_input + self.operations: list[Operation] = operations + self.updates: list[OperationUpdate] = [] + self.invocation_completions: list[InvocationCompletedDetails] = [] + self.used_tokens: set[str] = set() + # TODO: this will need to persist/rehydrate depending on inmemory vs sqllite store + self._token_sequence: int = 0 + self._state_lock: Lock = Lock() + self.is_complete: bool = False + self.result: DurableExecutionInvocationOutput | None = None + self.consecutive_failed_invocation_attempts: int = 0 + self.close_status: ExecutionStatus | None = None + + @property + def token_sequence(self) -> int: + """Get current token sequence value.""" + return self._token_sequence + + def current_status(self) -> ExecutionStatus: + """Get execution status.""" + if not self.is_complete: + return ExecutionStatus.RUNNING + + if not self.close_status: + msg: str = "close_status must be set when execution is complete" + raise IllegalStateException(msg) + + return self.close_status + + @staticmethod + def new(input: StartDurableExecutionInput) -> Execution: # noqa: A002 + # make a nicer arn + # Pattern: arn:(aws[a-zA-Z-]*)?:lambda:[a-z]{2}(-gov)?-[a-z]+-\d{1}:\d{12}:durable-execution:[a-zA-Z0-9-_\.]+:[a-zA-Z0-9-_\.]+:[a-zA-Z0-9-_\.]+ + # Example: arn:aws:lambda:us-east-1:123456789012:durable-execution:myDurableFunction:myDurableExecutionName:ce67da72-3701-4f83-9174-f4189d27b0a5 + return Execution( + durable_execution_arn=str(uuid4()) + + "/" + + (input.invocation_id or str(uuid4())), + start_input=input, + operations=[], + ) + + def to_json_dict(self) -> dict[str, Any]: + """Serialize execution to JSON-serializable dictionary""" + return { + "DurableExecutionArn": self.durable_execution_arn, + "StartInput": self.start_input.to_dict(), + "Operations": [op.to_json_dict() for op in self.operations], + "Updates": [update.to_dict() for update in self.updates], + "InvocationCompletions": [ + completion.to_json_dict() for completion in self.invocation_completions + ], + "UsedTokens": list(self.used_tokens), + "TokenSequence": self._token_sequence, + "IsComplete": self.is_complete, + "Result": self.result.to_dict() if self.result else None, + "ConsecutiveFailedInvocationAttempts": self.consecutive_failed_invocation_attempts, + "CloseStatus": self.close_status.value if self.close_status else None, + } + + @classmethod + def from_json_dict(cls, data: dict[str, Any]) -> Execution: + """Deserialize execution from dictionary.""" + # Reconstruct start_input + start_input = StartDurableExecutionInput.from_dict(data["StartInput"]) + + # Reconstruct operations + operations = [ + Operation.from_json_dict(op_data) for op_data in data["Operations"] + ] + + # Create execution + execution = cls( + durable_execution_arn=data["DurableExecutionArn"], + start_input=start_input, + operations=operations, + ) + + # Set additional fields + execution.updates = [ + OperationUpdate.from_dict(update_data) for update_data in data["Updates"] + ] + execution.invocation_completions = [ + InvocationCompletedDetails.from_json_dict(item) + for item in data.get("InvocationCompletions", []) + ] + execution.used_tokens = set(data["UsedTokens"]) + execution._token_sequence = data["TokenSequence"] # noqa: SLF001 + execution.is_complete = data["IsComplete"] + execution.result = ( + DurableExecutionInvocationOutput.from_dict(data["Result"]) + if data["Result"] + else None + ) + execution.consecutive_failed_invocation_attempts = data[ + "ConsecutiveFailedInvocationAttempts" + ] + close_status_str = data.get("CloseStatus") + execution.close_status = ( + ExecutionStatus(close_status_str) if close_status_str else None + ) + + return execution + + def start(self) -> None: + if self.start_input.invocation_id is None: + msg: str = "invocation_id is required" + raise InvalidParameterValueException(msg) + with self._state_lock: + self.operations.append( + Operation( + operation_id=self.start_input.invocation_id, + parent_id=None, + name=self.start_input.execution_name, + start_timestamp=datetime.now(UTC), + operation_type=OperationType.EXECUTION, + status=OperationStatus.STARTED, + execution_details=ExecutionDetails( + input_payload=self.start_input.get_normalized_input() + ), + ) + ) + + def get_operation_execution_started(self) -> Operation: + if not self.operations: + msg: str = "execution not started." + + raise IllegalStateException(msg) + + return self.operations[0] + + def get_new_checkpoint_token(self) -> str: + """Generate a new checkpoint token with incremented sequence""" + with self._state_lock: + self._token_sequence += 1 + new_token_sequence = self._token_sequence + token = CheckpointToken( + execution_arn=self.durable_execution_arn, + token_sequence=new_token_sequence, + ) + token_str = token.to_str() + self.used_tokens.add(token_str) + return token_str + + def get_navigable_operations(self) -> list[Operation]: + """Get list of operations, but exclude child operations where the parent has already completed.""" + return self.operations + + def get_assertable_operations(self) -> list[Operation]: + """Get list of operations, but exclude the EXECUTION operations""" + # TODO: this excludes EXECUTION at start, but can there be an EXECUTION at the end if there was a checkpoint with large payload? + return self.operations[1:] + + def has_pending_operations(self, execution: Execution) -> bool: + """True if execution has pending operations.""" + + for operation in execution.operations: + if ( + operation.operation_type == OperationType.STEP + and operation.status == OperationStatus.PENDING + ) or ( + operation.operation_type + in [ + OperationType.WAIT, + OperationType.CALLBACK, + OperationType.CHAINED_INVOKE, + ] + and operation.status == OperationStatus.STARTED + ): + return True + return False + + def record_invocation_completion( + self, start_timestamp: datetime, end_timestamp: datetime, request_id: str + ) -> None: + """Record an invocation completion event.""" + self.invocation_completions.append( + InvocationCompletedDetails( + start_timestamp=start_timestamp, + end_timestamp=end_timestamp, + request_id=request_id, + ) + ) + + def complete_success(self, result: str | None) -> None: + """Complete execution successfully (DecisionType.COMPLETE_WORKFLOW_EXECUTION).""" + self.result = DurableExecutionInvocationOutput( + status=InvocationStatus.SUCCEEDED, result=result + ) + self.is_complete = True + self.close_status = ExecutionStatus.SUCCEEDED + self._end_execution(OperationStatus.SUCCEEDED) + + def complete_fail(self, error: ErrorObject) -> None: + """Complete execution with failure (DecisionType.FAIL_WORKFLOW_EXECUTION).""" + self.result = DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, error=error + ) + self.is_complete = True + self.close_status = ExecutionStatus.FAILED + self._end_execution(OperationStatus.FAILED) + + def complete_timeout(self, error: ErrorObject) -> None: + """Complete execution with timeout.""" + self.result = DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, error=error + ) + self.is_complete = True + self.close_status = ExecutionStatus.TIMED_OUT + self._end_execution(OperationStatus.TIMED_OUT) + + def complete_stopped(self, error: ErrorObject) -> None: + """Complete execution as terminated (TerminateWorkflowExecutionV2Request).""" + self.result = DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, error=error + ) + self.is_complete = True + self.close_status = ExecutionStatus.STOPPED + self._end_execution(OperationStatus.STOPPED) + + def find_operation(self, operation_id: str) -> tuple[int, Operation]: + """Find operation by ID, return index and operation.""" + for i, operation in enumerate(self.operations): + if operation.operation_id == operation_id: + return i, operation + msg: str = f"Attempting to update state of an Operation [{operation_id}] that doesn't exist" + raise IllegalStateException(msg) + + def find_callback_operation(self, callback_id: str) -> tuple[int, Operation]: + """Find callback operation by callback_id, return index and operation.""" + for i, operation in enumerate(self.operations): + if ( + operation.operation_type == OperationType.CALLBACK + and operation.callback_details + and operation.callback_details.callback_id == callback_id + ): + return i, operation + msg: str = f"Callback operation with callback_id [{callback_id}] not found" + raise IllegalStateException(msg) + + def complete_wait(self, operation_id: str) -> Operation: + """Complete WAIT operation when timer fires.""" + index, operation = self.find_operation(operation_id) + + # Validate + if operation.status != OperationStatus.STARTED: + msg_wait_not_started: str = f"Attempting to transition a Wait Operation[{operation_id}] to SUCCEEDED when it's not STARTED" + raise IllegalStateException(msg_wait_not_started) + if operation.operation_type != OperationType.WAIT: + msg_not_wait: str = ( + f"Expected WAIT operation, got {operation.operation_type}" + ) + raise IllegalStateException(msg_not_wait) + + # Thread-safe increment sequence and operation update + with self._state_lock: + self._token_sequence += 1 + # Build and assign updated operation + self.operations[index] = replace( + operation, + status=OperationStatus.SUCCEEDED, + end_timestamp=datetime.now(UTC), + ) + return self.operations[index] + + def complete_retry(self, operation_id: str) -> Operation: + """Complete STEP retry when timer fires.""" + index, operation = self.find_operation(operation_id) + + # Validate + if operation.status != OperationStatus.PENDING: + msg_step_not_pending: str = f"Attempting to transition a Step Operation[{operation_id}] to READY when it's not PENDING" + raise IllegalStateException(msg_step_not_pending) + if operation.operation_type != OperationType.STEP: + msg_not_step: str = ( + f"Expected STEP operation, got {operation.operation_type}" + ) + raise IllegalStateException(msg_not_step) + + # Thread-safe increment sequence and operation update + with self._state_lock: + self._token_sequence += 1 + # Build updated step_details with cleared next_attempt_timestamp + new_step_details = None + if operation.step_details: + new_step_details = replace( + operation.step_details, next_attempt_timestamp=None + ) + + # Build updated operation + updated_operation = replace( + operation, status=OperationStatus.READY, step_details=new_step_details + ) + + # Assign + self.operations[index] = updated_operation + return updated_operation + + def complete_callback_success( + self, callback_id: str, result: bytes | None = None + ) -> Operation: + """Complete CALLBACK operation with success.""" + index, operation = self.find_callback_operation(callback_id) + if operation.status != OperationStatus.STARTED: + msg: str = f"Callback operation [{callback_id}] is not in STARTED state" + raise IllegalStateException(msg) + + with self._state_lock: + self._token_sequence += 1 + updated_callback_details = None + if operation.callback_details: + updated_callback_details = replace( + operation.callback_details, + result=result.decode() if result else None, + ) + + self.operations[index] = replace( + operation, + status=OperationStatus.SUCCEEDED, + end_timestamp=datetime.now(UTC), + callback_details=updated_callback_details, + ) + return self.operations[index] + + def complete_callback_failure( + self, callback_id: str, error: ErrorObject + ) -> Operation: + """Complete CALLBACK operation with failure.""" + index, operation = self.find_callback_operation(callback_id) + + if operation.status != OperationStatus.STARTED: + msg: str = f"Callback operation [{callback_id}] is not in STARTED state" + raise IllegalStateException(msg) + + with self._state_lock: + self._token_sequence += 1 + updated_callback_details = None + if operation.callback_details: + updated_callback_details = replace( + operation.callback_details, error=error + ) + + self.operations[index] = replace( + operation, + status=OperationStatus.FAILED, + end_timestamp=datetime.now(UTC), + callback_details=updated_callback_details, + ) + return self.operations[index] + + def complete_callback_timeout( + self, callback_id: str, error: ErrorObject + ) -> Operation: + """Complete CALLBACK operation with timeout.""" + index, operation = self.find_callback_operation(callback_id) + + if operation.status != OperationStatus.STARTED: + msg: str = f"Callback operation [{callback_id}] is not in STARTED state" + raise IllegalStateException(msg) + + with self._state_lock: + self._token_sequence += 1 + updated_callback_details = None + if operation.callback_details: + updated_callback_details = replace( + operation.callback_details, error=error + ) + + self.operations[index] = replace( + operation, + status=OperationStatus.TIMED_OUT, + end_timestamp=datetime.now(UTC), + callback_details=updated_callback_details, + ) + return self.operations[index] + + def _end_execution(self, status: OperationStatus) -> None: + """Set the end_timestamp on the main EXECUTION operation when execution completes.""" + execution_op: Operation = self.get_operation_execution_started() + if execution_op.operation_type == OperationType.EXECUTION: + with self._state_lock: + self.operations[0] = replace( + execution_op, + status=status, + end_timestamp=datetime.now(UTC), + ) diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/executor.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/executor.py new file mode 100644 index 0000000..02a1504 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/executor.py @@ -0,0 +1,1234 @@ +"""Execution life-cycle logic.""" + +from __future__ import annotations + +import logging +import time +import uuid +from datetime import UTC, datetime +from typing import TYPE_CHECKING + +from aws_durable_execution_sdk_python.execution import ( + DurableExecutionInvocationInput, + DurableExecutionInvocationOutput, + InvocationStatus, +) +from aws_durable_execution_sdk_python.lambda_service import ( + CallbackTimeoutType, + ErrorObject, + Operation, + OperationUpdate, + OperationStatus, + OperationType, + CallbackOptions, +) + +from aws_durable_execution_sdk_python_testing.exceptions import ( + ExecutionAlreadyStartedException, + IllegalStateException, + InvalidParameterValueException, + ResourceNotFoundException, +) +from aws_durable_execution_sdk_python_testing.execution import Execution +from aws_durable_execution_sdk_python_testing.model import ( + CheckpointDurableExecutionResponse, + CheckpointUpdatedExecutionState, + EventCreationContext, + EventType, + GetDurableExecutionHistoryResponse, + GetDurableExecutionResponse, + GetDurableExecutionStateResponse, + ListDurableExecutionsByFunctionResponse, + ListDurableExecutionsResponse, + SendDurableExecutionCallbackFailureResponse, + SendDurableExecutionCallbackHeartbeatResponse, + SendDurableExecutionCallbackSuccessResponse, + StartDurableExecutionInput, + StartDurableExecutionOutput, + StopDurableExecutionResponse, + TERMINAL_STATUSES, +) +from aws_durable_execution_sdk_python_testing.model import ( + Event as HistoryEvent, +) +from aws_durable_execution_sdk_python_testing.model import ( + Execution as ExecutionSummary, +) +from aws_durable_execution_sdk_python_testing.observer import ExecutionObserver +from aws_durable_execution_sdk_python_testing.token import CallbackToken + + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + from concurrent.futures import Future + + from aws_durable_execution_sdk_python_testing.checkpoint.processor import ( + CheckpointProcessor, + ) + from aws_durable_execution_sdk_python_testing.invoker import Invoker + from aws_durable_execution_sdk_python_testing.scheduler import Event, Scheduler + from aws_durable_execution_sdk_python_testing.stores.base import ExecutionStore + +logger = logging.getLogger(__name__) + + +class Executor(ExecutionObserver): + MAX_CONSECUTIVE_FAILED_ATTEMPTS: int = 5 + RETRY_BACKOFF_SECONDS: int = 5 + + def __init__( + self, + store: ExecutionStore, + scheduler: Scheduler, + invoker: Invoker, + checkpoint_processor: CheckpointProcessor, + ): + self._store = store + self._scheduler = scheduler + self._invoker = invoker + self._checkpoint_processor = checkpoint_processor + self._completion_events: dict[str, Event] = {} + self._callback_timeouts: dict[str, Future] = {} + self._callback_heartbeats: dict[str, Future] = {} + self._execution_timeout: Future | None = None + + def start_execution( + self, + input: StartDurableExecutionInput, # noqa: A002 + ) -> StartDurableExecutionOutput: + # Generate invocation_id if not provided + if input.invocation_id is None: + input = StartDurableExecutionInput( + account_id=input.account_id, + function_name=input.function_name, + function_qualifier=input.function_qualifier, + execution_name=input.execution_name, + execution_timeout_seconds=input.execution_timeout_seconds, + execution_retention_period_days=input.execution_retention_period_days, + invocation_id=str(uuid.uuid4()), + trace_fields=input.trace_fields, + tenant_id=input.tenant_id, + input=input.input, + lambda_endpoint=input.lambda_endpoint, + ) + + execution = Execution.new(input=input) + execution.start() + self._store.save(execution) + logger.debug("Created execution with ARN: %s", execution.durable_execution_arn) + + completion_event = self._scheduler.create_event() + self._completion_events[execution.durable_execution_arn] = completion_event + + # Schedule execution timeout + if input.execution_timeout_seconds > 0: + + def timeout_handler(): + error = ErrorObject.from_message( + f"Execution timed out after {input.execution_timeout_seconds} seconds." + ) + self.on_timed_out(execution.durable_execution_arn, error) + + self._execution_timeout = self._scheduler.call_later( + timeout_handler, + delay=input.execution_timeout_seconds, + completion_event=completion_event, + ) + + # Schedule initial invocation to run immediately + self._invoke_execution(execution.durable_execution_arn) + + return StartDurableExecutionOutput( + execution_arn=execution.durable_execution_arn + ) + + def get_execution(self, execution_arn: str) -> Execution: + """Get execution by ARN. + + Args: + execution_arn: The execution ARN to retrieve + + Returns: + Execution: The execution object + + Raises: + ResourceNotFoundException: If execution does not exist + """ + try: + return self._store.load(execution_arn) + except KeyError as e: + msg: str = f"Execution {execution_arn} not found" + raise ResourceNotFoundException(msg) from e + + def get_execution_details(self, execution_arn: str) -> GetDurableExecutionResponse: + """Get detailed execution information for web API response. + + Args: + execution_arn: The execution ARN to retrieve + + Returns: + GetDurableExecutionResponse: Detailed execution information + + Raises: + ResourceNotFoundException: If execution does not exist + """ + execution = self.get_execution(execution_arn) + + # Extract execution details from the first operation (EXECUTION type) + execution_op = execution.get_operation_execution_started() + status = execution.current_status().value + + # Extract result and error from execution result + result = None + error = None + if execution.result: + if execution.result.status == InvocationStatus.SUCCEEDED: + result = execution.result.result + elif execution.result.status == InvocationStatus.FAILED: + error = execution.result.error + + return GetDurableExecutionResponse( + durable_execution_arn=execution.durable_execution_arn, + durable_execution_name=execution.start_input.execution_name, + function_arn=f"arn:aws:lambda:us-east-1:123456789012:function:{execution.start_input.function_name}", + status=status, + start_timestamp=execution_op.start_timestamp + if execution_op.start_timestamp + else datetime.now(UTC), + input_payload=execution_op.execution_details.input_payload + if execution_op.execution_details + else None, + result=result, + error=error, + end_timestamp=execution_op.end_timestamp + if execution_op.end_timestamp + else None, + version="1.0", + ) + + def list_executions( + self, + function_name: str | None = None, + function_version: str | None = None, # noqa: ARG002 + execution_name: str | None = None, + status_filter: str | None = None, + started_after: str | None = None, + started_before: str | None = None, + marker: str | None = None, + max_items: int | None = None, + reverse_order: bool = False, # noqa: FBT001, FBT002 + ) -> ListDurableExecutionsResponse: + """List executions with filtering and pagination. + + Args: + function_name: Filter by function name + function_version: Filter by function version + execution_name: Filter by execution name + status_filter: Filter by status (RUNNING, SUCCEEDED, FAILED) + started_after: Filter executions started after this time + started_before: Filter executions started before this time + marker: Pagination marker + max_items: Maximum items to return (default 50) + reverse_order: Return results in reverse chronological order + + Returns: + ListDurableExecutionsResponse: List of executions with pagination + """ + # Convert marker to offset + offset: int = 0 + if marker: + try: + offset = int(marker) + except ValueError: + offset = 0 + + # Query store directly with parameters + executions, next_marker = self._store.query( + function_name=function_name, + execution_name=execution_name, + status_filter=status_filter, + started_after=started_after, + started_before=started_before, + limit=max_items or 50, + offset=offset, + reverse_order=reverse_order, + ) + + # Convert to ExecutionSummary objects + execution_summaries: list[ExecutionSummary] = [ + ExecutionSummary.from_execution(execution, execution.current_status().value) + for execution in executions + ] + + return ListDurableExecutionsResponse( + durable_executions=execution_summaries, next_marker=next_marker + ) + + def list_executions_by_function( + self, + function_name: str, + qualifier: str | None = None, # noqa: ARG002 + execution_name: str | None = None, + status_filter: str | None = None, + started_after: str | None = None, + started_before: str | None = None, + marker: str | None = None, + max_items: int | None = None, + reverse_order: bool = False, # noqa: FBT001, FBT002 + ) -> ListDurableExecutionsByFunctionResponse: + """List executions for a specific function. + + Args: + function_name: The function name to filter by + qualifier: Function qualifier/version + execution_name: Filter by execution name + status_filter: Filter by status (RUNNING, SUCCEEDED, FAILED) + started_after: Filter executions started after this time + started_before: Filter executions started before this time + marker: Pagination marker + max_items: Maximum items to return (default 50) + reverse_order: Return results in reverse chronological order + + Returns: + ListDurableExecutionsByFunctionResponse: List of executions for the function + """ + # Use the general list_executions method with function_name filter + list_response = self.list_executions( + function_name=function_name, + execution_name=execution_name, + status_filter=status_filter, + started_after=started_after, + started_before=started_before, + marker=marker, + max_items=max_items, + reverse_order=reverse_order, + ) + + return ListDurableExecutionsByFunctionResponse( + durable_executions=list_response.durable_executions, + next_marker=list_response.next_marker, + ) + + def stop_execution( + self, execution_arn: str, error: ErrorObject | None = None + ) -> StopDurableExecutionResponse: + """Stop a running execution. + + Args: + execution_arn: The execution ARN to stop + error: Optional error to use when stopping the execution + + Returns: + StopDurableExecutionResponse: Response containing end timestamp + + Raises: + ResourceNotFoundException: If execution does not exist + """ + execution = self.get_execution(execution_arn) + + if execution.is_complete: + # Idempotent: return the existing stop timestamp + execution_op = execution.get_operation_execution_started() + stop_timestamp = execution_op.end_timestamp or datetime.now(UTC) + return StopDurableExecutionResponse(stop_timestamp=stop_timestamp) + + # Use provided error or create a default one + stop_error = error or ErrorObject.from_message( + "Execution stopped by user request" + ) + + # Stop sets TERMINATED close status (different from fail) + logger.exception("[%s] Stopping execution.", execution_arn) + execution.complete_stopped(error=stop_error) # Sets CloseStatus.TERMINATED + self._store.update(execution) + self._complete_events(execution_arn=execution_arn) + + return StopDurableExecutionResponse(stop_timestamp=datetime.now(UTC)) + + def get_execution_state( + self, + execution_arn: str, + checkpoint_token: str | None = None, + marker: str | None = None, + max_items: int | None = None, + ) -> GetDurableExecutionStateResponse: + """Get execution state with operations. + + Args: + execution_arn: The execution ARN + checkpoint_token: Checkpoint token for state consistency + marker: Pagination marker + max_items: Maximum items to return + + Returns: + GetDurableExecutionStateResponse: Execution state with operations + + Raises: + ResourceNotFoundException: If execution does not exist + InvalidParameterValueException: If checkpoint token is invalid + """ + execution = self.get_execution(execution_arn) + + # TODO: Validate checkpoint token if provided + if checkpoint_token and checkpoint_token not in execution.used_tokens: + msg: str = f"Invalid checkpoint token: {checkpoint_token}" + raise InvalidParameterValueException(msg) + + # Get operations (excluding the initial EXECUTION operation for state) + operations = execution.get_assertable_operations() + + # Apply pagination + if max_items is None: + max_items = 100 + + # Simple pagination - in real implementation would need proper marker handling + start_index = 0 + if marker: + try: + start_index = int(marker) + except ValueError: + start_index = 0 + + end_index = start_index + max_items + paginated_operations = operations[start_index:end_index] + + next_marker = None + if end_index < len(operations): + next_marker = str(end_index) + + return GetDurableExecutionStateResponse( + operations=paginated_operations, next_marker=next_marker + ) + + def get_execution_history( + self, + execution_arn: str, + include_execution_data: bool = False, # noqa: FBT001, FBT002 + reverse_order: bool = False, # noqa: FBT001, FBT002 + marker: str | None = None, + max_items: int | None = None, + ) -> GetDurableExecutionHistoryResponse: + """Get execution history with events. + + Args: + execution_arn: The execution ARN + include_execution_data: Whether to include execution data in events + reverse_order: Return events in reverse chronological order + marker: Pagination marker (event_id) + max_items: Maximum items to return + + Returns: + GetDurableExecutionHistoryResponse: Execution history with events + + Raises: + ResourceNotFoundException: If execution does not exist + """ + execution: Execution = self.get_execution(execution_arn) + + # Generate events + all_events: list[HistoryEvent] = [] + ops: list[Operation] = execution.operations + updates: list[OperationUpdate] = execution.updates + updates_dict: dict[str, OperationUpdate] = {u.operation_id: u for u in updates} + durable_execution_arn: str = execution.durable_execution_arn + + # Add InvocationCompleted events + for completion in execution.invocation_completions: + invocation_event = HistoryEvent.create_invocation_completed( + event_id=0, # Temporary, will be reassigned + event_timestamp=completion.end_timestamp, + start_timestamp=completion.start_timestamp, + end_timestamp=completion.end_timestamp, + request_id=completion.request_id, + ) + all_events.append(invocation_event) + + # Generate all events first (without final event IDs) + for op in ops: + operation_update: OperationUpdate | None = updates_dict.get( + op.operation_id, None + ) + + if op.status is OperationStatus.PENDING: + if ( + op.operation_type is not OperationType.CHAINED_INVOKE + or op.start_timestamp is None + ): + continue + context: EventCreationContext = EventCreationContext( + op, + 0, # Temporary event_id, will be reassigned after sorting + durable_execution_arn, + execution.start_input, + execution.result, + operation_update, + include_execution_data, + ) + pending = HistoryEvent.create_chained_invoke_event_pending(context) + all_events.append(pending) + if op.start_timestamp is not None: + context = EventCreationContext( + op, + 0, # Temporary event_id, will be reassigned after sorting + durable_execution_arn, + execution.start_input, + execution.result, + operation_update, + include_execution_data, + ) + started = HistoryEvent.create_event_started(context) + all_events.append(started) + if op.end_timestamp is not None and op.status in TERMINAL_STATUSES: + context = EventCreationContext( + op, + 0, # Temporary event_id, will be reassigned after sorting + durable_execution_arn, + execution.start_input, + execution.result, + operation_update, + include_execution_data, + ) + finished = HistoryEvent.create_event_terminated(context) + all_events.append(finished) + + # Sort events by timestamp to get correct chronological order + all_events.sort(key=lambda event: event.event_timestamp) + + # Reassign event IDs based on chronological order + all_events = [ + HistoryEvent.from_event_with_id(event, i) + for i, event in enumerate(all_events, 1) + ] + + # Apply cursor-based pagination + if max_items is None: + max_items = 100 + + # Handle pagination marker + if reverse_order: + all_events.reverse() + start_index: int = 0 + if marker: + try: + marker_event_id: int = int(marker) + # Find the index of the first event with event_id >= marker + start_index = len(all_events) + for i, e in enumerate(all_events): + is_valid_page_start: bool = ( + e.event_id < marker_event_id + if reverse_order + else e.event_id >= marker_event_id + ) + if is_valid_page_start: + start_index = i + break + except ValueError: + start_index = 0 + + # Get paginated events + end_index: int = start_index + max_items + paginated_events: list[HistoryEvent] = all_events[start_index:end_index] + + # Generate next marker + next_marker: str | None = None + if end_index < len(all_events): + if reverse_order: + # Next marker is the event_id of the last returned event + next_marker = ( + str(paginated_events[-1].event_id) if paginated_events else None + ) + else: + # Next marker is the event_id of the next event after the last returned + next_marker = ( + str(all_events[end_index].event_id) + if end_index < len(all_events) + else None + ) + + return GetDurableExecutionHistoryResponse( + events=paginated_events, next_marker=next_marker + ) + + def checkpoint_execution( + self, + execution_arn: str, + checkpoint_token: str, + updates: list[OperationUpdate] | None = None, + client_token: str | None = None, + ) -> CheckpointDurableExecutionResponse: + """Process checkpoint for an execution. + + Args: + execution_arn: The execution ARN + checkpoint_token: Current checkpoint token + updates: List of operation updates to process + client_token: Client token for idempotency + + Returns: + CheckpointDurableExecutionResponse: Updated checkpoint token and state + + Raises: + ResourceNotFoundException: If execution does not exist + InvalidParameterValueException: If checkpoint token is invalid + """ + execution = self.get_execution(execution_arn) + + # Validate checkpoint token + if checkpoint_token not in execution.used_tokens: + msg: str = f"Invalid checkpoint token: {checkpoint_token}" + raise InvalidParameterValueException(msg) + + if updates: + checkpoint_output = self._checkpoint_processor.process_checkpoint( + checkpoint_token=checkpoint_token, + updates=updates, + client_token=client_token, + ) + + new_execution_state = None + if checkpoint_output.new_execution_state: + new_execution_state = CheckpointUpdatedExecutionState( + operations=checkpoint_output.new_execution_state.operations, + next_marker=checkpoint_output.new_execution_state.next_marker, + ) + + return CheckpointDurableExecutionResponse( + checkpoint_token=checkpoint_output.checkpoint_token, + new_execution_state=new_execution_state, + ) + + # Save execution state after generating new token + new_checkpoint_token = execution.get_new_checkpoint_token() + self._store.update(execution) + + return CheckpointDurableExecutionResponse( + checkpoint_token=new_checkpoint_token, + new_execution_state=None, + ) + + def send_callback_success( + self, + callback_id: str, + result: bytes | None = None, + ) -> SendDurableExecutionCallbackSuccessResponse: + """Send callback success response. + + Args: + callback_id: The callback ID to respond to + result: Optional result data for the callback + + Returns: + SendDurableExecutionCallbackSuccessResponse: Empty response + + Raises: + InvalidParameterValueException: If callback_id is invalid + ResourceNotFoundException: If callback does not exist + """ + if not callback_id: + msg: str = "callback_id is required" + raise InvalidParameterValueException(msg) + + try: + callback_token = CallbackToken.from_str(callback_id) + execution = self.get_execution(callback_token.execution_arn) + execution.complete_callback_success(callback_id, result) + self._store.update(execution) + self._cleanup_callback_timeouts(callback_id) + self._invoke_execution(callback_token.execution_arn) + logger.info("Callback success completed for callback_id: %s", callback_id) + except Exception as e: + msg = f"Failed to process callback success: {e}" + raise ResourceNotFoundException(msg) from e + + return SendDurableExecutionCallbackSuccessResponse() + + def send_callback_failure( + self, + callback_id: str, + error: ErrorObject | None = None, + ) -> SendDurableExecutionCallbackFailureResponse: + """Send callback failure response. + + Args: + callback_id: The callback ID to respond to + error: Optional error object for the callback failure + + Returns: + SendDurableExecutionCallbackFailureResponse: Empty response + + Raises: + InvalidParameterValueException: If callback_id is invalid + ResourceNotFoundException: If callback does not exist + """ + if not callback_id: + msg: str = "callback_id is required" + raise InvalidParameterValueException(msg) + + callback_error: ErrorObject = error or ErrorObject.from_message("") + + try: + callback_token: CallbackToken = CallbackToken.from_str(callback_id) + execution: Execution = self.get_execution(callback_token.execution_arn) + execution.complete_callback_failure(callback_id, callback_error) + self._store.update(execution) + self._cleanup_callback_timeouts(callback_id) + self._invoke_execution(callback_token.execution_arn) + logger.info("Callback failure completed for callback_id: %s", callback_id) + except Exception as e: + msg = f"Failed to process callback failure: {e}" + raise ResourceNotFoundException(msg) from e + + return SendDurableExecutionCallbackFailureResponse() + + def send_callback_heartbeat( + self, callback_id: str + ) -> SendDurableExecutionCallbackHeartbeatResponse: + """Send callback heartbeat to keep callback alive. + + Args: + callback_id: The callback ID to send heartbeat for + + Returns: + SendDurableExecutionCallbackHeartbeatResponse: Empty response + + Raises: + InvalidParameterValueException: If callback_id is invalid + ResourceNotFoundException: If callback does not exist + """ + if not callback_id: + msg: str = "callback_id is required" + raise InvalidParameterValueException(msg) + + try: + callback_token: CallbackToken = CallbackToken.from_str(callback_id) + execution: Execution = self.get_execution(callback_token.execution_arn) + + # Find callback operation to verify it exists and is active + _, operation = execution.find_callback_operation(callback_id) + if operation.status != OperationStatus.STARTED: + msg = f"Callback {callback_id} is not active" + raise ResourceNotFoundException(msg) + + # Reset heartbeat timeout if configured + self._reset_callback_heartbeat_timeout( + callback_id, execution.durable_execution_arn + ) + logger.info("Callback heartbeat processed for callback_id: %s", callback_id) + except Exception as e: + msg = f"Failed to process callback heartbeat: {e}" + raise ResourceNotFoundException(msg) from e + + return SendDurableExecutionCallbackHeartbeatResponse() + + def _validate_invocation_response_and_store( + self, + execution_arn: str, + response: DurableExecutionInvocationOutput, + execution: Execution, + ): + """Validate response status and save it to the store if fine. + + Raises: + InvalidParameterValueException: If the response status is invalid. + IllegalStateException: If the response status is valid but the execution is already completed. + """ + if execution.is_complete: + msg_already_complete: str = "Execution already completed, ignoring result" + + raise IllegalStateException(msg_already_complete) + + if response.status is None: + msg_status_required: str = "Response status is required" + + raise InvalidParameterValueException(msg_status_required) + + match response.status: + case InvocationStatus.FAILED: + if response.result is not None: + msg_failed_result: str = ( + "Cannot provide a Result for FAILED status." + ) + raise InvalidParameterValueException(msg_failed_result) + logger.info("[%s] Execution failed", execution_arn) + self._complete_workflow( + execution_arn, result=None, error=response.error + ) + + case InvocationStatus.SUCCEEDED: + if response.error is not None: + msg_success_error: str = ( + "Cannot provide an Error for SUCCEEDED status." + ) + raise InvalidParameterValueException(msg_success_error) + logger.info("[%s] Execution succeeded", execution_arn) + self._complete_workflow( + execution_arn, result=response.result, error=None + ) + + case InvocationStatus.PENDING: + if not execution.has_pending_operations(execution): + msg_pending_ops: str = ( + "Cannot return PENDING status with no pending operations." + ) + raise InvalidParameterValueException(msg_pending_ops) + logger.info("[%s] Execution pending async work", execution_arn) + + case _: + msg_unexpected_status: str = ( + f"Unexpected invocation status: {response.status}" + ) + raise IllegalStateException(msg_unexpected_status) + + def _invoke_handler(self, execution_arn: str) -> Callable[[], Awaitable[None]]: + """Create a parameterless callable that captures execution arn for the scheduler.""" + + async def invoke() -> None: + execution: Execution = self._store.load(execution_arn) + + # Early exit if execution is already completed - like Java's COMPLETED check + if execution.is_complete: + logger.info( + "[%s] Execution already completed, ignoring result", execution_arn + ) + return + + try: + invocation_input: DurableExecutionInvocationInput = ( + self._invoker.create_invocation_input(execution=execution) + ) + + self._store.save(execution) + + invocation_start = datetime.now(UTC) + invoke_response = self._invoker.invoke( + execution.start_input.function_name, + invocation_input, + execution.start_input.lambda_endpoint, + ) + invocation_end = datetime.now(UTC) + + # Reload execution after invocation in case it was completed via checkpoint + execution = self._store.load(execution_arn) + + # Record invocation completion and save immediately + execution.record_invocation_completion( + invocation_start, invocation_end, invoke_response.request_id + ) + self._store.save(execution) + + if execution.is_complete: + logger.info( + "[%s] Execution completed during invocation, ignoring result", + execution_arn, + ) + return + + # Process successful received response - validate status and handle accordingly + response = invoke_response.invocation_output + try: + self._validate_invocation_response_and_store( + execution_arn, response, execution + ) + except (InvalidParameterValueException, IllegalStateException) as e: + logger.warning( + "[%s] Lambda output validation failure: %s", execution_arn, e + ) + error_obj = ErrorObject.from_exception(e) + self._retry_invocation(execution, error_obj) + + except ResourceNotFoundException: + logger.warning( + "[%s] Function No longer exists: %s", + execution_arn, + execution.start_input.function_name, + ) + error_obj = ErrorObject.from_message( + message=f"Function not found: {execution.start_input.function_name}" + ) + self._fail_workflow(execution_arn, error_obj) + + except Exception as e: # noqa: BLE001 + # Handle invocation errors (network, function not found, etc.) + logger.warning("[%s] Invocation failed: %s", execution_arn, e) + error_obj = ErrorObject.from_exception(e) + self._retry_invocation(execution, error_obj) + + return invoke + + def _invoke_execution(self, execution_arn: str, delay: float = 0) -> None: + """Invoke execution after delay in seconds.""" + completion_event = self._completion_events.get(execution_arn) + self._scheduler.call_later( + self._invoke_handler(execution_arn), + delay=delay, + completion_event=completion_event, + ) + + def _complete_workflow( + self, execution_arn: str, result: str | None, error: ErrorObject | None + ): + """Complete workflow - handles both success and failure with terminal state validation.""" + execution = self._store.load(execution_arn) + + if execution.is_complete: + msg: str = "Cannot make multiple close workflow decisions." + + raise IllegalStateException(msg) + + if error is not None: + self.fail_execution(execution_arn, error) + else: + self.complete_execution(execution_arn, result) + + def _fail_workflow(self, execution_arn: str, error: ErrorObject): + """Fail workflow with terminal state validation.""" + execution = self._store.load(execution_arn) + + if execution.is_complete: + msg: str = "Cannot make multiple close workflow decisions." + + raise IllegalStateException(msg) + + self.fail_execution(execution_arn, error) + + def _retry_invocation(self, execution: Execution, error: ErrorObject): + """Handle retry logic or fail execution if retries exhausted.""" + if ( + execution.consecutive_failed_invocation_attempts + > self.MAX_CONSECUTIVE_FAILED_ATTEMPTS + ): + # Exhausted retries - fail the execution + self._fail_workflow( + execution_arn=execution.durable_execution_arn, error=error + ) + else: + # Schedule retry with backoff + execution.consecutive_failed_invocation_attempts += 1 + self._store.save(execution) + self._invoke_execution( + execution_arn=execution.durable_execution_arn, + delay=self.RETRY_BACKOFF_SECONDS, + ) + + def _complete_events(self, execution_arn: str): + # complete doesn't actually checkpoint explicitly + if event := self._completion_events.get(execution_arn): + event.set() + if self._execution_timeout: + self._execution_timeout.cancel() + self._execution_timeout = None + + def wait_until_complete( + self, execution_arn: str, timeout: float | None = None + ) -> bool: + """Block until execution completion. Don't do this unless you actually want to block. + + Args + timeout (int|float|None): Wait for event to set until this timeout. + + Returns: + True when set. False if the event timed out without being set. + """ + if event := self._completion_events.get(execution_arn): + return event.wait(timeout) + + # this really shouldn't happen - implies execution timed out? + msg: str = "execution does not exist." + + raise ResourceNotFoundException(msg) + + def complete_execution(self, execution_arn: str, result: str | None = None) -> None: + """Complete execution successfully (COMPLETE_WORKFLOW_EXECUTION decision).""" + logger.debug("[%s] Completing execution with result: %s", execution_arn, result) + execution: Execution = self._store.load(execution_arn=execution_arn) + execution.complete_success(result=result) # Sets CloseStatus.COMPLETED + self._store.update(execution) + if execution.result is None: + msg: str = "Execution result is required" + raise IllegalStateException(msg) + self._complete_events(execution_arn=execution_arn) + + def fail_execution(self, execution_arn: str, error: ErrorObject) -> None: + """Fail execution with error (FAIL_WORKFLOW_EXECUTION decision).""" + logger.error("[%s] Completing execution with error: %s", execution_arn, error) + execution: Execution = self._store.load(execution_arn=execution_arn) + execution.complete_fail(error=error) # Sets CloseStatus.FAILED + self._store.update(execution) + # set by complete_fail + if execution.result is None: + msg: str = "Execution result is required" + raise IllegalStateException(msg) + self._complete_events(execution_arn=execution_arn) + + def _on_wait_succeeded(self, execution_arn: str, operation_id: str) -> None: + """Private method - called when a wait operation completes successfully.""" + execution = self._store.load(execution_arn) + + if execution.is_complete: + logger.info( + "[%s] Execution already completed, ignoring wait succeeded event", + execution_arn, + ) + return + + try: + execution.complete_wait(operation_id=operation_id) + self._store.update(execution) + logger.debug( + "[%s] Wait succeeded for operation %s", execution_arn, operation_id + ) + except Exception: + logger.exception("[%s] Error processing wait succeeded.", execution_arn) + + def _on_retry_ready(self, execution_arn: str, operation_id: str) -> None: + """Private method - called when a retry delay has elapsed and retry is ready.""" + execution = self._store.load(execution_arn) + + if execution.is_complete: + logger.info( + "[%s] Execution already completed, ignoring retry", execution_arn + ) + return + + try: + execution.complete_retry(operation_id=operation_id) + self._store.update(execution) + logger.debug( + "[%s] Retry ready for operation %s", execution_arn, operation_id + ) + except Exception: + logger.exception("[%s] Error processing retry ready.", execution_arn) + + # region ExecutionObserver + def on_completed(self, execution_arn: str, result: str | None = None) -> None: + """Complete execution successfully. Observer method triggered by notifier.""" + self.complete_execution(execution_arn, result) + + def on_failed(self, execution_arn: str, error: ErrorObject) -> None: + """Fail execution. Observer method triggered by notifier.""" + self.fail_execution(execution_arn, error) + + def on_timed_out(self, execution_arn: str, error: ErrorObject) -> None: + """Handle execution timeout (workflow timeout). Observer method triggered by notifier.""" + logger.exception("[%s] Execution timed out.", execution_arn) + execution: Execution = self._store.load(execution_arn=execution_arn) + execution.complete_timeout(error=error) # Sets CloseStatus.TIMED_OUT + self._store.update(execution) + self._complete_events(execution_arn=execution_arn) + + def on_stopped(self, execution_arn: str, error: ErrorObject) -> None: + """Handle execution stop. Observer method triggered by notifier.""" + # This should not be called directly - stop_execution handles termination + self.fail_execution(execution_arn, error) + + def on_wait_timer_scheduled( + self, execution_arn: str, operation_id: str, delay: float + ) -> None: + """Schedule a wait operation. Observer method triggered by notifier.""" + logger.debug("[%s] scheduling wait with delay: %d", execution_arn, delay) + + def wait_handler() -> None: + self._on_wait_succeeded(execution_arn, operation_id) + self._invoke_execution(execution_arn, delay=0) + + completion_event = self._completion_events.get(execution_arn) + self._scheduler.call_later( + wait_handler, delay=delay, completion_event=completion_event + ) + + def on_step_retry_scheduled( + self, execution_arn: str, operation_id: str, delay: float + ) -> None: + """Schedule a retry a step. Observer method triggered by notifier.""" + logger.debug( + "[%s] scheduling retry for %s with delay: %d", + execution_arn, + operation_id, + delay, + ) + + def retry_handler() -> None: + self._on_retry_ready(execution_arn, operation_id) + self._invoke_execution(execution_arn, delay=0) + + completion_event = self._completion_events.get(execution_arn) + self._scheduler.call_later( + retry_handler, delay=delay, completion_event=completion_event + ) + + def on_callback_created( + self, + execution_arn: str, + operation_id: str, + callback_options: CallbackOptions | None, + callback_token: CallbackToken, + ) -> None: + """Handle callback creation. Observer method triggered by notifier.""" + callback_id = callback_token.to_str() + logger.debug( + "[%s] Callback created for operation %s with callback_id: %s", + execution_arn, + operation_id, + callback_id, + ) + + # Schedule callback timeouts if configured + self._schedule_callback_timeouts(execution_arn, callback_options, callback_id) + + # endregion ExecutionObserver + + # region Callback Timeouts + def _schedule_callback_timeouts( + self, + execution_arn: str, + callback_options: CallbackOptions | None, + callback_id: str, + ) -> None: + """Schedule callback timeout and heartbeat timeout if configured.""" + try: + if not callback_options: + return + + completion_event = self._completion_events.get(execution_arn) + + # Schedule main timeout if configured + if callback_options.timeout_seconds > 0: + + def timeout_handler(): + self._on_callback_timeout(execution_arn, callback_id) + + timeout_future = self._scheduler.call_later( + timeout_handler, + delay=callback_options.timeout_seconds, + completion_event=completion_event, + ) + self._callback_timeouts[callback_id] = timeout_future + + # Schedule heartbeat timeout if configured + if callback_options.heartbeat_timeout_seconds > 0: + + def heartbeat_timeout_handler(): + self._on_callback_heartbeat_timeout(execution_arn, callback_id) + + heartbeat_future = self._scheduler.call_later( + heartbeat_timeout_handler, + delay=callback_options.heartbeat_timeout_seconds, + completion_event=completion_event, + ) + self._callback_heartbeats[callback_id] = heartbeat_future + + except Exception: + logger.exception( + "[%s] Error scheduling callback timeouts for %s", + execution_arn, + callback_id, + ) + + def _reset_callback_heartbeat_timeout( + self, callback_id: str, execution_arn: str + ) -> None: + """Reset the heartbeat timeout for a callback.""" + # Cancel existing heartbeat timeout + if heartbeat_future := self._callback_heartbeats.pop(callback_id, None): + heartbeat_future.cancel() + + # Find callback options to reschedule heartbeat timeout + try: + callback_token = CallbackToken.from_str(callback_id) + execution = self.get_execution(callback_token.execution_arn) + + callback_options = None + for update in execution.updates: + if ( + update.operation_id == callback_token.operation_id + and update.callback_options + and update.action.value == "START" + ): + callback_options = update.callback_options + break + + if callback_options and callback_options.heartbeat_timeout_seconds > 0: + + def heartbeat_timeout_handler(): + self._on_callback_heartbeat_timeout(execution_arn, callback_id) + + completion_event = self._completion_events.get(execution_arn) + + heartbeat_future = self._scheduler.call_later( + heartbeat_timeout_handler, + delay=callback_options.heartbeat_timeout_seconds, + completion_event=completion_event, + ) + self._callback_heartbeats[callback_id] = heartbeat_future + + except Exception: + logger.exception( + "[%s] Error resetting callback heartbeat timeout for %s", + execution_arn, + callback_id, + ) + + def _cleanup_callback_timeouts(self, callback_id: str) -> None: + """Clean up timeout events for a completed callback.""" + # Clean up main timeout + if timeout_future := self._callback_timeouts.pop(callback_id, None): + timeout_future.cancel() + + # Clean up heartbeat timeout + if heartbeat_future := self._callback_heartbeats.pop(callback_id, None): + heartbeat_future.cancel() + + def _on_callback_timeout(self, execution_arn: str, callback_id: str) -> None: + """Handle callback timeout.""" + try: + callback_token = CallbackToken.from_str(callback_id) + execution = self.get_execution(callback_token.execution_arn) + + if execution.is_complete: + return + + # Fail the callback with timeout error + timeout_error = ErrorObject.from_message( + f"Callback timed out: {CallbackTimeoutType.TIMEOUT.value}" + ) + execution.complete_callback_timeout(callback_id, timeout_error) + self._store.update(execution) + logger.warning("[%s] Callback %s timed out", execution_arn, callback_id) + self._invoke_execution(callback_token.execution_arn) + except Exception: + logger.exception( + "[%s] Error processing callback timeout for %s", + execution_arn, + callback_id, + ) + + def _on_callback_heartbeat_timeout( + self, execution_arn: str, callback_id: str + ) -> None: + """Handle callback heartbeat timeout.""" + try: + callback_token = CallbackToken.from_str(callback_id) + execution = self.get_execution(callback_token.execution_arn) + + if execution.is_complete: + return + + # Fail the callback with heartbeat timeout error + + heartbeat_error = ErrorObject.from_message( + f"Callback heartbeat timed out: {CallbackTimeoutType.HEARTBEAT.value}" + ) + execution.complete_callback_timeout(callback_id, heartbeat_error) + self._store.update(execution) + logger.warning( + "[%s] Callback %s heartbeat timed out", execution_arn, callback_id + ) + self._invoke_execution(callback_token.execution_arn) + except Exception: + logger.exception( + "[%s] Error processing callback heartbeat timeout for %s", + execution_arn, + callback_id, + ) + + # endregion Callback Timeouts diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/invoker.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/invoker.py new file mode 100644 index 0000000..26143c9 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/invoker.py @@ -0,0 +1,340 @@ +from __future__ import annotations + +import json +from dataclasses import dataclass +from threading import Lock +from typing import TYPE_CHECKING, Any, Protocol +from uuid import uuid4 + +import boto3 # type: ignore +from botocore.config import Config # type: ignore + +from aws_durable_execution_sdk_python.execution import ( + DurableExecutionInvocationInput, + DurableExecutionInvocationInputWithClient, + DurableExecutionInvocationOutput, + InitialExecutionState, +) + +from aws_durable_execution_sdk_python_testing.exceptions import ( + DurableFunctionsTestError, + InvalidParameterValueException, + ResourceNotFoundException, +) +from aws_durable_execution_sdk_python_testing.model import LambdaContext + + +if TYPE_CHECKING: + from collections.abc import Callable + + from aws_durable_execution_sdk_python_testing.client import InMemoryServiceClient + from aws_durable_execution_sdk_python_testing.execution import Execution + + +# Max Lambda function timeout is 15 minutes (900s); we give headroom for +# network round-trip and RIE startup. +_LAMBDA_READ_TIMEOUT_SECONDS = 960 +_LAMBDA_CLIENT_CONFIG = Config( + read_timeout=_LAMBDA_READ_TIMEOUT_SECONDS, + retries={"max_attempts": 0}, +) + + +def create_lambda_client(endpoint_url: str | None, region_name: str) -> Any: + """Create a boto3 Lambda client configured for durable function invocations.""" + + return boto3.client( + "lambda", + endpoint_url=endpoint_url, + region_name=region_name, + config=_LAMBDA_CLIENT_CONFIG, + ) + + +@dataclass(frozen=True) +class InvokeResponse: + """Response from invoking a durable function.""" + + invocation_output: DurableExecutionInvocationOutput + request_id: str + + +def create_test_lambda_context() -> LambdaContext: + # Create client context as a dictionary, not as objects + # LambdaContext.__init__ expects dictionaries and will create the objects internally + client_context_dict = { + "custom": {"test_key": "test_value"}, + "env": {"platform": "test", "make": "test", "model": "test"}, + "client": { + "installation_id": "test-installation-123", + "app_title": "TestApp", + "app_version_name": "1.0.0", + "app_version_code": "100", + "app_package_name": "com.test.app", + }, + } + + cognito_identity_dict = { + "cognitoIdentityId": "test-cognito-identity-123", + "cognitoIdentityPoolId": "us-west-2:test-pool-456", + } + + return LambdaContext( + aws_request_id="test-invoke-12345", + client_context=client_context_dict, + identity=cognito_identity_dict, + invoked_function_arn="arn:aws:lambda:us-west-2:123456789012:function:test-function", + tenant_id="test-tenant-789", + ) + + +class Invoker(Protocol): + def create_invocation_input( + self, execution: Execution + ) -> DurableExecutionInvocationInput: ... # pragma: no cover + + def invoke( + self, + function_name: str, + input: DurableExecutionInvocationInput, + endpoint_url: str | None = None, + ) -> InvokeResponse: ... # pragma: no cover + + def update_endpoint( + self, endpoint_url: str, region_name: str + ) -> None: ... # pragma: no cover + + +class InProcessInvoker(Invoker): + def __init__(self, handler: Callable, service_client: InMemoryServiceClient): + self.handler = handler + self.service_client = service_client + + def create_invocation_input( + self, execution: Execution + ) -> DurableExecutionInvocationInput: + return DurableExecutionInvocationInputWithClient( + durable_execution_arn=execution.durable_execution_arn, + # TODO: this needs better logic - use existing if not used yet, vs create new + checkpoint_token=execution.get_new_checkpoint_token(), + initial_execution_state=InitialExecutionState( + operations=execution.operations, + next_marker="", + ), + service_client=self.service_client, + ) + + def invoke( + self, + function_name: str, # noqa: ARG002 + input: DurableExecutionInvocationInput, + endpoint_url: str | None = None, # noqa: ARG002 + ) -> InvokeResponse: + # TODO: reasses if function_name will be used in future + input_with_client = DurableExecutionInvocationInputWithClient.from_durable_execution_invocation_input( + input, self.service_client + ) + context = create_test_lambda_context() + response_dict = self.handler(input_with_client, context) + output = DurableExecutionInvocationOutput.from_dict(response_dict) + return InvokeResponse( + invocation_output=output, request_id=context.aws_request_id + ) + + def update_endpoint(self, endpoint_url: str, region_name: str) -> None: + """No-op for in-process invoker.""" + + +class LambdaInvoker(Invoker): + def __init__(self, lambda_client: Any) -> None: + self.lambda_client = lambda_client + # Maps execution_arn -> endpoint for that execution + # Maps endpoint -> client to reuse clients across executions + self._execution_endpoints: dict[str, str] = {} + self._endpoint_clients: dict[str, Any] = {} + self._current_endpoint: str = "" # Track current endpoint for new executions + self._lock = Lock() + + @staticmethod + def create(endpoint_url: str, region_name: str) -> LambdaInvoker: + """Create with the boto lambda client.""" + invoker = LambdaInvoker(create_lambda_client(endpoint_url, region_name)) + invoker._current_endpoint = endpoint_url + invoker._endpoint_clients[endpoint_url] = invoker.lambda_client + return invoker + + def update_endpoint(self, endpoint_url: str, region_name: str) -> None: + """Update the Lambda client endpoint.""" + # Cache client by endpoint to reuse across executions + with self._lock: + if endpoint_url not in self._endpoint_clients: + self._endpoint_clients[endpoint_url] = create_lambda_client( + endpoint_url, region_name + ) + self.lambda_client = self._endpoint_clients[endpoint_url] + self._current_endpoint = endpoint_url + + def _get_client_for_execution( + self, + durable_execution_arn: str, + lambda_endpoint: str | None = None, + region_name: str | None = None, + ) -> Any: + """Get the appropriate client for this execution.""" + # Use provided endpoint or fall back to cached endpoint for this execution + if lambda_endpoint: + if lambda_endpoint not in self._endpoint_clients: + self._endpoint_clients[lambda_endpoint] = create_lambda_client( + lambda_endpoint, region_name or "us-east-1" + ) + return self._endpoint_clients[lambda_endpoint] + + # Fallback to cached endpoint + if durable_execution_arn not in self._execution_endpoints: + with self._lock: + if durable_execution_arn not in self._execution_endpoints: + self._execution_endpoints[durable_execution_arn] = ( + self._current_endpoint + ) + + endpoint = self._execution_endpoints[durable_execution_arn] + + # If no endpoint configured, fall back to default client + if not endpoint: + return self.lambda_client + + return self._endpoint_clients[endpoint] + + def create_invocation_input( + self, execution: Execution + ) -> DurableExecutionInvocationInput: + return DurableExecutionInvocationInput( + durable_execution_arn=execution.durable_execution_arn, + checkpoint_token=execution.get_new_checkpoint_token(), + initial_execution_state=InitialExecutionState( + operations=execution.operations, + next_marker="", + ), + ) + + def invoke( + self, + function_name: str, + input: DurableExecutionInvocationInput, + endpoint_url: str | None = None, + ) -> InvokeResponse: + """Invoke AWS Lambda function and return durable execution result. + + Args: + function_name: Name of the Lambda function to invoke + input: Durable execution invocation input + endpoint_url: Lambda endpoint url + + Returns: + InvokeResponse: Response containing invocation output and request ID + + Raises: + ResourceNotFoundException: If function does not exist + InvalidParameterValueException: If parameters are invalid + DurableFunctionsTestError: For other invocation failures + """ + + # Parameter validation + if not function_name or not function_name.strip(): + msg = "Function name is required" + raise InvalidParameterValueException(msg) + + # Get the client for this execution + client = self._get_client_for_execution( + input.durable_execution_arn, endpoint_url + ) + + try: + # Invoke AWS Lambda function using standard invoke method + response = client.invoke( + FunctionName=function_name, + InvocationType="RequestResponse", # Synchronous invocation + Payload=json.dumps(input.to_json_dict()), + ) + + # Check HTTP status code + status_code = response.get("StatusCode") + if status_code not in (200, 202, 204): + msg = f"Lambda invocation failed with status code: {status_code}" + raise DurableFunctionsTestError(msg) + + # Check for function errors + if "FunctionError" in response: + error_payload = response["Payload"].read().decode("utf-8") + msg = f"Lambda invocation failed with status {status_code}: {error_payload}" + raise DurableFunctionsTestError(msg) + + # Parse response payload + response_payload = response["Payload"].read().decode("utf-8") + response_dict = json.loads(response_payload) + + # Extract request ID from response headers (x-amzn-RequestId or x-amzn-request-id) + headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {}) + request_id = ( + headers.get("x-amzn-RequestId") + or headers.get("x-amzn-request-id") + or f"local-{uuid4()}" + ) + + # Convert to DurableExecutionInvocationOutput + output = DurableExecutionInvocationOutput.from_dict(response_dict) + return InvokeResponse(invocation_output=output, request_id=request_id) + + except client.exceptions.ResourceNotFoundException as e: + msg = f"Function not found: {function_name}" + raise ResourceNotFoundException(msg) from e + except client.exceptions.InvalidParameterValueException as e: + msg = f"Invalid parameter: {e}" + raise InvalidParameterValueException(msg) from e + except ( + client.exceptions.TooManyRequestsException, + client.exceptions.ServiceException, + client.exceptions.ResourceConflictException, + client.exceptions.InvalidRequestContentException, + client.exceptions.RequestTooLargeException, + client.exceptions.UnsupportedMediaTypeException, + client.exceptions.InvalidRuntimeException, + client.exceptions.InvalidZipFileException, + client.exceptions.ResourceNotReadyException, + client.exceptions.SnapStartTimeoutException, + client.exceptions.SnapStartNotReadyException, + client.exceptions.SnapStartException, + client.exceptions.RecursiveInvocationException, + ) as e: + msg = f"Lambda invocation failed: {e}" + raise DurableFunctionsTestError(msg) from e + except ( + client.exceptions.InvalidSecurityGroupIDException, + client.exceptions.EC2ThrottledException, + client.exceptions.EFSMountConnectivityException, + client.exceptions.SubnetIPAddressLimitReachedException, + client.exceptions.EC2UnexpectedException, + client.exceptions.InvalidSubnetIDException, + client.exceptions.EC2AccessDeniedException, + client.exceptions.EFSIOException, + client.exceptions.ENILimitReachedException, + client.exceptions.EFSMountTimeoutException, + client.exceptions.EFSMountFailureException, + ) as e: + msg = f"Lambda infrastructure error: {e}" + raise DurableFunctionsTestError(msg) from e + except ( + client.exceptions.KMSAccessDeniedException, + client.exceptions.KMSDisabledException, + client.exceptions.KMSNotFoundException, + client.exceptions.KMSInvalidStateException, + ) as e: + msg = f"Lambda KMS error: {e}" + raise DurableFunctionsTestError(msg) from e + except Exception as e: + # Handle any remaining exceptions, including custom ones like DurableExecutionAlreadyStartedException + if "DurableExecutionAlreadyStartedException" in str(type(e)): + msg = f"Durable execution already started: {e}" + raise DurableFunctionsTestError(msg) from e + msg = f"Unexpected error during Lambda invocation: {e}" + raise DurableFunctionsTestError(msg) from e diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/model.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/model.py new file mode 100644 index 0000000..12e69a8 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/model.py @@ -0,0 +1,3296 @@ +"""Model classes for the web API.""" + +from __future__ import annotations + +import datetime +import json +from dataclasses import dataclass, replace +from enum import Enum +from typing import Any + +from aws_durable_execution_sdk_python.execution import DurableExecutionInvocationOutput + +# Import existing types from the main SDK - REUSE EVERYTHING POSSIBLE +from aws_durable_execution_sdk_python.lambda_service import ( + CallbackDetails, + CallbackOptions, + ChainedInvokeDetails, + ChainedInvokeOptions, + ContextDetails, + ContextOptions, + ErrorObject, + ExecutionDetails, + Operation, + OperationAction, + OperationStatus, + OperationSubType, + OperationType, + OperationUpdate, + StepDetails, + StepOptions, + TimestampConverter, + WaitDetails, + WaitOptions, +) +from aws_durable_execution_sdk_python.types import ( + LambdaContext as LambdaContextProtocol, +) +from dateutil.tz import UTC + +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) + + +class EventType(Enum): + """Event types for durable execution events.""" + + EXECUTION_STARTED = "ExecutionStarted" + EXECUTION_SUCCEEDED = "ExecutionSucceeded" + EXECUTION_FAILED = "ExecutionFailed" + EXECUTION_TIMED_OUT = "ExecutionTimedOut" + EXECUTION_STOPPED = "ExecutionStopped" + CONTEXT_STARTED = "ContextStarted" + CONTEXT_SUCCEEDED = "ContextSucceeded" + CONTEXT_FAILED = "ContextFailed" + WAIT_STARTED = "WaitStarted" + WAIT_SUCCEEDED = "WaitSucceeded" + WAIT_CANCELLED = "WaitCancelled" + STEP_STARTED = "StepStarted" + STEP_SUCCEEDED = "StepSucceeded" + STEP_FAILED = "StepFailed" + CHAINED_INVOKE_STARTED = "ChainedInvokeStarted" + CHAINED_INVOKE_SUCCEEDED = "ChainedInvokeSucceeded" + CHAINED_INVOKE_FAILED = "ChainedInvokeFailed" + CHAINED_INVOKE_TIMED_OUT = "ChainedInvokeTimedOut" + CHAINED_INVOKE_STOPPED = "ChainedInvokeStopped" + CALLBACK_STARTED = "CallbackStarted" + CALLBACK_SUCCEEDED = "CallbackSucceeded" + CALLBACK_FAILED = "CallbackFailed" + CALLBACK_TIMED_OUT = "CallbackTimedOut" + INVOCATION_COMPLETED = "InvocationCompleted" + + +TERMINAL_STATUSES: set[OperationStatus] = { + OperationStatus.SUCCEEDED, + OperationStatus.FAILED, + OperationStatus.TIMED_OUT, + OperationStatus.STOPPED, + OperationStatus.CANCELLED, +} + + +@dataclass(frozen=True) +class LambdaContext(LambdaContextProtocol): + """Lambda context for testing.""" + + aws_request_id: str + log_group_name: str | None = None + log_stream_name: str | None = None + function_name: str | None = None + memory_limit_in_mb: str | None = None + function_version: str | None = None + invoked_function_arn: str | None = None + tenant_id: str | None = None + client_context: dict | None = None + identity: dict | None = None + + def get_remaining_time_in_millis(self) -> int: + return 900000 # 15 minutes default + + def log(self, msg) -> None: + pass # No-op for testing + + +# region web_api_models +# Web API specific models (not in Smithy but needed for web interface) +@dataclass(frozen=True) +class StartDurableExecutionInput: + """Input for starting a durable execution via web API.""" + + account_id: str + function_name: str + function_qualifier: str + execution_name: str + execution_timeout_seconds: int + execution_retention_period_days: int + invocation_id: str | None = None + trace_fields: dict | None = None + tenant_id: str | None = None + input: str | None = None + lambda_endpoint: str | None = None # Endpoint for this specific execution + + @classmethod + def from_dict(cls, data: dict) -> StartDurableExecutionInput: + # Validate required fields and raise AWS-compliant exceptions + required_fields = [ + "AccountId", + "FunctionName", + "FunctionQualifier", + "ExecutionName", + "ExecutionTimeoutSeconds", + "ExecutionRetentionPeriodDays", + ] + + for field in required_fields: + if field not in data: + msg: str = f"Missing required field: {field}" + raise InvalidParameterValueException(msg) + + return cls( + account_id=data["AccountId"], + function_name=data["FunctionName"], + function_qualifier=data["FunctionQualifier"], + execution_name=data["ExecutionName"], + execution_timeout_seconds=data["ExecutionTimeoutSeconds"], + execution_retention_period_days=data["ExecutionRetentionPeriodDays"], + invocation_id=data.get("InvocationId"), + trace_fields=data.get("TraceFields"), + tenant_id=data.get("TenantId"), + input=data.get("Input"), + lambda_endpoint=data.get("LambdaEndpoint", None), + ) + + def to_dict(self) -> dict[str, Any]: + result = { + "AccountId": self.account_id, + "FunctionName": self.function_name, + "FunctionQualifier": self.function_qualifier, + "ExecutionName": self.execution_name, + "ExecutionTimeoutSeconds": self.execution_timeout_seconds, + "ExecutionRetentionPeriodDays": self.execution_retention_period_days, + } + if self.invocation_id is not None: + result["InvocationId"] = self.invocation_id + if self.trace_fields is not None: + result["TraceFields"] = self.trace_fields + if self.tenant_id is not None: + result["TenantId"] = self.tenant_id + if self.input is not None: + result["Input"] = self.input + if self.lambda_endpoint is not None: + result["LambdaEndpoint"] = self.lambda_endpoint + return result + + def get_normalized_input(self): + """ + Normalize input string to be JSON deserializable. + Avoid double coding json input. + """ + # Try to parse once + try: + _ = json.loads(self.input) + return self.input + except (json.JSONDecodeError, TypeError): + # Not valid JSON, treat as plain string and encode it + return json.dumps(self.input) + + +@dataclass(frozen=True) +class StartDurableExecutionOutput: + """Output from starting a durable execution via web API.""" + + execution_arn: str | None = None + + @classmethod + def from_dict(cls, data: dict) -> StartDurableExecutionOutput: + return cls(execution_arn=data.get("ExecutionArn")) + + def to_dict(self) -> dict[str, Any]: + result = {} + if self.execution_arn is not None: + result["ExecutionArn"] = self.execution_arn + return result + + +# endregion web_api_models + + +# region smithy_api_models +# Smithy-based API models +@dataclass(frozen=True) +class GetDurableExecutionRequest: + """Request to get durable execution details.""" + + durable_execution_arn: str + + @classmethod + def from_dict(cls, data: dict) -> GetDurableExecutionRequest: + return cls(durable_execution_arn=data["DurableExecutionArn"]) + + def to_dict(self) -> dict[str, Any]: + return {"DurableExecutionArn": self.durable_execution_arn} + + +@dataclass(frozen=True) +class GetDurableExecutionResponse: + """Response containing durable execution details.""" + + durable_execution_arn: str + durable_execution_name: str + function_arn: str + status: str + start_timestamp: datetime.datetime + input_payload: str | None = None + result: str | None = None + error: ErrorObject | None = None + end_timestamp: datetime.datetime | None = None + version: str | None = None + + @classmethod + def from_dict(cls, data: dict) -> GetDurableExecutionResponse: + error = None + if error_data := data.get("Error"): + error = ErrorObject.from_dict(error_data) + + return cls( + durable_execution_arn=data["DurableExecutionArn"], + durable_execution_name=data["DurableExecutionName"], + function_arn=data["FunctionArn"], + status=data["Status"], + start_timestamp=data["StartTimestamp"], + input_payload=data.get("InputPayload"), + result=data.get("Result"), + error=error, + end_timestamp=data.get("EndTimestamp"), + version=data.get("Version"), + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = { + "DurableExecutionArn": self.durable_execution_arn, + "DurableExecutionName": self.durable_execution_name, + "FunctionArn": self.function_arn, + "Status": self.status, + "StartTimestamp": self.start_timestamp, + } + if self.input_payload is not None: + result["InputPayload"] = self.input_payload + if self.result is not None: + result["Result"] = self.result + if self.error is not None: + result["Error"] = self.error.to_dict() + if self.end_timestamp is not None: + result["EndTimestamp"] = self.end_timestamp + if self.end_timestamp is not None: + result["EndTimestamp"] = self.end_timestamp + if self.version is not None: + result["Version"] = self.version + return result + + +@dataclass(frozen=True) +class Execution: + """Execution summary structure from Smithy model.""" + + durable_execution_arn: str + durable_execution_name: str + function_arn: str + status: str + start_timestamp: datetime.datetime + end_timestamp: datetime.datetime | None = None + + @classmethod + def from_dict(cls, data: dict) -> Execution: + return cls( + durable_execution_arn=data["DurableExecutionArn"], + durable_execution_name=data["DurableExecutionName"], + function_arn=data.get( + "FunctionArn", "" + ), # Make optional for backward compatibility + status=data["Status"], + start_timestamp=data["StartTimestamp"], + end_timestamp=data.get("EndTimestamp"), + ) + + def to_dict(self) -> dict[str, Any]: + result = { + "DurableExecutionArn": self.durable_execution_arn, + "DurableExecutionName": self.durable_execution_name, + "Status": self.status, + "StartTimestamp": self.start_timestamp, + } + if self.function_arn: # Only include if not empty + result["FunctionArn"] = self.function_arn + if self.end_timestamp is not None: + result["EndTimestamp"] = self.end_timestamp + return result + + @classmethod + def from_execution(cls, execution, status: str) -> Execution: + """Create ExecutionSummary from Execution object.""" + + execution_op = execution.get_operation_execution_started() + return cls( + durable_execution_arn=execution.durable_execution_arn, + durable_execution_name=execution.start_input.execution_name, + function_arn=f"arn:aws:lambda:us-east-1:123456789012:function:{execution.start_input.function_name}", + status=status, + start_timestamp=execution_op.start_timestamp + if execution_op.start_timestamp + else datetime.datetime.now(datetime.UTC), + end_timestamp=execution_op.end_timestamp + if execution_op.end_timestamp + else None, + ) + + +@dataclass(frozen=True) +class ListDurableExecutionsRequest: + """Request to list durable executions.""" + + function_name: str | None = None + function_version: str | None = None + durable_execution_name: str | None = None + status_filter: list[str] | None = None + started_after: str | None = None + started_before: str | None = None + marker: str | None = None + max_items: int = 0 + reverse_order: bool | None = None + + @classmethod + def from_dict(cls, data: dict) -> ListDurableExecutionsRequest: + # Handle query parameters that may be lists + function_name = data.get("FunctionName") + if isinstance(function_name, list): + function_name = function_name[0] if function_name else None + + function_version = data.get("FunctionVersion") + if isinstance(function_version, list): + function_version = function_version[0] if function_version else None + + durable_execution_name = data.get("DurableExecutionName") + if isinstance(durable_execution_name, list): + durable_execution_name = ( + durable_execution_name[0] if durable_execution_name else None + ) + + status_filter = data.get("StatusFilter") + if isinstance(status_filter, list): + status_filter = status_filter if status_filter else None + elif status_filter: + status_filter = [status_filter] + + started_after = data.get("StartedAfter") + if isinstance(started_after, list): + started_after = started_after[0] if started_after else None + + started_before = data.get("StartedBefore") + if isinstance(started_before, list): + started_before = started_before[0] if started_before else None + + marker = data.get("Marker") + if isinstance(marker, list): + marker = marker[0] if marker else None + + max_items = data.get("MaxItems", 0) + if isinstance(max_items, list): + max_items = int(max_items[0]) if max_items else 0 + + reverse_order = data.get("ReverseOrder") + if isinstance(reverse_order, list): + reverse_order = ( + reverse_order[0].lower() in ("true", "1", "yes") + if reverse_order + else None + ) + elif isinstance(reverse_order, str): + reverse_order = reverse_order.lower() in ("true", "1", "yes") + + return cls( + function_name=function_name, + function_version=function_version, + durable_execution_name=durable_execution_name, + status_filter=status_filter, + started_after=started_after, + started_before=started_before, + marker=marker, + max_items=max_items, + reverse_order=reverse_order, + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + if self.function_name is not None: + result["FunctionName"] = self.function_name + if self.function_version is not None: + result["FunctionVersion"] = self.function_version + if self.durable_execution_name is not None: + result["DurableExecutionName"] = self.durable_execution_name + if self.status_filter is not None: + result["StatusFilter"] = self.status_filter + if self.started_after is not None: + result["StartedAfter"] = self.started_after + if self.started_before is not None: + result["StartedBefore"] = self.started_before + if self.marker is not None: + result["Marker"] = self.marker + if self.max_items is not None: + result["MaxItems"] = self.max_items + if self.reverse_order is not None: + result["ReverseOrder"] = self.reverse_order + return result + + +@dataclass(frozen=True) +class ListDurableExecutionsResponse: + """Response containing list of durable executions.""" + + durable_executions: list[Execution] + next_marker: str | None = None + + @classmethod + def from_dict(cls, data: dict) -> ListDurableExecutionsResponse: + executions = [ + Execution.from_dict(exec_data) + for exec_data in data.get("DurableExecutions", []) + ] + return cls( + durable_executions=executions, + next_marker=data.get("NextMarker"), + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = { + "DurableExecutions": [exe.to_dict() for exe in self.durable_executions] + } + if self.next_marker is not None: + result["NextMarker"] = self.next_marker + return result + + +@dataclass(frozen=True) +class StopDurableExecutionRequest: + """Request to stop a durable execution.""" + + durable_execution_arn: str + error: ErrorObject | None = None + + @classmethod + def from_dict(cls, data: dict) -> StopDurableExecutionRequest: + error = None + if error_data := data.get("Error"): + error = ErrorObject.from_dict(error_data) + + return cls( + durable_execution_arn=data["DurableExecutionArn"], + error=error, + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {"DurableExecutionArn": self.durable_execution_arn} + if self.error is not None: + result["Error"] = self.error.to_dict() + return result + + +@dataclass(frozen=True) +class StopDurableExecutionResponse: + """Response from stopping a durable execution.""" + + stop_timestamp: datetime.datetime + + @classmethod + def from_dict(cls, data: dict) -> StopDurableExecutionResponse: + return cls(stop_timestamp=data["StopTimestamp"]) + + def to_dict(self) -> dict[str, Any]: + return {"StopTimestamp": self.stop_timestamp} + + +@dataclass(frozen=True) +class GetDurableExecutionStateRequest: + """Request to get durable execution state.""" + + durable_execution_arn: str + checkpoint_token: str + marker: str | None = None + max_items: int = 0 + + @classmethod + def from_dict(cls, data: dict) -> GetDurableExecutionStateRequest: + return cls( + durable_execution_arn=data["DurableExecutionArn"], + checkpoint_token=data["CheckpointToken"], + marker=data.get("Marker"), + max_items=data.get("MaxItems", 0), + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = { + "DurableExecutionArn": self.durable_execution_arn, + "CheckpointToken": self.checkpoint_token, + } + if self.marker is not None: + result["Marker"] = self.marker + if self.max_items is not None: + result["MaxItems"] = self.max_items + return result + + +@dataclass(frozen=True) +class GetDurableExecutionStateResponse: + """Response containing durable execution state operations.""" + + operations: list[Operation] + next_marker: str | None = None + + @classmethod + def from_dict(cls, data: dict) -> GetDurableExecutionStateResponse: + operations = [ + Operation.from_dict(op_data) for op_data in data.get("Operations", []) + ] + return cls( + operations=operations, + next_marker=data.get("NextMarker"), + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = { + "Operations": [op.to_dict() for op in self.operations] + } + if self.next_marker is not None: + result["NextMarker"] = self.next_marker + return result + + +# endregion smithy_api_models + + +# region event_structures +# Event-related structures from Smithy model +@dataclass(frozen=True) +class EventInput: + """Event input structure.""" + + payload: str | None = None + truncated: bool = False + + @classmethod + def from_dict(cls, data: dict) -> EventInput: + return cls( + payload=data.get("Payload"), + truncated=data.get("Truncated", False), + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {"Truncated": self.truncated} + if self.payload is not None: + result["Payload"] = self.payload + return result + + @classmethod + def from_details( + cls, + details: ExecutionDetails, + include: bool = False, # noqa: FBT001, FBT002 + ) -> EventInput: + details_input: str | None = details.input_payload if details else None + payload: str | None = details_input if include else None + truncated: bool = not include + return cls(payload=payload, truncated=truncated) + + @classmethod + def from_start_durable_execution_input( + cls, + start_durable_execution_input: StartDurableExecutionInput, + include: bool = False, # noqa: FBT001, FBT002 + ) -> EventInput: + input: str | None = start_durable_execution_input.input + truncated: bool = not include + return cls(input, truncated) + + +@dataclass(frozen=True) +class EventResult: + """Event result structure.""" + + payload: str | None = None + truncated: bool = False + + @classmethod + def from_dict(cls, data: dict) -> EventResult: + return cls( + payload=data.get("Payload"), + truncated=data.get("Truncated", False), + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {"Truncated": self.truncated} + if self.payload is not None: + result["Payload"] = self.payload + return result + + @classmethod + def from_details( + cls, + details: CallbackDetails | StepDetails | ChainedInvokeDetails | ContextDetails, + include: bool = False, # noqa: FBT001, FBT002 + ) -> EventResult: + details_result: str | None = details.result if details else None + payload: str | None = details_result if include else None + truncated: bool = not include + return cls(payload=payload, truncated=truncated) + + @classmethod + def from_durable_execution_invocation_output( + cls, + durable_execution_invocation_output: DurableExecutionInvocationOutput, + include: bool = False, # noqa: FBT001, FBT002 + ) -> EventResult: + truncated: bool = not include + return cls(durable_execution_invocation_output.result, truncated) + + +@dataclass(frozen=True) +class EventError: + """Event error structure.""" + + payload: ErrorObject | None = None + truncated: bool = False + + @classmethod + def from_dict(cls, data: dict) -> EventError: + payload = None + if payload_data := data.get("Payload"): + payload = ErrorObject.from_dict(payload_data) + + return cls( + payload=payload, + truncated=data.get("Truncated", False), + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {"Truncated": self.truncated} + if self.payload is not None: + result["Payload"] = self.payload.to_dict() + return result + + @classmethod + def from_details( + cls, + details: CallbackDetails | StepDetails | ChainedInvokeDetails | ContextDetails, + include: bool = False, # noqa: FBT001, FBT002 + ) -> EventError: + error_object: ErrorObject | None = details.error if details else None + truncated: bool = not include + return cls(error_object, truncated) + + @classmethod + def from_durable_execution_invocation_output( + cls, + durable_execution_invocation_output: DurableExecutionInvocationOutput, + include: bool = False, # noqa: FBT001, FBT002 + ) -> EventError: + truncated: bool = not include + return cls(durable_execution_invocation_output.error, truncated) + + +@dataclass(frozen=True) +class RetryDetails: + """Retry details structure.""" + + current_attempt: int = 0 + next_attempt_delay_seconds: int | None = None + + @classmethod + def from_dict(cls, data: dict) -> RetryDetails: + return cls( + current_attempt=data.get("CurrentAttempt", 0), + next_attempt_delay_seconds=data.get("NextAttemptDelaySeconds"), + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {"CurrentAttempt": self.current_attempt} + if self.next_attempt_delay_seconds is not None: + result["NextAttemptDelaySeconds"] = self.next_attempt_delay_seconds + return result + + +# Event detail structures +@dataclass(frozen=True) +class ExecutionStartedDetails: + """Execution started event details.""" + + input: EventInput | None = None + execution_timeout: int | None = None + + @classmethod + def from_dict(cls, data: dict) -> ExecutionStartedDetails: + input_data = None + if input_dict := data.get("Input"): + input_data = EventInput.from_dict(input_dict) + + return cls( + input=input_data, + execution_timeout=data.get("ExecutionTimeout"), + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + if self.input is not None: + result["Input"] = self.input.to_dict() + if self.execution_timeout is not None: + result["ExecutionTimeout"] = self.execution_timeout + return result + + +@dataclass(frozen=True) +class ExecutionSucceededDetails: + """Execution succeeded event details.""" + + result: EventResult | None = None + + @classmethod + def from_dict(cls, data: dict) -> ExecutionSucceededDetails: + result_data = None + if result_dict := data.get("Result"): + result_data = EventResult.from_dict(result_dict) + + return cls(result=result_data) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + if self.result is not None: + result["Result"] = self.result.to_dict() + return result + + +@dataclass(frozen=True) +class ExecutionFailedDetails: + """Execution failed event details.""" + + error: EventError | None = None + + @classmethod + def from_dict(cls, data: dict) -> ExecutionFailedDetails: + error_data = None + if error_dict := data.get("Error"): + error_data = EventError.from_dict(error_dict) + + return cls(error=error_data) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + if self.error is not None: + result["Error"] = self.error.to_dict() + return result + + +@dataclass(frozen=True) +class ExecutionTimedOutDetails: + """Execution timed out event details.""" + + error: EventError | None = None + + @classmethod + def from_dict(cls, data: dict) -> ExecutionTimedOutDetails: + error_data = None + if error_dict := data.get("Error"): + error_data = EventError.from_dict(error_dict) + + return cls(error=error_data) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + if self.error is not None: + result["Error"] = self.error.to_dict() + return result + + +@dataclass(frozen=True) +class ExecutionStoppedDetails: + """Execution stopped event details.""" + + error: EventError | None = None + + @classmethod + def from_dict(cls, data: dict) -> ExecutionStoppedDetails: + error_data = None + if error_dict := data.get("Error"): + error_data = EventError.from_dict(error_dict) + + return cls(error=error_data) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + if self.error is not None: + result["Error"] = self.error.to_dict() + return result + + +@dataclass(frozen=True) +class ContextStartedDetails: + """Context started event details.""" + + @classmethod + def from_dict(cls, data: dict) -> ContextStartedDetails: # noqa: ARG003 + return cls() + + def to_dict(self) -> dict[str, Any]: + return {} + + +@dataclass(frozen=True) +class ContextSucceededDetails: + """Context succeeded event details.""" + + result: EventResult | None = None + + @classmethod + def from_dict(cls, data: dict) -> ContextSucceededDetails: + result_data = None + if result_dict := data.get("Result"): + result_data = EventResult.from_dict(result_dict) + + return cls(result=result_data) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + if self.result is not None: + result["Result"] = self.result.to_dict() + return result + + +@dataclass(frozen=True) +class ContextFailedDetails: + """Context failed event details.""" + + error: EventError | None = None + + @classmethod + def from_dict(cls, data: dict) -> ContextFailedDetails: + error_data = None + if error_dict := data.get("Error"): + error_data = EventError.from_dict(error_dict) + + return cls(error=error_data) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + if self.error is not None: + result["Error"] = self.error.to_dict() + return result + + +@dataclass(frozen=True) +class WaitStartedDetails: + """Wait started event details.""" + + duration: int | None = None + scheduled_end_timestamp: datetime.datetime | None = None + + @classmethod + def from_dict(cls, data: dict) -> WaitStartedDetails: + return cls( + duration=data.get("Duration"), + scheduled_end_timestamp=data.get("ScheduledEndTimestamp"), + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + if self.duration is not None: + result["Duration"] = self.duration + if self.scheduled_end_timestamp is not None: + result["ScheduledEndTimestamp"] = self.scheduled_end_timestamp + return result + + +@dataclass(frozen=True) +class WaitSucceededDetails: + """Wait succeeded event details.""" + + duration: int | None = None + + @classmethod + def from_dict(cls, data: dict) -> WaitSucceededDetails: + return cls(duration=data.get("Duration")) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + if self.duration is not None: + result["Duration"] = self.duration + return result + + +@dataclass(frozen=True) +class WaitCancelledDetails: + """Wait cancelled event details.""" + + error: EventError | None = None + + @classmethod + def from_dict(cls, data: dict) -> WaitCancelledDetails: + error_data = None + if error_dict := data.get("Error"): + error_data = EventError.from_dict(error_dict) + + return cls(error=error_data) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + if self.error is not None: + result["Error"] = self.error.to_dict() + return result + + +@dataclass(frozen=True) +class StepStartedDetails: + """Step started event details.""" + + @classmethod + def from_dict(cls, data: dict) -> StepStartedDetails: # noqa: ARG003 + return cls() + + def to_dict(self) -> dict[str, Any]: + return {} + + +@dataclass(frozen=True) +class StepSucceededDetails: + """Step succeeded event details.""" + + result: EventResult | None = None + retry_details: RetryDetails | None = None + + @classmethod + def from_dict(cls, data: dict) -> StepSucceededDetails: + result_data = None + if result_dict := data.get("Result"): + result_data = EventResult.from_dict(result_dict) + + retry_details_data = None + if retry_dict := data.get("RetryDetails"): + retry_details_data = RetryDetails.from_dict(retry_dict) + + return cls(result=result_data, retry_details=retry_details_data) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + if self.result is not None: + result["Result"] = self.result.to_dict() + if self.retry_details is not None: + result["RetryDetails"] = self.retry_details.to_dict() + return result + + +@dataclass(frozen=True) +class StepFailedDetails: + """Step failed event details.""" + + error: EventError | None = None + retry_details: RetryDetails | None = None + + @classmethod + def from_dict(cls, data: dict) -> StepFailedDetails: + error_data = None + if error_dict := data.get("Error"): + error_data = EventError.from_dict(error_dict) + + retry_details_data = None + if retry_dict := data.get("RetryDetails"): + retry_details_data = RetryDetails.from_dict(retry_dict) + + return cls(error=error_data, retry_details=retry_details_data) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + if self.error is not None: + result["Error"] = self.error.to_dict() + if self.retry_details is not None: + result["RetryDetails"] = self.retry_details.to_dict() + return result + + +@dataclass(frozen=True) +class ChainedInvokePendingDetails: + """Chained Invoke Pending event details.""" + + input: EventInput | None = None + function_name: str | None = None + + @classmethod + def from_dict(cls, data: dict) -> ChainedInvokePendingDetails: + input_data = None + if input_dict := data.get("Input"): + input_data = EventInput.from_dict(input_dict) + + return cls( + input=input_data, + function_name=data.get("FunctionName"), + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + if self.input is not None: + result["Input"] = self.input.to_dict() + if self.function_name is not None: + result["FunctionName"] = self.function_name + return result + + +@dataclass(frozen=True) +class ChainedInvokeStartedDetails: + """Chained invoke started event details.""" + + durable_execution_arn: str | None = None + + @classmethod + def from_dict(cls, data: dict) -> ChainedInvokeStartedDetails: + return cls( + durable_execution_arn=data.get("DurableExecutionArn"), + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + if self.durable_execution_arn is not None: + result["DurableExecutionArn"] = self.durable_execution_arn + return result + + +@dataclass(frozen=True) +class ChainedInvokeSucceededDetails: + """Chained invoke succeeded event details.""" + + result: EventResult | None = None + + @classmethod + def from_dict(cls, data: dict) -> ChainedInvokeSucceededDetails: + result_data = None + if result_dict := data.get("Result"): + result_data = EventResult.from_dict(result_dict) + + return cls(result=result_data) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + if self.result is not None: + result["Result"] = self.result.to_dict() + return result + + +@dataclass(frozen=True) +class ChainedInvokeFailedDetails: + """Chained invoke failed event details.""" + + error: EventError | None = None + + @classmethod + def from_dict(cls, data: dict) -> ChainedInvokeFailedDetails: + error_data = None + if error_dict := data.get("Error"): + error_data = EventError.from_dict(error_dict) + + return cls(error=error_data) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + if self.error is not None: + result["Error"] = self.error.to_dict() + return result + + +@dataclass(frozen=True) +class ChainedInvokeTimedOutDetails: + """Chained invoke timed out event details.""" + + error: EventError | None = None + + @classmethod + def from_dict(cls, data: dict) -> ChainedInvokeTimedOutDetails: + error_data = None + if error_dict := data.get("Error"): + error_data = EventError.from_dict(error_dict) + + return cls(error=error_data) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + if self.error is not None: + result["Error"] = self.error.to_dict() + return result + + +@dataclass(frozen=True) +class ChainedInvokeStoppedDetails: + """Chained invoke stopped event details.""" + + error: EventError | None = None + + @classmethod + def from_dict(cls, data: dict) -> ChainedInvokeStoppedDetails: + error_data = None + if error_dict := data.get("Error"): + error_data = EventError.from_dict(error_dict) + + return cls(error=error_data) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + if self.error is not None: + result["Error"] = self.error.to_dict() + return result + + +@dataclass(frozen=True) +class CallbackStartedDetails: + """Callback started event details.""" + + callback_id: str | None = None + heartbeat_timeout: int | None = None + timeout: int | None = None + + @classmethod + def from_dict(cls, data: dict) -> CallbackStartedDetails: + return cls( + callback_id=data.get("CallbackId"), + heartbeat_timeout=data.get("HeartbeatTimeout"), + timeout=data.get("Timeout"), + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + if self.callback_id is not None: + result["CallbackId"] = self.callback_id + if self.heartbeat_timeout is not None: + result["HeartbeatTimeout"] = self.heartbeat_timeout + if self.timeout is not None: + result["Timeout"] = self.timeout + return result + + +@dataclass(frozen=True) +class CallbackSucceededDetails: + """Callback succeeded event details.""" + + result: EventResult | None = None + + @classmethod + def from_dict(cls, data: dict) -> CallbackSucceededDetails: + result_data = None + if result_dict := data.get("Result"): + result_data = EventResult.from_dict(result_dict) + + return cls(result=result_data) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + if self.result is not None: + result["Result"] = self.result.to_dict() + return result + + +@dataclass(frozen=True) +class CallbackFailedDetails: + """Callback failed event details.""" + + error: EventError | None = None + + @classmethod + def from_dict(cls, data: dict) -> CallbackFailedDetails: + error_data = None + if error_dict := data.get("Error"): + error_data = EventError.from_dict(error_dict) + + return cls(error=error_data) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + if self.error is not None: + result["Error"] = self.error.to_dict() + return result + + +@dataclass(frozen=True) +class CallbackTimedOutDetails: + """Callback timed out event details.""" + + error: EventError | None = None + + @classmethod + def from_dict(cls, data: dict) -> CallbackTimedOutDetails: + error_data = None + if error_dict := data.get("Error"): + error_data = EventError.from_dict(error_dict) + + return cls(error=error_data) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + if self.error is not None: + result["Error"] = self.error.to_dict() + return result + + +@dataclass(frozen=True) +class InvocationCompletedDetails: + """Invocation completed event details.""" + + start_timestamp: datetime.datetime + end_timestamp: datetime.datetime + request_id: str + + @classmethod + def from_dict(cls, data: dict) -> InvocationCompletedDetails: + return cls( + start_timestamp=data["StartTimestamp"], + end_timestamp=data["EndTimestamp"], + request_id=data["RequestId"], + ) + + @classmethod + def from_json_dict(cls, data: dict) -> InvocationCompletedDetails: + """Deserialize from JSON dict with Unix millisecond timestamps.""" + start_ts: datetime.datetime | None = TimestampConverter.from_unix_millis( + data["StartTimestamp"] + ) # type: ignore[arg-type] + end_ts: datetime.datetime | None = TimestampConverter.from_unix_millis( + data["EndTimestamp"] + ) # type: ignore[arg-type] + + if start_ts is None or end_ts is None: + raise InvalidParameterValueException( + "StartTimestamp and EndTimestamp cannot be null" + ) + + return cls( + start_timestamp=start_ts, + end_timestamp=end_ts, + request_id=data["RequestId"], + ) + + def to_dict(self) -> dict[str, Any]: + return { + "StartTimestamp": self.start_timestamp, + "EndTimestamp": self.end_timestamp, + "RequestId": self.request_id, + } + + def to_json_dict(self) -> dict[str, Any]: + """Convert to JSON-serializable dict with Unix millisecond timestamps.""" + return { + "StartTimestamp": TimestampConverter.to_unix_millis(self.start_timestamp), + "EndTimestamp": TimestampConverter.to_unix_millis(self.end_timestamp), + "RequestId": self.request_id, + } + + +# endregion event_structures + + +@dataclass(frozen=True) +class EventCreationContext: + operation: Operation + event_id: int + durable_execution_arn: str + start_durable_execution_input: StartDurableExecutionInput + durable_execution_invocation_output: DurableExecutionInvocationOutput | None = None + operation_update: OperationUpdate | None = None + include_execution_data: bool = False # noqa: FBT001, FBT002 + + @classmethod + def create( + cls, + operation: Operation, + event_id: int, + durable_execution_arn: str, + start_input: StartDurableExecutionInput, + result: DurableExecutionInvocationOutput | None = None, + operation_update: OperationUpdate | None = None, + include_execution_data: bool = False, # noqa: FBT001, FBT002 + ) -> EventCreationContext: + return cls( + operation=operation, + event_id=event_id, + durable_execution_arn=durable_execution_arn, + start_durable_execution_input=start_input, + durable_execution_invocation_output=result, + operation_update=operation_update, + include_execution_data=include_execution_data, + ) + + @property + def sub_type(self) -> str | None: + return self.operation.sub_type.value if self.operation.sub_type else None + + def get_retry_details(self) -> RetryDetails | None: + if not self.operation.step_details or not self.operation_update: + return None + + delay = 0 + if ( + self.operation_update.operation_type == OperationType.STEP + and self.operation_update.step_options + ): + delay = self.operation_update.step_options.next_attempt_delay_seconds + + return RetryDetails( + current_attempt=self.operation.step_details.attempt, + next_attempt_delay_seconds=delay, + ) + + @property + def start_timestamp(self) -> datetime.datetime: + return ( + self.operation.start_timestamp + if self.operation.start_timestamp is not None + else datetime.datetime.now(UTC) + ) + + @property + def end_timestamp(self) -> datetime.datetime: + return ( + self.operation.end_timestamp + if self.operation.end_timestamp is not None + else datetime.datetime.now(UTC) + ) + + +# region event_class +@dataclass(frozen=True) +class Event: + """Event structure from Smithy model.""" + + event_type: str + event_timestamp: datetime.datetime + sub_type: str | None = None + event_id: int = 1 + operation_id: str | None = None + name: str | None = None + parent_id: str | None = None + execution_started_details: ExecutionStartedDetails | None = None + execution_succeeded_details: ExecutionSucceededDetails | None = None + execution_failed_details: ExecutionFailedDetails | None = None + execution_timed_out_details: ExecutionTimedOutDetails | None = None + execution_stopped_details: ExecutionStoppedDetails | None = None + context_started_details: ContextStartedDetails | None = None + context_succeeded_details: ContextSucceededDetails | None = None + context_failed_details: ContextFailedDetails | None = None + wait_started_details: WaitStartedDetails | None = None + wait_succeeded_details: WaitSucceededDetails | None = None + wait_cancelled_details: WaitCancelledDetails | None = None + step_started_details: StepStartedDetails | None = None + step_succeeded_details: StepSucceededDetails | None = None + step_failed_details: StepFailedDetails | None = None + chained_invoke_pending_details: ChainedInvokePendingDetails | None = None + chained_invoke_started_details: ChainedInvokeStartedDetails | None = None + chained_invoke_succeeded_details: ChainedInvokeSucceededDetails | None = None + chained_invoke_failed_details: ChainedInvokeFailedDetails | None = None + chained_invoke_timed_out_details: ChainedInvokeTimedOutDetails | None = None + chained_invoke_stopped_details: ChainedInvokeStoppedDetails | None = None + callback_started_details: CallbackStartedDetails | None = None + callback_succeeded_details: CallbackSucceededDetails | None = None + callback_failed_details: CallbackFailedDetails | None = None + callback_timed_out_details: CallbackTimedOutDetails | None = None + invocation_completed_details: InvocationCompletedDetails | None = None + + @classmethod + def from_dict(cls, data: dict) -> Event: + # Parse all the detail structures + execution_started_details = None + if details_data := data.get("ExecutionStartedDetails"): + execution_started_details = ExecutionStartedDetails.from_dict(details_data) + + execution_succeeded_details = None + if details_data := data.get("ExecutionSucceededDetails"): + execution_succeeded_details = ExecutionSucceededDetails.from_dict( + details_data + ) + + execution_failed_details = None + if details_data := data.get("ExecutionFailedDetails"): + execution_failed_details = ExecutionFailedDetails.from_dict(details_data) + + execution_timed_out_details = None + if details_data := data.get("ExecutionTimedOutDetails"): + execution_timed_out_details = ExecutionTimedOutDetails.from_dict( + details_data + ) + + execution_stopped_details = None + if details_data := data.get("ExecutionStoppedDetails"): + execution_stopped_details = ExecutionStoppedDetails.from_dict(details_data) + + context_started_details = None + if details_data := data.get("ContextStartedDetails"): + context_started_details = ContextStartedDetails.from_dict(details_data) + + context_succeeded_details = None + if details_data := data.get("ContextSucceededDetails"): + context_succeeded_details = ContextSucceededDetails.from_dict(details_data) + + context_failed_details = None + if details_data := data.get("ContextFailedDetails"): + context_failed_details = ContextFailedDetails.from_dict(details_data) + + wait_started_details = None + if details_data := data.get("WaitStartedDetails"): + wait_started_details = WaitStartedDetails.from_dict(details_data) + + wait_succeeded_details = None + if details_data := data.get("WaitSucceededDetails"): + wait_succeeded_details = WaitSucceededDetails.from_dict(details_data) + + wait_cancelled_details = None + if details_data := data.get("WaitCancelledDetails"): + wait_cancelled_details = WaitCancelledDetails.from_dict(details_data) + + step_started_details = None + if details_data := data.get("StepStartedDetails"): + step_started_details = StepStartedDetails.from_dict(details_data) + + step_succeeded_details = None + if details_data := data.get("StepSucceededDetails"): + step_succeeded_details = StepSucceededDetails.from_dict(details_data) + + step_failed_details = None + if details_data := data.get("StepFailedDetails"): + step_failed_details = StepFailedDetails.from_dict(details_data) + + chained_invoke_pending_details = None + if details_data := data.get("ChainedInvokePendingDetails"): + chained_invoke_pending_details = ChainedInvokePendingDetails.from_dict( + details_data + ) + + chained_invoke_started_details = None + if details_data := data.get("ChainedInvokeStartedDetails"): + chained_invoke_started_details = ChainedInvokeStartedDetails.from_dict( + details_data + ) + + chained_invoke_succeeded_details = None + if details_data := data.get("ChainedInvokeSucceededDetails"): + chained_invoke_succeeded_details = ChainedInvokeSucceededDetails.from_dict( + details_data + ) + + chained_invoke_failed_details = None + if details_data := data.get("ChainedInvokeFailedDetails"): + chained_invoke_failed_details = ChainedInvokeFailedDetails.from_dict( + details_data + ) + + chained_invoke_timed_out_details = None + if details_data := data.get("ChainedInvokeTimedOutDetails"): + chained_invoke_timed_out_details = ChainedInvokeTimedOutDetails.from_dict( + details_data + ) + + chained_invoke_stopped_details = None + if details_data := data.get("ChainedInvokeStoppedDetails"): + chained_invoke_stopped_details = ChainedInvokeStoppedDetails.from_dict( + details_data + ) + + callback_started_details = None + if details_data := data.get("CallbackStartedDetails"): + callback_started_details = CallbackStartedDetails.from_dict(details_data) + + callback_succeeded_details = None + if details_data := data.get("CallbackSucceededDetails"): + callback_succeeded_details = CallbackSucceededDetails.from_dict( + details_data + ) + + callback_failed_details = None + if details_data := data.get("CallbackFailedDetails"): + callback_failed_details = CallbackFailedDetails.from_dict(details_data) + + callback_timed_out_details = None + if details_data := data.get("CallbackTimedOutDetails"): + callback_timed_out_details = CallbackTimedOutDetails.from_dict(details_data) + + invocation_completed_details = None + if details_data := data.get("InvocationCompletedDetails"): + invocation_completed_details = InvocationCompletedDetails.from_dict( + details_data + ) + + return cls( + event_type=data["EventType"], + event_timestamp=data["EventTimestamp"], + sub_type=data.get("SubType"), + event_id=data.get("EventId", 1), + operation_id=data.get("Id"), + name=data.get("Name"), + parent_id=data.get("ParentId"), + execution_started_details=execution_started_details, + execution_succeeded_details=execution_succeeded_details, + execution_failed_details=execution_failed_details, + execution_timed_out_details=execution_timed_out_details, + execution_stopped_details=execution_stopped_details, + context_started_details=context_started_details, + context_succeeded_details=context_succeeded_details, + context_failed_details=context_failed_details, + wait_started_details=wait_started_details, + wait_succeeded_details=wait_succeeded_details, + wait_cancelled_details=wait_cancelled_details, + step_started_details=step_started_details, + step_succeeded_details=step_succeeded_details, + step_failed_details=step_failed_details, + chained_invoke_pending_details=chained_invoke_pending_details, + chained_invoke_started_details=chained_invoke_started_details, + chained_invoke_succeeded_details=chained_invoke_succeeded_details, + chained_invoke_failed_details=chained_invoke_failed_details, + chained_invoke_timed_out_details=chained_invoke_timed_out_details, + chained_invoke_stopped_details=chained_invoke_stopped_details, + callback_started_details=callback_started_details, + callback_succeeded_details=callback_succeeded_details, + callback_failed_details=callback_failed_details, + callback_timed_out_details=callback_timed_out_details, + invocation_completed_details=invocation_completed_details, + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = { + "EventType": self.event_type, + "EventTimestamp": self.event_timestamp, + "EventId": self.event_id, + } + if self.sub_type is not None: + result["SubType"] = self.sub_type + if self.operation_id is not None: + result["Id"] = self.operation_id + if self.name is not None: + result["Name"] = self.name + if self.parent_id is not None: + result["ParentId"] = self.parent_id + if self.execution_started_details is not None: + result["ExecutionStartedDetails"] = self.execution_started_details.to_dict() + if self.execution_succeeded_details is not None: + result["ExecutionSucceededDetails"] = ( + self.execution_succeeded_details.to_dict() + ) + if self.execution_failed_details is not None: + result["ExecutionFailedDetails"] = self.execution_failed_details.to_dict() + if self.execution_timed_out_details is not None: + result["ExecutionTimedOutDetails"] = ( + self.execution_timed_out_details.to_dict() + ) + if self.execution_stopped_details is not None: + result["ExecutionStoppedDetails"] = self.execution_stopped_details.to_dict() + if self.context_started_details is not None: + result["ContextStartedDetails"] = self.context_started_details.to_dict() + if self.context_succeeded_details is not None: + result["ContextSucceededDetails"] = self.context_succeeded_details.to_dict() + if self.context_failed_details is not None: + result["ContextFailedDetails"] = self.context_failed_details.to_dict() + if self.wait_started_details is not None: + result["WaitStartedDetails"] = self.wait_started_details.to_dict() + if self.wait_succeeded_details is not None: + result["WaitSucceededDetails"] = self.wait_succeeded_details.to_dict() + if self.wait_cancelled_details is not None: + result["WaitCancelledDetails"] = self.wait_cancelled_details.to_dict() + if self.step_started_details is not None: + result["StepStartedDetails"] = self.step_started_details.to_dict() + if self.step_succeeded_details is not None: + result["StepSucceededDetails"] = self.step_succeeded_details.to_dict() + if self.step_failed_details is not None: + result["StepFailedDetails"] = self.step_failed_details.to_dict() + if self.chained_invoke_pending_details is not None: + result["ChainedInvokePendingDetails"] = ( + self.chained_invoke_pending_details.to_dict() + ) + if self.chained_invoke_started_details is not None: + result["ChainedInvokeStartedDetails"] = ( + self.chained_invoke_started_details.to_dict() + ) + if self.chained_invoke_succeeded_details is not None: + result["ChainedInvokeSucceededDetails"] = ( + self.chained_invoke_succeeded_details.to_dict() + ) + if self.chained_invoke_failed_details is not None: + result["ChainedInvokeFailedDetails"] = ( + self.chained_invoke_failed_details.to_dict() + ) + if self.chained_invoke_timed_out_details is not None: + result["ChainedInvokeTimedOutDetails"] = ( + self.chained_invoke_timed_out_details.to_dict() + ) + if self.chained_invoke_stopped_details is not None: + result["ChainedInvokeStoppedDetails"] = ( + self.chained_invoke_stopped_details.to_dict() + ) + if self.callback_started_details is not None: + result["CallbackStartedDetails"] = self.callback_started_details.to_dict() + if self.callback_succeeded_details is not None: + result["CallbackSucceededDetails"] = ( + self.callback_succeeded_details.to_dict() + ) + if self.callback_failed_details is not None: + result["CallbackFailedDetails"] = self.callback_failed_details.to_dict() + if self.callback_timed_out_details is not None: + result["CallbackTimedOutDetails"] = ( + self.callback_timed_out_details.to_dict() + ) + if self.invocation_completed_details is not None: + result["InvocationCompletedDetails"] = ( + self.invocation_completed_details.to_dict() + ) + return result + + # region execution + @classmethod + def create_execution_event_started(cls, context: EventCreationContext) -> Event: + execution_details: ExecutionDetails | None = context.operation.execution_details + event_input: EventInput | None = ( + EventInput.from_details(execution_details, context.include_execution_data) + if execution_details + else None + ) + execution_timeout: int | None = ( + context.start_durable_execution_input.execution_timeout_seconds + ) + + return cls( + event_type=EventType.EXECUTION_STARTED.value, + event_timestamp=context.start_timestamp, + sub_type=context.sub_type, + event_id=context.event_id, + operation_id=context.operation.operation_id, + name=context.operation.name, + parent_id=context.operation.parent_id, + execution_started_details=ExecutionStartedDetails( + input=event_input, + execution_timeout=execution_timeout, + ), + ) + + @classmethod + def create_execution_event_succeeded(cls, context: EventCreationContext) -> Event: + result: EventResult | None = ( + EventResult.from_durable_execution_invocation_output( + context.durable_execution_invocation_output, + context.include_execution_data, + ) + if context.durable_execution_invocation_output + else None + ) + return cls( + event_type=EventType.EXECUTION_SUCCEEDED.value, + event_timestamp=context.end_timestamp, + sub_type=context.sub_type, + event_id=context.event_id, + operation_id=context.operation.operation_id, + name=context.operation.name, + parent_id=context.operation.parent_id, + execution_succeeded_details=ExecutionSucceededDetails(result=result), + ) + + @classmethod + def create_execution_event_failed(cls, context: EventCreationContext) -> Event: + error: EventError | None = ( + EventError.from_durable_execution_invocation_output( + context.durable_execution_invocation_output, + include=context.include_execution_data, + ) + if context.durable_execution_invocation_output + else None + ) + return cls( + event_type=EventType.EXECUTION_FAILED.value, + event_timestamp=context.end_timestamp, + sub_type=context.sub_type, + event_id=context.event_id, + operation_id=context.operation.operation_id, + name=context.operation.name, + parent_id=context.operation.parent_id, + execution_failed_details=ExecutionFailedDetails(error=error), + ) + + @classmethod + def create_execution_event_timed_out(cls, context: EventCreationContext) -> Event: + error: EventError | None = ( + EventError.from_durable_execution_invocation_output( + context.durable_execution_invocation_output, + include=context.include_execution_data, + ) + if context.durable_execution_invocation_output + else None + ) + return cls( + event_type=EventType.EXECUTION_TIMED_OUT.value, + event_timestamp=context.end_timestamp, + sub_type=context.sub_type, + event_id=context.event_id, + operation_id=context.operation.operation_id, + name=context.operation.name, + parent_id=context.operation.parent_id, + execution_timed_out_details=ExecutionTimedOutDetails(error=error), + ) + + @classmethod + def create_execution_event_stopped(cls, context: EventCreationContext) -> Event: + error: EventError | None = ( + EventError.from_durable_execution_invocation_output( + context.durable_execution_invocation_output, + include=context.include_execution_data, + ) + if context.durable_execution_invocation_output + else None + ) + return cls( + event_type=EventType.EXECUTION_STOPPED.value, + event_timestamp=context.end_timestamp, + sub_type=context.sub_type, + event_id=context.event_id, + operation_id=context.operation.operation_id, + name=context.operation.name, + parent_id=context.operation.parent_id, + execution_stopped_details=ExecutionStoppedDetails(error=error), + ) + + @classmethod + def create_execution_event(cls, context: EventCreationContext) -> Event: + """Create execution event based on action.""" + match context.operation.status: + case OperationStatus.STARTED: + return cls.create_execution_event_started(context) + case OperationStatus.SUCCEEDED: + return cls.create_execution_event_succeeded(context) + case OperationStatus.FAILED: + return cls.create_execution_event_failed(context) + case OperationStatus.TIMED_OUT: + return cls.create_execution_event_timed_out(context) + case OperationStatus.STOPPED: + return cls.create_execution_event_stopped(context) + case _: + msg = f"Operation status {context.operation.status} is not valid for execution operations. Valid statuses are: STARTED, SUCCEEDED, FAILED, TIMED_OUT, STOPPED" + raise InvalidParameterValueException(msg) + + # endregion execution + + # region context + @classmethod + def create_context_event_started(cls, context: EventCreationContext) -> Event: + return cls( + event_type=EventType.CONTEXT_STARTED.value, + event_timestamp=context.start_timestamp, + sub_type=context.sub_type, + event_id=context.event_id, + operation_id=context.operation.operation_id, + name=context.operation.name, + parent_id=context.operation.parent_id, + context_started_details=ContextStartedDetails(), + ) + + @classmethod + def create_context_event_succeeded(cls, context: EventCreationContext) -> Event: + context_details: ContextDetails | None = context.operation.context_details + event_result: EventResult | None = ( + EventResult.from_details(context_details, context.include_execution_data) + if context_details + else None + ) + return cls( + event_type=EventType.CONTEXT_SUCCEEDED.value, + event_timestamp=context.end_timestamp, + sub_type=context.sub_type, + event_id=context.event_id, + operation_id=context.operation.operation_id, + name=context.operation.name, + parent_id=context.operation.parent_id, + context_succeeded_details=ContextSucceededDetails(result=event_result), + ) + + @classmethod + def create_context_event_failed(cls, context: EventCreationContext) -> Event: + context_details: ContextDetails | None = context.operation.context_details + event_error: EventError | None = ( + EventError.from_details(context_details) if context_details else None + ) + return cls( + event_type=EventType.CONTEXT_FAILED.value, + event_timestamp=context.end_timestamp, + sub_type=context.sub_type, + event_id=context.event_id, + operation_id=context.operation.operation_id, + name=context.operation.name, + parent_id=context.operation.parent_id, + context_failed_details=ContextFailedDetails(error=event_error), + ) + + @classmethod + def create_context_event(cls, context: EventCreationContext) -> Event: + """Create context event based on action.""" + match context.operation.status: + case OperationStatus.STARTED: + return cls.create_context_event_started(context) + case OperationStatus.SUCCEEDED: + return cls.create_context_event_succeeded(context) + case OperationStatus.FAILED: + return cls.create_context_event_failed(context) + case _: + msg = ( + f"Operation status {context.operation.status} is not valid for context operations. " + f"Valid statuses are: STARTED, SUCCEEDED, FAILED" + ) + raise InvalidParameterValueException(msg) + + # endregion context + + # region wait + @classmethod + def create_wait_event_started(cls, context: EventCreationContext) -> Event: + wait_details: WaitDetails | None = context.operation.wait_details + scheduled_end_timestamp: datetime.datetime | None = ( + wait_details.scheduled_end_timestamp if wait_details else None + ) + duration: int | None = None + if ( + wait_details + and wait_details.scheduled_end_timestamp + and context.operation.start_timestamp + ): + duration = round( + ( + wait_details.scheduled_end_timestamp + - context.operation.start_timestamp + ).total_seconds() + ) + return cls( + event_type=EventType.WAIT_STARTED.value, + event_timestamp=context.start_timestamp, + sub_type=context.sub_type, + event_id=context.event_id, + operation_id=context.operation.operation_id, + name=context.operation.name, + parent_id=context.operation.parent_id, + wait_started_details=WaitStartedDetails( + duration=duration, + scheduled_end_timestamp=scheduled_end_timestamp, + ), + ) + + @classmethod + def create_wait_event_succeeded(cls, context: EventCreationContext) -> Event: + wait_details: WaitDetails | None = context.operation.wait_details + duration: int | None = None + if ( + wait_details + and wait_details.scheduled_end_timestamp + and context.operation.start_timestamp + ): + duration = round( + ( + wait_details.scheduled_end_timestamp - context.start_timestamp + ).total_seconds() + ) + return cls( + event_type=EventType.WAIT_SUCCEEDED.value, + event_timestamp=context.end_timestamp, + sub_type=context.sub_type, + event_id=context.event_id, + operation_id=context.operation.operation_id, + name=context.operation.name, + parent_id=context.operation.parent_id, + wait_succeeded_details=WaitSucceededDetails(duration=duration), + ) + + @classmethod + def create_wait_event_cancelled(cls, context: EventCreationContext) -> Event: + error: EventError | None = None + if ( + context.operation_update + and context.operation_update.operation_type == OperationType.WAIT + and context.operation_update.action == OperationAction.CANCEL + ): + error = EventError( + context.operation_update.error, not context.include_execution_data + ) + return cls( + event_type=EventType.WAIT_CANCELLED.value, + event_timestamp=context.end_timestamp, + sub_type=context.sub_type, + event_id=context.event_id, + operation_id=context.operation.operation_id, + name=context.operation.name, + parent_id=context.operation.parent_id, + wait_cancelled_details=WaitCancelledDetails(error=error), + ) + + @classmethod + def create_wait_event(cls, context: EventCreationContext) -> Event: + """Create wait event based on action.""" + match context.operation.status: + case OperationStatus.STARTED: + return cls.create_wait_event_started(context) + case OperationStatus.SUCCEEDED: + return cls.create_wait_event_succeeded(context) + case OperationStatus.CANCELLED: + return cls.create_wait_event_cancelled(context) + case _: + msg = ( + f"Operation status {context.operation.status} is not valid for wait operations. " + f"Valid statuses are: STARTED, SUCCEEDED, CANCELLED" + ) + raise InvalidParameterValueException(msg) + + # endregion wait + + # region step + @classmethod + def create_step_event_started(cls, context: EventCreationContext) -> Event: + return cls( + event_type=EventType.STEP_STARTED.value, + event_timestamp=context.start_timestamp, + sub_type=context.sub_type, + event_id=context.event_id, + operation_id=context.operation.operation_id, + name=context.operation.name, + parent_id=context.operation.parent_id, + step_started_details=StepStartedDetails(), + ) + + @classmethod + def create_step_event_succeeded(cls, context: EventCreationContext) -> Event: + step_details: StepDetails | None = context.operation.step_details + event_result: EventResult | None = ( + EventResult.from_details(step_details, context.include_execution_data) + if step_details + else None + ) + return cls( + event_type=EventType.STEP_SUCCEEDED.value, + event_timestamp=context.end_timestamp, + sub_type=context.sub_type, + event_id=context.event_id, + operation_id=context.operation.operation_id, + name=context.operation.name, + parent_id=context.operation.parent_id, + step_succeeded_details=StepSucceededDetails( + result=event_result, + retry_details=context.get_retry_details(), + ), + ) + + @classmethod + def create_step_event_failed(cls, context: EventCreationContext) -> Event: + step_details: StepDetails | None = context.operation.step_details + event_error: EventError | None = ( + EventError.from_details( + step_details, include=context.include_execution_data + ) + if step_details + else None + ) + return cls( + event_type=EventType.STEP_FAILED.value, + event_timestamp=context.end_timestamp, + sub_type=context.sub_type, + event_id=context.event_id, + operation_id=context.operation.operation_id, + name=context.operation.name, + parent_id=context.operation.parent_id, + step_failed_details=StepFailedDetails( + error=event_error, + retry_details=context.get_retry_details(), + ), + ) + + @classmethod + def create_step_event(cls, context: EventCreationContext) -> Event: + """Create step event based on action.""" + match context.operation.status: + case OperationStatus.STARTED: + return cls.create_step_event_started(context) + case OperationStatus.SUCCEEDED: + return cls.create_step_event_succeeded(context) + case OperationStatus.FAILED: + return cls.create_step_event_failed(context) + case _: + msg = ( + f"Operation status {context.operation.status} is not valid for step operations. " + f"Valid statuses are: STARTED, SUCCEEDED, FAILED" + ) + raise InvalidParameterValueException(msg) + + # endregion step + + # region chained_invoke + @classmethod + def create_chained_invoke_event_pending( + cls, context: EventCreationContext + ) -> Event: + input: EventInput = EventInput.from_start_durable_execution_input( + context.start_durable_execution_input, context.include_execution_data + ) + return cls( + event_type=EventType.CHAINED_INVOKE_STARTED.value, + event_timestamp=context.start_timestamp, + sub_type=context.sub_type, + event_id=context.event_id, + operation_id=context.operation.operation_id, + name=context.operation.name, + parent_id=context.operation.parent_id, + chained_invoke_pending_details=ChainedInvokePendingDetails( + input=input, + function_name=context.start_durable_execution_input.function_name, + ), + ) + + @classmethod + def create_chained_invoke_event_started( + cls, context: EventCreationContext + ) -> Event: + return cls( + event_type=EventType.CHAINED_INVOKE_STARTED.value, + event_timestamp=context.start_timestamp, + sub_type=context.sub_type, + event_id=context.event_id, + operation_id=context.operation.operation_id, + name=context.operation.name, + parent_id=context.operation.parent_id, + chained_invoke_started_details=ChainedInvokeStartedDetails( + durable_execution_arn=context.durable_execution_arn + ), + ) + + @classmethod + def create_chained_invoke_event_succeeded( + cls, context: EventCreationContext + ) -> Event: + chained_invoke_details: ChainedInvokeDetails | None = ( + context.operation.chained_invoke_details + ) + event_result: EventResult | None = ( + EventResult.from_details( + chained_invoke_details, context.include_execution_data + ) + if chained_invoke_details + else None + ) + return cls( + event_type=EventType.CHAINED_INVOKE_SUCCEEDED.value, + event_timestamp=context.end_timestamp, + sub_type=context.sub_type, + event_id=context.event_id, + operation_id=context.operation.operation_id, + name=context.operation.name, + parent_id=context.operation.parent_id, + chained_invoke_succeeded_details=ChainedInvokeSucceededDetails( + result=event_result + ), + ) + + @classmethod + def create_chained_invoke_event_failed(cls, context: EventCreationContext) -> Event: + chained_invoke_details: ChainedInvokeDetails | None = ( + context.operation.chained_invoke_details + ) + event_error: EventError | None = ( + EventError.from_details( + chained_invoke_details, include=context.include_execution_data + ) + if chained_invoke_details + else None + ) + return cls( + event_type=EventType.CHAINED_INVOKE_FAILED.value, + event_timestamp=context.end_timestamp, + sub_type=context.sub_type, + event_id=context.event_id, + operation_id=context.operation.operation_id, + name=context.operation.name, + parent_id=context.operation.parent_id, + chained_invoke_failed_details=ChainedInvokeFailedDetails(error=event_error), + ) + + @classmethod + def create_chained_invoke_event_timed_out( + cls, context: EventCreationContext + ) -> Event: + chained_invoke_details: ChainedInvokeDetails | None = ( + context.operation.chained_invoke_details + ) + event_error: EventError | None = ( + EventError.from_details( + chained_invoke_details, include=context.include_execution_data + ) + if chained_invoke_details + else None + ) + return cls( + event_type=EventType.CHAINED_INVOKE_TIMED_OUT.value, + event_timestamp=context.end_timestamp, + sub_type=context.sub_type, + event_id=context.event_id, + operation_id=context.operation.operation_id, + name=context.operation.name, + parent_id=context.operation.parent_id, + chained_invoke_timed_out_details=ChainedInvokeTimedOutDetails( + error=event_error + ), + ) + + @classmethod + def create_chained_invoke_event_stopped( + cls, context: EventCreationContext + ) -> Event: + chained_invoke_details: ChainedInvokeDetails | None = ( + context.operation.chained_invoke_details + ) + event_error: EventError | None = ( + EventError.from_details( + chained_invoke_details, include=context.include_execution_data + ) + if chained_invoke_details + else None + ) + return cls( + event_type=EventType.CHAINED_INVOKE_STOPPED.value, + event_timestamp=context.end_timestamp, + sub_type=context.sub_type, + event_id=context.event_id, + operation_id=context.operation.operation_id, + name=context.operation.name, + parent_id=context.operation.parent_id, + chained_invoke_stopped_details=ChainedInvokeStoppedDetails( + error=event_error + ), + ) + + @classmethod + def create_chained_invoke_event(cls, context: EventCreationContext) -> Event: + """Create chained invoke event based on action.""" + match context.operation.status: + case OperationStatus.PENDING: + return cls.create_chained_invoke_event_pending(context) + case OperationStatus.STARTED: + return cls.create_chained_invoke_event_started(context) + case OperationStatus.SUCCEEDED: + return cls.create_chained_invoke_event_succeeded(context) + case OperationStatus.FAILED: + return cls.create_chained_invoke_event_failed(context) + case OperationStatus.TIMED_OUT: + return cls.create_chained_invoke_event_timed_out(context) + case OperationStatus.STOPPED: + return cls.create_chained_invoke_event_stopped(context) + case _: + msg = ( + f"Operation status {context.operation.status} is not valid for chained invoke operations. Valid statuses are: " + f"STARTED, SUCCEEDED, FAILED, TIMED_OUT, STOPPED" + ) + raise InvalidParameterValueException(msg) + + # endregion chained_invoke + + # region callback + @classmethod + def create_callback_event_started(cls, context: EventCreationContext) -> Event: + callback_details: CallbackDetails | None = context.operation.callback_details + callback_id: str | None = ( + callback_details.callback_id if callback_details else None + ) + callback_options: CallbackOptions | None = ( + context.operation_update.callback_options + if context.operation_update + else None + ) + timeout: int | None = ( + callback_options.timeout_seconds if callback_options else None + ) + heartbeat_timeout: int | None = ( + callback_options.heartbeat_timeout_seconds if callback_options else None + ) + return cls( + event_type=EventType.CALLBACK_STARTED.value, + event_timestamp=context.start_timestamp, + sub_type=context.sub_type, + event_id=context.event_id, + operation_id=context.operation.operation_id, + name=context.operation.name, + parent_id=context.operation.parent_id, + callback_started_details=CallbackStartedDetails( + callback_id=callback_id, + timeout=timeout, + heartbeat_timeout=heartbeat_timeout, + ), + ) + + @classmethod + def create_callback_event_succeeded(cls, context: EventCreationContext) -> Event: + callback_details: CallbackDetails | None = context.operation.callback_details + event_result: EventResult | None = ( + EventResult.from_details(callback_details, context.include_execution_data) + if callback_details + else None + ) + return cls( + event_type=EventType.CALLBACK_SUCCEEDED.value, + event_timestamp=context.end_timestamp, + sub_type=context.sub_type, + event_id=context.event_id, + operation_id=context.operation.operation_id, + name=context.operation.name, + parent_id=context.operation.parent_id, + callback_succeeded_details=CallbackSucceededDetails(result=event_result), + ) + + @classmethod + def create_callback_event_failed(cls, context: EventCreationContext) -> Event: + callback_details: CallbackDetails | None = context.operation.callback_details + event_error: EventError | None = ( + EventError.from_details(callback_details) if callback_details else None + ) + return cls( + event_type=EventType.CALLBACK_FAILED.value, + event_timestamp=context.end_timestamp, + sub_type=context.sub_type, + event_id=context.event_id, + operation_id=context.operation.operation_id, + name=context.operation.name, + parent_id=context.operation.parent_id, + callback_failed_details=CallbackFailedDetails(error=event_error), + ) + + @classmethod + def create_callback_event_timed_out(cls, context: EventCreationContext) -> Event: + callback_details: CallbackDetails | None = context.operation.callback_details + event_error: EventError | None = ( + EventError.from_details(callback_details) if callback_details else None + ) + return cls( + event_type=EventType.CALLBACK_TIMED_OUT.value, + event_timestamp=context.end_timestamp, + sub_type=context.sub_type, + event_id=context.event_id, + operation_id=context.operation.operation_id, + name=context.operation.name, + parent_id=context.operation.parent_id, + callback_timed_out_details=CallbackTimedOutDetails(error=event_error), + ) + + @classmethod + def create_callback_event(cls, context: EventCreationContext) -> Event: + """Create callback event based on action.""" + match context.operation.status: + case OperationStatus.STARTED: + return cls.create_callback_event_started(context) + case OperationStatus.SUCCEEDED: + return cls.create_callback_event_succeeded(context) + case OperationStatus.FAILED: + return cls.create_callback_event_failed(context) + case OperationStatus.TIMED_OUT: + return cls.create_callback_event_timed_out(context) + case _: + msg = ( + f"Operation status {context.operation.status} is not valid for callback operations. " + f"Valid statuses are: STARTED, SUCCEEDED, FAILED, TIMED_OUT" + ) + raise InvalidParameterValueException(msg) + + # endregion callback + + # region invocation_completed + @classmethod + def create_invocation_completed( + cls, + event_id: int, + event_timestamp: datetime.datetime, + start_timestamp: datetime.datetime, + end_timestamp: datetime.datetime, + request_id: str, + ) -> Event: + """Create invocation completed event.""" + return cls( + event_type=EventType.INVOCATION_COMPLETED.value, + event_timestamp=event_timestamp, + event_id=event_id, + invocation_completed_details=InvocationCompletedDetails( + start_timestamp=start_timestamp, + end_timestamp=end_timestamp, + request_id=request_id, + ), + ) + + # endregion invocation_completed + + @classmethod + def create_event_started(cls, context: EventCreationContext) -> Event: + """Convert operation to started event.""" + if context.operation.start_timestamp is None: + msg: str = "Operation start timestamp cannot be None when converting to started event" + raise InvalidParameterValueException(msg) + + match context.operation.operation_type: + case OperationType.EXECUTION: + return cls.create_execution_event_started(context) + case OperationType.CONTEXT: + return cls.create_context_event_started(context) + case OperationType.WAIT: + return cls.create_wait_event_started(context) + case OperationType.STEP: + return cls.create_step_event_started(context) + case OperationType.CHAINED_INVOKE: + return cls.create_chained_invoke_event_started(context) + case OperationType.CALLBACK: + return cls.create_callback_event_started(context) + case _: + msg = f"Unknown operation type: {context.operation.operation_type}" + raise InvalidParameterValueException(msg) + + @classmethod + def from_event_with_id(cls, event: Event, event_id: int) -> Event: + """Create a new Event from an existing event with updated event_id.""" + return cls( + event_type=event.event_type, + event_timestamp=event.event_timestamp, + sub_type=event.sub_type, + event_id=event_id, + operation_id=event.operation_id, + name=event.name, + parent_id=event.parent_id, + execution_started_details=event.execution_started_details, + execution_succeeded_details=event.execution_succeeded_details, + execution_failed_details=event.execution_failed_details, + execution_timed_out_details=event.execution_timed_out_details, + execution_stopped_details=event.execution_stopped_details, + context_started_details=event.context_started_details, + context_succeeded_details=event.context_succeeded_details, + context_failed_details=event.context_failed_details, + wait_started_details=event.wait_started_details, + wait_succeeded_details=event.wait_succeeded_details, + wait_cancelled_details=event.wait_cancelled_details, + step_started_details=event.step_started_details, + step_succeeded_details=event.step_succeeded_details, + step_failed_details=event.step_failed_details, + chained_invoke_pending_details=event.chained_invoke_pending_details, + chained_invoke_started_details=event.chained_invoke_started_details, + chained_invoke_succeeded_details=event.chained_invoke_succeeded_details, + chained_invoke_failed_details=event.chained_invoke_failed_details, + chained_invoke_timed_out_details=event.chained_invoke_timed_out_details, + chained_invoke_stopped_details=event.chained_invoke_stopped_details, + callback_started_details=event.callback_started_details, + callback_succeeded_details=event.callback_succeeded_details, + callback_failed_details=event.callback_failed_details, + callback_timed_out_details=event.callback_timed_out_details, + ) + + @classmethod + def create_event_terminated(cls, context: EventCreationContext) -> Event: + """Convert operation to finished event.""" + operation: Operation = context.operation + if operation.end_timestamp is None: + msg: str = "Operation end timestamp cannot be None when converting to finished event" + raise InvalidParameterValueException(msg) + + if operation.status not in TERMINAL_STATUSES: + msg = f"Operation status must be one of SUCCEEDED, FAILED, TIMED_OUT, STOPPED, or CANCELLED. Got: {operation.status}" + raise InvalidParameterValueException(msg) + + match operation.operation_type: + case OperationType.EXECUTION: + return cls.create_execution_event(context) + case OperationType.CONTEXT: + return cls.create_context_event(context) + case OperationType.WAIT: + return cls.create_wait_event(context) + case OperationType.STEP: + return cls.create_step_event(context) + case OperationType.CHAINED_INVOKE: + return cls.create_chained_invoke_event(context) + case OperationType.CALLBACK: + return cls.create_callback_event(context) + case _: + msg = f"Unknown operation type: {operation.operation_type}" + raise InvalidParameterValueException(msg) + + +# endregion event_class + + +# region history_models +@dataclass(frozen=True) +class HistoryEventTypeConfig: + """Configuration for how to process a specific event type.""" + + operation_type: OperationType | None + operation_status: OperationStatus | None + is_start_event: bool + is_end_event: bool + has_result: bool # Whether this event type contains result/error data + + +# Mapping of event types to their processing configuration +# This matches the TypeScript historyEventTypes constant +HISTORY_EVENT_TYPES: dict[str, HistoryEventTypeConfig] = { + "ExecutionStarted": HistoryEventTypeConfig( + operation_type=OperationType.EXECUTION, + operation_status=OperationStatus.STARTED, + is_start_event=True, + is_end_event=False, + has_result=False, + ), + "ExecutionFailed": HistoryEventTypeConfig( + operation_type=OperationType.EXECUTION, + operation_status=OperationStatus.FAILED, + is_start_event=False, + is_end_event=True, + has_result=False, + ), + "ExecutionStopped": HistoryEventTypeConfig( + operation_type=OperationType.EXECUTION, + operation_status=OperationStatus.STOPPED, + is_start_event=False, + is_end_event=True, + has_result=False, + ), + "ExecutionSucceeded": HistoryEventTypeConfig( + operation_type=OperationType.EXECUTION, + operation_status=OperationStatus.SUCCEEDED, + is_start_event=False, + is_end_event=True, + has_result=False, + ), + "ExecutionTimedOut": HistoryEventTypeConfig( + operation_type=OperationType.EXECUTION, + operation_status=OperationStatus.TIMED_OUT, + is_start_event=False, + is_end_event=True, + has_result=False, + ), + "CallbackStarted": HistoryEventTypeConfig( + operation_type=OperationType.CALLBACK, + operation_status=OperationStatus.STARTED, + is_start_event=True, + is_end_event=False, + has_result=False, + ), + "CallbackFailed": HistoryEventTypeConfig( + operation_type=OperationType.CALLBACK, + operation_status=OperationStatus.FAILED, + is_start_event=False, + is_end_event=True, + has_result=True, + ), + "CallbackSucceeded": HistoryEventTypeConfig( + operation_type=OperationType.CALLBACK, + operation_status=OperationStatus.SUCCEEDED, + is_start_event=False, + is_end_event=True, + has_result=True, + ), + "CallbackTimedOut": HistoryEventTypeConfig( + operation_type=OperationType.CALLBACK, + operation_status=OperationStatus.TIMED_OUT, + is_start_event=False, + is_end_event=True, + has_result=True, + ), + "ContextStarted": HistoryEventTypeConfig( + operation_type=OperationType.CONTEXT, + operation_status=OperationStatus.STARTED, + is_start_event=True, + is_end_event=False, + has_result=False, + ), + "ContextFailed": HistoryEventTypeConfig( + operation_type=OperationType.CONTEXT, + operation_status=OperationStatus.FAILED, + is_start_event=False, + is_end_event=True, + has_result=True, + ), + "ContextSucceeded": HistoryEventTypeConfig( + operation_type=OperationType.CONTEXT, + operation_status=OperationStatus.SUCCEEDED, + is_start_event=False, + is_end_event=True, + has_result=True, + ), + "ChainedInvokeStarted": HistoryEventTypeConfig( + operation_type=OperationType.CHAINED_INVOKE, + operation_status=OperationStatus.STARTED, + is_start_event=True, + is_end_event=False, + has_result=False, + ), + "ChainedInvokeFailed": HistoryEventTypeConfig( + operation_type=OperationType.CHAINED_INVOKE, + operation_status=OperationStatus.FAILED, + is_start_event=False, + is_end_event=True, + has_result=True, + ), + "ChainedInvokeSucceeded": HistoryEventTypeConfig( + operation_type=OperationType.CHAINED_INVOKE, + operation_status=OperationStatus.SUCCEEDED, + is_start_event=False, + is_end_event=True, + has_result=True, + ), + "ChainedInvokeTimedOut": HistoryEventTypeConfig( + operation_type=OperationType.CHAINED_INVOKE, + operation_status=OperationStatus.TIMED_OUT, + is_start_event=False, + is_end_event=True, + has_result=True, + ), + "ChainedInvokeCancelled": HistoryEventTypeConfig( + operation_type=OperationType.CHAINED_INVOKE, + operation_status=OperationStatus.CANCELLED, + is_start_event=False, + is_end_event=True, + has_result=True, + ), + "StepStarted": HistoryEventTypeConfig( + operation_type=OperationType.STEP, + operation_status=OperationStatus.STARTED, + is_start_event=True, + is_end_event=False, + has_result=False, + ), + "StepFailed": HistoryEventTypeConfig( + operation_type=OperationType.STEP, + operation_status=OperationStatus.FAILED, + is_start_event=False, + is_end_event=True, + has_result=True, + ), + "StepSucceeded": HistoryEventTypeConfig( + operation_type=OperationType.STEP, + operation_status=OperationStatus.SUCCEEDED, + is_start_event=False, + is_end_event=True, + has_result=True, + ), + "WaitStarted": HistoryEventTypeConfig( + operation_type=OperationType.WAIT, + operation_status=OperationStatus.STARTED, + is_start_event=True, + is_end_event=False, + has_result=True, + ), + "WaitSucceeded": HistoryEventTypeConfig( + operation_type=OperationType.WAIT, + operation_status=OperationStatus.SUCCEEDED, + is_start_event=False, + is_end_event=True, + has_result=True, + ), + "WaitCancelled": HistoryEventTypeConfig( + operation_type=OperationType.WAIT, + operation_status=OperationStatus.CANCELLED, + is_start_event=False, + is_end_event=True, + has_result=True, + ), + # TODO: add support for populating invocation information from InvocationCompleted event + "InvocationCompleted": HistoryEventTypeConfig( + operation_type=None, + operation_status=None, + is_start_event=False, + is_end_event=False, + has_result=True, + ), +} + + +def events_to_operations(events: list[Event]) -> list[Operation]: + """Convert a list of history events into operations. + + This function processes raw history events and groups them by operation ID, + creating comprehensive operation objects following the TypeScript pattern from + aws-durable-execution-sdk-js-testing. + + Multiple events for the same operation_id are merged together, with each event + contributing its specific fields (e.g., CallbackStarted provides callback_id, + CallbackSucceeded provides result). + + Args: + events: List of history events to process + + Returns: + List of operations, one per unique operation ID + + Raises: + InvalidParameterValueException: When required fields are missing from an event + + Note: + InvocationCompleted events are currently skipped as they don't represent + operations. Future enhancement: populate invocation information from these + events (TODO). + """ + operations_map: dict[str, Operation] = {} + + for event in events: + if not event.event_type: + msg = "Missing required 'event_type' field in event" + raise InvalidParameterValueException(msg) + + # Get event type configuration + event_config: HistoryEventTypeConfig | None = HISTORY_EVENT_TYPES.get( + event.event_type + ) + if not event_config: + msg = f"Unknown event type: {event.event_type}" + raise InvalidParameterValueException(msg) + + # TODO: add support for populating invocation information from InvocationCompleted event + if event.event_type == "InvocationCompleted": + continue + + if not event.operation_id: + msg = f"Missing required 'operation_id' field in event {event.event_id}" + raise InvalidParameterValueException(msg) + + # Get previous operation if it exists + previous_operation: Operation | None = operations_map.get(event.operation_id) + + # Get operation type and status from configuration + operation_type: OperationType = ( + event_config.operation_type or OperationType.EXECUTION + ) + status: OperationStatus = ( + event_config.operation_status or OperationStatus.PENDING + ) + + # Parse sub_type + sub_type: OperationSubType | None = None + if event.sub_type: + try: + sub_type = OperationSubType(event.sub_type) + except ValueError as e: + raise InvalidParameterValueException(str(e)) from e + + # Create base operation + operation = Operation( + operation_id=event.operation_id, + operation_type=operation_type, + status=status, + name=event.name, + parent_id=event.parent_id, + sub_type=sub_type, + start_timestamp=datetime.datetime.now(tz=datetime.timezone.utc), + ) + + # Merge with previous operation if it exists + # Most fields are immutable, so they get preserved from previous events + if previous_operation: + operation = replace( + operation, + name=operation.name or previous_operation.name, + parent_id=operation.parent_id or previous_operation.parent_id, + sub_type=operation.sub_type or previous_operation.sub_type, + start_timestamp=previous_operation.start_timestamp, + end_timestamp=previous_operation.end_timestamp, + execution_details=previous_operation.execution_details, + context_details=previous_operation.context_details, + step_details=previous_operation.step_details, + wait_details=previous_operation.wait_details, + callback_details=previous_operation.callback_details, + chained_invoke_details=previous_operation.chained_invoke_details, + ) + + # Set timestamps based on event configuration + if event_config.is_start_event: + operation = replace(operation, start_timestamp=event.event_timestamp) + if event_config.is_end_event: + operation = replace(operation, end_timestamp=event.event_timestamp) + + # Add operation-specific details incrementally + # Each event type contributes only the fields it has + + # EXECUTION details + if ( + operation_type == OperationType.EXECUTION + and event.execution_started_details + and event.execution_started_details.input + ): + operation = replace( + operation, + execution_details=ExecutionDetails( + input_payload=event.execution_started_details.input.payload + ), + ) + + # CALLBACK details - merge callback_id, result, and error from different events + if operation_type == OperationType.CALLBACK: + existing_cb: CallbackDetails | None = operation.callback_details + callback_id: str = existing_cb.callback_id if existing_cb else "" + result: str | None = existing_cb.result if existing_cb else None + error: ErrorObject | None = existing_cb.error if existing_cb else None + + # CallbackStarted provides callback_id + if event.callback_started_details: + callback_id = event.callback_started_details.callback_id or callback_id + + # CallbackSucceeded provides result + if ( + event.callback_succeeded_details + and event.callback_succeeded_details.result + ): + result = event.callback_succeeded_details.result.payload + + # CallbackFailed provides error + if event.callback_failed_details and event.callback_failed_details.error: + error = event.callback_failed_details.error.payload + + # CallbackTimedOut provides error + if ( + event.callback_timed_out_details + and event.callback_timed_out_details.error + ): + error = event.callback_timed_out_details.error.payload + + operation = replace( + operation, + callback_details=CallbackDetails( + callback_id=callback_id, + result=result, + error=error, + ), + ) + + # STEP details - only update if this event type has result data + if operation_type == OperationType.STEP and event_config.has_result: + existing_step: StepDetails | None = operation.step_details + result_val: str | None = existing_step.result if existing_step else None + error_val: ErrorObject | None = ( + existing_step.error if existing_step else None + ) + attempt: int = existing_step.attempt if existing_step else 0 + next_attempt_ts: datetime.datetime | None = ( + existing_step.next_attempt_timestamp if existing_step else None + ) + + # StepSucceeded provides result + if event.step_succeeded_details: + if event.step_succeeded_details.result: + result_val = event.step_succeeded_details.result.payload + if event.step_succeeded_details.retry_details: + attempt = event.step_succeeded_details.retry_details.current_attempt + + # StepFailed provides error and retry details + if event.step_failed_details: + if event.step_failed_details.error: + error_val = event.step_failed_details.error.payload + if event.step_failed_details.retry_details: + attempt = event.step_failed_details.retry_details.current_attempt + if ( + event.step_failed_details.retry_details.next_attempt_delay_seconds + is not None + ): + next_attempt_ts = event.event_timestamp + datetime.timedelta( + seconds=event.step_failed_details.retry_details.next_attempt_delay_seconds + ) + + operation = replace( + operation, + step_details=StepDetails( + result=result_val, + error=error_val, + attempt=attempt, + next_attempt_timestamp=next_attempt_ts, + ), + ) + + # WAIT details + if operation_type == OperationType.WAIT and event.wait_started_details: + operation = replace( + operation, + wait_details=WaitDetails( + scheduled_end_timestamp=event.wait_started_details.scheduled_end_timestamp + ), + ) + + # CONTEXT details - only update if this event type has result data (matching TypeScript hasResult) + if operation_type == OperationType.CONTEXT and event_config.has_result: + if ( + event.context_succeeded_details + and event.context_succeeded_details.result + ): + operation = replace( + operation, + context_details=ContextDetails( + result=event.context_succeeded_details.result.payload, + error=None, + ), + ) + elif event.context_failed_details and event.context_failed_details.error: + operation = replace( + operation, + context_details=ContextDetails( + result=None, + error=event.context_failed_details.error.payload, + ), + ) + + # CHAINED_INVOKE details - only update if this event type has result data (matching TypeScript hasResult) + if operation_type == OperationType.CHAINED_INVOKE and event_config.has_result: + if ( + event.chained_invoke_succeeded_details + and event.chained_invoke_succeeded_details.result + ): + operation = replace( + operation, + chained_invoke_details=ChainedInvokeDetails( + result=event.chained_invoke_succeeded_details.result.payload, + error=None, + ), + ) + elif ( + event.chained_invoke_failed_details + and event.chained_invoke_failed_details.error + ): + operation = replace( + operation, + chained_invoke_details=ChainedInvokeDetails( + result=None, + error=event.chained_invoke_failed_details.error.payload, + ), + ) + + # Store in map + operations_map[event.operation_id] = operation + + return list(operations_map.values()) + + +@dataclass(frozen=True) +class GetDurableExecutionHistoryRequest: + """Request to get durable execution history.""" + + durable_execution_arn: str + include_execution_data: bool | None = None + reverse_order: bool | None = None + marker: str | None = None + max_items: int = 0 + + @classmethod + def from_dict(cls, data: dict) -> GetDurableExecutionHistoryRequest: + return cls( + durable_execution_arn=data["DurableExecutionArn"], + include_execution_data=data.get("IncludeExecutionData"), + reverse_order=data.get("ReverseOrder"), + marker=data.get("Marker"), + max_items=data.get("MaxItems", 0), + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {"DurableExecutionArn": self.durable_execution_arn} + if self.include_execution_data is not None: + result["IncludeExecutionData"] = self.include_execution_data + if self.reverse_order is not None: + result["ReverseOrder"] = self.reverse_order + if self.marker is not None: + result["Marker"] = self.marker + if self.max_items is not None: + result["MaxItems"] = self.max_items + return result + + +@dataclass(frozen=True) +class GetDurableExecutionHistoryResponse: + """Response containing durable execution history events.""" + + events: list[Event] + next_marker: str | None = None + + @classmethod + def from_dict(cls, data: dict) -> GetDurableExecutionHistoryResponse: + events = [Event.from_dict(event_data) for event_data in data.get("Events", [])] + return cls( + events=events, + next_marker=data.get("NextMarker"), + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {"Events": [event.to_dict() for event in self.events]} + if self.next_marker is not None: + result["NextMarker"] = self.next_marker + return result + + +@dataclass(frozen=True) +class ListDurableExecutionsByFunctionRequest: + """Request to list durable executions by function.""" + + function_name: str + qualifier: str | None = None + durable_execution_name: str | None = None + status_filter: list[str] | None = None + started_after: str | None = None + started_before: str | None = None + marker: str | None = None + max_items: int = 0 + reverse_order: bool | None = None + + @classmethod + def from_dict(cls, data: dict) -> ListDurableExecutionsByFunctionRequest: + # Handle query parameters that may be lists + function_name = data.get("FunctionName") + if isinstance(function_name, list): + function_name = function_name[0] if function_name else "" + elif not function_name: + function_name = "" + + qualifier = data.get("Qualifier") or data.get("functionVersion") + if isinstance(qualifier, list): + qualifier = qualifier[0] if qualifier else None + + durable_execution_name = data.get("DurableExecutionName") or data.get( + "executionName" + ) + if isinstance(durable_execution_name, list): + durable_execution_name = ( + durable_execution_name[0] if durable_execution_name else None + ) + + status_filter = data.get("StatusFilter") or data.get("statusFilter") + if isinstance(status_filter, list): + status_filter = status_filter if status_filter else None + elif status_filter: + status_filter = [status_filter] + + started_after = data.get("StartedAfter") or data.get("startedAfter") + if isinstance(started_after, list): + started_after = started_after[0] if started_after else None + + started_before = data.get("StartedBefore") or data.get("startedBefore") + if isinstance(started_before, list): + started_before = started_before[0] if started_before else None + + marker = data.get("Marker") or data.get("marker") + if isinstance(marker, list): + marker = marker[0] if marker else None + + max_items = data.get("MaxItems") or data.get("maxItems", 0) + if isinstance(max_items, list): + max_items = int(max_items[0]) if max_items else 0 + + reverse_order = data.get("ReverseOrder") or data.get("reverseOrder") + if isinstance(reverse_order, list): + reverse_order = ( + reverse_order[0].lower() in ("true", "1", "yes") + if reverse_order + else None + ) + elif isinstance(reverse_order, str): + reverse_order = reverse_order.lower() in ("true", "1", "yes") + + return cls( + function_name=function_name, + qualifier=qualifier, + durable_execution_name=durable_execution_name, + status_filter=status_filter, + started_after=started_after, + started_before=started_before, + marker=marker, + max_items=max_items, + reverse_order=reverse_order, + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {"FunctionName": self.function_name} + if self.qualifier is not None: + result["Qualifier"] = self.qualifier + if self.durable_execution_name is not None: + result["DurableExecutionName"] = self.durable_execution_name + if self.status_filter is not None: + result["StatusFilter"] = self.status_filter + if self.started_after is not None: + result["StartedAfter"] = self.started_after + if self.started_before is not None: + result["StartedBefore"] = self.started_before + if self.marker is not None: + result["Marker"] = self.marker + if self.max_items is not None: + result["MaxItems"] = self.max_items + if self.reverse_order is not None: + result["ReverseOrder"] = self.reverse_order + return result + + +@dataclass(frozen=True) +class ListDurableExecutionsByFunctionResponse: + """Response containing list of durable executions by function.""" + + durable_executions: list[Execution] + next_marker: str | None = None + + @classmethod + def from_dict(cls, data: dict) -> ListDurableExecutionsByFunctionResponse: + executions = [ + Execution.from_dict(exec_data) + for exec_data in data.get("DurableExecutions", []) + ] + return cls( + durable_executions=executions, + next_marker=data.get("NextMarker"), + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = { + "DurableExecutions": [exe.to_dict() for exe in self.durable_executions] + } + if self.next_marker is not None: + result["NextMarker"] = self.next_marker + return result + + +# endregion history_models + + +# region callback_models +# Callback-related models +@dataclass(frozen=True) +class SendDurableExecutionCallbackSuccessRequest: + """Request to send callback success.""" + + callback_id: str + result: bytes | None = None + + @classmethod + def from_dict(cls, data: dict) -> SendDurableExecutionCallbackSuccessRequest: + return cls( + callback_id=data["CallbackId"], + result=data.get("Result"), + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {"CallbackId": self.callback_id} + if self.result is not None: + result["Result"] = self.result + return result + + +@dataclass(frozen=True) +class SendDurableExecutionCallbackSuccessResponse: + """Response from sending callback success.""" + + +@dataclass(frozen=True) +class SendDurableExecutionCallbackFailureRequest: + """Request to send callback failure.""" + + callback_id: str + error: ErrorObject | None = None + + @classmethod + def from_dict( + cls, data: dict, callback_id: str + ) -> SendDurableExecutionCallbackFailureRequest: + error = ErrorObject.from_dict(data) if data else None + + return cls( + callback_id=callback_id, + error=error, + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {"CallbackId": self.callback_id} + if self.error is not None: + result["Error"] = self.error.to_dict() + return result + + +@dataclass(frozen=True) +class SendDurableExecutionCallbackFailureResponse: + """Response from sending callback failure.""" + + +@dataclass(frozen=True) +class SendDurableExecutionCallbackHeartbeatRequest: + """Request to send callback heartbeat.""" + + callback_id: str + + @classmethod + def from_dict(cls, data: dict) -> SendDurableExecutionCallbackHeartbeatRequest: + return cls(callback_id=data["CallbackId"]) + + def to_dict(self) -> dict[str, Any]: + return {"CallbackId": self.callback_id} + + +@dataclass(frozen=True) +class SendDurableExecutionCallbackHeartbeatResponse: + """Response from sending callback heartbeat.""" + + +# endregion callback_models + + +# region checkpoint_models +# Checkpoint-related models +@dataclass(frozen=True) +class CheckpointUpdatedExecutionState: + """Updated execution state from checkpoint.""" + + operations: list[Operation] + next_marker: str | None = None + + @classmethod + def from_dict(cls, data: dict) -> CheckpointUpdatedExecutionState: + operations = [ + Operation.from_dict(op_data) for op_data in data.get("Operations", []) + ] + return cls( + operations=operations, + next_marker=data.get("NextMarker"), + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = { + "Operations": [op.to_dict() for op in self.operations] + } + if self.next_marker is not None: + result["NextMarker"] = self.next_marker + return result + + +@dataclass(frozen=True) +class CheckpointDurableExecutionRequest: + """Request to checkpoint a durable execution.""" + + durable_execution_arn: str + checkpoint_token: str + updates: list[OperationUpdate] | None = None + client_token: str | None = None + + @classmethod + def from_dict( + cls, data: dict, durable_execution_arn: str + ) -> CheckpointDurableExecutionRequest: + updates = None + if updates_data := data.get("Updates"): + updates = [] + for update_data in updates_data: + # Map dictionary fields to OperationUpdate constructor parameters + operation_update = OperationUpdate( + operation_id=update_data["Id"], + operation_type=OperationType(update_data["Type"]), + action=OperationAction(update_data["Action"]), + parent_id=update_data.get("ParentId"), + name=update_data.get("Name"), + sub_type=OperationSubType(update_data["SubType"]) + if update_data.get("SubType") + else None, + payload=update_data.get("Payload"), + error=ErrorObject.from_dict(update_data["Error"]) + if update_data.get("Error") + else None, + context_options=ContextOptions.from_dict( + update_data["ContextOptions"] + ) + if update_data.get("ContextOptions") + else None, + step_options=StepOptions.from_dict(update_data["StepOptions"]) + if update_data.get("StepOptions") + else None, + wait_options=WaitOptions.from_dict(update_data["WaitOptions"]) + if update_data.get("WaitOptions") + else None, + callback_options=CallbackOptions.from_dict( + update_data["CallbackOptions"] + ) + if update_data.get("CallbackOptions") + else None, + chained_invoke_options=ChainedInvokeOptions.from_dict( + update_data["ChainedInvokeOptions"] + ) + if update_data.get("ChainedInvokeOptions") + else None, + ) + updates.append(operation_update) + + return cls( + durable_execution_arn=durable_execution_arn, + checkpoint_token=data["CheckpointToken"], + updates=updates, + client_token=data.get("ClientToken"), + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = { + "DurableExecutionArn": self.durable_execution_arn, + "CheckpointToken": self.checkpoint_token, + } + if self.updates is not None: + result["Updates"] = [update.to_dict() for update in self.updates] + if self.client_token is not None: + result["ClientToken"] = self.client_token + return result + + +@dataclass(frozen=True) +class CheckpointDurableExecutionResponse: + """Response from checkpointing a durable execution.""" + + checkpoint_token: str + new_execution_state: CheckpointUpdatedExecutionState | None = None + + @classmethod + def from_dict(cls, data: dict) -> CheckpointDurableExecutionResponse: + new_execution_state = None + if state_data := data.get("NewExecutionState"): + new_execution_state = CheckpointUpdatedExecutionState.from_dict(state_data) + + return cls( + checkpoint_token=data["CheckpointToken"], + new_execution_state=new_execution_state, + ) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {"CheckpointToken": self.checkpoint_token} + if self.new_execution_state is not None: + result["NewExecutionState"] = self.new_execution_state.to_dict() + return result + + +# endregion checkpoint_models + + +# region error_models +# Error response structure for consistent error handling +@dataclass(frozen=True) +class ErrorResponse: + """Structured error response for web service operations.""" + + error_type: str + error_message: str + error_code: str | None = None + request_id: str | None = None + + @classmethod + def from_dict(cls, data: dict) -> ErrorResponse: + """Create ErrorResponse from dictionary. + + Args: + data: Dictionary containing error data + + Returns: + ErrorResponse: The error response object + """ + error_data = data.get("error", data) # Support both nested and flat structures + return cls( + error_type=error_data["type"], + error_message=error_data["message"], + error_code=error_data.get("code"), + request_id=error_data.get("requestId"), + ) + + def to_dict(self) -> dict[str, Any]: + """Convert ErrorResponse to dictionary. + + Returns: + dict: Dictionary representation of the error response + """ + error_data: dict[str, Any] = { + "type": self.error_type, + "message": self.error_message, + } + + if self.error_code is not None: + error_data["code"] = self.error_code + if self.request_id is not None: + error_data["requestId"] = self.request_id + + return {"error": error_data} + + +# endregion error_models diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/observer.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/observer.py new file mode 100644 index 0000000..1b518ce --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/observer.py @@ -0,0 +1,144 @@ +"""Checkpoint processors can notify the Execution of notable event state changes. Observer pattern.""" + +from __future__ import annotations + +import threading +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +from aws_durable_execution_sdk_python_testing.token import CallbackToken + +if TYPE_CHECKING: + from collections.abc import Callable + + from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + CallbackOptions, + ) + + +class ExecutionObserver(ABC): + """Observer for execution lifecycle events.""" + + @abstractmethod + def on_completed(self, execution_arn: str, result: str | None = None) -> None: + """Called when execution completes successfully.""" + + @abstractmethod + def on_failed(self, execution_arn: str, error: ErrorObject) -> None: + """Called when execution fails.""" + + @abstractmethod + def on_timed_out(self, execution_arn: str, error: ErrorObject) -> None: + """Called when execution times out.""" + + @abstractmethod + def on_stopped(self, execution_arn: str, error: ErrorObject) -> None: + """Called when execution is stopped.""" + + @abstractmethod + def on_wait_timer_scheduled( + self, execution_arn: str, operation_id: str, delay: float + ) -> None: + """Called when wait timer scheduled.""" + + @abstractmethod + def on_step_retry_scheduled( + self, execution_arn: str, operation_id: str, delay: float + ) -> None: + """Called when step retry scheduled.""" + + @abstractmethod + def on_callback_created( + self, + execution_arn: str, + operation_id: str, + callback_options: CallbackOptions | None, + callback_token: CallbackToken, + ) -> None: + """Called when callback is created.""" + + +class ExecutionNotifier: + """Notifies observers about execution events. Thread-safe.""" + + def __init__(self) -> None: + self._observers: list[ExecutionObserver] = [] + self._lock = threading.RLock() + + def add_observer(self, observer: ExecutionObserver) -> None: + """Add an observer to be notified of execution events.""" + with self._lock: + self._observers.append(observer) + + def _notify_observers(self, method: Callable, *args, **kwargs) -> None: + """Notify all observers by calling the specified method.""" + with self._lock: + observers = self._observers.copy() + for observer in observers: + getattr(observer, method.__name__)(*args, **kwargs) + + # region event emitters + def notify_completed(self, execution_arn: str, result: str | None = None) -> None: + """Notify observers about execution completion.""" + self._notify_observers( + ExecutionObserver.on_completed, execution_arn=execution_arn, result=result + ) + + def notify_failed(self, execution_arn: str, error: ErrorObject) -> None: + """Notify observers about execution failure.""" + self._notify_observers( + ExecutionObserver.on_failed, execution_arn=execution_arn, error=error + ) + + def notify_timed_out(self, execution_arn: str, error: ErrorObject) -> None: + """Notify observers about execution timeout.""" + self._notify_observers( + ExecutionObserver.on_timed_out, execution_arn=execution_arn, error=error + ) + + def notify_stopped(self, execution_arn: str, error: ErrorObject) -> None: + """Notify observers about execution being stopped.""" + self._notify_observers( + ExecutionObserver.on_stopped, execution_arn=execution_arn, error=error + ) + + def notify_wait_timer_scheduled( + self, execution_arn: str, operation_id: str, delay: float + ) -> None: + """Notify observers about wait timer scheduling.""" + self._notify_observers( + ExecutionObserver.on_wait_timer_scheduled, + execution_arn=execution_arn, + operation_id=operation_id, + delay=delay, + ) + + def notify_step_retry_scheduled( + self, execution_arn: str, operation_id: str, delay: float + ) -> None: + """Notify observers about step retry scheduling.""" + self._notify_observers( + ExecutionObserver.on_step_retry_scheduled, + execution_arn=execution_arn, + operation_id=operation_id, + delay=delay, + ) + + def notify_callback_created( + self, + execution_arn: str, + operation_id: str, + callback_options: CallbackOptions | None, + callback_token: CallbackToken, + ) -> None: + """Notify observers about callback creation.""" + self._notify_observers( + ExecutionObserver.on_callback_created, + execution_arn=execution_arn, + operation_id=operation_id, + callback_options=callback_options, + callback_token=callback_token, + ) + + # endregion event emitters diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/py.typed b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/py.typed new file mode 100644 index 0000000..7ef2116 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/py.typed @@ -0,0 +1 @@ +# Marker file that indicates this package supports typing diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/runner.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/runner.py new file mode 100644 index 0000000..d60e774 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/runner.py @@ -0,0 +1,1167 @@ +from __future__ import annotations + +import json +import logging +import os +import time +from dataclasses import dataclass, field +from typing import ( + TYPE_CHECKING, + Any, + Concatenate, + ParamSpec, + Protocol, + Self, + TypeVar, + cast, +) + +import aws_durable_execution_sdk_python +import boto3 # type: ignore +from botocore.exceptions import ClientError # type: ignore +from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, + durable_execution, +) +from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + OperationPayload, + OperationStatus, + OperationSubType, + OperationType, +) +from aws_durable_execution_sdk_python.lambda_service import Operation as SvcOperation + +from aws_durable_execution_sdk_python_testing.checkpoint.processor import ( + CheckpointProcessor, +) +from aws_durable_execution_sdk_python_testing.client import InMemoryServiceClient +from aws_durable_execution_sdk_python_testing.exceptions import ( + DurableFunctionsLocalRunnerError, + DurableFunctionsTestError, + InvalidParameterValueException, + ResourceNotFoundException, +) +from aws_durable_execution_sdk_python_testing.executor import Executor +from aws_durable_execution_sdk_python_testing.invoker import ( + InProcessInvoker, + LambdaInvoker, + create_lambda_client, +) +from aws_durable_execution_sdk_python_testing.model import ( + GetDurableExecutionHistoryResponse, + GetDurableExecutionResponse, + StartDurableExecutionInput, + StartDurableExecutionOutput, + events_to_operations, +) +from aws_durable_execution_sdk_python_testing.scheduler import Scheduler +from aws_durable_execution_sdk_python_testing.stores.base import ( + ExecutionStore, + StoreType, +) +from aws_durable_execution_sdk_python_testing.stores.filesystem import ( + FileSystemExecutionStore, +) +from aws_durable_execution_sdk_python_testing.stores.memory import ( + InMemoryExecutionStore, +) +from aws_durable_execution_sdk_python_testing.stores.sqlite import SQLiteExecutionStore +from aws_durable_execution_sdk_python_testing.web.server import WebServer + + +if TYPE_CHECKING: + import datetime + from collections.abc import Callable, MutableMapping + + from aws_durable_execution_sdk_python.context import DurableContext + from aws_durable_execution_sdk_python.execution import InvocationStatus + + from aws_durable_execution_sdk_python_testing.execution import Execution + from aws_durable_execution_sdk_python_testing.web.server import WebServiceConfig + from aws_durable_execution_sdk_python_testing.model import Event + + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class WebRunnerConfig: + """Configuration for the WebRunner using composition pattern. + + This configuration class encapsulates all settings needed to run the web server + for durable functions testing, including HTTP server configuration and Lambda + service configuration. + """ + + # HTTP server configuration (existing WebServiceConfig) + web_service: WebServiceConfig + + # Lambda service configuration (web runner specific) + lambda_endpoint: str = "http://127.0.0.1:3001" + local_runner_endpoint: str = "http://0.0.0.0:5000" + local_runner_region: str = "us-west-2" + local_runner_mode: str = "local" + + # Store configuration + store_type: StoreType = StoreType.MEMORY + store_path: str | None = None # Path for filesystem store + + +@dataclass(frozen=True) +class Operation: + operation_id: str + operation_type: OperationType + status: OperationStatus + parent_id: str | None = field(default=None, kw_only=True) + name: str | None = field(default=None, kw_only=True) + sub_type: OperationSubType | None = field(default=None, kw_only=True) + start_timestamp: datetime.datetime | None = field(default=None, kw_only=True) + end_timestamp: datetime.datetime | None = field(default=None, kw_only=True) + + +T = TypeVar("T", bound=Operation) +P = ParamSpec("P") + + +class OperationFactory(Protocol): + @staticmethod + def from_svc_operation( + operation: SvcOperation, all_operations: list[SvcOperation] | None = None + ) -> Operation: ... + + +@dataclass(frozen=True) +class ExecutionOperation(Operation): + input_payload: str | None = None + + @staticmethod + def from_svc_operation( + operation: SvcOperation, + all_operations: list[SvcOperation] | None = None, # noqa: ARG004 + ) -> ExecutionOperation: + if operation.operation_type != OperationType.EXECUTION: + msg: str = f"Expected EXECUTION operation, got {operation.operation_type}" + raise InvalidParameterValueException(msg) + return ExecutionOperation( + operation_id=operation.operation_id, + operation_type=operation.operation_type, + status=operation.status, + parent_id=operation.parent_id, + name=operation.name, + sub_type=operation.sub_type, + start_timestamp=operation.start_timestamp, + end_timestamp=operation.end_timestamp, + input_payload=( + operation.execution_details.input_payload + if operation.execution_details + else None + ), + ) + + +@dataclass(frozen=True) +class ContextOperation(Operation): + child_operations: list[Operation] + result: OperationPayload | None = None + error: ErrorObject | None = None + + @staticmethod + def from_svc_operation( + operation: SvcOperation, all_operations: list[SvcOperation] | None = None + ) -> ContextOperation: + if operation.operation_type != OperationType.CONTEXT: + msg: str = f"Expected CONTEXT operation, got {operation.operation_type}" + raise InvalidParameterValueException(msg) + + child_operations = [] + if all_operations: + child_operations = [ + create_operation(op, all_operations) + for op in all_operations + if op.parent_id == operation.operation_id + ] + + return ContextOperation( + operation_id=operation.operation_id, + operation_type=operation.operation_type, + status=operation.status, + parent_id=operation.parent_id, + name=operation.name, + sub_type=operation.sub_type, + start_timestamp=operation.start_timestamp, + end_timestamp=operation.end_timestamp, + child_operations=child_operations, + result=operation.context_details.result + if operation.context_details + else None, + error=operation.context_details.error + if operation.context_details + else None, + ) + + def get_operation_by_name(self, name: str) -> Operation: + for operation in self.child_operations: + if operation.name == name: + return operation + msg: str = f"Child Operation with name '{name}' not found" + raise DurableFunctionsTestError(msg) + + def get_step(self, name: str) -> StepOperation: + return cast(StepOperation, self.get_operation_by_name(name)) + + def get_wait(self, name: str) -> WaitOperation: + return cast(WaitOperation, self.get_operation_by_name(name)) + + def get_context(self, name: str) -> ContextOperation: + return cast(ContextOperation, self.get_operation_by_name(name)) + + def get_callback(self, name: str) -> CallbackOperation: + return cast(CallbackOperation, self.get_operation_by_name(name)) + + def get_invoke(self, name: str) -> InvokeOperation: + return cast(InvokeOperation, self.get_operation_by_name(name)) + + def get_execution(self, name: str) -> ExecutionOperation: + return cast(ExecutionOperation, self.get_operation_by_name(name)) + + +@dataclass(frozen=True) +class StepOperation(ContextOperation): + attempt: int = 0 + next_attempt_timestamp: datetime.datetime | None = None + result: OperationPayload | None = None + error: ErrorObject | None = None + + @staticmethod + def from_svc_operation( + operation: SvcOperation, all_operations: list[SvcOperation] | None = None + ) -> StepOperation: + if operation.operation_type != OperationType.STEP: + msg: str = f"Expected STEP operation, got {operation.operation_type}" + raise InvalidParameterValueException(msg) + + child_operations = [] + if all_operations: + child_operations = [ + create_operation(op, all_operations) + for op in all_operations + if op.parent_id == operation.operation_id + ] + + return StepOperation( + operation_id=operation.operation_id, + operation_type=operation.operation_type, + status=operation.status, + parent_id=operation.parent_id, + name=operation.name, + sub_type=operation.sub_type, + start_timestamp=operation.start_timestamp, + end_timestamp=operation.end_timestamp, + child_operations=child_operations, + attempt=operation.step_details.attempt if operation.step_details else 0, + next_attempt_timestamp=( + operation.step_details.next_attempt_timestamp + if operation.step_details + else None + ), + result=operation.step_details.result if operation.step_details else None, + error=operation.step_details.error if operation.step_details else None, + ) + + +@dataclass(frozen=True) +class WaitOperation(Operation): + scheduled_end_timestamp: datetime.datetime | None = None + + @staticmethod + def from_svc_operation( + operation: SvcOperation, + all_operations: list[SvcOperation] | None = None, # noqa: ARG004 + ) -> WaitOperation: + if operation.operation_type != OperationType.WAIT: + msg: str = f"Expected WAIT operation, got {operation.operation_type}" + raise InvalidParameterValueException(msg) + return WaitOperation( + operation_id=operation.operation_id, + operation_type=operation.operation_type, + status=operation.status, + parent_id=operation.parent_id, + name=operation.name, + sub_type=operation.sub_type, + start_timestamp=operation.start_timestamp, + end_timestamp=operation.end_timestamp, + scheduled_end_timestamp=( + operation.wait_details.scheduled_end_timestamp + if operation.wait_details + else None + ), + ) + + +@dataclass(frozen=True) +class CallbackOperation(ContextOperation): + callback_id: str | None = None + result: OperationPayload | None = None + error: ErrorObject | None = None + + @staticmethod + def from_svc_operation( + operation: SvcOperation, all_operations: list[SvcOperation] | None = None + ) -> CallbackOperation: + if operation.operation_type != OperationType.CALLBACK: + msg: str = f"Expected CALLBACK operation, got {operation.operation_type}" + raise InvalidParameterValueException(msg) + + child_operations = [] + if all_operations: + child_operations = [ + create_operation(op, all_operations) + for op in all_operations + if op.parent_id == operation.operation_id + ] + + return CallbackOperation( + operation_id=operation.operation_id, + operation_type=operation.operation_type, + status=operation.status, + parent_id=operation.parent_id, + name=operation.name, + sub_type=operation.sub_type, + start_timestamp=operation.start_timestamp, + end_timestamp=operation.end_timestamp, + child_operations=child_operations, + callback_id=( + operation.callback_details.callback_id + if operation.callback_details + else None + ), + result=operation.callback_details.result + if operation.callback_details + else None, + error=operation.callback_details.error + if operation.callback_details + else None, + ) + + +@dataclass(frozen=True) +class InvokeOperation(Operation): + result: OperationPayload | None = None + error: ErrorObject | None = None + + @staticmethod + def from_svc_operation( + operation: SvcOperation, + all_operations: list[SvcOperation] | None = None, # noqa: ARG004 + ) -> InvokeOperation: + if operation.operation_type != OperationType.CHAINED_INVOKE: + msg: str = f"Expected INVOKE operation, got {operation.operation_type}" + raise InvalidParameterValueException(msg) + return InvokeOperation( + operation_id=operation.operation_id, + operation_type=operation.operation_type, + status=operation.status, + parent_id=operation.parent_id, + name=operation.name, + sub_type=operation.sub_type, + start_timestamp=operation.start_timestamp, + end_timestamp=operation.end_timestamp, + result=operation.chained_invoke_details.result + if operation.chained_invoke_details + else None, + error=operation.chained_invoke_details.error + if operation.chained_invoke_details + else None, + ) + + +OPERATION_FACTORIES: MutableMapping[OperationType, type[OperationFactory]] = { + OperationType.EXECUTION: ExecutionOperation, + OperationType.CONTEXT: ContextOperation, + OperationType.STEP: StepOperation, + OperationType.WAIT: WaitOperation, + OperationType.CHAINED_INVOKE: InvokeOperation, + OperationType.CALLBACK: CallbackOperation, +} + + +def create_operation( + svc_operation: SvcOperation, all_operations: list[SvcOperation] | None = None +) -> Operation: + operation_class: type[OperationFactory] | None = OPERATION_FACTORIES.get( + svc_operation.operation_type + ) + if not operation_class: + msg: str = f"Unknown operation type: {svc_operation.operation_type}" + raise DurableFunctionsTestError(msg) + return operation_class.from_svc_operation(svc_operation, all_operations) + + +def _get_callback_id_from_events( + events: list[Event], name: str | None = None +) -> str | None: + """ + Get callback ID from execution history for callbacks that haven't completed. + + Args: + execution_arn: The ARN of the execution to query. + name: Optional callback name to search for. If not provided, returns the latest callback. + + Returns: + The callback ID string for a non-completed callback, or None if not found. + + Raises: + DurableFunctionsTestError: If the named callback has already succeeded/failed/timed out. + """ + callback_started_events = [ + event for event in events if event.event_type == "CallbackStarted" + ] + + if not callback_started_events: + return None + + completed_callback_ids = { + event.event_id + for event in events + if event.event_type + in ["CallbackSucceeded", "CallbackFailed", "CallbackTimedOut"] + } + + if name is not None: + for event in callback_started_events: + if event.name == name: + callback_id = event.event_id + if callback_id in completed_callback_ids: + raise DurableFunctionsTestError( + f"Callback {name} has already completed (succeeded/failed/timed out)" + ) + return ( + event.callback_started_details.callback_id + if event.callback_started_details + else None + ) + return None + + # If name is not provided, find the latest non-completed callback event + active_callbacks = [ + event + for event in callback_started_events + if event.event_id not in completed_callback_ids + ] + + if not active_callbacks: + return None + + latest_event = active_callbacks[-1] + return ( + latest_event.callback_started_details.callback_id + if latest_event.callback_started_details + else None + ) + + +@dataclass(frozen=True) +class DurableFunctionTestResult: + status: InvocationStatus + operations: list[Operation] + result: OperationPayload | None = None + error: ErrorObject | None = None + + @classmethod + def create(cls, execution: Execution) -> DurableFunctionTestResult: + operations = [] + for operation in execution.operations: + if operation.operation_type is OperationType.EXECUTION: + # don't want the EXECUTION operations in the list test code asserts against + continue + + if operation.parent_id is None: + operations.append(create_operation(operation, execution.operations)) + + if execution.result is None: + msg: str = "Execution result must exist to create test result." + raise DurableFunctionsTestError(msg) + + return cls( + status=execution.result.status, + operations=operations, + result=execution.result.result, + error=execution.result.error, + ) + + @classmethod + def from_execution_history( + cls, + execution_response: GetDurableExecutionResponse, + history_response: GetDurableExecutionHistoryResponse, + ) -> DurableFunctionTestResult: + """Create test result from execution history responses. + + Factory method for cloud runner that builds DurableFunctionTestResult + from GetDurableExecution and GetDurableExecutionHistory API responses. + """ + # Map status string to InvocationStatus enum + try: + status = InvocationStatus[execution_response.status] + except KeyError: + logger.warning( + "Unknown status: %s, defaulting to FAILED", execution_response.status + ) + status = InvocationStatus.FAILED + + # Convert Events to Operations - group by operation_id and merge + try: + svc_operations = events_to_operations(history_response.events) + except Exception as e: + logger.warning("Failed to convert events to operations: %s", e) + svc_operations = [] + + # Build operation tree (exclude EXECUTION type from top level) + operations = [] + for svc_op in svc_operations: + if svc_op.operation_type == OperationType.EXECUTION: + continue + if svc_op.parent_id is None: + operations.append(create_operation(svc_op, svc_operations)) + + return cls( + status=status, + operations=operations, + result=execution_response.result, + error=execution_response.error, + ) + + def get_operation_by_name(self, name: str) -> Operation: + for operation in self.operations: + if operation.name == name: + return operation + msg: str = f"Operation with name '{name}' not found" + raise DurableFunctionsTestError(msg) + + def get_step(self, name: str) -> StepOperation: + return cast(StepOperation, self.get_operation_by_name(name)) + + def get_wait(self, name: str) -> WaitOperation: + return cast(WaitOperation, self.get_operation_by_name(name)) + + def get_context(self, name: str) -> ContextOperation: + return cast(ContextOperation, self.get_operation_by_name(name)) + + def get_callback(self, name: str) -> CallbackOperation: + return cast(CallbackOperation, self.get_operation_by_name(name)) + + def get_invoke(self, name: str) -> InvokeOperation: + return cast(InvokeOperation, self.get_operation_by_name(name)) + + def get_execution(self, name: str) -> ExecutionOperation: + return cast(ExecutionOperation, self.get_operation_by_name(name)) + + def get_all_operations(self) -> list[Operation]: + """Recursively get all operations including nested ones.""" + all_ops = [] + stack = list(self.operations) + while stack: + op = stack.pop() + all_ops.append(op) + # Add child operations to stack (if they exist) + if hasattr(op, "child_operations") and op.child_operations: + stack.extend(op.child_operations) + return all_ops + + +class DurableFunctionTestRunner: + def __init__(self, handler: Callable, poll_interval: float = 1.0): + self._scheduler: Scheduler = Scheduler() + self._scheduler.start() + self._store = InMemoryExecutionStore() + self.poll_interval = poll_interval + self._checkpoint_processor = CheckpointProcessor( + store=self._store, scheduler=self._scheduler + ) + self._service_client = InMemoryServiceClient(self._checkpoint_processor) + self._invoker = InProcessInvoker(handler, self._service_client) + self._executor = Executor( + store=self._store, + scheduler=self._scheduler, + invoker=self._invoker, + checkpoint_processor=self._checkpoint_processor, + ) + + # Wire up observer pattern - CheckpointProcessor uses this to notify executor of state changes + self._checkpoint_processor.add_execution_observer(self._executor) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def close(self): + self._scheduler.stop() + + def run( + self, + input: str | None = None, # noqa: A002 + timeout: int = 900, + function_name: str = "test-function", + execution_name: str = "execution-name", + account_id: str = "123456789012", + ) -> DurableFunctionTestResult: + execution_arn = self.run_async( + input=input, + timeout=timeout, + function_name=function_name, + execution_name=execution_name, + account_id=account_id, + ) + + return self.wait_for_result(execution_arn=execution_arn, timeout=timeout) + + def send_callback_success( + self, callback_id: str, result: bytes | None = None + ) -> None: + self._executor.send_callback_success(callback_id=callback_id, result=result) + + def send_callback_failure( + self, callback_id: str, error: ErrorObject | None = None + ) -> None: + self._executor.send_callback_failure(callback_id=callback_id, error=error) + + def send_callback_heartbeat(self, callback_id: str) -> None: + self._executor.send_callback_heartbeat(callback_id=callback_id) + + def run_async( + self, + input: str | None = None, # noqa: A002 + timeout: int = 900, + function_name: str = "test-function", + execution_name: str = "execution-name", + account_id: str = "123456789012", + ) -> str: + start_input = StartDurableExecutionInput( + account_id=account_id, + function_name=function_name, + function_qualifier="$LATEST", + execution_name=execution_name, + execution_timeout_seconds=timeout, + execution_retention_period_days=7, + invocation_id="inv-12345678-1234-1234-1234-123456789012", + trace_fields={"trace_id": "abc123", "span_id": "def456"}, + tenant_id="tenant-001", + input=input, + ) + + output: StartDurableExecutionOutput = self._executor.start_execution( + start_input + ) + + if output.execution_arn is None: + msg_arn: str = "Execution ARN must exist to run test." + raise DurableFunctionsTestError(msg_arn) + return output.execution_arn + + def wait_for_result( + self, execution_arn: str, timeout: int = 60 + ) -> DurableFunctionTestResult: + # Block until completion + completed = self._executor.wait_until_complete(execution_arn, timeout) + + if not completed: + msg_timeout: str = "Execution did not complete within timeout" + + raise TimeoutError(msg_timeout) + + execution: Execution = self._store.load(execution_arn) + return DurableFunctionTestResult.create(execution=execution) + + def wait_for_callback( + self, execution_arn: str, name: str | None = None, timeout: int = 60 + ) -> str: + start_time = time.time() + + while time.time() - start_time < timeout: + try: + history_response = self._executor.get_execution_history(execution_arn) + callback_id = _get_callback_id_from_events( + events=history_response.events, name=name + ) + if callback_id: + return callback_id + except ResourceNotFoundException as e: + pass + except Exception as e: + msg = f"Failed to fetch execution history: {e}" + raise DurableFunctionsTestError(msg) from e + + # Wait before next poll + time.sleep(self.poll_interval) + + # Timeout reached + elapsed = time.time() - start_time + msg = f"Callback did not available within {timeout}s (elapsed: {elapsed:.1f}s." + raise TimeoutError(msg) + + +class DurableChildContextTestRunner(DurableFunctionTestRunner): + """Test a durable block, annotated with @durable_with_child_context, in isolation.""" + + def __init__( + self, + context_function: Callable[Concatenate[DurableContext, P], Any], + *args, + **kwargs, + ): + # wrap the durable context around a durable execution handler as a convenience to run directly + @durable_execution + def handler(event: Any, context: DurableContext): # noqa: ARG001 + return context_function(*args, **kwargs)(context) + + super().__init__(handler) + + +class WebRunner: + """Web server runner for durable functions testing with HTTP API endpoints.""" + + def __init__(self, config: WebRunnerConfig) -> None: + """Initialize WebRunner with configuration. + + Args: + config: WebRunnerConfig containing server and Lambda service settings + """ + self._config = config + self._server: WebServer | None = None + self._scheduler: Scheduler | None = None + self._store: ExecutionStore | None = None + self._invoker: LambdaInvoker | None = None + self._executor: Executor | None = None + + def __enter__(self) -> Self: + """Context manager entry point. + + Returns: + WebRunner: Self for use in with statement + """ + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Context manager exit point with cleanup. + + Args: + exc_type: Exception type if an exception occurred + exc_val: Exception value if an exception occurred + exc_tb: Exception traceback if an exception occurred + """ + self.stop() + + def start(self) -> None: + """Start the server and initialize all dependencies. + + Creates and configures all required components including scheduler, + store, invoker, executor, and web server. It does not however start + serving web requests, for that you need serve_forever. + + Raises: + DurableFunctionsLocalRunnerError: If server is already started + """ + if self._server is not None: + msg = "Server is already running" + raise DurableFunctionsLocalRunnerError(msg) + + # Create dependencies and server + if self._config.store_type == StoreType.SQLITE: + store_path = self._config.store_path + self._store = SQLiteExecutionStore.create_and_initialize(store_path) + elif self._config.store_type == StoreType.FILESYSTEM: + store_path = self._config.store_path or ".durable_executions" + self._store = FileSystemExecutionStore.create(store_path) + else: + self._store = InMemoryExecutionStore() + self._scheduler = Scheduler() + self._invoker = LambdaInvoker(self._create_boto3_client()) + + # Create shared CheckpointProcessor + checkpoint_processor = CheckpointProcessor(self._store, self._scheduler) + + # Create executor with all dependencies including checkpoint processor + self._executor = Executor( + store=self._store, + scheduler=self._scheduler, + invoker=self._invoker, + checkpoint_processor=checkpoint_processor, + ) + + # Add executor as observer to the checkpoint processor + checkpoint_processor.add_execution_observer(self._executor) + + # Start the scheduler + self._scheduler.start() + + # Create web server with configuration and executor + self._server = WebServer( + config=self._config.web_service, executor=self._executor + ) + + def serve_forever(self) -> None: + """Start serving HTTP requests indefinitely. + + Delegates to the underlying WebServer.serve_forever() method. + This method blocks until the server is stopped. + + Raises: + DurableFunctionsLocalRunnerError: If server has not been started + """ + if self._server is None: + msg = "Server not started" + raise DurableFunctionsLocalRunnerError(msg) + + # This blocks until KeyboardInterrupt - let caller handle the exception + self._server.serve_forever() + + def stop(self) -> None: + """Stop the web server and cleanup resources. + + Gracefully shuts down the server, scheduler, and cleans up + all allocated resources. Safe to call multiple times. + Handles cleanup exceptions gracefully to ensure all resources + are cleaned up even if some fail. + """ + if self._server is not None: + try: + self._server.server_close() + except Exception: + # Log the exception but continue cleanup + logger.exception("error closing web server") + + self._server = None + + if self._scheduler is not None: + try: + self._scheduler.stop() + except Exception: + logger.exception("error stopping scheduler") + self._scheduler = None + + self._store = None + self._invoker = None + self._executor = None + + def _create_boto3_client(self) -> Any: + """Create boto3 client for Lambda service. + + Creates a boto3 client with the local runner endpoint and region from configuration. + + Returns: + Configured boto3 client for Lambda service + + Raises: + Exception: If client creation fails - exceptions propagate naturally + for CLI to handle as general Exception + """ + return create_lambda_client( + endpoint_url=self._config.lambda_endpoint, + region_name=self._config.local_runner_region, + ) + + +class DurableFunctionCloudTestRunner: + """Test runner that executes durable functions against actual AWS Lambda backend. + + This runner invokes deployed Lambda functions and polls for execution completion, + providing the same interface as DurableFunctionTestRunner for seamless test + compatibility between local and cloud modes. + + Example: + >>> runner = DurableFunctionCloudTestRunner( + ... function_name="HelloWorld-Python-PR-123", region="us-west-2" + ... ) + >>> with runner: + ... result = runner.run(input={"name": "World"}, timeout=60) + >>> assert result.current_status == InvocationStatus.SUCCEEDED + """ + + def __init__( + self, + function_name: str, + region: str = "us-west-2", + lambda_endpoint: str | None = None, + poll_interval: float = 1.0, + ): + """Initialize cloud test runner.""" + self.function_name = function_name + self.region = region + self.lambda_endpoint = lambda_endpoint + self.poll_interval = poll_interval + + client_config = boto3.session.Config(parameter_validation=False) + self.lambda_client = boto3.client( + "lambda", + endpoint_url=lambda_endpoint, + region_name=region, + config=client_config, + ) + + def run( + self, + input: str | None = None, # noqa: A002 + timeout: int = 60, + ) -> DurableFunctionTestResult: + """Execute function on AWS Lambda and wait for completion.""" + logger.info( + "Invoking Lambda function: %s (timeout: %ds)", self.function_name, timeout + ) + + # JSON encode input + payload = json.dumps(input) + + # Invoke Lambda function + try: + response = self.lambda_client.invoke( + FunctionName=self.function_name, + InvocationType="RequestResponse", + Payload=payload, + ) + except Exception as e: + msg = f"Failed to invoke Lambda function {self.function_name}: {e}" + raise DurableFunctionsTestError(msg) from e + + # Check HTTP status code, 200 for RequestResponse + status_code = response.get("StatusCode") + if status_code != 200: + error_payload = response["Payload"].read().decode("utf-8") + msg = f"Lambda invocation failed with status {status_code}: {error_payload}" + raise DurableFunctionsTestError(msg) + + # Check for function errors, we want to return function error for testing purpose + if "FunctionError" in response: + error_payload = response["Payload"].read().decode("utf-8") + logger.warning("Lambda function failed: %s", error_payload) + + result_payload = response["Payload"].read().decode("utf-8") + logger.info( + "Lambda invocation completed, response: %s", + result_payload, + ) + + # Extract durable execution ARN from response headers + # The InvocationResponse includes X-Amz-Durable-Execution-Arn header + execution_arn = response.get("DurableExecutionArn") + if not execution_arn: + msg = ( + f"No DurableExecutionArn in response for function {self.function_name}" + ) + raise DurableFunctionsTestError(msg) + + return self.wait_for_result(execution_arn=execution_arn, timeout=timeout) + + def run_async( + self, + input: str | None = None, # noqa: A002 + timeout: int = 60, + ) -> str: + """Execute function on AWS Lambda asynchronously""" + logger.info( + "Invoking Lambda function: %s (timeout: %ds)", self.function_name, timeout + ) + payload = json.dumps(input) + try: + response = self.lambda_client.invoke( + FunctionName=self.function_name, + InvocationType="Event", + Payload=payload, + ) + except Exception as e: + msg = f"Failed to invoke Lambda function {self.function_name}: {e}" + raise DurableFunctionsTestError(msg) from e + + # Check HTTP status code, 202 for Event + status_code = response.get("StatusCode") + if status_code != 202: + error_payload = response["Payload"].read().decode("utf-8") + msg = f"Lambda invocation failed with status {status_code}: {error_payload}" + raise DurableFunctionsTestError(msg) + + return response.get("DurableExecutionArn") + + def send_callback_success( + self, callback_id: str, result: bytes | None = None + ) -> None: + try: + self.lambda_client.send_durable_execution_callback_success( + CallbackId=callback_id, Result=result + ) + except Exception as e: + msg = f"Failed to send callback success for {self.function_name}, callback_id {callback_id}: {e}" + raise DurableFunctionsTestError(msg) from e + + def send_callback_failure( + self, callback_id: str, error: ErrorObject | None = None + ) -> None: + try: + self.lambda_client.send_durable_execution_callback_failure( + CallbackId=callback_id, Error=error.to_dict() if error else None + ) + except Exception as e: + msg = f"Failed to send callback failure for {self.function_name}, callback_id {callback_id}: {e}" + raise DurableFunctionsTestError(msg) from e + + def send_callback_heartbeat(self, callback_id: str) -> None: + try: + self.lambda_client.send_durable_execution_callback_heartbeat( + CallbackId=callback_id + ) + except Exception as e: + msg = f"Failed to send callback heartbeat for {self.function_name}, callback_id {callback_id}: {e}" + raise DurableFunctionsTestError(msg) from e + + def _wait_for_completion( + self, execution_arn: str, timeout: int + ) -> GetDurableExecutionResponse: + """Poll execution status until completion or timeout. + + Args: + execution_arn: ARN of the durable execution + timeout: Maximum seconds to wait + + Returns: + GetDurableExecutionResponse with typed execution details + + Raises: + TimeoutError: If execution doesn't complete within timeout + DurableFunctionsTestError: If status check fails + """ + start_time = time.time() + last_status = None + + while time.time() - start_time < timeout: + try: + execution_dict = self.lambda_client.get_durable_execution( + DurableExecutionArn=execution_arn + ) + execution = GetDurableExecutionResponse.from_dict(execution_dict) + except Exception as e: + msg = f"Failed to get execution status: {e}" + raise DurableFunctionsTestError(msg) from e + + # Log status changes + if execution.status != last_status: + logger.info("Execution status: %s", execution.status) + last_status = execution.status + + # Check if execution completed + if execution.status == "SUCCEEDED": + logger.info("Execution succeeded") + return execution + if execution.status == "FAILED": + logger.warning("Execution failed") + return execution + if execution.status in ["TIMED_OUT", "ABORTED"]: + logger.warning("Execution terminated: %s", execution.status) + return execution + + # Wait before next poll + time.sleep(self.poll_interval) + + # Timeout reached + elapsed = time.time() - start_time + msg = ( + f"Execution did not complete within {timeout}s " + f"(elapsed: {elapsed:.1f}s, last status: {last_status})" + ) + raise TimeoutError(msg) + + def wait_for_result( + self, execution_arn: str, timeout: int = 60 + ) -> DurableFunctionTestResult: + # Poll for completion + execution_response = self._wait_for_completion(execution_arn, timeout) + + try: + history_response = self._fetch_execution_history(execution_arn) + except Exception as e: + msg = f"Failed to fetch execution history: {e}" + raise DurableFunctionsTestError(msg) from e + + # Build test result from execution history + return DurableFunctionTestResult.from_execution_history( + execution_response, history_response + ) + + def wait_for_callback( + self, execution_arn: str, name: str | None = None, timeout: int = 60 + ) -> str: + """ + Wait for and retrieve a callback ID from a Step Functions execution. + + Polls the execution history at regular intervals until a callback ID is found + or the timeout is reached. + + Args: + execution_arn: Execution Arn + name: Specific callback name, default to None + timeout: Maximum time in seconds to wait for callback. Defaults to 60. + + Returns: + str: The callback ID/token retrieved from the execution history + + Raises: + TimeoutError: If callback is not found within the specified timeout period + DurableFunctionsTestError: If there's an error fetching execution history + (excluding retryable errors) + """ + start_time = time.time() + + while time.time() - start_time < timeout: + try: + history_response = self._fetch_execution_history(execution_arn) + callback_id = _get_callback_id_from_events( + events=history_response.events, name=name + ) + if callback_id: + return callback_id + except ClientError as e: + error_code = e.response["Error"]["Code"] + # retryable error, the execution may not start yet in async invoke situation + if error_code in ["ResourceNotFoundException"]: + pass + else: + msg = f"Failed to fetch execution history: {e}" + raise DurableFunctionsTestError(msg) from e + except DurableFunctionsTestError as e: + raise e + except Exception as e: + msg = f"Failed to fetch execution history: {e}" + raise DurableFunctionsTestError(msg) from e + + # Wait before next poll + time.sleep(self.poll_interval) + + # Timeout reached + elapsed = time.time() - start_time + msg = f"Callback did not available within {timeout}s (elapsed: {elapsed:.1f}s." + raise TimeoutError(msg) + + def _fetch_execution_history( + self, execution_arn: str + ) -> GetDurableExecutionHistoryResponse: + """Retrieve execution history from Lambda service. + + Args: + execution_arn: ARN of the durable execution + + Returns: + GetDurableExecutionHistoryResponse with typed Event objects + + Raises: + ClientError: If lambda client encounter error + """ + history_dict = self.lambda_client.get_durable_execution_history( + DurableExecutionArn=execution_arn, + IncludeExecutionData=True, + ) + history_response = GetDurableExecutionHistoryResponse.from_dict(history_dict) + + logger.info("Retrieved %d events from history", len(history_response.events)) + + return history_response diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/scheduler.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/scheduler.py new file mode 100644 index 0000000..a45b942 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/scheduler.py @@ -0,0 +1,246 @@ +"""A Scheduler that can run awaitables or standard sync callables on a schedule once or repeatedly.""" + +from __future__ import annotations + +import asyncio +import itertools +import logging +import threading +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from collections.abc import Callable + from concurrent.futures import Future + +logger = logging.getLogger(__name__) + + +class Event: + """An event created by Scheduler that will block on wait until it's set.""" + + def __init__(self, scheduler: Scheduler, asyncio_event: asyncio.Event) -> None: + self._scheduler: Scheduler = scheduler + self._asyncio_event: asyncio.Event = asyncio_event + self._exception: Exception | None = None + + def set(self): + """Set the event with this to unblock wait.""" + self._scheduler.set_event(self._asyncio_event) + + def set_exception(self, exception: Exception): + """Set exception and unblock waiters.""" + self._exception = exception + self._scheduler.set_event(self._asyncio_event) + + def wait(self, timeout: float | None = None, *, clear_on_set: bool = True) -> bool: + """Wait until the event is set. + + Args: + timeout (int | float | None): Wait for event to set until this timeout. + clear_on_set (bool): Remove the event from the Scheduler on completion. + Use this if you won't re-use the event. + + Returns: + True when set. False if the event timed out without being set. + + Raises: + Exception: If an exception was stored via set_exception(). + """ + result = self._scheduler.wait_for_event(self._asyncio_event, timeout) + if clear_on_set: + self._scheduler.remove_event(self._asyncio_event) + if result and self._exception: + raise self._exception + return result + + def remove(self): + """Remove the event from the Scheduler. Do this to avoid build-up of many events in the scheduler.""" + self._scheduler.remove_event(self._asyncio_event) + + +class Scheduler: + """A Scheduler to run callables later, repeatedly or raise events.""" + + def __init__(self) -> None: + self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() + self._ready_event: threading.Event = threading.Event() + self._thread: threading.Thread = threading.Thread( + target=self._start_loop, daemon=True + ) + self._running: bool = False + self._events: set[asyncio.Event] = set() + + # region context manager + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + + # endregion context manager + + # region event loop + def start(self): + """Start the scheduler. Not thread-safe.""" + if self._running: + return + + self._running = True + + self._thread.start() + # Wait for inside of loop to notify it's ready (meaning _start_loop has completed) + self._ready_event.wait() + + def stop(self): + """Stop the scheduler, releasing resources. Not thread-safe.""" + if not self._running: + return + + self._running = False + self._loop.call_soon_threadsafe(self._cleanup_and_stop) + self._thread.join() + + def is_started(self) -> bool: + """Return True if the scheduler is started.""" + return self._running + + def event_count(self) -> int: + """Return the number of events in the scheduler.""" + return len(self._events) + + def task_count(self) -> int: + """Return the number of tasks in the scheduler.""" + if not self._running: + return 0 + return len(asyncio.all_tasks(self._loop)) + + def _cleanup_and_stop(self): + """Cancel all tasks and clear all events. Stop the event-loop.""" + # Cancel all tasks + for task in asyncio.all_tasks(self._loop): + task.cancel() + + # Clear events (don't set them) + self._events.clear() + + self._loop.stop() + + def _start_loop(self): + """Initialize the event-loop. The ready event notifies that the loop is started.""" + asyncio.set_event_loop(self._loop) + # signal that loop is ready from within the loop + self._loop.call_soon(self._ready_event.set) + # block indefinitely - call_soon with the read_event will run soon as the loop starts + self._loop.run_forever() + + # endregion event loop + # region Tasks + def call_later( + self, + func: Callable[[], Any], + delay: float = 0, + count: int | None = 1, + completion_event: Event | None = None, + ) -> Future[Any]: + """Call func after the delay. + + If func is async it runs inside a thread-safe coroutine. If func is sync it runs in its own + threadpool, so it won't block the event loop. + + Args: + func (Callable[[], Any]): The function to call later. This can be an async or a standard + sync function. + delay (float | int): Delay in seconds before calling func. + count (int | None): Number of times to call func. Default is 1 (call once). + Use None for infinite repeats. + completion_event (Event | None): Event to notify on exception. + + Returns: Future that completes when the scheduled work is done. + """ + # infinite counter if count = None, else it maxes out at count + loop_iter: itertools.count[int] | range = ( + itertools.count() if count is None else range(count) + ) + + async def delayed_func() -> Any: + try: + for _ in loop_iter: + await asyncio.sleep(delay) + + try: + if asyncio.iscoroutinefunction(func): + result = await func() + else: + result = await asyncio.to_thread(func) + return result # noqa: TRY300 + except Exception as err: + if completion_event: + completion_event.set_exception(err) + else: + msg: str = "error in scheduled task" + logger.exception(msg) + raise + except asyncio.CancelledError: # noqa: TRY302 + # might want to handle more things here + raise + + future: Future[Any] = asyncio.run_coroutine_threadsafe( + delayed_func(), self._loop + ) + return future + + # endregion Tasks + + # region Events + + def create_event(self) -> Event: + """Create an event controlled by the Scheduler to signal between threads and coroutines.""" + # create event inside the Scheduler event-loop + future: Future[asyncio.Event] = asyncio.run_coroutine_threadsafe( + self._create_event(), self._loop + ) + + # Add timeout to prevent surprising "hangs" if for whatever reason event fails to create. + # result with block. Do NOT call anything in _create_event that calls back into scheduler + # methods because it could create a circular depdendency which will deadlock. + event = future.result(timeout=5.0) + return Event(self, event) + + def wait_for_event( + self, event: asyncio.Event, timeout: float | None = None + ) -> bool: + """Run event's wait inside the Scheduler event-loop.""" + if event not in self._events: + return False + + future: Future[bool] = asyncio.run_coroutine_threadsafe( + asyncio.wait_for(event.wait(), timeout), self._loop + ) + + try: + return future.result() + except TimeoutError: + return False + + def set_event(self, event: asyncio.Event): + """Set event inside the Scheduler event-loop.""" + if event in self._events: + self._loop.call_soon_threadsafe(event.set) + + def remove_event(self, event: asyncio.Event): + """Remove event from Scheduler in the Scheduler event-loop.""" + + def _remove(): + self._events.discard(event) + + self._loop.call_soon_threadsafe(_remove) + + async def _create_event(self) -> asyncio.Event: + """Create event and add it to the scheduler events list.""" + event = asyncio.Event() + self._events.add(event) + return event + + # endregion Events diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/__init__.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/__init__.py @@ -0,0 +1 @@ + diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/base.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/base.py new file mode 100644 index 0000000..ca87e28 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/base.py @@ -0,0 +1,147 @@ +"""Base classes and protocols for execution stores.""" + +from __future__ import annotations + +from datetime import UTC +from enum import Enum +from typing import TYPE_CHECKING, Protocol + + +if TYPE_CHECKING: + from aws_durable_execution_sdk_python.lambda_service import Operation + + from aws_durable_execution_sdk_python_testing.execution import Execution + + +class StoreType(Enum): + """Supported execution store types.""" + + MEMORY = "memory" + FILESYSTEM = "filesystem" + SQLITE = "sqlite" + + +class ExecutionStore(Protocol): + """Protocol for execution storage implementations.""" + + # ignore cover because coverage doesn't understand elipses + def save(self, execution: Execution) -> None: ... # pragma: no cover + def load(self, execution_arn: str) -> Execution: ... # pragma: no cover + def update(self, execution: Execution) -> None: ... # pragma: no cover + def query( + self, + function_name: str | None = None, + execution_name: str | None = None, + status_filter: str | None = None, + started_after: str | None = None, + started_before: str | None = None, + limit: int | None = None, + offset: int = 0, + reverse_order: bool = False, # noqa: FBT001, FBT002 + ) -> tuple[list[Execution], str | None]: ... # pragma: no cover + def list_all( + self, + ) -> list[Execution]: ... # pragma: no cover # Keep for backward compatibility + + +class BaseExecutionStore(ExecutionStore): + """Base implementation for execution stores with shared query logic.""" + + @staticmethod + def process_query( + executions: list[Execution], + function_name: str | None = None, + execution_name: str | None = None, + status_filter: str | None = None, + started_after: str | None = None, + started_before: str | None = None, + limit: int | None = None, + offset: int = 0, + reverse_order: bool = False, # noqa: FBT001, FBT002 + ) -> tuple[list[Execution], str | None]: + """Apply filtering, sorting, and pagination to executions.""" + # Apply filters + filtered: list[Execution] = [] + for execution in executions: + if function_name and execution.start_input.function_name != function_name: + continue + if ( + execution_name + and execution.start_input.execution_name != execution_name + ): + continue + + # Status filtering + if status_filter and execution.current_status().value != status_filter: + continue + + # Time filtering + if started_after or started_before: + try: + operation: Operation = execution.get_operation_execution_started() + if operation.start_timestamp: + timestamp: float = ( + operation.start_timestamp.timestamp() + if hasattr(operation.start_timestamp, "timestamp") + else operation.start_timestamp.replace( + tzinfo=UTC + ).timestamp() + ) + if started_after and timestamp < float(started_after): + continue + if started_before and timestamp > float(started_before): + continue + except (ValueError, AttributeError): + continue + + filtered.append(execution) + + # Sort by start timestamp + def get_sort_key(exe: Execution): + try: + op: Operation = exe.get_operation_execution_started() + if op.start_timestamp: + return ( + op.start_timestamp.timestamp() + if hasattr(op.start_timestamp, "timestamp") + else op.start_timestamp.replace(tzinfo=UTC).timestamp() + ) + except Exception: # noqa: BLE001, S110 + pass + return 0 + + filtered.sort(key=get_sort_key, reverse=reverse_order) + + # Apply pagination + if limit is not None and limit > 0: + end_idx: int = offset + limit + paginated: list[Execution] = filtered[offset:end_idx] + has_more: bool = end_idx < len(filtered) + next_marker: str | None = str(end_idx) if has_more else None + return paginated, next_marker + return filtered[offset:], None + + def query( + self, + function_name: str | None = None, + execution_name: str | None = None, + status_filter: str | None = None, + started_after: str | None = None, + started_before: str | None = None, + limit: int | None = None, + offset: int = 0, + reverse_order: bool = False, # noqa: FBT001, FBT002 + ) -> tuple[list[Execution], str | None]: + """Apply filtering, sorting, and pagination to executions.""" + executions: list[Execution] = self.list_all() + return self.process_query( + executions, + function_name=function_name, + execution_name=execution_name, + status_filter=status_filter, + started_after=started_after, + started_before=started_before, + limit=limit, + offset=offset, + reverse_order=reverse_order, + ) diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/filesystem.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/filesystem.py new file mode 100644 index 0000000..f0f3154 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/filesystem.py @@ -0,0 +1,79 @@ +"""File system-based execution store implementation.""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path + +from aws_durable_execution_sdk_python_testing.exceptions import ( + ResourceNotFoundException, +) +from aws_durable_execution_sdk_python_testing.execution import Execution +from aws_durable_execution_sdk_python_testing.stores.base import ( + BaseExecutionStore, +) + + +class FileSystemExecutionStore(BaseExecutionStore): + """File system-based execution store for persistence.""" + + def __init__(self, storage_dir: Path) -> None: + self._storage_dir = storage_dir + + @classmethod + def create(cls, storage_dir: str | Path | None = None) -> FileSystemExecutionStore: + """Create a FileSystemExecutionStore with directory creation. + + Args: + storage_dir: Directory path for storage. Defaults to '.durable_executions' + + Returns: + FileSystemExecutionStore instance with created directory + """ + path = Path(storage_dir) if storage_dir else Path(".durable_executions") + path.mkdir(exist_ok=True) + return cls(storage_dir=path) + + def _get_file_path(self, execution_arn: str) -> Path: + """Get file path for execution ARN.""" + # Use ARN as filename with .json extension, replacing unsafe characters + safe_filename = execution_arn.replace(":", "_").replace("/", "_") + return self._storage_dir / f"{safe_filename}.json" + + def save(self, execution: Execution) -> None: + """Save execution to file system.""" + file_path = self._get_file_path(execution.durable_execution_arn) + data = execution.to_json_dict() + + with open(file_path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + + def load(self, execution_arn: str) -> Execution: + """Load execution from file system.""" + file_path = self._get_file_path(execution_arn) + if not file_path.exists(): + msg = f"Execution {execution_arn} not found" + raise ResourceNotFoundException(msg) + + with open(file_path, encoding="utf-8") as f: + data = json.load(f) + + return Execution.from_json_dict(data) + + def update(self, execution: Execution) -> None: + """Update execution in file system (same as save).""" + self.save(execution) + + def list_all(self) -> list[Execution]: + """List all executions from file system.""" + executions = [] + for file_path in self._storage_dir.glob("*.json"): + try: + with open(file_path, encoding="utf-8") as f: + data = json.load(f) + executions.append(Execution.from_json_dict(data)) + except (json.JSONDecodeError, KeyError, OSError) as e: + logging.warning("Skipping corrupted file %s: %s", file_path, e) + continue + return executions diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/memory.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/memory.py new file mode 100644 index 0000000..5e6e083 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/memory.py @@ -0,0 +1,38 @@ +"""In-memory execution store implementation.""" + +from __future__ import annotations + +from threading import Lock +from typing import TYPE_CHECKING + +from aws_durable_execution_sdk_python_testing.stores.base import ( + BaseExecutionStore, +) + + +if TYPE_CHECKING: + from aws_durable_execution_sdk_python_testing.execution import Execution + + +class InMemoryExecutionStore(BaseExecutionStore): + """Dict-based storage for testing.""" + + def __init__(self) -> None: + self._store: dict[str, Execution] = {} + self._lock: Lock = Lock() + + def save(self, execution: Execution) -> None: + with self._lock: + self._store[execution.durable_execution_arn] = execution + + def load(self, execution_arn: str) -> Execution: + with self._lock: + return self._store[execution_arn] + + def update(self, execution: Execution) -> None: + with self._lock: + self._store[execution.durable_execution_arn] = execution + + def list_all(self) -> list[Execution]: + with self._lock: + return list(self._store.values()) diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/sqlite.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/sqlite.py new file mode 100644 index 0000000..fac1ca4 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/stores/sqlite.py @@ -0,0 +1,273 @@ +"""SQLite-based execution store implementation.""" + +from __future__ import annotations + +import json +import sqlite3 +from datetime import datetime +from pathlib import Path +from typing import Any, cast + +from aws_durable_execution_sdk_python_testing.exceptions import ( + ResourceNotFoundException, + InvalidParameterValueException, + RuntimeException, +) +from aws_durable_execution_sdk_python_testing.execution import Execution +from aws_durable_execution_sdk_python_testing.stores.base import ( + ExecutionStore, +) + + +class SQLiteExecutionStore(ExecutionStore): + """SQLite-based execution store for efficient querying.""" + + def __init__(self, db_path: Path) -> None: + self.db_path: Path = db_path + + @classmethod + def create_and_initialize( + cls, db_path: Path | str | None = None + ) -> SQLiteExecutionStore: + """Create SQLite store with default path.""" + path: Path = Path(db_path) if db_path else Path("durable-executions.db") + path.parent.mkdir(exist_ok=True) + store: SQLiteExecutionStore = cls(path) + store._init_db() + return store + + def _get_connection(self) -> sqlite3.Connection: + """Get SQLite connection with optimizations.""" + conn: sqlite3.Connection = sqlite3.connect(self.db_path, timeout=30.0) + conn.execute("PRAGMA journal_mode=WAL;") + conn.execute("PRAGMA synchronous=NORMAL;") + return conn + + def _init_db(self) -> None: + """Initialize database schema.""" + try: + with self._get_connection() as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS executions ( + durable_execution_arn TEXT PRIMARY KEY, + function_name TEXT NOT NULL, + execution_name TEXT, + status TEXT NOT NULL, + start_timestamp REAL, + end_timestamp REAL, + data TEXT NOT NULL + ) + """) + # Create indexes for better query performance + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_function_name ON executions(function_name)" + ) + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_status ON executions(status)" + ) + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_start_timestamp ON executions(start_timestamp)" + ) + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_composite ON executions(function_name, status, start_timestamp)" + ) + except sqlite3.Error as e: + raise RuntimeError(f"Failed to initialize database: {e}") from e + + def save(self, execution: Execution) -> None: + """Save execution to SQLite.""" + try: + execution_op = execution.get_operation_execution_started() + status: str = execution.current_status().value + + with self._get_connection() as conn: + conn.execute( + """ + INSERT OR REPLACE INTO executions + (durable_execution_arn, function_name, execution_name, status, start_timestamp, end_timestamp, data) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + execution.durable_execution_arn, + execution.start_input.function_name, + execution.start_input.execution_name, + status, + execution_op.start_timestamp.timestamp() + if execution_op.start_timestamp + else None, + execution_op.end_timestamp.timestamp() + if execution_op.end_timestamp + else None, + json.dumps(execution.to_json_dict()), + ), + ) + except sqlite3.Error as e: + raise RuntimeError( + f"Failed to save execution {execution.durable_execution_arn}: {e}" + ) from e + except (AttributeError, TypeError) as e: + raise ValueError(f"Invalid execution data: {e}") from e + + def load(self, execution_arn: str) -> Execution: + """Load execution from SQLite.""" + try: + with self._get_connection() as conn: + cursor: sqlite3.Cursor = conn.execute( + "SELECT data FROM executions WHERE durable_execution_arn = ?", + (execution_arn,), + ) + row: tuple[str] | None = cursor.fetchone() + + if not row: + raise ResourceNotFoundException(f"Execution {execution_arn} not found") + + return Execution.from_json_dict(json.loads(row[0])) + except sqlite3.Error as e: + raise RuntimeError(f"Failed to load execution {execution_arn}: {e}") from e + except json.JSONDecodeError as e: + raise ValueError( + f"Corrupted execution data for {execution_arn}: {e}" + ) from e + + def update(self, execution: Execution) -> None: + """Update execution (same as save).""" + self.save(execution) + + def query( + self, + function_name: str | None = None, + execution_name: str | None = None, + status_filter: str | None = None, + started_after: str | None = None, + started_before: str | None = None, + limit: int | None = None, + offset: int = 0, + reverse_order: bool = False, + ) -> tuple[list[Execution], str | None]: + """Query executions with efficient SQL filtering.""" + try: + # Build query safely with parameterized conditions + conditions: list[str] = [] + params: list[str | float | int] = [] + + if function_name: + conditions.append("function_name = ?") + params.append(function_name) + + if execution_name: + conditions.append("execution_name = ?") + params.append(execution_name) + + if status_filter: + conditions.append("status = ?") + params.append(status_filter) + + if started_after: + started_after_float: float = datetime.fromisoformat( + started_after + ).timestamp() + conditions.append("start_timestamp >= ?") + params.append(started_after_float) + + if started_before: + started_before_float: float = datetime.fromisoformat( + started_before + ).timestamp() + conditions.append("start_timestamp <= ?") + params.append(started_before_float) + + # Build WHERE clause safely + where_clause: str = "" + if conditions: + where_clause = "WHERE " + " AND ".join(conditions) + + # Build ORDER BY clause + order_direction: str = "DESC" if reverse_order else "ASC" + order_clause: str = f"ORDER BY start_timestamp {order_direction}" + + # For better performance, only get metadata for counting and pagination + base_query: str = f"FROM executions {where_clause}" + count_query: str = f"SELECT COUNT(*) {base_query}" + + limit_exists: bool = limit is not None and limit > 0 + + # Only fetch data we need + if limit_exists: + data_query: str = f"SELECT durable_execution_arn, data {base_query} {order_clause} LIMIT ? OFFSET ?" + params_with_limit: list[str | float | int] = params + [ + cast(int, limit), + offset, + ] + else: + data_query = ( + f"SELECT durable_execution_arn, data {base_query} {order_clause}" + ) + params_with_limit = params + + with self._get_connection() as conn: + # Get total count for pagination + total_count: int = int(conn.execute(count_query, params).fetchone()[0]) + + # Get actual data + cursor: sqlite3.Cursor = conn.execute(data_query, params_with_limit) + rows: list[tuple[str, str]] = cursor.fetchall() + + # Only deserialize the executions we actually need + executions: list[Execution] = [] + for durable_execution_arn, data in rows: + try: + executions.append(Execution.from_json_dict(json.loads(data))) + except (json.JSONDecodeError, ValueError) as e: + # Log corrupted data but continue with other records + print( + f"Warning: Skipping corrupted execution {durable_execution_arn}: {e}" + ) + continue + + # Calculate pagination + has_more: bool = limit_exists and (offset + len(executions) < total_count) + next_marker: str | None = ( + str(offset + len(executions)) if has_more else None + ) + + return executions, next_marker + + except sqlite3.Error as e: + raise RuntimeException(f"Query failed: {e}") from e + except ValueError as e: + raise InvalidParameterValueException( + f"Invalid query parameters: {e}" + ) from e + + def list_all(self) -> list[Execution]: + """List all executions (for backward compatibility).""" + executions, _ = self.query() + return executions + + def get_execution_metadata(self, execution_arn: str) -> dict[str, Any] | None: + """Get just the metadata without full deserialization for performance.""" + try: + with self._get_connection() as conn: + cursor: sqlite3.Cursor = conn.execute( + "SELECT function_name, execution_name, status, start_timestamp, end_timestamp FROM executions WHERE durable_execution_arn = ?", + (execution_arn,), + ) + row: tuple[str, str | None, str, float | None, float | None] | None = ( + cursor.fetchone() + ) + + if not row: + return None + + return { + "durable_execution_arn": execution_arn, + "function_name": row[0], + "execution_name": row[1], + "status": row[2], + "start_timestamp": row[3], + "end_timestamp": row[4], + } + except sqlite3.Error as e: + raise RuntimeError( + f"Failed to get metadata for {execution_arn}: {e}" + ) from e diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/token.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/token.py new file mode 100644 index 0000000..23d81be --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/token.py @@ -0,0 +1,49 @@ +"""Token models.""" + +from __future__ import annotations + +import base64 +import json +from dataclasses import dataclass + + +@dataclass(frozen=True) +class CheckpointToken: + """Model a checkpoint token. This isn't exactly the same format as the actual svc, but it will do for testing purposes.""" + + execution_arn: str + token_sequence: int + + def to_str(self) -> str: + data = {"arn": self.execution_arn, "seq": self.token_sequence} + json_str = json.dumps(data, separators=(",", ":")) + # str -> bytes -> base64 bytes -> str + return base64.b64encode(json_str.encode()).decode() + + @classmethod + def from_str(cls, token: str) -> CheckpointToken: + # str -> base64 bytes -> str + decoded = base64.b64decode(token).decode() + data = json.loads(decoded) + return cls(execution_arn=data["arn"], token_sequence=data["seq"]) + + +@dataclass(frozen=True) +class CallbackToken: + """Model a callback token.""" + + execution_arn: str + operation_id: str + + def to_str(self) -> str: + data = {"arn": self.execution_arn, "op": self.operation_id} + json_str = json.dumps(data, separators=(",", ":")) + # str -> bytes -> base64 bytes -> str + return base64.b64encode(json_str.encode()).decode() + + @classmethod + def from_str(cls, token: str) -> CallbackToken: + # str -> base64 bytes -> str + decoded = base64.b64decode(token).decode() + data = json.loads(decoded) + return cls(execution_arn=data["arn"], operation_id=data["op"]) diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/__init__.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/__init__.py new file mode 100644 index 0000000..5f5f19e --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/__init__.py @@ -0,0 +1 @@ +"""Web server module for the AWS Durable Functions SDK Python Testing Framework.""" diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/errors.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/errors.py new file mode 100644 index 0000000..75bc81c --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/errors.py @@ -0,0 +1,8 @@ +"""Error handling utilities for AWS Lambda Durable Functions web service. + +This module is deprecated and will be removed. All error handling now uses +AWS-compliant exception classes directly. +""" + +# This file is kept temporarily for backward compatibility during migration. +# All functionality has been moved to direct AWS exception usage. diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/handlers.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/handlers.py new file mode 100644 index 0000000..e8cb841 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/handlers.py @@ -0,0 +1,813 @@ +"""HTTP endpoint handlers for AWS Lambda Durable Functions operations.""" + +from __future__ import annotations + +import base64 +import json +import logging +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, cast + +from aws_durable_execution_sdk_python_testing.exceptions import ( + AwsApiException, + ExecutionAlreadyStartedException, + ExecutionConflictException, + IllegalStateException, + InvalidParameterValueException, + ServiceException, +) +from aws_durable_execution_sdk_python_testing.model import ( + CheckpointDurableExecutionRequest, + CheckpointDurableExecutionResponse, + GetDurableExecutionHistoryResponse, + GetDurableExecutionStateResponse, + ListDurableExecutionsByFunctionRequest, + ListDurableExecutionsRequest, + ListDurableExecutionsResponse, + SendDurableExecutionCallbackFailureRequest, + SendDurableExecutionCallbackFailureResponse, + SendDurableExecutionCallbackHeartbeatRequest, + SendDurableExecutionCallbackHeartbeatResponse, + SendDurableExecutionCallbackSuccessResponse, + StartDurableExecutionInput, + StartDurableExecutionOutput, + StopDurableExecutionRequest, + StopDurableExecutionResponse, +) +from aws_durable_execution_sdk_python_testing.web.models import ( + HTTPRequest, + HTTPResponse, +) +from aws_durable_execution_sdk_python_testing.web.routes import ( + CallbackFailureRoute, + CallbackHeartbeatRoute, + CallbackSuccessRoute, + CheckpointDurableExecutionRoute, + GetDurableExecutionHistoryRoute, + GetDurableExecutionRoute, + GetDurableExecutionStateRoute, + ListDurableExecutionsByFunctionRoute, + StopDurableExecutionRoute, +) + + +if TYPE_CHECKING: + from aws_durable_execution_sdk_python_testing.executor import Executor + from aws_durable_execution_sdk_python_testing.web.routes import Route + +logger = logging.getLogger(__name__) + + +class EndpointHandler(ABC): + """Abstract base class for HTTP endpoint handlers.""" + + def __init__(self, executor: Executor) -> None: + """Initialize the handler with an executor. + + Args: + executor: The executor instance for handling operations + """ + self.executor = executor + + @abstractmethod + def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: + """Handle an HTTP request and return an HTTP response. + + Args: + parsed_route: The strongly-typed route object + request: The HTTP request data + + Returns: + HTTPResponse: The HTTP response to send to the client + """ + + def _parse_json_body(self, request: HTTPRequest) -> dict[str, Any]: + """Parse JSON body from HTTP request with validation. + + Args: + request: The HTTP request containing the JSON body + + Returns: + dict: The parsed JSON data + + Raises: + InvalidParameterValueException: If the request body is empty + """ + if not request.body: + msg = "Request body is required" + raise InvalidParameterValueException(msg) + return self._parse_json_body_optional(request) + + def _parse_json_body_optional(self, request: HTTPRequest) -> dict[str, Any]: + """Parse JSON body from HTTP request with validation. + + Args: + request: The HTTP request containing the JSON body + + Returns: + dict: The parsed JSON data + + Raises: + InvalidParameterValueException: If the request body is invalid JSON + """ + if not request.body: + return {} + + # Handle both dict and bytes body types + if isinstance(request.body, dict): + return request.body + + try: + return json.loads(request.body.decode("utf-8")) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + msg = f"Invalid JSON in request body: {e}" + raise InvalidParameterValueException(msg) from e + + def _json_response( + self, + status_code: int, + data: dict[str, Any], + additional_headers: dict[str, str] | None = None, + ) -> HTTPResponse: + """Create a JSON HTTP response. + + Args: + status_code: HTTP status code + data: Data to serialize as JSON + additional_headers: Optional additional headers to include + + Returns: + HTTPResponse: The HTTP response with JSON body + """ + return HTTPResponse.create_json(status_code, data, additional_headers) + + def _success_response( + self, data: dict[str, Any], additional_headers: dict[str, str] | None = None + ) -> HTTPResponse: + """Create a successful JSON response (200 OK). + + Args: + data: Data to serialize as JSON + additional_headers: Optional additional headers to include + + Returns: + HTTPResponse: The HTTP response with JSON body + """ + return self._json_response(200, data, additional_headers) + + def _created_response( + self, data: dict[str, Any], additional_headers: dict[str, str] | None = None + ) -> HTTPResponse: + """Create a created JSON response (201 Created). + + Args: + data: Data to serialize as JSON + additional_headers: Optional additional headers to include + + Returns: + HTTPResponse: The HTTP response with JSON body + """ + return self._json_response(201, data, additional_headers) + + def _no_content_response( + self, additional_headers: dict[str, str] | None = None + ) -> HTTPResponse: + """Create a no content response (204 No Content). + + Args: + additional_headers: Optional additional headers to include + + Returns: + HTTPResponse: The HTTP response with empty body + """ + return HTTPResponse.create_empty(204, additional_headers) + + # Removed deprecated _error_response method - use AWS exceptions directly + + def _parse_callback_result_payload(self, request: HTTPRequest) -> bytes: + """Parse callback result payload from request body. + + Args: + request: The HTTP request containing the binary payload + + Returns: + bytes: The result payload + + Raises: + InvalidParameterValueException: If payload parsing fails + """ + if not request.body or not isinstance(request.body, bytes): + return b"" + + return request.body + + def _parse_query_param(self, request: HTTPRequest, param_name: str) -> str | None: + """Parse a single query parameter from the request. + + Args: + request: The HTTP request + param_name: Name of the query parameter + + Returns: + str | None: The parameter value or None if not present + """ + param_values = request.query_params.get(param_name) + return param_values[0] if param_values else None + + def _parse_query_param_list( + self, request: HTTPRequest, param_name: str + ) -> list[str]: + """Parse a query parameter that can have multiple values. + + Args: + request: The HTTP request + param_name: Name of the query parameter + + Returns: + list[str]: List of parameter values (empty if not present) + """ + return request.query_params.get(param_name, []) + + def _validate_required_fields( + self, data: dict[str, Any], required_fields: list[str] + ) -> None: + """Validate that required fields are present in the data. + + Args: + data: The data dictionary to validate + required_fields: List of required field names + + Raises: + ValueError: If any required field is missing + """ + missing_fields = [field for field in required_fields if field not in data] + if missing_fields: + msg = f"Missing required fields: {', '.join(missing_fields)}" + raise InvalidParameterValueException(msg) + + def _handle_aws_exception(self, exception: AwsApiException) -> HTTPResponse: + """Handle AWS API exceptions directly. + + Args: + exception: The AWS API exception + + Returns: + HTTPResponse: AWS-compliant error response + """ + # Log server errors + if exception.http_status_code >= 500: # noqa: PLR2004 + logger.exception("Server error: %s", exception) + return HTTPResponse.create_error_from_exception(exception) + + def _handle_framework_exception(self, exception: Exception) -> HTTPResponse: + """Handle framework exceptions by mapping to AWS exceptions. + + Args: + exception: The framework exception + + Returns: + HTTPResponse: AWS-compliant error response + """ + if isinstance(exception, (ValueError | KeyError)): + return HTTPResponse.create_error_from_exception( + InvalidParameterValueException(str(exception)) + ) + logger.exception("Unexpected error: %s", exception) + return HTTPResponse.create_error_from_exception( + ServiceException(str(exception)) + ) + + +class StartExecutionHandler(EndpointHandler): + """Handler for POST /start-durable-execution.""" + + def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: # noqa: ARG002 + """Handle start execution request. + + Args: + parsed_route: The strongly-typed route object + request: The HTTP request data + + Returns: + HTTPResponse: The HTTP response to send to the client + """ + logger.debug("🌟 HANDLER: Received POST /start-durable-execution request") + try: + body_data: dict[str, Any] = self._parse_json_body(request) + logger.debug("🌟 HANDLER: Parsed request body successfully") + + start_input: StartDurableExecutionInput = ( + StartDurableExecutionInput.from_dict(body_data) + ) + logger.debug( + "🌟 HANDLER: Created StartDurableExecutionInput, calling executor.start_execution()" + ) + + start_output: StartDurableExecutionOutput = self.executor.start_execution( + start_input + ) + logger.debug("🌟 HANDLER: executor.start_execution() returned successfully") + + response_data: dict[str, Any] = start_output.to_dict() + + # Return HTTP 201 Created response + return self._created_response(response_data) + + except IllegalStateException as e: + # For StartExecution operations, map to ExecutionAlreadyStartedException + aws_exception = ExecutionAlreadyStartedException( + str(e), + "arn:aws:lambda:us-east-1:123456789012:function:test", + ) + return HTTPResponse.create_error_from_exception(aws_exception) + except AwsApiException as e: + return self._handle_aws_exception(e) + except Exception as e: # noqa: BLE001 + return self._handle_framework_exception(e) + + +class GetDurableExecutionHandler(EndpointHandler): + """Handler for GET /2025-12-01/durable-executions/{arn}.""" + + def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: # noqa: ARG002 + """Handle get durable execution request. + + Args: + parsed_route: The strongly-typed route object + request: The HTTP request data + + Returns: + HTTPResponse: The HTTP response to send to the client + """ + try: + route = cast(GetDurableExecutionRoute, parsed_route) + + execution_response = self.executor.get_execution_details(route.arn) + + response_data: dict[str, Any] = execution_response.to_dict() + + # HTTP 200 OK response + return self._success_response(response_data) + + except AwsApiException as e: + return self._handle_aws_exception(e) + except Exception as e: # noqa: BLE001 + return self._handle_framework_exception(e) + + +class CheckpointDurableExecutionHandler(EndpointHandler): + """Handler for POST /2025-12-01/durable-executions/{arn}/checkpoint.""" + + def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: + """Handle checkpoint durable execution request. + + Args: + parsed_route: The strongly-typed route object + request: The HTTP request data + + Returns: + HTTPResponse: The HTTP response to send to the client + """ + try: + body_data: dict[str, Any] = self._parse_json_body(request) + + checkpoint_route = cast(CheckpointDurableExecutionRoute, parsed_route) + execution_arn: str = checkpoint_route.arn + + checkpoint_request: CheckpointDurableExecutionRequest = ( + CheckpointDurableExecutionRequest.from_dict(body_data, execution_arn) + ) + + checkpoint_response: CheckpointDurableExecutionResponse = ( + self.executor.checkpoint_execution( + execution_arn, + checkpoint_request.checkpoint_token, + checkpoint_request.updates, + checkpoint_request.client_token, + ) + ) + + response_data: dict[str, Any] = checkpoint_response.to_dict() + + return self._success_response(response_data) + + except AwsApiException as e: + return self._handle_aws_exception(e) + except Exception as e: # noqa: BLE001 + return self._handle_framework_exception(e) + + +class StopDurableExecutionHandler(EndpointHandler): + """Handler for POST /2025-12-01/durable-executions/{arn}/stop.""" + + def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: + """Handle stop durable execution request. + + Args: + parsed_route: The strongly-typed route object + request: The HTTP request data + + Returns: + HTTPResponse: The HTTP response to send to the client + """ + try: + body_data: dict[str, Any] = self._parse_json_body_optional(request) + + stop_route = cast(StopDurableExecutionRoute, parsed_route) + execution_arn: str = stop_route.arn + + body_data["DurableExecutionArn"] = execution_arn + stop_request: StopDurableExecutionRequest = ( + StopDurableExecutionRequest.from_dict(body_data) + ) + + stop_response: StopDurableExecutionResponse = self.executor.stop_execution( + execution_arn, stop_request.error + ) + + response_data: dict[str, Any] = stop_response.to_dict() + + return self._success_response(response_data) + + except AwsApiException as e: + return self._handle_aws_exception(e) + except Exception as e: # noqa: BLE001 + return self._handle_framework_exception(e) + + +class GetDurableExecutionStateHandler(EndpointHandler): + """Handler for GET /2025-12-01/durable-executions/{arn}/state.""" + + def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: # noqa: ARG002 + """Handle get durable execution state request. + + Args: + parsed_route: The strongly-typed route object + request: The HTTP request data + + Returns: + HTTPResponse: The HTTP response to send to the client + """ + try: + state_route = cast(GetDurableExecutionStateRoute, parsed_route) + execution_arn: str = state_route.arn + + state_response: GetDurableExecutionStateResponse = ( + self.executor.get_execution_state(execution_arn) + ) + + response_data: dict[str, Any] = state_response.to_dict() + + return self._success_response(response_data) + + except AwsApiException as e: + return self._handle_aws_exception(e) + except Exception as e: # noqa: BLE001 + return self._handle_framework_exception(e) + + +class GetDurableExecutionHistoryHandler(EndpointHandler): + """Handler for GET /2025-12-01/durable-executions/{arn}/history.""" + + def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: + """Handle get durable execution history request. + + Args: + parsed_route: The strongly-typed route object + request: The HTTP request data + + Returns: + HTTPResponse: The HTTP response to send to the client + """ + try: + history_route = cast(GetDurableExecutionHistoryRoute, parsed_route) + execution_arn: str = history_route.arn + + max_items: str | None = self._parse_query_param(request, "MaxItems") + marker: str | None = self._parse_query_param(request, "Marker") + include_execution_data_str: str | None = self._parse_query_param( + request, "IncludeExecutionData" + ) + include_execution_data: bool = ( + include_execution_data_str == "true" + if include_execution_data_str + else False + ) + + history_response: GetDurableExecutionHistoryResponse = ( + self.executor.get_execution_history( + execution_arn, + include_execution_data=include_execution_data, + reverse_order=False, + marker=marker, + max_items=int(max_items) if max_items else None, + ) + ) + + response_data: dict[str, Any] = history_response.to_dict() + + return self._success_response(response_data) + + except AwsApiException as e: + return self._handle_aws_exception(e) + except Exception as e: # noqa: BLE001 + return self._handle_framework_exception(e) + + +class ListDurableExecutionsHandler(EndpointHandler): + """Handler for GET /2025-12-01/durable-executions.""" + + def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: # noqa: ARG002 + """Handle list durable executions request. + + Args: + parsed_route: The strongly-typed route object + request: The HTTP request data + + Returns: + HTTPResponse: The HTTP response to send to the client + """ + try: + list_request: ListDurableExecutionsRequest = ( + ListDurableExecutionsRequest.from_dict(request.query_params) + ) + + # Call executor method with correct attribute mapping + list_response: ListDurableExecutionsResponse = self.executor.list_executions( + function_name=list_request.function_name, + function_version=list_request.function_version, + execution_name=list_request.durable_execution_name, # Map to executor parameter + status_filter=list_request.status_filter[0] + if list_request.status_filter + else None, # Executor expects single string + started_after=list_request.started_after, + started_before=list_request.started_before, + marker=list_request.marker, + max_items=list_request.max_items + if list_request.max_items > 0 + else None, + reverse_order=list_request.reverse_order or False, + ) + + # Serialize response + response_data: dict[str, Any] = list_response.to_dict() + + return self._success_response(response_data) + + except AwsApiException as e: + return self._handle_aws_exception(e) + except Exception as e: # noqa: BLE001 + return self._handle_framework_exception(e) + + +class ListDurableExecutionsByFunctionHandler(EndpointHandler): + """Handler for GET /2025-12-01/functions/{function_name}/durable-executions.""" + + @staticmethod + def _validate_function_name(function_name: str) -> None: + """Validate function name parameter.""" + if not function_name or not function_name.strip(): + msg = "Function name is required" + raise InvalidParameterValueException(msg) + + def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: + """Handle list durable executions by function request. + + Args: + parsed_route: The strongly-typed route object + request: The HTTP request data + + Returns: + HTTPResponse: The HTTP response to send to the client + """ + function_route = cast(ListDurableExecutionsByFunctionRoute, parsed_route) + function_name: str = function_route.function_name + + # Validate function name before processing + self._validate_function_name(function_name) + + try: + # Add function name from route to query params + query_params = dict(request.query_params) + query_params["FunctionName"] = [function_name] + list_request = ListDurableExecutionsByFunctionRequest.from_dict( + query_params + ) + + list_response = self.executor.list_executions_by_function( + function_name=list_request.function_name, + qualifier=list_request.qualifier, + execution_name=list_request.durable_execution_name, + status_filter=list_request.status_filter[0] + if list_request.status_filter + else None, + started_after=list_request.started_after, + started_before=list_request.started_before, + marker=list_request.marker, + max_items=list_request.max_items + if list_request.max_items > 0 + else None, + reverse_order=list_request.reverse_order or False, + ) + + return self._success_response(list_response.to_dict()) + + except AwsApiException as e: + return self._handle_aws_exception(e) + except Exception as e: # noqa: BLE001 + return self._handle_framework_exception(e) + + +class SendDurableExecutionCallbackSuccessHandler(EndpointHandler): + """Handler for POST /2025-12-01/durable-execution-callbacks/{callback_id}/succeed.""" + + def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: + """Handle send durable execution callback success request. + + Args: + parsed_route: The strongly-typed route object + request: The HTTP request data + + Returns: + HTTPResponse: The HTTP response to send to the client + """ + try: + callback_route = cast(CallbackSuccessRoute, parsed_route) + callback_id: str = callback_route.callback_id + + result_bytes: bytes = self._parse_callback_result_payload(request) + + callback_response: SendDurableExecutionCallbackSuccessResponse = ( # noqa: F841 + self.executor.send_callback_success( + callback_id=callback_id, result=result_bytes + ) + ) + + logger.debug( + "Callback %s succeeded with result: %s", + callback_id, + result_bytes.decode("utf-8", errors="replace"), + ) + + # Callback success response is empty + return self._success_response({}) + + except IllegalStateException as e: + # For callback operations, map to ExecutionConflictException + aws_exception = ExecutionConflictException(str(e)) + return HTTPResponse.create_error_from_exception(aws_exception) + except AwsApiException as e: + return self._handle_aws_exception(e) + except Exception as e: # noqa: BLE001 + return self._handle_framework_exception(e) + + +class SendDurableExecutionCallbackFailureHandler(EndpointHandler): + """Handler for POST /2025-12-01/durable-execution-callbacks/{callback_id}/fail.""" + + def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: + """Handle send durable execution callback failure request. + + Args: + parsed_route: The strongly-typed route object + request: The HTTP request data + + Returns: + HTTPResponse: The HTTP response to send to the client + """ + try: + callback_route = cast(CallbackFailureRoute, parsed_route) + callback_id: str = callback_route.callback_id + + body_data: dict[str, Any] = self._parse_json_body_optional(request) + callback_request: SendDurableExecutionCallbackFailureRequest = ( + SendDurableExecutionCallbackFailureRequest.from_dict( + body_data, callback_id + ) + ) + + callback_response: SendDurableExecutionCallbackFailureResponse = ( # noqa: F841 + self.executor.send_callback_failure( + callback_id=callback_id, error=callback_request.error + ) + ) + + logger.debug( + "Callback %s failed with error: %s", callback_id, callback_request.error + ) + + # Callback failure response is empty + return self._success_response({}) + + except IllegalStateException as e: + # For callback operations, map to ExecutionConflictException + aws_exception = ExecutionConflictException(str(e)) + return HTTPResponse.create_error_from_exception(aws_exception) + except AwsApiException as e: + return self._handle_aws_exception(e) + except Exception as e: # noqa: BLE001 + return self._handle_framework_exception(e) + + +class SendDurableExecutionCallbackHeartbeatHandler(EndpointHandler): + """Handler for POST /2025-12-01/durable-execution-callbacks/{callback_id}/heartbeat.""" + + def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: + """Handle send durable execution callback heartbeat request. + + Args: + parsed_route: The strongly-typed route object + request: The HTTP request data + + Returns: + HTTPResponse: The HTTP response to send to the client + """ + try: + # Heartbeat requests don't have a body, only callback_id from URL + callback_route = cast(CallbackHeartbeatRoute, parsed_route) + callback_id: str = callback_route.callback_id + + callback_response: SendDurableExecutionCallbackHeartbeatResponse = ( # noqa: F841 + self.executor.send_callback_heartbeat(callback_id=callback_id) + ) + + # Callback heartbeat response is empty + return self._success_response({}) + + except IllegalStateException as e: + # For callback operations, map to ExecutionConflictException + aws_exception = ExecutionConflictException(str(e)) + return HTTPResponse.create_error_from_exception(aws_exception) + except AwsApiException as e: + return self._handle_aws_exception(e) + except Exception as e: # noqa: BLE001 + return self._handle_framework_exception(e) + + +# TODO: should this be /ping instead? +class HealthHandler(EndpointHandler): + """Handler for GET /health.""" + + def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: # noqa: ARG002 + """Handle health check request. + + Args: + parsed_route: The strongly-typed route object + request: The HTTP request data + + Returns: + HTTPResponse: The HTTP response to send to the client + """ + return self._success_response({"status": "healthy"}) + + +class MetricsHandler(EndpointHandler): + """Handler for GET /metrics.""" + + def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: # noqa: ARG002 + """Handle metrics request. + + Args: + parsed_route: The strongly-typed route object + request: The HTTP request data + + Returns: + HTTPResponse: The HTTP response to send to the client + """ + # TODO: Implement metrics collection logic + return self._success_response({"metrics": {}}) + + +class UpdateLambdaEndpointHandler(EndpointHandler): + """Handler for PUT /lambda-endpoint.""" + + def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: # noqa: ARG002 + """Handle update Lambda endpoint request. + + Args: + parsed_route: The strongly-typed route object + request: The HTTP request data + + Returns: + HTTPResponse: The HTTP response to send to the client + """ + try: + body = self._parse_json_body(request) + endpoint_url = body.get("EndpointUrl") + region_name = body.get("RegionName", "us-east-1") + + if not endpoint_url: + return self._handle_aws_exception( + InvalidParameterValueException("EndpointUrl is required") + ) + + # Update the invoker's Lambda endpoint + invoker = self.executor._invoker # noqa: SLF001 + logger.info("Updating lambda endpoint to %s", endpoint_url) + invoker.update_endpoint(endpoint_url, region_name) + return self._success_response( + {"message": "Lambda endpoint updated successfully"} + ) + + except Exception as e: # noqa: BLE001 + return self._handle_framework_exception(e) diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/models.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/models.py new file mode 100644 index 0000000..eebd0fe --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/models.py @@ -0,0 +1,266 @@ +"""HTTP request/response data models and utilities for the web runner.""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass +from typing import Any, Protocol + +from aws_durable_execution_sdk_python_testing.exceptions import ( + AwsApiException, + InvalidParameterValueException, +) + +# Removed deprecated imports from web.errors +from aws_durable_execution_sdk_python_testing.web.routes import Route +from aws_durable_execution_sdk_python_testing.web.serialization import ( + AwsRestJsonDeserializer, + JSONSerializer, + Serializer, +) + + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class HTTPRequest: + """HTTP request data model with dict or bytes body for handler logic.""" + + method: str + path: Route + headers: dict[str, str] + query_params: dict[str, list[str]] + body: dict[str, Any] | bytes + + @classmethod + def from_raw_bytes( + cls, + body_bytes: bytes, + method: str = "POST", + path: Route | None = None, + headers: dict[str, str] | None = None, + query_params: dict[str, list[str]] | None = None, + ) -> HTTPRequest: + """Create HTTPRequest with raw bytes body (no parsing).""" + if headers is None: + headers = {} + if query_params is None: + query_params = {} + if path is None: + path = Route.from_string("") + + return cls( + method=method, + path=path, + headers=headers, + query_params=query_params, + body=body_bytes, + ) + + @classmethod + def from_bytes( + cls, + body_bytes: bytes, + operation_name: str | None = None, + method: str = "POST", + path: Route | None = None, + headers: dict[str, str] | None = None, + query_params: dict[str, list[str]] | None = None, + ) -> HTTPRequest: + """Create HTTPRequest from raw bytes, deserializing to dict body. + + Args: + body_bytes: Raw bytes to deserialize + operation_name: Optional AWS operation name for boto deserialization + method: HTTP method (default: POST) + path: Route object (required for actual usage) + headers: HTTP headers (default: empty dict) + query_params: Query parameters (default: empty dict) + + Returns: + HTTPRequest: Request with deserialized dict body + + Raises: + InvalidParameterValueException: If deserialization fails with both AWS and JSON methods + """ + if headers is None: + headers = {} + if query_params is None: + query_params = {} + + # Skip body parsing for GET requests + if method == "GET": + body_dict = {} + logger.debug("GET request, skipping body parsing") + # Try AWS deserialization first if operation_name provided + elif operation_name: + try: + deserializer = AwsRestJsonDeserializer.create(operation_name) + body_dict = deserializer.from_bytes(body_bytes) + logger.debug( + "Successfully deserialized request using AWS deserializer for %s", + operation_name, + ) + except InvalidParameterValueException as e: + logger.warning( + "AWS deserialization failed for %s, falling back to JSON: %s", + operation_name, + e, + ) + # Fall back to standard JSON + try: + body_dict = json.loads(body_bytes.decode("utf-8")) + logger.debug( + "Successfully deserialized request using JSON fallback" + ) + except (json.JSONDecodeError, UnicodeDecodeError) as json_error: + msg = f"Both AWS and JSON deserialization failed: AWS error: {e}, JSON error: {json_error}" + raise InvalidParameterValueException(msg) from json_error + else: + # Use standard JSON deserialization + try: + if body_bytes == b"": + body_dict = {} + else: + body_dict = json.loads(body_bytes.decode("utf-8")) + logger.debug("Successfully deserialized request using standard JSON") + except (json.JSONDecodeError, UnicodeDecodeError) as e: + msg = f"JSON deserialization failed: {e}" + raise InvalidParameterValueException(msg) from e + + # Handle case where path is None for testing + if path is None: + path = Route.from_string("") + + return cls( + method=method, + path=path, + headers=headers, + query_params=query_params, + body=body_dict, + ) + + +@dataclass(frozen=True) +class HTTPResponse: + """HTTP response data model with dict body and serialization capabilities.""" + + status_code: int + headers: dict[str, str] + body: dict[str, Any] + serializer: Serializer = JSONSerializer() + + def body_to_bytes(self) -> bytes: + """Convert response dict body to bytes for HTTP transmission. + + Returns: + bytes: Serialized response body + + Raises: + InvalidParameterValueException: If serialization fails with both AWS and JSON methods + """ + result = self.serializer.to_bytes(data=self.body) + logger.debug("Serialized result - before: %s, after: %s", self.body, result) + return result + + @classmethod + def from_dict( + cls, + data: dict[str, Any], + status_code: int = 200, + headers: dict[str, str] | None = None, + ) -> HTTPResponse: + """Create HTTPResponse from dict data. + + Args: + data: Response data as dictionary + status_code: HTTP status code (default: 200) + headers: HTTP headers (default: empty dict) + + Returns: + HTTPResponse: Response with dict body + """ + if headers is None: + headers = {} + + return cls(status_code=status_code, headers=headers, body=data) + + @staticmethod + def create_json( + status_code: int, + data: dict[str, Any], + additional_headers: dict[str, str] | None = None, + ) -> HTTPResponse: + """Create a JSON HTTP response. + + Args: + status_code: HTTP status code + data: Data to serialize as JSON + additional_headers: Optional additional headers to include + + Returns: + HTTPResponse: The HTTP response with dict body + """ + headers = {"Content-Type": "application/json"} + if additional_headers: + headers.update(additional_headers) + + return HTTPResponse(status_code=status_code, headers=headers, body=data) + + # Removed deprecated create_error method - use create_error_from_exception instead + + @staticmethod + def create_error_from_exception(aws_exception: AwsApiException) -> HTTPResponse: + """Create AWS-compliant error response from AwsApiException. + + Args: + aws_exception: The AWS API exception to convert to HTTP response + + Returns: + HTTPResponse: The HTTP error response with AWS-compliant format + """ + if not isinstance(aws_exception, AwsApiException): + msg = f"Expected AwsApiException, got {type(aws_exception)}" + raise TypeError(msg) + + # Use exception's http_status_code and to_dict() method + # This removes the wrapper "error" object to match AWS format + error_data = aws_exception.to_dict() + return HTTPResponse.create_json(aws_exception.http_status_code, error_data) + + @staticmethod + def create_empty( + status_code: int, additional_headers: dict[str, str] | None = None + ) -> HTTPResponse: + """Create an empty HTTP response. + + Args: + status_code: HTTP status code + additional_headers: Optional additional headers to include + + Returns: + HTTPResponse: The HTTP response with empty dict body + """ + headers = {} + if additional_headers: + headers.update(additional_headers) + + return HTTPResponse(status_code=status_code, headers=headers, body={}) + + +class OperationHandler(Protocol): + """Protocol for handling HTTP operations with strongly-typed paths.""" + + def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: + """Handle an HTTP request and return an HTTP response. + + Args: + parsed_route: The strongly-typed route object + request: The HTTP request data + + Returns: + HTTPResponse: The HTTP response to send to the client + """ + ... # pragma: no cover diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/routes.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/routes.py new file mode 100644 index 0000000..bc09906 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/routes.py @@ -0,0 +1,692 @@ +"""Strongly-typed route parsing system for HTTP request routing.""" + +from __future__ import annotations + +from dataclasses import dataclass +from urllib.parse import unquote + +from aws_durable_execution_sdk_python_testing.exceptions import ( + UnknownRouteError, +) + + +@dataclass(frozen=True) +class Route: + """Base route with segments and pattern matching capabilities.""" + + raw_path: str + segments: list[str] + + @classmethod + def from_route(cls, _route: Route) -> Route: + """Create a typed route from a base Route. + + Note: Call is_match(route, method) first to ensure the route is valid for this type. + + Args: + route: Base route to convert + + Returns: + Typed route instance + + Raises: + NotImplementedError: This is an abstract method that must be implemented by subclasses + """ + msg = "Subclasses must implement from_route()" + raise NotImplementedError(msg) + + @classmethod + def from_string(cls, path: str) -> Route: + """Create a Route from a string. + + Each segment is URL-decoded; ``raw_path`` is preserved as the + original wire path. Splitting on ``/`` happens before decoding so + that an encoded ``%2F`` inside a captured value (e.g. an ARN that + contains ``/``) stays inside its segment instead of being treated + as a path separator. + + Args: + path: The raw path string + + Returns: + Route instance with parsed, URL-decoded segments + """ + # Remove leading/trailing slashes, split on '/', then URL-decode each + # segment. Order matters: split on the literal '/' first so '%2F'- + # encoded slashes inside values don't act as separators. + segments = [unquote(s) for s in path.strip("/").split("/") if s] + return cls(raw_path=path, segments=segments) + + def matches_pattern(self, pattern: list[str]) -> bool: + """Check if route matches the given pattern. + + Args: + pattern: List of pattern segments. Use '*' for wildcards. + + Returns: + True if the route matches the pattern + """ + if len(self.segments) != len(pattern): + return False + + for segment, pattern_part in zip(self.segments, pattern, strict=False): + if pattern_part not in ("*", segment): + return False + return True + + @classmethod + def is_match(cls, _route: Route, _method: str) -> bool: + """Check if the route and HTTP method match this route type. + + Args: + _route: Route to check + _method: HTTP method to check + + Returns: + True if the route and method match + + Raises: + NotImplementedError: This is an abstract method that must be implemented by subclasses + """ + msg = "Subclasses must implement is_match()" + raise NotImplementedError(msg) + + +@dataclass(frozen=True) +class StartExecutionRoute(Route): + """Route: POST /start-durable-execution""" + + @classmethod + def is_match(cls, route: Route, method: str) -> bool: + """Check if the route and HTTP method match this route type. + + Args: + route: Route to check + method: HTTP method to check + + Returns: + True if the route and method match + """ + return route.raw_path == "/start-durable-execution" and method == "POST" + + @classmethod + def from_route(cls, route: Route) -> StartExecutionRoute: + """Create a StartExecutionRoute from a base Route. + + Note: Call is_match(route, method) first to ensure the route is valid for this type. + + Args: + route: Base route to convert + + Returns: + StartExecutionRoute instance + """ + return cls(raw_path=route.raw_path, segments=route.segments) + + +@dataclass(frozen=True) +class GetDurableExecutionRoute(Route): + """Route: GET /2025-12-01/durable-executions/{arn}""" + + arn: str + + @classmethod + def is_match(cls, route: Route, method: str) -> bool: + """Check if the route and HTTP method match this route type. + + Args: + route: Route to check + method: HTTP method to check + + Returns: + True if the route and method match + """ + return ( + route.matches_pattern(["2025-12-01", "durable-executions", "*"]) + and method == "GET" + ) + + @classmethod + def from_route(cls, route: Route) -> GetDurableExecutionRoute: + """Create a GetDurableExecutionRoute from a base Route. + + Note: Call is_match(route, method) first to ensure the route is valid for this type. + + Args: + route: Base route to convert + + Returns: + GetDurableExecutionRoute instance with extracted ARN + """ + return cls( + raw_path=route.raw_path, + segments=route.segments, + arn=route.segments[2], + ) + + +@dataclass(frozen=True) +class CheckpointDurableExecutionRoute(Route): + """Route: POST /2025-12-01/durable-executions/{arn}/checkpoint""" + + arn: str + + @classmethod + def is_match(cls, route: Route, method: str) -> bool: + """Check if the route and HTTP method match this route type. + + Args: + route: Route to check + method: HTTP method to check + + Returns: + True if the route and method match + """ + return ( + route.matches_pattern( + ["2025-12-01", "durable-executions", "*", "checkpoint"] + ) + and method == "POST" + ) + + @classmethod + def from_route(cls, route: Route) -> CheckpointDurableExecutionRoute: + """Create a CheckpointDurableExecutionRoute from a base Route. + + Note: Call is_match(route, method) first to ensure the route is valid for this type. + + Args: + route: Base route to convert + + Returns: + CheckpointDurableExecutionRoute instance with extracted ARN + """ + return cls( + raw_path=route.raw_path, + segments=route.segments, + arn=route.segments[2], + ) + + +@dataclass(frozen=True) +class StopDurableExecutionRoute(Route): + """Route: POST /2025-12-01/durable-executions/{arn}/stop""" + + arn: str + + @classmethod + def is_match(cls, route: Route, method: str) -> bool: + """Check if the route and HTTP method match this route type. + + Args: + route: Route to check + method: HTTP method to check + + Returns: + True if the route and method match + """ + return ( + route.matches_pattern(["2025-12-01", "durable-executions", "*", "stop"]) + and method == "POST" + ) + + @classmethod + def from_route(cls, route: Route) -> StopDurableExecutionRoute: + """Create a StopDurableExecutionRoute from a base Route. + + Note: Call is_match(route, method) first to ensure the route is valid for this type. + + Args: + route: Base route to convert + + Returns: + StopDurableExecutionRoute instance with extracted ARN + """ + return cls( + raw_path=route.raw_path, + segments=route.segments, + arn=route.segments[2], + ) + + +@dataclass(frozen=True) +class GetDurableExecutionStateRoute(Route): + """Route: GET /2025-12-01/durable-executions/{arn}/state""" + + arn: str + + @classmethod + def is_match(cls, route: Route, method: str) -> bool: + """Check if the route and HTTP method match this route type. + + Args: + route: Route to check + method: HTTP method to check + + Returns: + True if the route and method match + """ + return ( + route.matches_pattern(["2025-12-01", "durable-executions", "*", "state"]) + and method == "GET" + ) + + @classmethod + def from_route(cls, route: Route) -> GetDurableExecutionStateRoute: + """Create a GetDurableExecutionStateRoute from a base Route. + + Note: Call is_match(route, method) first to ensure the route is valid for this type. + + Args: + route: Base route to convert + + Returns: + GetDurableExecutionStateRoute instance with extracted ARN + """ + return cls( + raw_path=route.raw_path, + segments=route.segments, + arn=route.segments[2], + ) + + +@dataclass(frozen=True) +class GetDurableExecutionHistoryRoute(Route): + """Route: GET /2025-12-01/durable-executions/{arn}/history""" + + arn: str + + @classmethod + def is_match(cls, route: Route, method: str) -> bool: + """Check if the route and HTTP method match this route type. + + Args: + route: Route to check + method: HTTP method to check + + Returns: + True if the route and method match + """ + return ( + route.matches_pattern(["2025-12-01", "durable-executions", "*", "history"]) + and method == "GET" + ) + + @classmethod + def from_route(cls, route: Route) -> GetDurableExecutionHistoryRoute: + """Create a GetDurableExecutionHistoryRoute from a base Route. + + Note: Call is_match(route, method) first to ensure the route is valid for this type. + + Args: + route: Base route to convert + + Returns: + GetDurableExecutionHistoryRoute instance with extracted ARN + """ + return cls( + raw_path=route.raw_path, + segments=route.segments, + arn=route.segments[2], + ) + + +@dataclass(frozen=True) +class ListDurableExecutionsRoute(Route): + """Route: GET /2025-12-01/durable-executions""" + + @classmethod + def is_match(cls, route: Route, method: str) -> bool: + """Check if the route and HTTP method match this route type. + + Args: + route: Route to check + method: HTTP method to check + + Returns: + True if the route and method match + """ + return ( + route.matches_pattern(["2025-12-01", "durable-executions"]) + and method == "GET" + ) + + @classmethod + def from_route(cls, route: Route) -> ListDurableExecutionsRoute: + """Create a ListDurableExecutionsRoute from a base Route. + + Note: Call is_match(route, method) first to ensure the route is valid for this type. + + Args: + route: Base route to convert + + Returns: + ListDurableExecutionsRoute instance + """ + return cls(raw_path=route.raw_path, segments=route.segments) + + +@dataclass(frozen=True) +class ListDurableExecutionsByFunctionRoute(Route): + """Route: GET /2025-12-01/functions/{function_name}/durable-executions""" + + function_name: str + + @classmethod + def is_match(cls, route: Route, method: str) -> bool: + """Check if the route and HTTP method match this route type. + + Args: + route: Route to check + method: HTTP method to check + + Returns: + True if the route and method match + """ + return ( + route.matches_pattern( + ["2025-12-01", "functions", "*", "durable-executions"] + ) + and method == "GET" + ) + + @classmethod + def from_route(cls, route: Route) -> ListDurableExecutionsByFunctionRoute: + """Create a ListDurableExecutionsByFunctionRoute from a base Route. + + Note: Call is_match(route, method) first to ensure the route is valid for this type. + + Args: + route: Base route to convert + + Returns: + ListDurableExecutionsByFunctionRoute instance with extracted function name + """ + return cls( + raw_path=route.raw_path, + segments=route.segments, + function_name=route.segments[2], + ) + + +@dataclass(frozen=True) +class BytesPayloadRoute(Route): + """Base class for routes that handle raw bytes payloads instead of JSON.""" + + +@dataclass(frozen=True) +class CallbackSuccessRoute(BytesPayloadRoute): + """Route: POST /2025-12-01/durable-execution-callbacks/{callback_id}/succeed""" + + callback_id: str + + @classmethod + def is_match(cls, route: Route, method: str) -> bool: + """Check if the route and HTTP method match this route type. + + Args: + route: Route to check + method: HTTP method to check + + Returns: + True if the route and method match + """ + return ( + route.matches_pattern( + ["2025-12-01", "durable-execution-callbacks", "*", "succeed"] + ) + and method == "POST" + ) + + @classmethod + def from_route(cls, route: Route) -> CallbackSuccessRoute: + """Create a CallbackSuccessRoute from a base Route. + + Note: Call is_match(route, method) first to ensure the route is valid for this type. + + Args: + route: Base route to convert + + Returns: + CallbackSuccessRoute instance with extracted callback ID + """ + return cls( + raw_path=route.raw_path, + segments=route.segments, + callback_id=route.segments[2], + ) + + +@dataclass(frozen=True) +class CallbackFailureRoute(BytesPayloadRoute): + """Route: POST /2025-12-01/durable-execution-callbacks/{callback_id}/fail""" + + callback_id: str + + @classmethod + def is_match(cls, route: Route, method: str) -> bool: + """Check if the route and HTTP method match this route type. + + Args: + route: Route to check + method: HTTP method to check + + Returns: + True if the route and method match + """ + return ( + route.matches_pattern( + ["2025-12-01", "durable-execution-callbacks", "*", "fail"] + ) + and method == "POST" + ) + + @classmethod + def from_route(cls, route: Route) -> CallbackFailureRoute: + """Create a CallbackFailureRoute from a base Route. + + Note: Call is_match(route, method) first to ensure the route is valid for this type. + + Args: + route: Base route to convert + + Returns: + CallbackFailureRoute instance with extracted callback ID + """ + return cls( + raw_path=route.raw_path, + segments=route.segments, + callback_id=route.segments[2], + ) + + +@dataclass(frozen=True) +class CallbackHeartbeatRoute(Route): + """Route: POST /2025-12-01/durable-execution-callbacks/{callback_id}/heartbeat""" + + callback_id: str + + @classmethod + def is_match(cls, route: Route, method: str) -> bool: + """Check if the route and HTTP method match this route type. + + Args: + route: Route to check + method: HTTP method to check + + Returns: + True if the route and method match + """ + return ( + route.matches_pattern( + ["2025-12-01", "durable-execution-callbacks", "*", "heartbeat"] + ) + and method == "POST" + ) + + @classmethod + def from_route(cls, route: Route) -> CallbackHeartbeatRoute: + """Create a CallbackHeartbeatRoute from a base Route. + + Note: Call is_match(route, method) first to ensure the route is valid for this type. + + Args: + route: Base route to convert + + Returns: + CallbackHeartbeatRoute instance with extracted callback ID + """ + return cls( + raw_path=route.raw_path, + segments=route.segments, + callback_id=route.segments[2], + ) + + +@dataclass(frozen=True) +class HealthRoute(Route): + """Route: GET /health""" + + @classmethod + def is_match(cls, route: Route, method: str) -> bool: + """Check if the route and HTTP method match this route type. + + Args: + route: Route to check + method: HTTP method to check + + Returns: + True if the route and method match + """ + return route.raw_path == "/health" and method == "GET" + + @classmethod + def from_route(cls, route: Route) -> HealthRoute: + """Create a HealthRoute from a base Route. + + Note: Call is_match(route, method) first to ensure the route is valid for this type. + + Args: + route: Base route to convert + + Returns: + HealthRoute instance + """ + return cls(raw_path=route.raw_path, segments=route.segments) + + +@dataclass(frozen=True) +class UpdateLambdaEndpointRoute(Route): + """Route: PUT /lambda-endpoint""" + + @classmethod + def is_match(cls, route: Route, method: str) -> bool: + """Check if the route and HTTP method match this route type. + + Args: + route: Route to check + method: HTTP method to check + + Returns: + True if the route and method match + """ + return route.raw_path == "/lambda-endpoint" and method == "PUT" + + @classmethod + def from_route(cls, route: Route) -> UpdateLambdaEndpointRoute: + """Create UpdateLambdaEndpointRoute from base route. + + Args: + route: Base route to convert + + Returns: + UpdateLambdaEndpointRoute instance + """ + return cls(raw_path=route.raw_path, segments=route.segments) + + +@dataclass(frozen=True) +class MetricsRoute(Route): + """Route: GET /metrics""" + + @classmethod + def is_match(cls, route: Route, method: str) -> bool: + """Check if the route and HTTP method match this route type. + + Args: + route: Route to check + method: HTTP method to check + + Returns: + True if the route and method match + """ + return route.raw_path == "/metrics" and method == "GET" + + @classmethod + def from_route(cls, route: Route) -> MetricsRoute: + """Create a MetricsRoute from a base Route. + + Note: Call is_match(route, method) first to ensure the route is valid for this type. + + Args: + route: Base route to convert + + Returns: + MetricsRoute instance + """ + return cls(raw_path=route.raw_path, segments=route.segments) + + +# Default registry of all route types for matching +DEFAULT_ROUTE_TYPES: list[type[Route]] = [ + StartExecutionRoute, + GetDurableExecutionRoute, + CheckpointDurableExecutionRoute, + StopDurableExecutionRoute, + GetDurableExecutionStateRoute, + GetDurableExecutionHistoryRoute, + ListDurableExecutionsRoute, + ListDurableExecutionsByFunctionRoute, + CallbackSuccessRoute, + CallbackFailureRoute, + CallbackHeartbeatRoute, + HealthRoute, + UpdateLambdaEndpointRoute, + MetricsRoute, +] + + +class Router: + """HTTP request router that matches routes to strongly-typed route objects.""" + + def __init__(self, route_types: list[type[Route]] | None = None) -> None: + """Initialize the router with route types. + + Args: + route_types: List of route type classes to use for matching. + If None, uses the default route types. + """ + self._route_types = ( + route_types if route_types is not None else DEFAULT_ROUTE_TYPES + ) + + def find_route(self, path: str, method: str) -> Route: + """Find a matching route for the given path and HTTP method. + + Args: + path: The raw path string to parse + method: The HTTP method (GET, POST, etc.) + + Returns: + Strongly-typed Route instance + + Raises: + UnknownRouteError: If the path and method don't match any known pattern + """ + base_route = Route.from_string(path) + + for route_type in self._route_types: + if route_type.is_match(base_route, method): + return route_type.from_route(base_route) + + raise UnknownRouteError(method, path) diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/serialization.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/serialization.py new file mode 100644 index 0000000..6532a66 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/serialization.py @@ -0,0 +1,235 @@ +"""Serialization interfaces and AWS boto integration for HTTP request/response models. + +This module provides Protocol interfaces for serialization and deserialization, +along with AWS-compatible implementations using boto's rest-json serializers. +""" + +from __future__ import annotations + +import json +import os +from typing import Any, Protocol +from datetime import datetime + +import aws_durable_execution_sdk_python +import botocore.loaders # type: ignore +from botocore.model import ServiceModel # type: ignore +from botocore.parsers import create_parser # type: ignore +from botocore.serialize import create_serializer # type: ignore + +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) + + +class Serializer(Protocol): + """Interface for serializing data to bytes.""" + + def to_bytes(self, data: Any) -> bytes: + """Serialize data to bytes. + + Args: + data: The data to serialize + + Returns: + bytes: The serialized data + + Raises: + InvalidParameterValueException: If serialization fails + """ + ... # pragma: no cover + + +class Deserializer(Protocol): + """Interface for deserializing bytes to data.""" + + def from_bytes(self, data: bytes) -> dict[str, Any]: + """Deserialize bytes to dictionary. + + Args: + data: The bytes to deserialize + + Returns: + dict: The deserialized data + + Raises: + InvalidParameterValueException: If deserialization fails + """ + ... # pragma: no cover + + +class JSONSerializer: + """JSON serializer with datetime support.""" + + def to_bytes(self, data: Any) -> bytes: + """Serialize data to JSON bytes.""" + try: + json_string = json.dumps( + data, separators=(",", ":"), default=self._default_handler + ) + return json_string.encode("utf-8") + except (TypeError, ValueError) as e: + raise InvalidParameterValueException( + f"Failed to serialize data to JSON: {str(e)}" + ) + + def _default_handler(self, obj: Any) -> float: + """Handle non-permitive objects.""" + if isinstance(obj, datetime): + return obj.timestamp() + # Raise TypeError for unsupported types + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") + + +class AwsRestJsonSerializer: + """AWS rest-json serializer using boto.""" + + def __init__(self, operation_name: str, serializer: Any, operation_model: Any): + """Initialize the AWS rest-json serializer. + + Args: + operation_name: Name of the AWS operation + serializer: Boto serializer instance + operation_model: Boto operation model + """ + self._operation_name = operation_name + self._serializer = serializer + self._operation_model = operation_model + + @classmethod + def create(cls, operation_name: str) -> AwsRestJsonSerializer: + """Create serializer with boto components. + + Args: + operation_name: Name of the AWS operation + + Returns: + AwsRestJsonSerializer: Configured serializer instance + + Raises: + InvalidParameterValueException: If serializer creation fails + """ + try: + # Load service model + loader = botocore.loaders.Loader() + + raw_model = loader.load_service_model("lambda", "service-2") + service_model = ServiceModel(raw_model) + + # Create serializer (rest-json protocol) + serializer = create_serializer("rest-json", include_validation=True) + operation_model = service_model.operation_model(operation_name) + + return cls(operation_name, serializer, operation_model) + except Exception as e: + msg = f"Failed to create serializer for {operation_name}: {e}" + raise InvalidParameterValueException(msg) from e + + def to_bytes(self, data: dict[str, Any]) -> bytes: + """Serialize data using boto rest-json serializer. + + Args: + data: Dictionary data to serialize + + Returns: + bytes: Serialized data + + Raises: + InvalidParameterValueException: If serialization fails + """ + if not self._serializer or not self._operation_model: + msg = f"Serializer not initialized for {self._operation_name}" + raise InvalidParameterValueException(msg) + + try: + serialized = self._serializer.serialize_to_request( + data, self._operation_model + ) + body = serialized.get("body", b"") + + if isinstance(body, str): + return body.encode("utf-8") + + return body # noqa: TRY300 + except Exception as e: + msg = f"Failed to serialize data for {self._operation_name}: {e}" + raise InvalidParameterValueException(msg) from e + + +class AwsRestJsonDeserializer: + """AWS rest-json deserializer using boto.""" + + def __init__(self, operation_name: str, parser: Any, operation_model: Any): + """Initialize the AWS rest-json deserializer. + + Args: + operation_name: Name of the AWS operation + parser: Boto parser instance + operation_model: Boto operation model + """ + self._operation_name = operation_name + self._parser = parser + self._operation_model = operation_model + + @classmethod + def create(cls, operation_name: str) -> AwsRestJsonDeserializer: + """Create deserializer with boto components. + + Args: + operation_name: Name of the AWS operation + + Returns: + AwsRestJsonDeserializer: Configured deserializer instance + + Raises: + InvalidParameterValueException: If deserializer creation fails + """ + try: + # Load service model + loader = botocore.loaders.Loader() + + raw_model = loader.load_service_model("lambda", "service-2") + service_model = ServiceModel(raw_model) + + # Create parser (rest-json protocol) + parser = create_parser("rest-json") + operation_model = service_model.operation_model(operation_name) + + return cls(operation_name, parser, operation_model) + except Exception as e: + msg = f"Failed to create deserializer for {operation_name}: {e}" + raise InvalidParameterValueException(msg) from e + + def from_bytes(self, data: bytes) -> dict[str, Any]: + """Deserialize bytes using boto rest-json parser. + + Args: + data: Bytes to deserialize + + Returns: + dict: Deserialized data + + Raises: + InvalidParameterValueException: If deserialization fails + """ + if not self._parser or not self._operation_model: + msg = f"Parser not initialized for {self._operation_name}" + raise InvalidParameterValueException(msg) + + try: + if self._operation_model.output_shape: + # Create response dict for boto parser + response_dict = { + "body": data, + "headers": {"content-type": "application/json"}, + "status_code": 200, + } + return self._parser.parse( + response_dict, self._operation_model.output_shape + ) + + # If no output shape, just parse as JSON + return json.loads(data.decode("utf-8")) + except Exception as e: + msg = f"Failed to deserialize data for {self._operation_name}: {e}" + raise InvalidParameterValueException(msg) from e diff --git a/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/server.py b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/server.py new file mode 100644 index 0000000..89cb219 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing/web/server.py @@ -0,0 +1,243 @@ +"""Local testing web server for AWS Lambda Durable Functions that mimics the actual Lambda backend services.""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from typing import TYPE_CHECKING, Self +from urllib.parse import parse_qs, urlparse + +from aws_durable_execution_sdk_python_testing.exceptions import ( + AwsApiException, + ServiceException, + UnknownRouteError, +) + + +if TYPE_CHECKING: + from aws_durable_execution_sdk_python_testing.executor import Executor + + +# Removed deprecated imports from web.errors +from aws_durable_execution_sdk_python_testing.web.handlers import ( + CheckpointDurableExecutionHandler, + EndpointHandler, + GetDurableExecutionHandler, + GetDurableExecutionHistoryHandler, + GetDurableExecutionStateHandler, + HealthHandler, + ListDurableExecutionsByFunctionHandler, + ListDurableExecutionsHandler, + MetricsHandler, + SendDurableExecutionCallbackFailureHandler, + SendDurableExecutionCallbackHeartbeatHandler, + SendDurableExecutionCallbackSuccessHandler, + StartExecutionHandler, + StopDurableExecutionHandler, + UpdateLambdaEndpointHandler, +) +from aws_durable_execution_sdk_python_testing.web.models import ( + HTTPRequest, + HTTPResponse, +) +from aws_durable_execution_sdk_python_testing.web.routes import ( + BytesPayloadRoute, + CallbackFailureRoute, + CallbackHeartbeatRoute, + CallbackSuccessRoute, + CheckpointDurableExecutionRoute, + GetDurableExecutionHistoryRoute, + GetDurableExecutionRoute, + GetDurableExecutionStateRoute, + HealthRoute, + ListDurableExecutionsByFunctionRoute, + ListDurableExecutionsRoute, + MetricsRoute, + Route, + Router, + StartExecutionRoute, + StopDurableExecutionRoute, + UpdateLambdaEndpointRoute, +) + + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class WebServiceConfig: + """Configuration for the web service.""" + + host: str = "localhost" + port: int = 5000 + log_level: int = logging.INFO + max_request_size: int = 10 * 1024 * 1024 # 10MB + + +class RequestHandler(BaseHTTPRequestHandler): + """HTTP request handler for durable execution operations.""" + + def __init__(self, request, client_address, server) -> None: + self.executor: Executor = server.executor + self.router: Router = server.router # Access shared router + self.endpoint_handlers: dict[type[Route], EndpointHandler] = ( + server.endpoint_handlers + ) # Access shared handlers + super().__init__(request, client_address, server) + + def do_GET(self) -> None: # noqa: N802 + """Handle GET requests.""" + self._handle_request("GET") + + def do_POST(self) -> None: # noqa: N802 + """Handle POST requests.""" + self._handle_request("POST") + + def do_PUT(self) -> None: # noqa: N802 + """Handle PUT requests.""" + self._handle_request("PUT") + + def _handle_request(self, method: str) -> None: + """Handle HTTP request with strongly-typed routing.""" + try: + # Parse URL path and method into strongly-typed Route object using shared router + url_path: str = self.path.split("?")[0] + parsed_route: Route = self.router.find_route(url_path, method) + + # Find handler for this route type + handler: EndpointHandler | None = self.endpoint_handlers.get( + type(parsed_route) + ) + + if not handler: + raise UnknownRouteError(method, url_path) # noqa: TRY301 + + # Parse query parameters and request body + parsed_url = urlparse(self.path) + query_params: dict[str, list[str]] = parse_qs(parsed_url.query) + content_length: int = int(self.headers.get("Content-Length", 0)) + body_bytes: bytes = ( + self.rfile.read(content_length) if content_length > 0 else b"" + ) + + # For callback operations, use raw bytes directly + if isinstance(parsed_route, BytesPayloadRoute): + request = HTTPRequest.from_raw_bytes( + body_bytes=body_bytes, + method=method, + path=parsed_route, + headers=dict(self.headers), + query_params=query_params, + ) + else: + # Create strongly-typed HTTP request object with pre-parsed body + request = HTTPRequest.from_bytes( + body_bytes=body_bytes, + operation_name=None, + method=method, + path=parsed_route, + headers=dict(self.headers), + query_params=query_params, + ) + + # Handle request with appropriate handler + response: HTTPResponse = handler.handle(parsed_route, request) + + # Send HTTP response + self._send_response(response) + + except Exception as e: + logger.exception("Request handling failed") + + aws_exception: AwsApiException = ( + e if isinstance(e, AwsApiException) else ServiceException(str(e)) + ) + + http_response = HTTPResponse.create_error_from_exception(aws_exception) + self._send_response(http_response) + + def _send_response(self, response: HTTPResponse) -> None: + """Send HTTP response to client.""" + self.send_response(response.status_code) + for header_name, header_value in response.headers.items(): + self.send_header(header_name, header_value) + self.end_headers() + + # Convert response body to bytes for transmission + if response.body: + self.wfile.write(response.body_to_bytes()) + + def log_message(self, format_string: str, *args) -> None: + """Override to use Python logging instead of stderr.""" + logger.info("%s - %s", self.address_string(), format_string % args) + + +class WebServer(ThreadingHTTPServer): + """Multi-threaded HTTP server for durable execution operations.""" + + def __init__(self, config: WebServiceConfig, executor: Executor) -> None: + """Initialize the web server. + + Args: + config: Server configuration + executor: Executor instance for handling operations + """ + self.config = config + self.executor = executor + + # Configure logging + logging.basicConfig(level=config.log_level) + logging.getLogger("botocore").setLevel(logging.WARNING) + + # Create shared router and endpoint handlers + self.router = Router() # Shared across all request handlers + self.endpoint_handlers = ( + self._create_endpoint_handlers() + ) # Shared handler registry + + # Initialize the HTTP server + super().__init__((config.host, config.port), RequestHandler) + + logger.info("Web server initialized on %s:%s", config.host, config.port) + + def _create_endpoint_handlers(self) -> dict[type[Route], EndpointHandler]: + """Create endpoint handlers registry - called once during server initialization.""" + return { + StartExecutionRoute: StartExecutionHandler(self.executor), + GetDurableExecutionRoute: GetDurableExecutionHandler(self.executor), + CheckpointDurableExecutionRoute: CheckpointDurableExecutionHandler( + self.executor + ), + StopDurableExecutionRoute: StopDurableExecutionHandler(self.executor), + GetDurableExecutionStateRoute: GetDurableExecutionStateHandler( + self.executor + ), + GetDurableExecutionHistoryRoute: GetDurableExecutionHistoryHandler( + self.executor + ), + ListDurableExecutionsRoute: ListDurableExecutionsHandler(self.executor), + ListDurableExecutionsByFunctionRoute: ListDurableExecutionsByFunctionHandler( + self.executor + ), + CallbackSuccessRoute: SendDurableExecutionCallbackSuccessHandler( + self.executor + ), + CallbackFailureRoute: SendDurableExecutionCallbackFailureHandler( + self.executor + ), + CallbackHeartbeatRoute: SendDurableExecutionCallbackHeartbeatHandler( + self.executor + ), + HealthRoute: HealthHandler(self.executor), + UpdateLambdaEndpointRoute: UpdateLambdaEndpointHandler(self.executor), + MetricsRoute: MetricsHandler(self.executor), + } + + def __enter__(self) -> Self: + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Context manager exit - cleanup server resources.""" + self.server_close() diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/__init__.py b/packages/aws-durable-execution-sdk-python-testing/tests/__init__.py new file mode 100644 index 0000000..66173ae --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/__init__.py @@ -0,0 +1 @@ +# Test package diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/__init__.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/__init__.py new file mode 100644 index 0000000..78d8de9 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/__init__.py @@ -0,0 +1 @@ +"""Test package""" diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processor_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processor_test.py new file mode 100644 index 0000000..1df24cc --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processor_test.py @@ -0,0 +1,295 @@ +"""Unit tests for CheckpointProcessor.""" + +from unittest.mock import Mock, patch + +import pytest +from aws_durable_execution_sdk_python.lambda_service import ( + CheckpointOutput, + CheckpointUpdatedExecutionState, + OperationAction, + OperationType, + OperationUpdate, + StateOutput, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.processor import ( + CheckpointProcessor, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) +from aws_durable_execution_sdk_python_testing.execution import Execution +from aws_durable_execution_sdk_python_testing.scheduler import Scheduler +from aws_durable_execution_sdk_python_testing.stores.base import ExecutionStore +from aws_durable_execution_sdk_python_testing.token import CheckpointToken + + +def test_init(): + """Test CheckpointProcessor initialization.""" + store = Mock(spec=ExecutionStore) + scheduler = Mock(spec=Scheduler) + + processor = CheckpointProcessor(store, scheduler) + + # Test that processor was created successfully by calling a public method + # This indirectly verifies that internal components were initialized + assert processor is not None + + # Test that we can add observers (verifies notifier is initialized) + observer = Mock() + processor.add_execution_observer(observer) # Should not raise an exception + + +@patch( + "aws_durable_execution_sdk_python_testing.checkpoint.processor.ExecutionNotifier" +) +def test_add_execution_observer(mock_notifier_class): + """Test adding execution observer.""" + store = Mock(spec=ExecutionStore) + scheduler = Mock(spec=Scheduler) + mock_notifier_instance = Mock() + mock_notifier_class.return_value = mock_notifier_instance + + processor = CheckpointProcessor(store, scheduler) + observer = Mock() + + processor.add_execution_observer(observer) + + # Verify observer was added through the notifier's public method + mock_notifier_instance.add_observer.assert_called_once_with(observer) + + +@patch( + "aws_durable_execution_sdk_python_testing.checkpoint.processor.CheckpointValidator" +) +@patch( + "aws_durable_execution_sdk_python_testing.checkpoint.processor.OperationTransformer" +) +def test_process_checkpoint_success(mock_transformer_class, mock_validator): + """Test successful checkpoint processing.""" + # Setup mocks + store = Mock(spec=ExecutionStore) + scheduler = Mock(spec=Scheduler) + mock_transformer_instance = Mock() + mock_transformer_class.return_value = mock_transformer_instance + + processor = CheckpointProcessor(store, scheduler) + + # Mock execution + execution = Mock(spec=Execution) + execution.is_complete = False + execution.token_sequence = 1 + execution.operations = [] + execution.updates = [] + execution.get_new_checkpoint_token.return_value = "new-token" + execution.get_navigable_operations.return_value = [] + + store.load.return_value = execution + + # Mock transformer + mock_transformer_instance.process_updates.return_value = ([], []) + + # Test data + checkpoint_token = "test-token" # noqa: S105 + updates = [ + OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.START, + ) + ] + + # Mock token parsing + with patch.object(CheckpointToken, "from_str") as mock_from_str: + mock_token = Mock() + mock_token.execution_arn = "arn:test" + mock_token.token_sequence = 1 + mock_from_str.return_value = mock_token + + result = processor.process_checkpoint(checkpoint_token, updates, "client-token") + + # Verify calls + store.load.assert_called_once_with("arn:test") + mock_validator.validate_input.assert_called_once_with(updates, execution) + mock_transformer_instance.process_updates.assert_called_once() + store.update.assert_called_once_with(execution) + + # Verify result + assert isinstance(result, CheckpointOutput) + assert result.checkpoint_token == "new-token" # noqa: S105 + assert isinstance(result.new_execution_state, CheckpointUpdatedExecutionState) + + +@patch( + "aws_durable_execution_sdk_python_testing.checkpoint.processor.CheckpointValidator" +) +def test_process_checkpoint_invalid_token_complete_execution(mock_validator): + """Test checkpoint processing with complete execution.""" + store = Mock(spec=ExecutionStore) + scheduler = Mock(spec=Scheduler) + processor = CheckpointProcessor(store, scheduler) + + # Mock execution as complete + execution = Mock(spec=Execution) + execution.is_complete = True + execution.token_sequence = 1 + + store.load.return_value = execution + + checkpoint_token = "test-token" # noqa: S105 + updates = [] + + with patch.object(CheckpointToken, "from_str") as mock_from_str: + mock_token = Mock() + mock_token.execution_arn = "arn:test" + mock_token.token_sequence = 1 + mock_from_str.return_value = mock_token + + with pytest.raises( + InvalidParameterValueException, match="Invalid checkpoint token" + ): + processor.process_checkpoint(checkpoint_token, updates, "client-token") + + +@patch( + "aws_durable_execution_sdk_python_testing.checkpoint.processor.CheckpointValidator" +) +def test_process_checkpoint_invalid_token_sequence(mock_validator): + """Test checkpoint processing with invalid token sequence.""" + store = Mock(spec=ExecutionStore) + scheduler = Mock(spec=Scheduler) + processor = CheckpointProcessor(store, scheduler) + + # Mock execution with different token sequence + execution = Mock(spec=Execution) + execution.is_complete = False + execution.token_sequence = 2 + + store.load.return_value = execution + + checkpoint_token = "test-token" # noqa: S105 + updates = [] + + with patch.object(CheckpointToken, "from_str") as mock_from_str: + mock_token = Mock() + mock_token.execution_arn = "arn:test" + mock_token.token_sequence = 1 # Different from execution + mock_from_str.return_value = mock_token + + with pytest.raises( + InvalidParameterValueException, match="Invalid checkpoint token" + ): + processor.process_checkpoint(checkpoint_token, updates, "client-token") + + +@patch( + "aws_durable_execution_sdk_python_testing.checkpoint.processor.CheckpointValidator" +) +@patch( + "aws_durable_execution_sdk_python_testing.checkpoint.processor.OperationTransformer" +) +def test_process_checkpoint_updates_execution_state( + mock_transformer_class, mock_validator +): + """Test that checkpoint processing updates execution state correctly.""" + store = Mock(spec=ExecutionStore) + scheduler = Mock(spec=Scheduler) + mock_transformer_instance = Mock() + mock_transformer_class.return_value = mock_transformer_instance + + processor = CheckpointProcessor(store, scheduler) + + # Mock execution + execution = Mock(spec=Execution) + execution.is_complete = False + execution.token_sequence = 1 + execution.operations = [] + execution.updates = [] + execution.get_new_checkpoint_token.return_value = "new-token" + execution.get_navigable_operations.return_value = [] + + store.load.return_value = execution + + # Mock transformer to return updated operations and updates + updated_operations = [Mock()] + all_updates = [Mock()] + mock_transformer_instance.process_updates.return_value = ( + updated_operations, + all_updates, + ) + + checkpoint_token = "test-token" # noqa: S105 + updates = [ + OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.START, + ) + ] + + with patch.object(CheckpointToken, "from_str") as mock_from_str: + mock_token = Mock() + mock_token.execution_arn = "arn:test" + mock_token.token_sequence = 1 + mock_from_str.return_value = mock_token + + processor.process_checkpoint(checkpoint_token, updates, "client-token") + + # Verify execution state was updated + assert execution.operations == updated_operations + # Check that updates were extended (execution.updates is a real list) + assert len(execution.updates) == len(all_updates) + + +def test_get_execution_state(): + """Test getting execution state.""" + store = Mock(spec=ExecutionStore) + scheduler = Mock(spec=Scheduler) + processor = CheckpointProcessor(store, scheduler) + + # Mock execution + execution = Mock(spec=Execution) + navigable_ops = [Mock()] + execution.get_navigable_operations.return_value = navigable_ops + + store.load.return_value = execution + + checkpoint_token = "test-token" # noqa: S105 + + with patch.object(CheckpointToken, "from_str") as mock_from_str: + mock_token = Mock() + mock_token.execution_arn = "arn:test" + mock_from_str.return_value = mock_token + + result = processor.get_execution_state(checkpoint_token, "next-marker", 500) + + # Verify calls + store.load.assert_called_once_with("arn:test") + execution.get_navigable_operations.assert_called_once() + + # Verify result + assert isinstance(result, StateOutput) + assert result.operations == navigable_ops + assert result.next_marker is None + + +def test_get_execution_state_default_max_items(): + """Test getting execution state with default max_items.""" + store = Mock(spec=ExecutionStore) + scheduler = Mock(spec=Scheduler) + processor = CheckpointProcessor(store, scheduler) + + execution = Mock(spec=Execution) + execution.get_navigable_operations.return_value = [] + store.load.return_value = execution + + checkpoint_token = "test-token" # noqa: S105 + + with patch.object(CheckpointToken, "from_str") as mock_from_str: + mock_token = Mock() + mock_token.execution_arn = "arn:test" + mock_from_str.return_value = mock_token + + result = processor.get_execution_state(checkpoint_token, "next-marker") + + assert isinstance(result, StateOutput) diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/__init__.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/__init__.py new file mode 100644 index 0000000..78d8de9 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/__init__.py @@ -0,0 +1 @@ +"""Test package""" diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/base_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/base_test.py new file mode 100644 index 0000000..700fd96 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/base_test.py @@ -0,0 +1,447 @@ +"""Tests for base operation processor.""" + +import datetime +from datetime import timedelta +from unittest.mock import Mock + +import pytest +from aws_durable_execution_sdk_python.lambda_service import ( + CallbackDetails, + ChainedInvokeDetails, + ChainedInvokeOptions, + ContextDetails, + ErrorObject, + ExecutionDetails, + Operation, + OperationAction, + OperationStatus, + OperationType, + OperationUpdate, + StepDetails, + WaitDetails, + WaitOptions, + ContextOptions, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.processors.base import ( + OperationProcessor, +) + + +def test_process_not_implemented(): + processor = OperationProcessor() + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.START, + ) + + try: + processor.process(update, None, Mock(), "test-arn") + pytest.fail("Expected NotImplementedError") + except NotImplementedError: + pass + + +class MockProcessor(OperationProcessor): + """Mock processor for testing base functionality.""" + + def process(self, update, current_op, notifier, execution_arn): + return self._translate_update_to_operation( + update, current_op, OperationStatus.STARTED + ) + + def translate_update(self, update, current_op, status): + """Public method to access _translate_update_to_operation for testing.""" + return self._translate_update_to_operation(update, current_op, status) + + def get_end_time(self, current_op, status): + """Public method to access _get_end_time for testing.""" + return self._get_end_time(current_op, status) + + def create_execution_details(self, update): + """Public method to access _create_execution_details for testing.""" + return self._create_execution_details(update) + + def create_context_details(self, update): + """Public method to access _create_context_details for testing.""" + return self._create_context_details(update) + + def create_step_details(self, update, current_operation): + """Public method to access _create_step_details for testing.""" + return self._create_step_details(update, current_operation) + + def create_callback_details(self, update): + """Public method to access _create_callback_details for testing.""" + return self._create_callback_details(update) + + def create_invoke_details(self, update): + """Public method to access _create_invoke_details for testing.""" + return self._create_invoke_details(update) + + def create_wait_details(self, update, current_op): + """Public method to access _create_wait_details for testing.""" + return self._create_wait_details(update, current_op) + + +def test_get_end_time_with_existing_end_timestamp(): + processor = MockProcessor() + end_time = datetime.datetime.now(tz=datetime.UTC) + current_op = Mock() + current_op.end_timestamp = end_time + + result = processor.get_end_time(current_op, OperationStatus.STARTED) + + assert result == end_time + + +def test_get_end_time_with_terminal_status(): + processor = MockProcessor() + current_op = Mock() + current_op.end_timestamp = None + + result = processor.get_end_time(current_op, OperationStatus.SUCCEEDED) + + assert result is not None + assert isinstance(result, datetime.datetime) + + +def test_get_end_time_with_non_terminal_status(): + processor = MockProcessor() + current_op = Mock() + current_op.end_timestamp = None + + result = processor.get_end_time(current_op, OperationStatus.STARTED) + + assert result is None + + +def test_create_execution_details(): + processor = MockProcessor() + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.EXECUTION, + action=OperationAction.START, + payload="test-payload", + ) + + result = processor.create_execution_details(update) + + assert isinstance(result, ExecutionDetails) + assert result.input_payload == "test-payload" + + +def test_create_execution_details_non_execution_type(): + processor = MockProcessor() + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.START, + payload="test-payload", + ) + + result = processor.create_execution_details(update) + + assert result is None + + +def test_create_context_details(): + processor = MockProcessor() + error = ErrorObject.from_message("test error") + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CONTEXT, + action=OperationAction.START, + payload="test-payload", + error=error, + ) + + result = processor.create_context_details(update) + + assert isinstance(result, ContextDetails) + assert result.result == "test-payload" + assert result.error == error + + +def test_create_context_details_non_context_type(): + processor = MockProcessor() + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.START, + payload="test-payload", + ) + + result = processor.create_context_details(update) + + assert result is None + + +def test_create_step_details(): + processor = MockProcessor() + error = ErrorObject.from_message("test error") + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.START, + payload="test-payload", + error=error, + ) + + current_op = Mock() + current_op.step_details = Mock() + current_op.step_details.attempt = Mock() + + result = processor.create_step_details(update, current_op) + + assert isinstance(result, StepDetails) + assert result.result == "test-payload" + assert result.error == error + + +def test_create_context_details_with_replay_children(): + processor = MockProcessor() + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CONTEXT, + action=OperationAction.SUCCEED, + payload="test-payload", + context_options=ContextOptions(replay_children=True), + ) + + result = processor.create_context_details(update) + + assert isinstance(result, ContextDetails) + assert result.result == "test-payload" + assert result.replay_children == True + + +def test_create_step_details_non_step_type(): + processor = MockProcessor() + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CONTEXT, + action=OperationAction.START, + payload="test-payload", + ) + + current_op = Mock() + current_op.step_details = Mock() + current_op.step_details.attempt = Mock() + + result = processor.create_step_details(update, current_op) + + assert result is None + + +def test_create_step_details_without_current_operation(): + processor = MockProcessor() + error = ErrorObject.from_message("test error") + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.START, + payload="test-payload", + error=error, + ) + + result = processor.create_step_details(update, None) + + assert isinstance(result, StepDetails) + assert result.result == "test-payload" + assert result.error == error + assert result.attempt == 0 + + +def test_create_callback_details(): + processor = MockProcessor() + error = ErrorObject.from_message("test error") + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CALLBACK, + action=OperationAction.START, + payload="test-payload", + error=error, + ) + + result = processor.create_callback_details(update) + + assert isinstance(result, CallbackDetails) + assert result.callback_id == "placeholder" + assert result.result == "test-payload" + assert result.error == error + + +def test_create_callback_details_non_callback_type(): + processor = MockProcessor() + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.START, + payload="test-payload", + ) + + result = processor.create_callback_details(update) + + assert result is None + + +def test_create_invoke_details(): + processor = MockProcessor() + error = ErrorObject.from_message("test error") + invoke_options = ChainedInvokeOptions(function_name="test-function") + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CHAINED_INVOKE, + action=OperationAction.START, + payload="test-payload", + error=error, + chained_invoke_options=invoke_options, + ) + + result = processor.create_invoke_details(update) + + assert isinstance(result, ChainedInvokeDetails) + assert result.result == "test-payload" + assert result.error == error + + +def test_create_invoke_details_non_invoke_type(): + processor = MockProcessor() + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.START, + payload="test-payload", + ) + + result = processor.create_invoke_details(update) + + assert result is None + + +def test_create_invoke_details_no_options(): + processor = MockProcessor() + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CHAINED_INVOKE, + action=OperationAction.START, + payload="test-payload", + ) + + result = processor.create_invoke_details(update) + + assert result is None + + +def test_create_wait_details_with_current_operation(): + processor = MockProcessor() + scheduled_end_timestamp = datetime.datetime.now(tz=datetime.UTC) + current_op = Mock() + current_op.wait_details = WaitDetails( + scheduled_end_timestamp=scheduled_end_timestamp + ) + + wait_options = WaitOptions(wait_seconds=30) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.WAIT, + action=OperationAction.START, + wait_options=wait_options, + ) + + result = processor.create_wait_details(update, current_op) + + assert isinstance(result, WaitDetails) + assert result.scheduled_end_timestamp == scheduled_end_timestamp + + +def test_create_wait_details_without_current_operation(): + processor = MockProcessor() + wait_options = WaitOptions(wait_seconds=30) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.WAIT, + action=OperationAction.START, + wait_options=wait_options, + ) + + result = processor.create_wait_details(update, None) + + assert isinstance(result, WaitDetails) + assert result.scheduled_end_timestamp > datetime.datetime.now(tz=datetime.UTC) + + +def test_create_wait_details_non_wait_type(): + processor = MockProcessor() + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.START, + ) + + result = processor.create_wait_details(update, None) + + assert result is None + + +def test_translate_update_to_operation_with_current_operation(): + processor = MockProcessor() + start_time = datetime.datetime.now(tz=datetime.UTC) - timedelta(minutes=5) + current_op = Mock() + current_op.start_timestamp = start_time + + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.START, + parent_id="parent-id", + name="test-operation", + sub_type="test-subtype", + ) + + result = processor.translate_update(update, current_op, OperationStatus.STARTED) + + assert isinstance(result, Operation) + assert result.operation_id == "test-id" + assert result.parent_id == "parent-id" + assert result.name == "test-operation" + assert result.start_timestamp == start_time + assert result.operation_type == OperationType.STEP + assert result.status == OperationStatus.STARTED + assert result.sub_type == "test-subtype" + + +def test_translate_update_to_operation_without_current_operation(): + processor = MockProcessor() + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.START, + parent_id="parent-id", + name="test-operation", + ) + + result = processor.translate_update(update, None, OperationStatus.STARTED) + + assert isinstance(result, Operation) + assert result.operation_id == "test-id" + assert result.parent_id == "parent-id" + assert result.name == "test-operation" + assert result.start_timestamp is not None + assert result.operation_type == OperationType.STEP + assert result.status == OperationStatus.STARTED + + +def test_translate_update_to_operation_with_terminal_status(): + processor = MockProcessor() + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.START, + ) + + result = processor.translate_update(update, None, OperationStatus.SUCCEEDED) + + assert result.end_timestamp is not None + assert result.status == OperationStatus.SUCCEEDED diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/callback_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/callback_test.py new file mode 100644 index 0000000..24bc129 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/callback_test.py @@ -0,0 +1,259 @@ +"""Tests for callback operation processor.""" + +from unittest.mock import Mock + +import pytest +from aws_durable_execution_sdk_python.lambda_service import ( + Operation, + OperationAction, + OperationStatus, + OperationType, + OperationUpdate, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.processors.callback import ( + CallbackProcessor, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) +from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier + + +class MockNotifier(ExecutionNotifier): + """Mock notifier for testing.""" + + def __init__(self): + super().__init__() + self.completed_calls = [] + self.failed_calls = [] + self.wait_timer_calls = [] + self.step_retry_calls = [] + + def notify_completed(self, execution_arn, result=None): + self.completed_calls.append((execution_arn, result)) + + def notify_failed(self, execution_arn, error): + self.failed_calls.append((execution_arn, error)) + + def notify_wait_timer_scheduled(self, execution_arn, operation_id, delay): + self.wait_timer_calls.append((execution_arn, operation_id, delay)) + + def notify_step_retry_scheduled(self, execution_arn, operation_id, delay): + self.step_retry_calls.append((execution_arn, operation_id, delay)) + + +def test_process_start_action(): + processor = CallbackProcessor() + notifier = MockNotifier() + + update = OperationUpdate( + operation_id="callback-123", + operation_type=OperationType.CALLBACK, + action=OperationAction.START, + name="test-callback", + ) + + result = processor.process( + update, None, notifier, "arn:aws:states:us-east-1:123456789012:execution:test" + ) + + assert isinstance(result, Operation) + assert result.operation_id == "callback-123" + assert result.operation_type == OperationType.CALLBACK + assert result.status == OperationStatus.STARTED + assert result.name == "test-callback" + assert result.callback_details is not None + + +def test_process_start_action_with_current_operation(): + processor = CallbackProcessor() + notifier = MockNotifier() + + current_op = Mock() + current_op.start_timestamp = Mock() + + update = OperationUpdate( + operation_id="callback-123", + operation_type=OperationType.CALLBACK, + action=OperationAction.START, + name="test-callback", + ) + + result = processor.process( + update, + current_op, + notifier, + "arn:aws:states:us-east-1:123456789012:execution:test", + ) + + assert isinstance(result, Operation) + assert result.operation_id == "callback-123" + assert result.status == OperationStatus.STARTED + assert result.start_timestamp == current_op.start_timestamp + + +def test_process_invalid_action(): + processor = CallbackProcessor() + notifier = MockNotifier() + + update = OperationUpdate( + operation_id="callback-123", + operation_type=OperationType.CALLBACK, + action=OperationAction.SUCCEED, + name="test-callback", + ) + + with pytest.raises( + InvalidParameterValueException, match="Invalid action for CALLBACK operation" + ): + processor.process( + update, + None, + notifier, + "arn:aws:states:us-east-1:123456789012:execution:test", + ) + + +def test_process_fail_action(): + processor = CallbackProcessor() + notifier = MockNotifier() + + update = OperationUpdate( + operation_id="callback-123", + operation_type=OperationType.CALLBACK, + action=OperationAction.FAIL, + name="test-callback", + ) + + with pytest.raises( + InvalidParameterValueException, match="Invalid action for CALLBACK operation" + ): + processor.process( + update, + None, + notifier, + "arn:aws:states:us-east-1:123456789012:execution:test", + ) + + +def test_process_cancel_action(): + processor = CallbackProcessor() + notifier = MockNotifier() + + update = OperationUpdate( + operation_id="callback-123", + operation_type=OperationType.CALLBACK, + action=OperationAction.CANCEL, + name="test-callback", + ) + + with pytest.raises( + InvalidParameterValueException, match="Invalid action for CALLBACK operation" + ): + processor.process( + update, + None, + notifier, + "arn:aws:states:us-east-1:123456789012:execution:test", + ) + + +def test_process_retry_action(): + processor = CallbackProcessor() + notifier = MockNotifier() + + update = OperationUpdate( + operation_id="callback-123", + operation_type=OperationType.CALLBACK, + action=OperationAction.RETRY, + name="test-callback", + ) + + with pytest.raises( + InvalidParameterValueException, match="Invalid action for CALLBACK operation" + ): + processor.process( + update, + None, + notifier, + "arn:aws:states:us-east-1:123456789012:execution:test", + ) + + +def test_process_with_payload(): + processor = CallbackProcessor() + notifier = MockNotifier() + + update = OperationUpdate( + operation_id="callback-123", + operation_type=OperationType.CALLBACK, + action=OperationAction.START, + name="test-callback", + payload="test-payload", + ) + + result = processor.process( + update, None, notifier, "arn:aws:states:us-east-1:123456789012:execution:test" + ) + + assert result.callback_details.result == "test-payload" + + +def test_process_with_parent_id(): + processor = CallbackProcessor() + notifier = MockNotifier() + + update = OperationUpdate( + operation_id="callback-123", + operation_type=OperationType.CALLBACK, + action=OperationAction.START, + name="test-callback", + parent_id="parent-456", + ) + + result = processor.process( + update, None, notifier, "arn:aws:states:us-east-1:123456789012:execution:test" + ) + + assert result.parent_id == "parent-456" + + +def test_process_with_sub_type(): + processor = CallbackProcessor() + notifier = MockNotifier() + + update = OperationUpdate( + operation_id="callback-123", + operation_type=OperationType.CALLBACK, + action=OperationAction.START, + name="test-callback", + sub_type="activity", + ) + + result = processor.process( + update, None, notifier, "arn:aws:states:us-east-1:123456789012:execution:test" + ) + + assert result.sub_type == "activity" + + +def test_notifier_not_called_for_start(): + processor = CallbackProcessor() + notifier = MockNotifier() + + update = OperationUpdate( + operation_id="callback-123", + operation_type=OperationType.CALLBACK, + action=OperationAction.START, + name="test-callback", + ) + + processor.process( + update, None, notifier, "arn:aws:states:us-east-1:123456789012:execution:test" + ) + + assert len(notifier.completed_calls) == 0 + assert len(notifier.failed_calls) == 0 + assert len(notifier.wait_timer_calls) == 0 + assert len(notifier.step_retry_calls) == 0 diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/context_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/context_test.py new file mode 100644 index 0000000..a070bc1 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/context_test.py @@ -0,0 +1,379 @@ +"""Tests for context operation processor.""" + +from datetime import UTC, datetime +from unittest.mock import Mock + +import pytest +from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + Operation, + OperationAction, + OperationStatus, + OperationType, + OperationUpdate, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.processors.context import ( + ContextProcessor, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) +from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier + + +class MockNotifier(ExecutionNotifier): + """Mock notifier for testing.""" + + def __init__(self): + super().__init__() + self.completed_calls = [] + self.failed_calls = [] + self.wait_timer_calls = [] + self.step_retry_calls = [] + + def notify_completed(self, execution_arn, result=None): + self.completed_calls.append((execution_arn, result)) + + def notify_failed(self, execution_arn, error): + self.failed_calls.append((execution_arn, error)) + + def notify_wait_timer_scheduled(self, execution_arn, operation_id, delay): + self.wait_timer_calls.append((execution_arn, operation_id, delay)) + + def notify_step_retry_scheduled(self, execution_arn, operation_id, delay): + self.step_retry_calls.append((execution_arn, operation_id, delay)) + + +def test_process_start_action(): + processor = ContextProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="context-123", + operation_type=OperationType.CONTEXT, + action=OperationAction.START, + name="test-context", + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert isinstance(result, Operation) + assert result.operation_id == "context-123" + assert result.operation_type == OperationType.CONTEXT + assert result.status == OperationStatus.STARTED + assert result.name == "test-context" + assert result.context_details is not None + + +def test_process_start_action_with_current_operation(): + processor = ContextProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + current_op = Mock() + current_op.start_timestamp = datetime.now(UTC) + + update = OperationUpdate( + operation_id="context-123", + operation_type=OperationType.CONTEXT, + action=OperationAction.START, + name="test-context", + ) + + result = processor.process(update, current_op, notifier, execution_arn) + + assert result.start_timestamp == current_op.start_timestamp + assert result.status == OperationStatus.STARTED + + +def test_process_succeed_action(): + processor = ContextProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="context-123", + operation_type=OperationType.CONTEXT, + action=OperationAction.SUCCEED, + name="test-context", + payload="success-result", + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert isinstance(result, Operation) + assert result.operation_id == "context-123" + assert result.status == OperationStatus.SUCCEEDED + assert result.context_details.result == "success-result" + assert result.context_details.error is None + + +def test_process_succeed_action_with_current_operation(): + processor = ContextProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + current_op = Mock() + current_op.start_timestamp = datetime.now(UTC) + + update = OperationUpdate( + operation_id="context-123", + operation_type=OperationType.CONTEXT, + action=OperationAction.SUCCEED, + name="test-context", + payload="success-result", + ) + + result = processor.process(update, current_op, notifier, execution_arn) + + assert result.start_timestamp == current_op.start_timestamp + assert result.status == OperationStatus.SUCCEEDED + + +def test_process_fail_action(): + processor = ContextProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + error = ErrorObject.from_message("context failed") + update = OperationUpdate( + operation_id="context-123", + operation_type=OperationType.CONTEXT, + action=OperationAction.FAIL, + name="test-context", + error=error, + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert isinstance(result, Operation) + assert result.operation_id == "context-123" + assert result.status == OperationStatus.FAILED + assert result.context_details.error == error + assert result.context_details.result is None + + +def test_process_fail_action_with_current_operation(): + processor = ContextProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + current_op = Mock() + current_op.start_timestamp = datetime.now(UTC) + + error = ErrorObject.from_message("context failed") + update = OperationUpdate( + operation_id="context-123", + operation_type=OperationType.CONTEXT, + action=OperationAction.FAIL, + name="test-context", + error=error, + ) + + result = processor.process(update, current_op, notifier, execution_arn) + + assert result.start_timestamp == current_op.start_timestamp + assert result.status == OperationStatus.FAILED + + +def test_process_fail_action_with_payload_and_error(): + processor = ContextProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + error = ErrorObject.from_message("context failed") + update = OperationUpdate( + operation_id="context-123", + operation_type=OperationType.CONTEXT, + action=OperationAction.FAIL, + name="test-context", + payload="partial-result", + error=error, + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert result.context_details.result == "partial-result" + assert result.context_details.error == error + + +def test_process_invalid_action(): + processor = ContextProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="context-123", + operation_type=OperationType.CONTEXT, + action=OperationAction.RETRY, + name="test-context", + ) + + with pytest.raises( + InvalidParameterValueException, match="Invalid action for CONTEXT operation" + ): + processor.process(update, None, notifier, execution_arn) + + +def test_process_cancel_action(): + processor = ContextProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="context-123", + operation_type=OperationType.CONTEXT, + action=OperationAction.CANCEL, + name="test-context", + ) + + with pytest.raises( + InvalidParameterValueException, match="Invalid action for CONTEXT operation" + ): + processor.process(update, None, notifier, execution_arn) + + +def test_process_with_parent_id(): + processor = ContextProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="context-123", + operation_type=OperationType.CONTEXT, + action=OperationAction.START, + name="test-context", + parent_id="parent-456", + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert result.parent_id == "parent-456" + + +def test_process_with_sub_type(): + processor = ContextProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="context-123", + operation_type=OperationType.CONTEXT, + action=OperationAction.START, + name="test-context", + sub_type="parallel", + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert result.sub_type == "parallel" + + +def test_process_start_without_payload(): + processor = ContextProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="context-123", + operation_type=OperationType.CONTEXT, + action=OperationAction.START, + name="test-context", + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert result.context_details.result is None + assert result.context_details.error is None + + +def test_process_succeed_without_payload(): + processor = ContextProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="context-123", + operation_type=OperationType.CONTEXT, + action=OperationAction.SUCCEED, + name="test-context", + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert result.context_details.result is None + assert result.context_details.error is None + + +def test_process_fail_without_error(): + processor = ContextProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="context-123", + operation_type=OperationType.CONTEXT, + action=OperationAction.FAIL, + name="test-context", + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert result.context_details.result is None + assert result.context_details.error is None + + +def test_no_notifier_calls(): + processor = ContextProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="context-123", + operation_type=OperationType.CONTEXT, + action=OperationAction.START, + name="test-context", + ) + + processor.process(update, None, notifier, execution_arn) + + assert len(notifier.completed_calls) == 0 + assert len(notifier.failed_calls) == 0 + assert len(notifier.wait_timer_calls) == 0 + assert len(notifier.step_retry_calls) == 0 + + +def test_end_timestamp_set_for_terminal_states(): + processor = ContextProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="context-123", + operation_type=OperationType.CONTEXT, + action=OperationAction.SUCCEED, + name="test-context", + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert result.end_timestamp is not None + + +def test_end_timestamp_not_set_for_non_terminal_states(): + processor = ContextProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="context-123", + operation_type=OperationType.CONTEXT, + action=OperationAction.START, + name="test-context", + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert result.end_timestamp is None diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/execution_processor_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/execution_processor_test.py new file mode 100644 index 0000000..37c91ea --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/execution_processor_test.py @@ -0,0 +1,242 @@ +"""Tests for execution operation processor.""" + +from unittest.mock import Mock + +from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + OperationAction, + OperationType, + OperationUpdate, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.processors.execution import ( + ExecutionProcessor, +) +from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier + + +class MockNotifier(ExecutionNotifier): + """Mock notifier for testing.""" + + def __init__(self): + super().__init__() + self.completed_calls = [] + self.failed_calls = [] + self.wait_timer_calls = [] + self.step_retry_calls = [] + + def notify_completed(self, execution_arn, result=None): + self.completed_calls.append((execution_arn, result)) + + def notify_failed(self, execution_arn, error): + self.failed_calls.append((execution_arn, error)) + + def notify_wait_timer_scheduled(self, execution_arn, operation_id, delay): + self.wait_timer_calls.append((execution_arn, operation_id, delay)) + + def notify_step_retry_scheduled(self, execution_arn, operation_id, delay): + self.step_retry_calls.append((execution_arn, operation_id, delay)) + + +def test_process_succeed_action(): + processor = ExecutionProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="execution-123", + operation_type=OperationType.EXECUTION, + action=OperationAction.SUCCEED, + payload="success-result", + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert result is None + assert len(notifier.completed_calls) == 1 + assert notifier.completed_calls[0] == (execution_arn, "success-result") + assert len(notifier.failed_calls) == 0 + + +def test_process_succeed_action_with_current_operation(): + processor = ExecutionProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + current_op = Mock() + + update = OperationUpdate( + operation_id="execution-123", + operation_type=OperationType.EXECUTION, + action=OperationAction.SUCCEED, + payload="success-result", + ) + + result = processor.process(update, current_op, notifier, execution_arn) + + assert result is None + assert len(notifier.completed_calls) == 1 + assert notifier.completed_calls[0] == (execution_arn, "success-result") + + +def test_process_succeed_action_without_payload(): + processor = ExecutionProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="execution-123", + operation_type=OperationType.EXECUTION, + action=OperationAction.SUCCEED, + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert result is None + assert len(notifier.completed_calls) == 1 + assert notifier.completed_calls[0] == (execution_arn, None) + + +def test_process_fail_action_with_error(): + processor = ExecutionProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + error = ErrorObject.from_message("execution failed") + update = OperationUpdate( + operation_id="execution-123", + operation_type=OperationType.EXECUTION, + action=OperationAction.FAIL, + error=error, + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert result is None + assert len(notifier.failed_calls) == 1 + assert notifier.failed_calls[0] == (execution_arn, error) + assert len(notifier.completed_calls) == 0 + + +def test_process_fail_action_without_error(): + processor = ExecutionProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="execution-123", + operation_type=OperationType.EXECUTION, + action=OperationAction.FAIL, + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert result is None + assert len(notifier.failed_calls) == 1 + execution_arn_arg, error_arg = notifier.failed_calls[0] + assert execution_arn_arg == execution_arn + assert isinstance(error_arg, ErrorObject) + assert ( + "There is no error details but EXECUTION checkpoint action is not SUCCEED" + in str(error_arg) + ) + + +def test_process_start_action(): + processor = ExecutionProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="execution-123", + operation_type=OperationType.EXECUTION, + action=OperationAction.START, + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert result is None + assert len(notifier.failed_calls) == 1 + execution_arn_arg, error_arg = notifier.failed_calls[0] + assert execution_arn_arg == execution_arn + assert isinstance(error_arg, ErrorObject) + + +def test_process_retry_action(): + processor = ExecutionProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="execution-123", + operation_type=OperationType.EXECUTION, + action=OperationAction.RETRY, + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert result is None + assert len(notifier.failed_calls) == 1 + execution_arn_arg, error_arg = notifier.failed_calls[0] + assert execution_arn_arg == execution_arn + assert isinstance(error_arg, ErrorObject) + + +def test_process_cancel_action(): + processor = ExecutionProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="execution-123", + operation_type=OperationType.EXECUTION, + action=OperationAction.CANCEL, + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert result is None + assert len(notifier.failed_calls) == 1 + execution_arn_arg, error_arg = notifier.failed_calls[0] + assert execution_arn_arg == execution_arn + assert isinstance(error_arg, ErrorObject) + + +def test_process_with_current_operation_and_error(): + processor = ExecutionProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + current_op = Mock() + error = ErrorObject.from_message("custom error") + + update = OperationUpdate( + operation_id="execution-123", + operation_type=OperationType.EXECUTION, + action=OperationAction.FAIL, + error=error, + ) + + result = processor.process(update, current_op, notifier, execution_arn) + + assert result is None + assert len(notifier.failed_calls) == 1 + assert notifier.failed_calls[0] == (execution_arn, error) + + +def test_no_wait_timer_or_step_retry_calls(): + processor = ExecutionProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="execution-123", + operation_type=OperationType.EXECUTION, + action=OperationAction.SUCCEED, + payload="result", + ) + + processor.process(update, None, notifier, execution_arn) + + assert len(notifier.wait_timer_calls) == 0 + assert len(notifier.step_retry_calls) == 0 diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/step_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/step_test.py new file mode 100644 index 0000000..6ba5cc6 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/step_test.py @@ -0,0 +1,421 @@ +"""Tests for step operation processor.""" + +from datetime import UTC, datetime +from unittest.mock import Mock + +import pytest +from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + Operation, + OperationAction, + OperationStatus, + OperationType, + OperationUpdate, + StepDetails, + StepOptions, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.processors.step import ( + StepProcessor, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) +from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier + + +class MockNotifier(ExecutionNotifier): + """Mock notifier for testing.""" + + def __init__(self): + super().__init__() + self.completed_calls = [] + self.failed_calls = [] + self.wait_timer_calls = [] + self.step_retry_calls = [] + + def notify_completed(self, execution_arn, result=None): + self.completed_calls.append((execution_arn, result)) + + def notify_failed(self, execution_arn, error): + self.failed_calls.append((execution_arn, error)) + + def notify_wait_timer_scheduled(self, execution_arn, operation_id, delay): + self.wait_timer_calls.append((execution_arn, operation_id, delay)) + + def notify_step_retry_scheduled(self, execution_arn, operation_id, delay): + self.step_retry_calls.append((execution_arn, operation_id, delay)) + + +def test_process_start_action(): + processor = StepProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="step-123", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="test-step", + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert isinstance(result, Operation) + assert result.operation_id == "step-123" + assert result.operation_type == OperationType.STEP + assert result.status == OperationStatus.STARTED + assert result.name == "test-step" + assert result.step_details is not None + + +def test_process_start_action_with_current_operation(): + processor = StepProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + current_op = Mock() + current_op.start_timestamp = datetime.now(UTC) + + update = OperationUpdate( + operation_id="step-123", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="test-step", + ) + + result = processor.process(update, current_op, notifier, execution_arn) + + assert result.start_timestamp == current_op.start_timestamp + + +def test_process_retry_action(): + processor = StepProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + current_op = Mock() + current_op.start_timestamp = datetime.now(UTC) + current_op.step_details = StepDetails(attempt=1, result="previous-result") + current_op.execution_details = None + current_op.context_details = None + current_op.wait_details = None + current_op.callback_details = None + current_op.chained_invoke_details = None + + step_options = StepOptions(next_attempt_delay_seconds=30) + update = OperationUpdate( + operation_id="step-123", + operation_type=OperationType.STEP, + action=OperationAction.RETRY, + name="test-step", + step_options=step_options, + ) + + result = processor.process(update, current_op, notifier, execution_arn) + + assert isinstance(result, Operation) + assert result.operation_id == "step-123" + assert result.status == OperationStatus.PENDING + assert result.step_details.attempt == 2 + assert result.step_details.result == "previous-result" + assert result.step_details.next_attempt_timestamp is not None + + assert len(notifier.step_retry_calls) == 1 + assert notifier.step_retry_calls[0] == (execution_arn, "step-123", 30) + + +def test_process_retry_action_without_step_options(): + processor = StepProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + current_op = Mock() + current_op.start_timestamp = datetime.now(UTC) + current_op.step_details = StepDetails(attempt=0) + current_op.execution_details = None + current_op.context_details = None + current_op.wait_details = None + current_op.callback_details = None + current_op.chained_invoke_details = None + + update = OperationUpdate( + operation_id="step-123", + operation_type=OperationType.STEP, + action=OperationAction.RETRY, + name="test-step", + ) + + result = processor.process(update, current_op, notifier, execution_arn) + + assert result.step_details.attempt == 1 + assert len(notifier.step_retry_calls) == 1 + assert notifier.step_retry_calls[0] == (execution_arn, "step-123", 0) + + +def test_process_retry_action_without_current_operation(): + processor = StepProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + step_options = StepOptions(next_attempt_delay_seconds=15) + update = OperationUpdate( + operation_id="step-123", + operation_type=OperationType.STEP, + action=OperationAction.RETRY, + name="test-step", + step_options=step_options, + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert result.step_details.attempt == 1 + assert result.step_details.result is None + assert result.step_details.error is None + + +def test_process_retry_action_without_current_step_details(): + processor = StepProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + current_op = Mock() + current_op.start_timestamp = datetime.now(UTC) + current_op.step_details = None + current_op.execution_details = None + current_op.context_details = None + current_op.wait_details = None + current_op.callback_details = None + current_op.chained_invoke_details = None + + step_options = StepOptions(next_attempt_delay_seconds=45) + update = OperationUpdate( + operation_id="step-123", + operation_type=OperationType.STEP, + action=OperationAction.RETRY, + name="test-step", + step_options=step_options, + ) + + result = processor.process(update, current_op, notifier, execution_arn) + + assert result.step_details.attempt == 1 + + +def test_process_succeed_action(): + processor = StepProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="step-123", + operation_type=OperationType.STEP, + action=OperationAction.SUCCEED, + name="test-step", + payload="success-result", + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert isinstance(result, Operation) + assert result.operation_id == "step-123" + assert result.status == OperationStatus.SUCCEEDED + assert result.step_details.result == "success-result" + + +def test_process_succeed_action_with_current_operation(): + processor = StepProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + current_op = Mock() + current_op.start_timestamp = datetime.now(UTC) + current_op.step_details = StepDetails() + + update = OperationUpdate( + operation_id="step-123", + operation_type=OperationType.STEP, + action=OperationAction.SUCCEED, + name="test-step", + payload="success-result", + ) + + result = processor.process(update, current_op, notifier, execution_arn) + + assert result.start_timestamp == current_op.start_timestamp + assert result.status == OperationStatus.SUCCEEDED + assert result.step_details.attempt == 1 + + +def test_process_fail_action(): + processor = StepProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + error = ErrorObject.from_message("step failed") + update = OperationUpdate( + operation_id="step-123", + operation_type=OperationType.STEP, + action=OperationAction.FAIL, + name="test-step", + error=error, + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert isinstance(result, Operation) + assert result.operation_id == "step-123" + assert result.status == OperationStatus.FAILED + assert result.step_details.error == error + + +def test_process_fail_action_with_current_operation(): + processor = StepProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + current_op = Mock() + current_op.start_timestamp = datetime.now(UTC) + current_op.step_details = StepDetails() + + error = ErrorObject.from_message("step failed") + update = OperationUpdate( + operation_id="step-123", + operation_type=OperationType.STEP, + action=OperationAction.FAIL, + name="test-step", + error=error, + ) + + result = processor.process(update, current_op, notifier, execution_arn) + + assert result.start_timestamp == current_op.start_timestamp + assert result.status == OperationStatus.FAILED + assert result.step_details.attempt == 1 + + +def test_process_invalid_action(): + processor = StepProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="step-123", + operation_type=OperationType.STEP, + action=OperationAction.CANCEL, + name="test-step", + ) + + with pytest.raises( + InvalidParameterValueException, match="Invalid action for STEP operation" + ): + processor.process(update, None, notifier, execution_arn) + + +def test_process_with_parent_id(): + processor = StepProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="step-123", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="test-step", + parent_id="parent-456", + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert result.parent_id == "parent-456" + + +def test_process_with_sub_type(): + processor = StepProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="step-123", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="test-step", + sub_type="lambda", + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert result.sub_type == "lambda" + + +def test_retry_preserves_current_operation_details(): + processor = StepProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + current_op = Mock() + current_op.start_timestamp = datetime.now(UTC) + current_op.step_details = StepDetails( + attempt=2, result="old-result", error=ErrorObject.from_message("old-error") + ) + current_op.execution_details = Mock() + current_op.context_details = Mock() + current_op.wait_details = Mock() + current_op.callback_details = Mock() + current_op.chained_invoke_details = Mock() + + step_options = StepOptions(next_attempt_delay_seconds=60) + update = OperationUpdate( + operation_id="step-123", + operation_type=OperationType.STEP, + action=OperationAction.RETRY, + name="test-step", + step_options=step_options, + ) + + result = processor.process(update, current_op, notifier, execution_arn) + + assert result.step_details.attempt == 3 + assert result.step_details.result == "old-result" + assert result.step_details.error == current_op.step_details.error + assert result.execution_details == current_op.execution_details + assert result.context_details == current_op.context_details + assert result.wait_details == current_op.wait_details + assert result.callback_details == current_op.callback_details + assert result.chained_invoke_details == current_op.chained_invoke_details + + +def test_no_completed_or_failed_calls_for_non_execution_actions(): + processor = StepProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="step-123", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="test-step", + ) + + processor.process(update, None, notifier, execution_arn) + + assert len(notifier.completed_calls) == 0 + assert len(notifier.failed_calls) == 0 + assert len(notifier.wait_timer_calls) == 0 + + +def test_no_step_retry_calls_for_non_retry_actions(): + processor = StepProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="step-123", + operation_type=OperationType.STEP, + action=OperationAction.SUCCEED, + name="test-step", + ) + + processor.process(update, None, notifier, execution_arn) + + assert len(notifier.step_retry_calls) == 0 diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/wait_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/wait_test.py new file mode 100644 index 0000000..42c29aa --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/processors/wait_test.py @@ -0,0 +1,313 @@ +"""Tests for wait operation processor.""" + +from datetime import UTC, datetime +from unittest.mock import Mock + +import pytest +from aws_durable_execution_sdk_python.lambda_service import ( + Operation, + OperationAction, + OperationStatus, + OperationType, + OperationUpdate, + WaitOptions, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.processors.wait import ( + WaitProcessor, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) +from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier + + +class MockNotifier(ExecutionNotifier): + """Mock notifier for testing.""" + + def __init__(self): + super().__init__() + self.completed_calls = [] + self.failed_calls = [] + self.wait_timer_calls = [] + self.step_retry_calls = [] + + def notify_completed(self, execution_arn, result=None): + self.completed_calls.append((execution_arn, result)) + + def notify_failed(self, execution_arn, error): + self.failed_calls.append((execution_arn, error)) + + def notify_wait_timer_scheduled(self, execution_arn, operation_id, delay): + self.wait_timer_calls.append((execution_arn, operation_id, delay)) + + def notify_step_retry_scheduled(self, execution_arn, operation_id, delay): + self.step_retry_calls.append((execution_arn, operation_id, delay)) + + +def test_process_start_action(): + processor = WaitProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + wait_options = WaitOptions(wait_seconds=30) + update = OperationUpdate( + operation_id="wait-123", + operation_type=OperationType.WAIT, + action=OperationAction.START, + name="test-wait", + wait_options=wait_options, + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert isinstance(result, Operation) + assert result.operation_id == "wait-123" + assert result.operation_type == OperationType.WAIT + assert result.status == OperationStatus.STARTED + assert result.name == "test-wait" + assert result.wait_details is not None + assert result.wait_details.scheduled_end_timestamp > datetime.now(UTC) + + assert len(notifier.wait_timer_calls) == 1 + assert notifier.wait_timer_calls[0] == (execution_arn, "wait-123", 30) + + +def test_process_start_action_without_wait_options(): + processor = WaitProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="wait-123", + operation_type=OperationType.WAIT, + action=OperationAction.START, + name="test-wait", + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert isinstance(result, Operation) + assert result.wait_details is not None + + assert len(notifier.wait_timer_calls) == 1 + assert notifier.wait_timer_calls[0] == (execution_arn, "wait-123", 0) + + +def test_process_start_action_with_zero_seconds(): + processor = WaitProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + wait_options = WaitOptions(wait_seconds=0) + update = OperationUpdate( + operation_id="wait-123", + operation_type=OperationType.WAIT, + action=OperationAction.START, + name="test-wait", + wait_options=wait_options, + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert isinstance(result, Operation) + assert result.wait_details is not None + + assert len(notifier.wait_timer_calls) == 1 + assert notifier.wait_timer_calls[0] == (execution_arn, "wait-123", 0) + + +def test_process_start_action_with_parent_id(): + processor = WaitProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + wait_options = WaitOptions(wait_seconds=15) + update = OperationUpdate( + operation_id="wait-123", + operation_type=OperationType.WAIT, + action=OperationAction.START, + name="test-wait", + parent_id="parent-456", + wait_options=wait_options, + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert result.parent_id == "parent-456" + + +def test_process_start_action_with_sub_type(): + processor = WaitProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + wait_options = WaitOptions(wait_seconds=15) + update = OperationUpdate( + operation_id="wait-123", + operation_type=OperationType.WAIT, + action=OperationAction.START, + name="test-wait", + sub_type="timer", + wait_options=wait_options, + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert result.sub_type == "timer" + + +def test_process_cancel_action(): + processor = WaitProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + current_op = Mock() + current_op.start_timestamp = datetime.now(UTC) + + update = OperationUpdate( + operation_id="wait-123", + operation_type=OperationType.WAIT, + action=OperationAction.CANCEL, + name="test-wait", + ) + + result = processor.process(update, current_op, notifier, execution_arn) + + assert isinstance(result, Operation) + assert result.operation_id == "wait-123" + assert result.status == OperationStatus.CANCELLED + assert result.start_timestamp == current_op.start_timestamp + + +def test_process_cancel_action_without_current_operation(): + processor = WaitProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="wait-123", + operation_type=OperationType.WAIT, + action=OperationAction.CANCEL, + name="test-wait", + ) + + result = processor.process(update, None, notifier, execution_arn) + + assert isinstance(result, Operation) + assert result.status == OperationStatus.CANCELLED + + +def test_process_invalid_action(): + processor = WaitProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="wait-123", + operation_type=OperationType.WAIT, + action=OperationAction.SUCCEED, + name="test-wait", + ) + + with pytest.raises( + InvalidParameterValueException, match="Invalid action for WAIT operation" + ): + processor.process(update, None, notifier, execution_arn) + + +def test_process_fail_action(): + processor = WaitProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="wait-123", + operation_type=OperationType.WAIT, + action=OperationAction.FAIL, + name="test-wait", + ) + + with pytest.raises( + InvalidParameterValueException, match="Invalid action for WAIT operation" + ): + processor.process(update, None, notifier, execution_arn) + + +def test_process_retry_action(): + processor = WaitProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + update = OperationUpdate( + operation_id="wait-123", + operation_type=OperationType.WAIT, + action=OperationAction.RETRY, + name="test-wait", + ) + + with pytest.raises( + InvalidParameterValueException, match="Invalid action for WAIT operation" + ): + processor.process(update, None, notifier, execution_arn) + + +def test_wait_details_created_correctly(): + processor = WaitProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + wait_options = WaitOptions(wait_seconds=60) + update = OperationUpdate( + operation_id="wait-123", + operation_type=OperationType.WAIT, + action=OperationAction.START, + name="test-wait", + wait_options=wait_options, + ) + + before_time = datetime.now(UTC) + result = processor.process(update, None, notifier, execution_arn) + + assert result.wait_details.scheduled_end_timestamp > before_time + + +def test_no_completed_or_failed_calls(): + processor = WaitProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + wait_options = WaitOptions(wait_seconds=30) + update = OperationUpdate( + operation_id="wait-123", + operation_type=OperationType.WAIT, + action=OperationAction.START, + name="test-wait", + wait_options=wait_options, + ) + + processor.process(update, None, notifier, execution_arn) + + assert len(notifier.completed_calls) == 0 + assert len(notifier.failed_calls) == 0 + assert len(notifier.step_retry_calls) == 0 + + +def test_cancel_no_timer_scheduled(): + processor = WaitProcessor() + notifier = MockNotifier() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + current_op = Mock() + current_op.start_timestamp = datetime.now(UTC) + + update = OperationUpdate( + operation_id="wait-123", + operation_type=OperationType.WAIT, + action=OperationAction.CANCEL, + name="test-wait", + ) + + processor.process(update, current_op, notifier, execution_arn) + + assert len(notifier.wait_timer_calls) == 0 diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/transformer_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/transformer_test.py new file mode 100644 index 0000000..387a96c --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/transformer_test.py @@ -0,0 +1,394 @@ +"""Unit tests for OperationTransformer.""" + +from unittest.mock import Mock + +import pytest +from aws_durable_execution_sdk_python.lambda_service import ( + OperationAction, + OperationType, + OperationUpdate, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.processors.base import ( + OperationProcessor, +) +from aws_durable_execution_sdk_python_testing.checkpoint.transformer import ( + OperationTransformer, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) + + +class MockProcessor(OperationProcessor): + """Mock processor for testing.""" + + def __init__(self, return_value=None): + self.return_value = return_value + self.process_calls = [] + + def process(self, update, current_op, notifier, execution_arn): + self.process_calls.append((update, current_op, notifier, execution_arn)) + return self.return_value + + +def test_init_with_default_processors(): + """Test initialization with default processors.""" + transformer = OperationTransformer() + + assert OperationType.STEP in transformer.processors + assert OperationType.WAIT in transformer.processors + assert OperationType.CONTEXT in transformer.processors + assert OperationType.CALLBACK in transformer.processors + assert OperationType.EXECUTION in transformer.processors + + +def test_init_with_custom_processors(): + """Test initialization with custom processors.""" + custom_processors = {OperationType.STEP: MockProcessor()} + transformer = OperationTransformer(processors=custom_processors) + + assert transformer.processors == custom_processors + + +def test_process_updates_empty_lists(): + """Test processing with empty updates and operations.""" + transformer = OperationTransformer() + notifier = Mock() + + operations, updates = transformer.process_updates([], [], notifier, "arn:test") + + assert operations == [] + assert updates == [] + + +def test_process_updates_processor_not_found_raises_error(): + """Test that missing processor raises InvalidParameterValueException.""" + transformer = OperationTransformer(processors={OperationType.STEP: MockProcessor()}) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.WAIT, + action=OperationAction.START, + ) + notifier = Mock() + + with pytest.raises( + InvalidParameterValueException, + match="Checkpoint for OperationType.WAIT is not implemented yet.", + ): + transformer.process_updates([update], [], notifier, "arn:test") + + +def test_process_updates_processor_returns_none(): + """Test processing when processor returns None.""" + mock_processor = MockProcessor(return_value=None) + transformer = OperationTransformer(processors={OperationType.STEP: mock_processor}) + + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.START, + ) + notifier = Mock() + + operations, updates = transformer.process_updates( + [update], [], notifier, "arn:test" + ) + + assert operations == [] + assert updates == [update] + assert len(mock_processor.process_calls) == 1 + + +def test_process_updates_new_operation(): + """Test processing creates new operation.""" + new_operation = Mock() + new_operation.operation_id = "new-id" + mock_processor = MockProcessor(return_value=new_operation) + transformer = OperationTransformer(processors={OperationType.STEP: mock_processor}) + + update = OperationUpdate( + operation_id="new-id", + operation_type=OperationType.STEP, + action=OperationAction.START, + ) + notifier = Mock() + + operations, updates = transformer.process_updates( + [update], [], notifier, "arn:test" + ) + + assert len(operations) == 1 + assert operations[0] == new_operation + assert updates == [update] + + +def test_process_updates_existing_operation(): + """Test processing updates existing operation.""" + existing_operation = Mock() + existing_operation.operation_id = "existing-id" + updated_operation = Mock() + updated_operation.operation_id = "existing-id" + + mock_processor = MockProcessor(return_value=updated_operation) + transformer = OperationTransformer(processors={OperationType.STEP: mock_processor}) + + update = OperationUpdate( + operation_id="existing-id", + operation_type=OperationType.STEP, + action=OperationAction.SUCCEED, + ) + notifier = Mock() + + operations, updates = transformer.process_updates( + [update], [existing_operation], notifier, "arn:test" + ) + + assert len(operations) == 1 + assert operations[0] == updated_operation + assert updates == [update] + + +def test_process_updates_multiple_operations_preserve_order(): + """Test processing multiple operations preserves order.""" + op1 = Mock() + op1.operation_id = "op1" + op2 = Mock() + op2.operation_id = "op2" + op3 = Mock() + op3.operation_id = "op3" + + updated_op2 = Mock() + updated_op2.operation_id = "op2" + new_op4 = Mock() + new_op4.operation_id = "op4" + + mock_processor = MockProcessor() + transformer = OperationTransformer(processors={OperationType.STEP: mock_processor}) + + mock_processor.return_value = updated_op2 + + updates = [ + OperationUpdate( + operation_id="op2", + operation_type=OperationType.STEP, + action=OperationAction.SUCCEED, + ), + ] + notifier = Mock() + + operations, result_updates = transformer.process_updates( + updates, [op1, op2, op3], notifier, "arn:test" + ) + + assert len(operations) == 3 + assert operations[0] == op1 + assert operations[1] == updated_op2 + assert operations[2] == op3 + assert result_updates == updates + + mock_processor.return_value = new_op4 + updates2 = [ + OperationUpdate( + operation_id="op4", + operation_type=OperationType.STEP, + action=OperationAction.START, + ) + ] + + operations2, result_updates2 = transformer.process_updates( + updates2, [op1, updated_op2, op3], notifier, "arn:test" + ) + + assert len(operations2) == 4 + assert operations2[0] == op1 + assert operations2[1] == updated_op2 + assert operations2[2] == op3 + assert operations2[3] == new_op4 + + +def test_process_updates_multiple_processors(): + """Test processing with multiple processor types.""" + step_op = Mock() + step_op.operation_id = "step-id" + wait_op = Mock() + wait_op.operation_id = "wait-id" + + step_processor = MockProcessor(return_value=step_op) + wait_processor = MockProcessor(return_value=wait_op) + + transformer = OperationTransformer( + processors={ + OperationType.STEP: step_processor, + OperationType.WAIT: wait_processor, + } + ) + + updates = [ + OperationUpdate( + operation_id="step-id", + operation_type=OperationType.STEP, + action=OperationAction.START, + ), + OperationUpdate( + operation_id="wait-id", + operation_type=OperationType.WAIT, + action=OperationAction.START, + ), + ] + notifier = Mock() + + operations, result_updates = transformer.process_updates( + updates, [], notifier, "arn:test" + ) + + assert len(operations) == 2 + assert operations[0] == step_op + assert operations[1] == wait_op + assert len(step_processor.process_calls) == 1 + assert len(wait_processor.process_calls) == 1 + + +def test_process_updates_passes_correct_parameters(): + """Test that correct parameters are passed to processor.""" + existing_op = Mock() + existing_op.operation_id = "test-id" + mock_processor = MockProcessor(return_value=existing_op) + transformer = OperationTransformer(processors={OperationType.STEP: mock_processor}) + + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.START, + ) + notifier = Mock() + execution_arn = "arn:aws:states:us-east-1:123456789012:execution:test" + + transformer.process_updates([update], [existing_op], notifier, execution_arn) + + call_args = mock_processor.process_calls[0] + assert call_args[0] == update + assert call_args[1] == existing_op + assert call_args[2] == notifier + assert call_args[3] == execution_arn + + +def test_process_updates_new_operation_not_in_map(): + """Test processing creates new operation when operation_id not in current operations.""" + new_operation = Mock() + new_operation.operation_id = "new-id" + mock_processor = MockProcessor(return_value=new_operation) + transformer = OperationTransformer(processors={OperationType.STEP: mock_processor}) + + # Existing operations with different IDs + existing_op = Mock() + existing_op.operation_id = "existing-id" + + update = OperationUpdate( + operation_id="new-id", # Different from existing operation + operation_type=OperationType.STEP, + action=OperationAction.START, + ) + notifier = Mock() + + operations, updates = transformer.process_updates( + [update], [existing_op], notifier, "arn:test" + ) + + # Should have both existing and new operation + assert len(operations) == 2 + assert operations[0] == existing_op # Original operation preserved + assert operations[1] == new_operation # New operation appended + assert updates == [update] + + +def test_process_updates_in_place_update_with_multiple_operations(): + """Test in-place update when operation exists in middle of operations list.""" + # Create three operations + op1 = Mock() + op1.operation_id = "op1" + op2 = Mock() + op2.operation_id = "op2" + op3 = Mock() + op3.operation_id = "op3" + + # Updated version of op2 + updated_op2 = Mock() + updated_op2.operation_id = "op2" + + mock_processor = MockProcessor(return_value=updated_op2) + transformer = OperationTransformer(processors={OperationType.STEP: mock_processor}) + + # Update for op2 (middle operation) + update = OperationUpdate( + operation_id="op2", + operation_type=OperationType.STEP, + action=OperationAction.SUCCEED, + ) + notifier = Mock() + + # Process update with op2 in the middle of the list + operations, updates = transformer.process_updates( + [update], [op1, op2, op3], notifier, "arn:test" + ) + + # Verify in-place update occurred + assert len(operations) == 3 + assert operations[0] == op1 # First operation unchanged + assert operations[1] == updated_op2 # Middle operation updated in-place + assert operations[2] == op3 # Last operation unchanged + assert updates == [update] + + +def test_process_updates_in_place_update_break_coverage(): + """Test to ensure break statement in in-place update loop is covered.""" + # Create operations where target is first in list to ensure break is hit + target_op = Mock() + target_op.operation_id = "target" + other_op = Mock() + other_op.operation_id = "other" + + updated_target = Mock() + updated_target.operation_id = "target" + + mock_processor = MockProcessor(return_value=updated_target) + transformer = OperationTransformer(processors={OperationType.STEP: mock_processor}) + + update = OperationUpdate( + operation_id="target", + operation_type=OperationType.STEP, + action=OperationAction.SUCCEED, + ) + notifier = Mock() + + # Target operation is first - should hit break immediately + operations, updates = transformer.process_updates( + [update], [target_op, other_op], notifier, "arn:test" + ) + + assert len(operations) == 2 + assert operations[0] == updated_target + + +def test_process_updates_empty_operations_list(): + """Test for loop exit when result_operations is empty.""" + updated_op = Mock() + updated_op.operation_id = "test-id" + + mock_processor = MockProcessor(return_value=updated_op) + transformer = OperationTransformer(processors={OperationType.STEP: mock_processor}) + + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.SUCCEED, + ) + notifier = Mock() + + # Empty current_operations list - for loop should exit immediately + operations, updates = transformer.process_updates( + [update], [], notifier, "arn:test" + ) + + assert len(operations) == 1 + assert operations[0] == updated_op diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/__init__.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/__init__.py new file mode 100644 index 0000000..78d8de9 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/__init__.py @@ -0,0 +1 @@ +"""Test package""" diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/checkpoint_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/checkpoint_test.py new file mode 100644 index 0000000..548e988 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/checkpoint_test.py @@ -0,0 +1,688 @@ +"""Unit tests for checkpoint validator.""" + +import json + +import pytest +from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + Operation, + OperationAction, + OperationStatus, + OperationType, + OperationUpdate, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.validators.checkpoint import ( + MAX_ERROR_PAYLOAD_SIZE_BYTES, + CheckpointValidator, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) +from aws_durable_execution_sdk_python_testing.execution import Execution +from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput + + +def _create_test_execution() -> Execution: + """Create a test execution with basic setup.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=900, + execution_retention_period_days=7, + input=json.dumps({"test": "data"}), + invocation_id="test-invocation-id", + ) + execution = Execution.new(start_input) + execution.start() + return execution + + +def test_validate_input_empty_updates(): + """Test validation with empty updates list.""" + execution = _create_test_execution() + CheckpointValidator.validate_input([], execution) + + +def test_validate_input_single_valid_update(): + """Test validation with single valid update.""" + execution = _create_test_execution() + updates = [ + OperationUpdate( + operation_id="test-step-id", + operation_type=OperationType.STEP, + action=OperationAction.START, + ) + ] + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_conflicting_execution_update_multiple(): + """Test validation fails with multiple execution updates.""" + execution = _create_test_execution() + updates = [ + OperationUpdate( + operation_id="exec-1", + operation_type=OperationType.EXECUTION, + action=OperationAction.SUCCEED, + ), + OperationUpdate( + operation_id="exec-2", + operation_type=OperationType.EXECUTION, + action=OperationAction.FAIL, + ), + ] + + with pytest.raises( + InvalidParameterValueException, + match="Cannot checkpoint multiple EXECUTION updates", + ): + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_conflicting_execution_update_not_last(): + """Test validation fails when execution update is not last.""" + execution = _create_test_execution() + updates = [ + OperationUpdate( + operation_id="exec-1", + operation_type=OperationType.EXECUTION, + action=OperationAction.SUCCEED, + ), + OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + ), + ] + + with pytest.raises( + InvalidParameterValueException, + match="EXECUTION checkpoint must be the last update", + ): + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_execution_update_as_last(): + """Test validation passes when execution update is last.""" + execution = _create_test_execution() + updates = [ + OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + ), + OperationUpdate( + operation_id="exec-1", + operation_type=OperationType.EXECUTION, + action=OperationAction.SUCCEED, + ), + ] + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_payload_sizes_error_too_large(): + """Test validation fails when error payload is too large.""" + execution = _create_test_execution() + + large_message = "x" * (MAX_ERROR_PAYLOAD_SIZE_BYTES + 1) + large_error = ErrorObject( + message=large_message, type="TestError", data=None, stack_trace=None + ) + + updates = [ + OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.FAIL, + error=large_error, + ) + ] + + with pytest.raises( + InvalidParameterValueException, + match=f"Error object size must be less than {MAX_ERROR_PAYLOAD_SIZE_BYTES} bytes", + ): + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_payload_sizes_error_within_limit(): + """Test validation passes when error payload is within limit.""" + execution = _create_test_execution() + + small_error = ErrorObject( + message="Small error", type="TestError", data=None, stack_trace=None + ) + updates = [ + OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.FAIL, + error=small_error, + ) + ] + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_duplicate_operation_ids(): + """Test validation allows duplicate operation IDs in same batch. + + With background batching, the SDK can send multiple updates for the same + operation in a single batch (e.g., START followed by SUCCEED). This is + valid behavior and should be allowed. + """ + execution = _create_test_execution() + updates = [ + OperationUpdate( + operation_id="duplicate-id", + operation_type=OperationType.STEP, + action=OperationAction.START, + ), + OperationUpdate( + operation_id="duplicate-id", + operation_type=OperationType.STEP, + action=OperationAction.SUCCEED, + ), + ] + + # Should not raise - duplicate operation IDs are allowed in batches + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_valid_parent_id_in_execution(): + """Test validation passes with valid parent ID from execution.""" + execution = _create_test_execution() + + context_op = Operation( + operation_id="context-1", + operation_type=OperationType.CONTEXT, + status=OperationStatus.STARTED, + ) + execution.operations.append(context_op) + + updates = [ + OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + parent_id="context-1", + ) + ] + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_valid_parent_id_in_updates(): + """Test validation passes with valid parent ID from updates.""" + execution = _create_test_execution() + updates = [ + OperationUpdate( + operation_id="context-1", + operation_type=OperationType.CONTEXT, + action=OperationAction.START, + ), + OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + parent_id="context-1", + ), + ] + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_invalid_parent_id_wrong_type(): + """Test validation fails with parent ID of wrong operation type.""" + execution = _create_test_execution() + + step_op = Operation( + operation_id="step-parent", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + ) + execution.operations.append(step_op) + + updates = [ + OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + parent_id="step-parent", + ) + ] + + with pytest.raises( + InvalidParameterValueException, match="Invalid parent operation id" + ): + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_invalid_parent_id_not_found(): + """Test validation fails with parent ID that doesn't exist.""" + execution = _create_test_execution() + updates = [ + OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + parent_id="non-existent-parent", + ) + ] + + with pytest.raises( + InvalidParameterValueException, match="Invalid parent operation id" + ): + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_no_parent_id(): + """Test validation passes with no parent ID.""" + execution = _create_test_execution() + updates = [ + OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + parent_id=None, + ) + ] + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_operation_status_transition_step(): + """Test validation calls step validator for STEP operations.""" + execution = _create_test_execution() + + step_op = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.READY, + ) + execution.operations.append(step_op) + + updates = [ + OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + ) + ] + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_operation_status_transition_context(): + """Test validation calls context validator for CONTEXT operations.""" + execution = _create_test_execution() + + context_op = Operation( + operation_id="context-1", + operation_type=OperationType.CONTEXT, + status=OperationStatus.STARTED, + ) + execution.operations.append(context_op) + + updates = [ + OperationUpdate( + operation_id="context-1", + operation_type=OperationType.CONTEXT, + action=OperationAction.SUCCEED, + ) + ] + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_operation_status_transition_wait(): + """Test validation calls wait validator for WAIT operations.""" + execution = _create_test_execution() + + wait_op = Operation( + operation_id="wait-1", + operation_type=OperationType.WAIT, + status=OperationStatus.STARTED, + ) + execution.operations.append(wait_op) + + updates = [ + OperationUpdate( + operation_id="wait-1", + operation_type=OperationType.WAIT, + action=OperationAction.CANCEL, + ) + ] + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_operation_status_transition_invoke(): + """Test validation calls invoke validator for INVOKE operations.""" + execution = _create_test_execution() + + invoke_op = Operation( + operation_id="invoke-1", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.STARTED, + ) + execution.operations.append(invoke_op) + + updates = [ + OperationUpdate( + operation_id="invoke-1", + operation_type=OperationType.CHAINED_INVOKE, + action=OperationAction.CANCEL, + ) + ] + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_operation_status_transition_execution(): + """Test validation calls execution validator for EXECUTION operations.""" + execution = _create_test_execution() + updates = [ + OperationUpdate( + operation_id="exec-1", + operation_type=OperationType.EXECUTION, + action=OperationAction.SUCCEED, + ) + ] + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_inconsistent_operation_type(): + """Test validation fails when operation type is inconsistent.""" + execution = _create_test_execution() + + # Add existing operation + step_op = Operation( + operation_id="op-1", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + ) + execution.operations.append(step_op) + + # Try to update with different type + updates = [ + OperationUpdate( + operation_id="op-1", + operation_type=OperationType.CONTEXT, + action=OperationAction.SUCCEED, + ) + ] + + with pytest.raises( + InvalidParameterValueException, match="Inconsistent operation type" + ): + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_inconsistent_operation_subtype(): + """Test validation fails when operation subtype is inconsistent.""" + execution = _create_test_execution() + + # Add existing operation with subtype + from aws_durable_execution_sdk_python.lambda_service import OperationSubType + + context_op = Operation( + operation_id="op-1", + operation_type=OperationType.CONTEXT, + status=OperationStatus.STARTED, + sub_type=OperationSubType.PARALLEL, + ) + execution.operations.append(context_op) + + # Try to update with different subtype + updates = [ + OperationUpdate( + operation_id="op-1", + operation_type=OperationType.CONTEXT, + action=OperationAction.SUCCEED, + sub_type=OperationSubType.MAP, + ) + ] + + with pytest.raises( + InvalidParameterValueException, match="Inconsistent operation subtype" + ): + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_inconsistent_operation_name(): + """Test validation fails when operation name is inconsistent.""" + execution = _create_test_execution() + + # Add existing operation with name + step_op = Operation( + operation_id="op-1", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + name="original_name", + ) + execution.operations.append(step_op) + + # Try to update with different name + updates = [ + OperationUpdate( + operation_id="op-1", + operation_type=OperationType.STEP, + action=OperationAction.SUCCEED, + name="different_name", + ) + ] + + with pytest.raises( + InvalidParameterValueException, match="Inconsistent operation name" + ): + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_inconsistent_parent_operation_id(): + """Test validation fails when parent operation ID is inconsistent.""" + execution = _create_test_execution() + + # Add TWO context operations + context_op1 = Operation( + operation_id="context-1", + operation_type=OperationType.CONTEXT, + status=OperationStatus.STARTED, + ) + execution.operations.append(context_op1) + + context_op2 = Operation( + operation_id="context-2", + operation_type=OperationType.CONTEXT, + status=OperationStatus.STARTED, + ) + execution.operations.append(context_op2) + + # Add existing step with parent context-1 + step_op = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + parent_id="context-1", + ) + execution.operations.append(step_op) + + # Try to update with different parent context-2 (which exists, so passes parent validation) + updates = [ + OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.SUCCEED, + parent_id="context-2", + ) + ] + + with pytest.raises( + InvalidParameterValueException, match="Inconsistent parent operation id" + ): + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_invalid_duplicate_wait_operations(): + """Test validation fails with duplicate WAIT operations.""" + execution = _create_test_execution() + + # WAIT operations cannot have duplicate updates in same batch + updates = [ + OperationUpdate( + operation_id="wait-1", + operation_type=OperationType.WAIT, + action=OperationAction.START, + ), + OperationUpdate( + operation_id="wait-1", + operation_type=OperationType.WAIT, + action=OperationAction.CANCEL, + ), + ] + + with pytest.raises( + InvalidParameterValueException, + match="Cannot checkpoint multiple operations with the same ID", + ): + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_invalid_duplicate_callback_operations(): + """Test validation fails with duplicate CALLBACK operations.""" + execution = _create_test_execution() + + # CALLBACK operations cannot have duplicate updates in same batch + updates = [ + OperationUpdate( + operation_id="callback-1", + operation_type=OperationType.CALLBACK, + action=OperationAction.START, + ), + OperationUpdate( + operation_id="callback-1", + operation_type=OperationType.CALLBACK, + action=OperationAction.SUCCEED, + ), + ] + + with pytest.raises( + InvalidParameterValueException, + match="Cannot checkpoint multiple operations with the same ID", + ): + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_invalid_duplicate_invoke_operations(): + """Test validation fails with duplicate CHAINED_INVOKE operations.""" + execution = _create_test_execution() + + # CHAINED_INVOKE operations cannot have duplicate updates in same batch + updates = [ + OperationUpdate( + operation_id="invoke-1", + operation_type=OperationType.CHAINED_INVOKE, + action=OperationAction.START, + ), + OperationUpdate( + operation_id="invoke-1", + operation_type=OperationType.CHAINED_INVOKE, + action=OperationAction.SUCCEED, + ), + ] + + with pytest.raises( + InvalidParameterValueException, + match="Cannot checkpoint multiple operations with the same ID", + ): + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_invalid_duplicate_execution_operations(): + """Test validation fails with duplicate EXECUTION operations.""" + execution = _create_test_execution() + + # EXECUTION operations cannot have duplicate updates in same batch + # (though this is also caught by _validate_conflicting_execution_update) + updates = [ + OperationUpdate( + operation_id="exec-1", + operation_type=OperationType.EXECUTION, + action=OperationAction.SUCCEED, + ), + OperationUpdate( + operation_id="exec-1", + operation_type=OperationType.EXECUTION, + action=OperationAction.SUCCEED, + ), + ] + + with pytest.raises(InvalidParameterValueException): + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_duplicate_context_start_then_succeed(): + """Test validation allows CONTEXT START followed by SUCCEED.""" + execution = _create_test_execution() + + # CONTEXT operations can have START + non-START in same batch + updates = [ + OperationUpdate( + operation_id="context-1", + operation_type=OperationType.CONTEXT, + action=OperationAction.START, + ), + OperationUpdate( + operation_id="context-1", + operation_type=OperationType.CONTEXT, + action=OperationAction.SUCCEED, + ), + ] + + # Should not raise + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_invalid_duplicate_context_non_start(): + """Test validation fails with duplicate CONTEXT non-START operations.""" + execution = _create_test_execution() + + # CONTEXT operations cannot have duplicate non-START updates + updates = [ + OperationUpdate( + operation_id="context-1", + operation_type=OperationType.CONTEXT, + action=OperationAction.SUCCEED, + ), + OperationUpdate( + operation_id="context-1", + operation_type=OperationType.CONTEXT, + action=OperationAction.SUCCEED, + ), + ] + + with pytest.raises( + InvalidParameterValueException, + match="Cannot checkpoint multiple operations with the same ID", + ): + CheckpointValidator.validate_input(updates, execution) + + +def test_validate_invalid_duplicate_step_non_start(): + """Test validation fails with duplicate STEP non-START operations.""" + execution = _create_test_execution() + + # STEP operations cannot have duplicate non-START updates + updates = [ + OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.SUCCEED, + ), + OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.SUCCEED, + ), + ] + + with pytest.raises( + InvalidParameterValueException, + match="Cannot checkpoint multiple operations with the same ID", + ): + CheckpointValidator.validate_input(updates, execution) diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/__init__.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/__init__.py new file mode 100644 index 0000000..866c947 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/__init__.py @@ -0,0 +1 @@ +"""Test package for operation validators.""" diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/callback_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/callback_test.py new file mode 100644 index 0000000..93f6dd3 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/callback_test.py @@ -0,0 +1,96 @@ +"""Unit tests for callback operation validator.""" + +import pytest +from aws_durable_execution_sdk_python.lambda_service import ( + Operation, + OperationAction, + OperationStatus, + OperationType, + OperationUpdate, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.callback import ( + CallbackOperationValidator, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) + + +def test_validate_start_action_with_no_current_state(): + """Test START action with no current state.""" + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CALLBACK, + action=OperationAction.START, + ) + CallbackOperationValidator.validate(None, update) + + +def test_validate_start_action_with_existing_state(): + """Test START action with existing state raises error.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CALLBACK, + action=OperationAction.START, + ) + + with pytest.raises( + InvalidParameterValueException, + match="Cannot start a CALLBACK that already exist", + ): + CallbackOperationValidator.validate(current_state, update) + + +def test_validate_cancel_action_with_no_current_state(): + """Test CANCEL action with no current state raises error.""" + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CALLBACK, + action=OperationAction.CANCEL, + ) + + with pytest.raises( + InvalidParameterValueException, + match="Invalid action for CALLBACK operation.", + ): + CallbackOperationValidator.validate(None, update) + + +def test_validate_cancel_action_with_completed_state(): + """Test CANCEL action with completed state raises error.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.CALLBACK, + status=OperationStatus.SUCCEEDED, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CALLBACK, + action=OperationAction.CANCEL, + ) + + with pytest.raises( + InvalidParameterValueException, + match="Invalid action for CALLBACK operation.", + ): + CallbackOperationValidator.validate(current_state, update) + + +def test_validate_invalid_action(): + """Test invalid action raises error.""" + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CALLBACK, + action=OperationAction.SUCCEED, + ) + + with pytest.raises( + InvalidParameterValueException, match="Invalid action for CALLBACK operation." + ): + CallbackOperationValidator.validate(None, update) diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/context_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/context_test.py new file mode 100644 index 0000000..4119c17 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/context_test.py @@ -0,0 +1,257 @@ +"""Tests for context operation validator.""" + +import pytest +from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + Operation, + OperationAction, + OperationStatus, + OperationType, + OperationUpdate, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.context import ( + VALID_ACTIONS_FOR_CONTEXT, + ContextOperationValidator, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) + + +def test_valid_actions_for_context(): + """Test that VALID_ACTIONS_FOR_CONTEXT contains expected actions.""" + expected_actions = { + OperationAction.START, + OperationAction.FAIL, + OperationAction.SUCCEED, + } + assert expected_actions == VALID_ACTIONS_FOR_CONTEXT + + +def test_validate_start_action_with_no_current_state(): + """Test START action validation when no current state exists.""" + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CONTEXT, + action=OperationAction.START, + ) + + # Should not raise exception + ContextOperationValidator.validate(None, update) + + +def test_validate_start_action_with_existing_state(): + """Test START action validation when current state already exists.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.CONTEXT, + status=OperationStatus.STARTED, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CONTEXT, + action=OperationAction.START, + ) + + with pytest.raises( + InvalidParameterValueException, + match="Cannot start a CONTEXT that already exist.", + ): + ContextOperationValidator.validate(current_state, update) + + +def test_validate_succeed_action_with_started_state(): + """Test SUCCEED action validation with STARTED state.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.CONTEXT, + status=OperationStatus.STARTED, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CONTEXT, + action=OperationAction.SUCCEED, + payload="success_payload", + ) + + # Should not raise exception + ContextOperationValidator.validate(current_state, update) + + +def test_validate_fail_action_with_started_state(): + """Test FAIL action validation with STARTED state.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.CONTEXT, + status=OperationStatus.STARTED, + ) + error = ErrorObject( + message="test error", type="TestError", data=None, stack_trace=None + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CONTEXT, + action=OperationAction.FAIL, + error=error, + ) + + # Should not raise exception + ContextOperationValidator.validate(current_state, update) + + +def test_validate_succeed_action_with_invalid_status(): + """Test SUCCEED action validation with invalid status.""" + invalid_statuses = [ + OperationStatus.PENDING, + OperationStatus.READY, + OperationStatus.SUCCEEDED, + OperationStatus.FAILED, + OperationStatus.CANCELLED, + OperationStatus.TIMED_OUT, + OperationStatus.STOPPED, + ] + + for status in invalid_statuses: + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.CONTEXT, + status=status, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CONTEXT, + action=OperationAction.SUCCEED, + payload="success_payload", + ) + + with pytest.raises( + InvalidParameterValueException, + match="Invalid current CONTEXT state to close.", + ): + ContextOperationValidator.validate(current_state, update) + + +def test_validate_fail_action_with_invalid_status(): + """Test FAIL action validation with invalid status.""" + invalid_statuses = [ + OperationStatus.PENDING, + OperationStatus.READY, + OperationStatus.SUCCEEDED, + OperationStatus.FAILED, + OperationStatus.CANCELLED, + OperationStatus.TIMED_OUT, + OperationStatus.STOPPED, + ] + + error = ErrorObject( + message="test error", type="TestError", data=None, stack_trace=None + ) + + for status in invalid_statuses: + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.CONTEXT, + status=status, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CONTEXT, + action=OperationAction.FAIL, + error=error, + ) + + with pytest.raises( + InvalidParameterValueException, + match="Invalid current CONTEXT state to close.", + ): + ContextOperationValidator.validate(current_state, update) + + +def test_validate_fail_action_with_payload(): + """Test FAIL action validation when payload is provided.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.CONTEXT, + status=OperationStatus.STARTED, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CONTEXT, + action=OperationAction.FAIL, + payload="invalid_payload", + ) + + with pytest.raises( + InvalidParameterValueException, + match="Cannot provide a Payload for FAIL action.", + ): + ContextOperationValidator.validate(current_state, update) + + +def test_validate_succeed_action_with_error(): + """Test SUCCEED action validation when error is provided.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.CONTEXT, + status=OperationStatus.STARTED, + ) + error = ErrorObject( + message="test error", type="TestError", data=None, stack_trace=None + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CONTEXT, + action=OperationAction.SUCCEED, + error=error, + ) + + with pytest.raises( + InvalidParameterValueException, + match="Cannot provide an Error for SUCCEED action.", + ): + ContextOperationValidator.validate(current_state, update) + + +def test_validate_close_actions_with_no_current_state(): + """Test SUCCEED and FAIL actions validation when no current state exists.""" + # SUCCEED with no current state should pass + succeed_update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CONTEXT, + action=OperationAction.SUCCEED, + payload="success_payload", + ) + ContextOperationValidator.validate(None, succeed_update) + + # FAIL with no current state should pass + error = ErrorObject( + message="test error", type="TestError", data=None, stack_trace=None + ) + fail_update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CONTEXT, + action=OperationAction.FAIL, + error=error, + ) + ContextOperationValidator.validate(None, fail_update) + + +def test_validate_invalid_action(): + """Test validation with invalid action.""" + invalid_actions = [ + OperationAction.RETRY, + OperationAction.CANCEL, + ] + + for action in invalid_actions: + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CONTEXT, + action=action, + ) + + with pytest.raises( + InvalidParameterValueException, match="Invalid CONTEXT action." + ): + ContextOperationValidator.validate(None, update) diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/execution_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/execution_test.py new file mode 100644 index 0000000..2c0e573 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/execution_test.py @@ -0,0 +1,107 @@ +"""Unit tests for execution operation validator.""" + +import pytest +from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + OperationAction, + OperationType, + OperationUpdate, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.execution import ( + ExecutionOperationValidator, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) + + +def test_validate_succeed_action(): + """Test SUCCEED action validation.""" + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.EXECUTION, + action=OperationAction.SUCCEED, + payload="success", + ) + ExecutionOperationValidator.validate(update) + + +def test_validate_fail_action(): + """Test FAIL action validation.""" + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.EXECUTION, + action=OperationAction.FAIL, + error=ErrorObject( + message="Test error", type="TestError", data=None, stack_trace=None + ), + ) + ExecutionOperationValidator.validate(update) + + +def test_validate_succeed_action_with_error(): + """Test SUCCEED action with error raises error.""" + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.EXECUTION, + action=OperationAction.SUCCEED, + error=ErrorObject( + message="Test error", type="TestError", data=None, stack_trace=None + ), + ) + + with pytest.raises( + InvalidParameterValueException, + match="Cannot provide an Error for SUCCEED action", + ): + ExecutionOperationValidator.validate(update) + + +def test_validate_fail_action_with_payload(): + """Test FAIL action with payload raises error.""" + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.EXECUTION, + action=OperationAction.FAIL, + payload="invalid", + ) + + with pytest.raises( + InvalidParameterValueException, match="Cannot provide a Payload for FAIL action" + ): + ExecutionOperationValidator.validate(update) + + +def test_validate_invalid_action(): + """Test invalid action raises error.""" + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.EXECUTION, + action=OperationAction.START, + ) + + with pytest.raises( + InvalidParameterValueException, match="Invalid EXECUTION action" + ): + ExecutionOperationValidator.validate(update) + + +def test_validate_fail_action_without_error(): + """Test FAIL action without error passes validation.""" + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.EXECUTION, + action=OperationAction.FAIL, + ) + ExecutionOperationValidator.validate(update) + + +def test_validate_succeed_action_without_payload(): + """Test SUCCEED action without payload passes validation.""" + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.EXECUTION, + action=OperationAction.SUCCEED, + ) + ExecutionOperationValidator.validate(update) diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/invoke_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/invoke_test.py new file mode 100644 index 0000000..3300b1a --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/invoke_test.py @@ -0,0 +1,109 @@ +"""Unit tests for invoke operation validator.""" + +import pytest +from aws_durable_execution_sdk_python.lambda_service import ( + Operation, + OperationAction, + OperationStatus, + OperationType, + OperationUpdate, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.invoke import ( + ChainedInvokeOperationValidator, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) + + +def test_validate_start_action_with_no_current_state(): + """Test START action with no current state.""" + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CHAINED_INVOKE, + action=OperationAction.START, + ) + ChainedInvokeOperationValidator.validate(None, update) + + +def test_validate_start_action_with_existing_state(): + """Test START action with existing state raises error.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.STARTED, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CHAINED_INVOKE, + action=OperationAction.START, + ) + + with pytest.raises( + InvalidParameterValueException, + match="Cannot start an INVOKE that already exist", + ): + ChainedInvokeOperationValidator.validate(current_state, update) + + +def test_validate_cancel_action_with_started_state(): + """Test CANCEL action with STARTED state.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.STARTED, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CHAINED_INVOKE, + action=OperationAction.CANCEL, + ) + ChainedInvokeOperationValidator.validate(current_state, update) + + +def test_validate_cancel_action_with_no_current_state(): + """Test CANCEL action with no current state raises error.""" + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CHAINED_INVOKE, + action=OperationAction.CANCEL, + ) + + with pytest.raises( + InvalidParameterValueException, + match="Cannot cancel an INVOKE that does not exist or has already completed", + ): + ChainedInvokeOperationValidator.validate(None, update) + + +def test_validate_cancel_action_with_completed_state(): + """Test CANCEL action with completed state raises error.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.SUCCEEDED, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CHAINED_INVOKE, + action=OperationAction.CANCEL, + ) + + with pytest.raises( + InvalidParameterValueException, + match="Cannot cancel an INVOKE that does not exist or has already completed", + ): + ChainedInvokeOperationValidator.validate(current_state, update) + + +def test_validate_invalid_action(): + """Test invalid action raises error.""" + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CHAINED_INVOKE, + action=OperationAction.SUCCEED, + ) + + with pytest.raises(InvalidParameterValueException, match="Invalid INVOKE action"): + ChainedInvokeOperationValidator.validate(None, update) diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/step_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/step_test.py new file mode 100644 index 0000000..7d19132 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/step_test.py @@ -0,0 +1,272 @@ +"""Unit tests for step operation validator.""" + +import pytest +from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + Operation, + OperationAction, + OperationStatus, + OperationType, + OperationUpdate, + StepOptions, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.step import ( + StepOperationValidator, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) + + +def test_validate_with_no_current_state(): + """Test validation with no current state.""" + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.START, + ) + StepOperationValidator.validate(None, update) + + +def test_validate_start_action_with_ready_state(): + """Test START action with READY state.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.STEP, + status=OperationStatus.READY, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.START, + ) + StepOperationValidator.validate(current_state, update) + + +def test_validate_start_action_with_invalid_state(): + """Test START action with invalid state raises error.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.START, + ) + + with pytest.raises( + InvalidParameterValueException, match="Invalid current STEP state to start" + ): + StepOperationValidator.validate(current_state, update) + + +def test_validate_succeed_action_with_started_state(): + """Test SUCCEED action with STARTED state.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.SUCCEED, + payload={"result": "success"}, + ) + StepOperationValidator.validate(current_state, update) + + +def test_validate_fail_action_with_ready_state(): + """Test FAIL action with READY state.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.STEP, + status=OperationStatus.READY, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.FAIL, + error=ErrorObject( + message="Test error", type="TestError", data=None, stack_trace=None + ), + ) + StepOperationValidator.validate(current_state, update) + + +def test_validate_fail_action_with_invalid_state(): + """Test FAIL action with invalid state raises error.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.FAIL, + ) + + with pytest.raises( + InvalidParameterValueException, match="Invalid current STEP state to close" + ): + StepOperationValidator.validate(current_state, update) + + +def test_validate_fail_action_with_payload(): + """Test FAIL action with payload raises error.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.FAIL, + payload={"invalid": "payload"}, + ) + + with pytest.raises( + InvalidParameterValueException, match="Cannot provide a Payload for FAIL action" + ): + StepOperationValidator.validate(current_state, update) + + +def test_validate_succeed_action_with_error(): + """Test SUCCEED action with error raises error.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.SUCCEED, + error=ErrorObject( + message="Test error", type="TestError", data=None, stack_trace=None + ), + ) + + with pytest.raises( + InvalidParameterValueException, + match="Cannot provide an Error for SUCCEED action", + ): + StepOperationValidator.validate(current_state, update) + + +def test_validate_retry_action_with_started_state(): + """Test RETRY action with STARTED state.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.RETRY, + step_options=StepOptions(next_attempt_delay_seconds=3), + ) + StepOperationValidator.validate(current_state, update) + + +def test_validate_retry_action_with_ready_state(): + """Test RETRY action with READY state.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.STEP, + status=OperationStatus.READY, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.RETRY, + step_options=StepOptions(next_attempt_delay_seconds=3), + ) + StepOperationValidator.validate(current_state, update) + + +def test_validate_retry_action_with_invalid_state(): + """Test RETRY action with invalid state raises error.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.RETRY, + step_options=StepOptions(next_attempt_delay_seconds=3), + ) + + with pytest.raises( + InvalidParameterValueException, match="Invalid current STEP state to re-attempt" + ): + StepOperationValidator.validate(current_state, update) + + +def test_validate_retry_action_without_step_options(): + """Test RETRY action without step options raises error.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.RETRY, + ) + + with pytest.raises( + InvalidParameterValueException, match="Invalid StepOptions for the given action" + ): + StepOperationValidator.validate(current_state, update) + + +def test_validate_retry_action_with_both_error_and_payload(): + """Test RETRY action with both error and payload raises error.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.RETRY, + step_options=StepOptions(next_attempt_delay_seconds=3), + error=ErrorObject( + message="Test error", type="TestError", data=None, stack_trace=None + ), + payload={"result": "success"}, + ) + + with pytest.raises( + InvalidParameterValueException, + match="Cannot provide both error and payload to RETRY a STEP", + ): + StepOperationValidator.validate(current_state, update) + + +def test_validate_invalid_action(): + """Test invalid action raises error.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.CANCEL, + ) + + with pytest.raises(InvalidParameterValueException, match="Invalid STEP action"): + StepOperationValidator.validate(current_state, update) diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/wait_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/wait_test.py new file mode 100644 index 0000000..77f3536 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/operations/wait_test.py @@ -0,0 +1,108 @@ +"""Unit tests for wait operation validator.""" + +import pytest +from aws_durable_execution_sdk_python.lambda_service import ( + Operation, + OperationAction, + OperationStatus, + OperationType, + OperationUpdate, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.validators.operations.wait import ( + WaitOperationValidator, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) + + +def test_validate_start_action_with_no_current_state(): + """Test START action with no current state.""" + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.WAIT, + action=OperationAction.START, + ) + WaitOperationValidator.validate(None, update) + + +def test_validate_start_action_with_existing_state(): + """Test START action with existing state raises error.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.WAIT, + status=OperationStatus.STARTED, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.WAIT, + action=OperationAction.START, + ) + + with pytest.raises( + InvalidParameterValueException, match="Cannot start a WAIT that already exist" + ): + WaitOperationValidator.validate(current_state, update) + + +def test_validate_cancel_action_with_started_state(): + """Test CANCEL action with STARTED state.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.WAIT, + status=OperationStatus.STARTED, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.WAIT, + action=OperationAction.CANCEL, + ) + WaitOperationValidator.validate(current_state, update) + + +def test_validate_cancel_action_with_no_current_state(): + """Test CANCEL action with no current state raises error.""" + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.WAIT, + action=OperationAction.CANCEL, + ) + + with pytest.raises( + InvalidParameterValueException, + match="Cannot cancel a WAIT that does not exist or has already completed", + ): + WaitOperationValidator.validate(None, update) + + +def test_validate_cancel_action_with_completed_state(): + """Test CANCEL action with completed state raises error.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.WAIT, + status=OperationStatus.SUCCEEDED, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.WAIT, + action=OperationAction.CANCEL, + ) + + with pytest.raises( + InvalidParameterValueException, + match="Cannot cancel a WAIT that does not exist or has already completed", + ): + WaitOperationValidator.validate(current_state, update) + + +def test_validate_invalid_action(): + """Test invalid action raises error.""" + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.WAIT, + action=OperationAction.SUCCEED, + ) + + with pytest.raises(InvalidParameterValueException, match="Invalid WAIT action"): + WaitOperationValidator.validate(None, update) diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/transitions_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/transitions_test.py new file mode 100644 index 0000000..901db3c --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/checkpoint/validators/transitions_test.py @@ -0,0 +1,150 @@ +"""Unit tests for transitions validator.""" + +import pytest +from aws_durable_execution_sdk_python.lambda_service import ( + OperationAction, + OperationType, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.validators.transitions import ( + ValidActionsByOperationTypeValidator, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) + + +def test_validate_step_valid_actions(): + """Test valid actions for STEP operations.""" + valid_actions = [ + OperationAction.START, + OperationAction.FAIL, + OperationAction.RETRY, + OperationAction.SUCCEED, + ] + for action in valid_actions: + ValidActionsByOperationTypeValidator.validate(OperationType.STEP, action) + + +def test_validate_context_valid_actions(): + """Test valid actions for CONTEXT operations.""" + valid_actions = [ + OperationAction.START, + OperationAction.FAIL, + OperationAction.SUCCEED, + ] + for action in valid_actions: + ValidActionsByOperationTypeValidator.validate(OperationType.CONTEXT, action) + + +def test_validate_wait_valid_actions(): + """Test valid actions for WAIT operations.""" + valid_actions = [ + OperationAction.START, + OperationAction.CANCEL, + ] + for action in valid_actions: + ValidActionsByOperationTypeValidator.validate(OperationType.WAIT, action) + + +def test_validate_callback_valid_actions(): + """Test valid actions for CALLBACK operations.""" + valid_actions = [ + OperationAction.START, + ] + for action in valid_actions: + ValidActionsByOperationTypeValidator.validate(OperationType.CALLBACK, action) + + +def test_validate_invoke_valid_actions(): + """Test valid actions for INVOKE operations.""" + valid_actions = [ + OperationAction.START, + OperationAction.CANCEL, + ] + for action in valid_actions: + ValidActionsByOperationTypeValidator.validate( + OperationType.CHAINED_INVOKE, action + ) + + +def test_validate_execution_valid_actions(): + """Test valid actions for EXECUTION operations.""" + valid_actions = [ + OperationAction.SUCCEED, + OperationAction.FAIL, + ] + for action in valid_actions: + ValidActionsByOperationTypeValidator.validate(OperationType.EXECUTION, action) + + +def test_validate_invalid_action_for_step(): + """Test invalid action for STEP operation.""" + with pytest.raises( + InvalidParameterValueException, + match="Invalid action for the given operation type", + ): + ValidActionsByOperationTypeValidator.validate( + OperationType.STEP, OperationAction.CANCEL + ) + + +def test_validate_invalid_action_for_context(): + """Test invalid action for CONTEXT operation.""" + with pytest.raises( + InvalidParameterValueException, + match="Invalid action for the given operation type", + ): + ValidActionsByOperationTypeValidator.validate( + OperationType.CONTEXT, OperationAction.RETRY + ) + + +def test_validate_invalid_action_for_wait(): + """Test invalid action for WAIT operation.""" + with pytest.raises( + InvalidParameterValueException, + match="Invalid action for the given operation type", + ): + ValidActionsByOperationTypeValidator.validate( + OperationType.WAIT, OperationAction.SUCCEED + ) + + +def test_validate_invalid_action_for_callback(): + """Test invalid action for CALLBACK operation.""" + with pytest.raises( + InvalidParameterValueException, + match="Invalid action for the given operation type", + ): + ValidActionsByOperationTypeValidator.validate( + OperationType.CALLBACK, OperationAction.FAIL + ) + + +def test_validate_invalid_action_for_invoke(): + """Test invalid action for INVOKE operation.""" + with pytest.raises( + InvalidParameterValueException, + match="Invalid action for the given operation type", + ): + ValidActionsByOperationTypeValidator.validate( + OperationType.CHAINED_INVOKE, OperationAction.RETRY + ) + + +def test_validate_invalid_action_for_execution(): + """Test invalid action for EXECUTION operation.""" + with pytest.raises( + InvalidParameterValueException, + match="Invalid action for the given operation type", + ): + ValidActionsByOperationTypeValidator.validate( + OperationType.EXECUTION, OperationAction.START + ) + + +def test_validate_unknown_operation_type(): + """Test validation with unknown operation type.""" + with pytest.raises(InvalidParameterValueException, match="Unknown operation type"): + ValidActionsByOperationTypeValidator.validate(None, OperationAction.START) diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/cli_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/cli_test.py new file mode 100644 index 0000000..98b53c6 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/cli_test.py @@ -0,0 +1,1117 @@ +"""Tests for the CLI module.""" + +from __future__ import annotations + +import argparse +import json +import logging +import os +import sys +from http.client import HTTPMessage +from io import StringIO, BytesIO +from unittest.mock import Mock, patch + +import pytest +from urllib.error import HTTPError, URLError + +from botocore.exceptions import ConnectionError # type: ignore + +from aws_durable_execution_sdk_python_testing.cli import CliApp, CliConfig, main +from aws_durable_execution_sdk_python_testing.exceptions import ( + DurableFunctionsLocalRunnerError, + InvalidParameterValueException, + ResourceNotFoundException, + ServiceException, + TooManyRequestsException, +) + + +def test_cli_config_has_correct_default_values() -> None: + """Test that CliConfig has correct default values.""" + config = CliConfig() + + assert config.host == "0.0.0.0" # noqa: S104 + assert config.port == 5000 + assert config.log_level == logging.INFO + assert config.lambda_endpoint == "http://127.0.0.1:3001" + assert config.local_runner_endpoint == "http://0.0.0.0:5000" + assert config.local_runner_region == "us-west-2" + assert config.local_runner_mode == "local" + + +def test_cli_config_from_environment_uses_defaults_when_no_env_vars() -> None: + """Test from_environment with no environment variables set.""" + with patch.dict(os.environ, {}, clear=True): + config = CliConfig.from_environment() + + assert config.host == "0.0.0.0" # noqa: S104 + assert config.port == 5000 + assert config.log_level == logging.INFO + assert config.lambda_endpoint == "http://127.0.0.1:3001" + assert config.local_runner_endpoint == "http://0.0.0.0:5000" + assert config.local_runner_region == "us-west-2" + assert config.local_runner_mode == "local" + + +def test_cli_config_from_environment_uses_all_env_vars_when_set() -> None: + """Test from_environment with all environment variables set.""" + env_vars = { + "AWS_DEX_HOST": "127.0.0.1", + "AWS_DEX_PORT": "8080", + "AWS_DEX_LOG_LEVEL": "DEBUG", + "AWS_DEX_LAMBDA_ENDPOINT": "http://localhost:4000", + "AWS_DEX_LOCAL_RUNNER_ENDPOINT": "http://localhost:8080", + "AWS_DEX_LOCAL_RUNNER_REGION": "us-east-1", + "AWS_DEX_LOCAL_RUNNER_MODE": "remote", + } + + with patch.dict(os.environ, env_vars, clear=True): + config = CliConfig.from_environment() + + assert config.host == "127.0.0.1" + assert config.port == 8080 + assert config.log_level == logging.DEBUG + assert config.lambda_endpoint == "http://localhost:4000" + assert config.local_runner_endpoint == "http://localhost:8080" + assert config.local_runner_region == "us-east-1" + assert config.local_runner_mode == "remote" + + +def test_cli_config_from_environment_uses_partial_env_vars_with_defaults() -> None: + """Test from_environment with some environment variables set.""" + env_vars = { + "AWS_DEX_HOST": "192.168.1.1", + "AWS_DEX_PORT": "9000", + } + + with patch.dict(os.environ, env_vars, clear=True): + config = CliConfig.from_environment() + + assert config.host == "192.168.1.1" + assert config.port == 9000 + # Other values should be defaults + assert config.log_level == logging.INFO + assert config.lambda_endpoint == "http://127.0.0.1:3001" + + +def test_cli_app_loads_config_from_environment_on_init() -> None: + """Test that CliApp loads configuration from environment on init.""" + env_vars = {"AWS_DEX_HOST": "test-host", "AWS_DEX_PORT": "7777"} + + with patch.dict(os.environ, env_vars, clear=True): + app = CliApp() + + assert app.config.host == "test-host" + assert app.config.port == 7777 + + +def test_cli_app_shows_help_and_returns_error_when_no_command() -> None: + """Test that running with no command shows help and returns error code.""" + app = CliApp() + + with patch("sys.stderr", new_callable=StringIO) as mock_stderr: + exit_code = app.run([]) + + assert exit_code == 2 # argparse error code + assert "required" in mock_stderr.getvalue().lower() + + +def test_cli_app_shows_usage_information_with_help_flag() -> None: + """Test that --help shows usage information.""" + app = CliApp() + + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + exit_code = app.run(["--help"]) + + assert exit_code == 0 + output = mock_stdout.getvalue() + assert "dex-local-runner" in output + assert "start-server" in output + assert "invoke" in output + assert "get-durable-execution" in output + assert "get-durable-execution-history" in output + + +def test_cli_app_handles_keyboard_interrupt_gracefully() -> None: + """Test that KeyboardInterrupt is handled gracefully.""" + app = CliApp() + + with patch.object(app, "_create_parsers") as mock_setup: + mock_setup.side_effect = KeyboardInterrupt() + + with patch("sys.stderr", new_callable=StringIO) as mock_stderr: + exit_code = app.run(["start-server"]) + + assert exit_code == 130 + assert "cancelled by user" in mock_stderr.getvalue() + + +def test_start_server_command_parses_arguments_correctly() -> None: + """Test that start-server command parses arguments correctly.""" + app = CliApp() + + # Test with default values + with patch( + "aws_durable_execution_sdk_python_testing.cli.WebRunner" + ) as mock_web_runner: + # Mock runner context manager + mock_runner_instance = mock_web_runner.return_value + mock_runner_instance.__enter__.return_value = mock_runner_instance + mock_runner_instance.__exit__.return_value = None + mock_runner_instance.serve_forever.side_effect = KeyboardInterrupt() + + exit_code = app.run(["start-server"]) + assert exit_code == 130 # KeyboardInterrupt exit code + + # Test with custom values + with patch( + "aws_durable_execution_sdk_python_testing.cli.WebRunner" + ) as mock_web_runner: + # Mock runner context manager + mock_runner_instance = mock_web_runner.return_value + mock_runner_instance.__enter__.return_value = mock_runner_instance + mock_runner_instance.__exit__.return_value = None + mock_runner_instance.serve_forever.side_effect = KeyboardInterrupt() + + exit_code = app.run( + [ + "start-server", + "--host", + "127.0.0.1", + "--port", + "8080", + "--log-level", + "DEBUG", + "--lambda-endpoint", + "http://localhost:4000", + "--local-runner-endpoint", + "http://localhost:8080", + "--local-runner-region", + "us-east-1", + "--local-runner-mode", + "remote", + ] + ) + assert exit_code == 130 # KeyboardInterrupt exit code + + +def test_invoke_command_parses_arguments_correctly() -> None: + """Test that invoke command parses arguments correctly.""" + app = CliApp() + + # Test with required function-name + with patch("sys.stdout", new_callable=StringIO): + exit_code = app.run(["invoke", "--function-name", "test-function"]) + assert exit_code == 1 # Not implemented yet + + # Test with all parameters + with patch("sys.stdout", new_callable=StringIO): + exit_code = app.run( + [ + "invoke", + "--function-name", + "test-function", + "--input", + '{"key": "value"}', + "--durable-execution-name", + "test-execution", + ] + ) + assert exit_code == 1 # Not implemented yet + + +def test_invoke_command_requires_function_name_parameter() -> None: + """Test that invoke command requires function-name parameter.""" + app = CliApp() + + with patch("sys.stderr", new_callable=StringIO) as mock_stderr: + exit_code = app.run(["invoke"]) + + assert exit_code == 2 # argparse error code + assert "required" in mock_stderr.getvalue().lower() + + +def test_invoke_command_validates_json_input_format() -> None: + """Test that invoke command validates JSON input.""" + app = CliApp() + + exit_code = app.run( + [ + "invoke", + "--function-name", + "test-function", + "--input", + "invalid-json", + ] + ) + + assert exit_code == 1 + + +def test_get_durable_execution_command_parses_arguments_correctly() -> None: + """Test that get-durable-execution command parses arguments correctly.""" + app = CliApp() + + with patch.object(app, "_create_boto3_client") as mock_create_client: + mock_client = Mock() + mock_client.get_durable_execution.side_effect = Exception("Connection refused") + mock_create_client.return_value = mock_client + + with patch("sys.stderr", new_callable=StringIO): + exit_code = app.run( + ["get-durable-execution", "--durable-execution-arn", "test-arn"] + ) + assert exit_code == 1 # Connection error + + +def test_get_durable_execution_command_requires_arn_parameter() -> None: + """Test that get-durable-execution command requires ARN parameter.""" + app = CliApp() + + with patch("sys.stderr", new_callable=StringIO) as mock_stderr: + exit_code = app.run(["get-durable-execution"]) + + assert exit_code == 2 # argparse error code + assert "required" in mock_stderr.getvalue().lower() + + +def test_get_durable_execution_history_command_parses_arguments_correctly() -> None: + """Test that get-durable-execution-history command parses arguments correctly.""" + app = CliApp() + + with patch.object(app, "_create_boto3_client") as mock_create_client: + mock_client = Mock() + mock_client.get_durable_execution_history.side_effect = Exception( + "Connection refused" + ) + mock_create_client.return_value = mock_client + + with patch("sys.stderr", new_callable=StringIO): + exit_code = app.run( + ["get-durable-execution-history", "--durable-execution-arn", "test-arn"] + ) + assert exit_code == 1 # Connection error + + +def test_get_durable_execution_history_command_requires_arn_parameter() -> None: + """Test that get-durable-execution-history command requires ARN parameter.""" + app = CliApp() + + with patch("sys.stderr", new_callable=StringIO) as mock_stderr: + exit_code = app.run(["get-durable-execution-history"]) + + assert exit_code == 2 # argparse error code + assert "required" in mock_stderr.getvalue().lower() + + +def test_logging_configuration_uses_specified_log_level() -> None: + """Test that logging is configured based on log level.""" + app = CliApp() + + with patch("logging.basicConfig") as mock_basic_config: + with patch("sys.stdout", new_callable=StringIO): + with patch.object(app, "start_server_command", return_value=0): + app.run(["start-server", "--log-level", "DEBUG"]) + + mock_basic_config.assert_called_once() + call_args = mock_basic_config.call_args + assert call_args[1]["level"] == 10 + + +def test_parser_creation_includes_all_subcommands() -> None: + """Test that parser creation includes all expected subcommands.""" + app = CliApp() + + # Test that all subcommands are available + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + exit_code = app.run(["--help"]) + assert exit_code == 0 + output = mock_stdout.getvalue() + assert "start-server" in output + assert "invoke" in output + assert "get-durable-execution" in output + assert "get-durable-execution-history" in output + + +def test_start_server_command_works_with_mocked_dependencies() -> None: + """Test start-server command with mocked WebRunner.""" + app = CliApp() + + with patch( + "aws_durable_execution_sdk_python_testing.cli.WebRunner" + ) as mock_web_runner: + # Mock runner context manager + mock_runner_instance = mock_web_runner.return_value + mock_runner_instance.__enter__.return_value = mock_runner_instance + mock_runner_instance.__exit__.return_value = None + + # Mock serve_forever to avoid blocking + mock_runner_instance.serve_forever.side_effect = KeyboardInterrupt() + + exit_code = app.run( + [ + "start-server", + "--host", + "127.0.0.1", + "--port", + "8080", + "--log-level", + "DEBUG", + ] + ) + + assert exit_code == 130 # KeyboardInterrupt exit code + mock_web_runner.assert_called_once() + + # Verify WebRunnerConfig was created with correct values + call_args = mock_web_runner.call_args[0][0] # First positional argument + assert call_args.web_service.host == "127.0.0.1" + assert call_args.web_service.port == 8080 + assert call_args.web_service.log_level == "DEBUG" + + +def test_start_server_command_handles_server_startup_errors() -> None: + """Test start-server command handles server startup errors.""" + app = CliApp() + + with patch( + "aws_durable_execution_sdk_python_testing.cli.WebRunner" + ) as mock_web_runner: + # Make WebRunner constructor raise an exception + mock_web_runner.side_effect = Exception("Server startup failed") + + exit_code = app.run(["start-server"]) + + assert exit_code == 1 + + +def test_start_server_command_creates_correct_web_runner_config() -> None: + """Test that start-server command creates WebRunnerConfig with all CLI arguments.""" + app = CliApp() + + with patch( + "aws_durable_execution_sdk_python_testing.cli.WebRunner" + ) as mock_web_runner: + # Mock runner context manager + mock_runner_instance = mock_web_runner.return_value + mock_runner_instance.__enter__.return_value = mock_runner_instance + mock_runner_instance.__exit__.return_value = None + mock_runner_instance.serve_forever.side_effect = KeyboardInterrupt() + + exit_code = app.run( + [ + "start-server", + "--host", + "192.168.1.100", + "--port", + "9000", + "--log-level", + "WARNING", + "--lambda-endpoint", + "http://custom-lambda:4000", + "--local-runner-endpoint", + "http://custom-runner:9000", + "--local-runner-region", + "eu-west-1", + "--local-runner-mode", + "remote", + ] + ) + + assert exit_code == 130 # KeyboardInterrupt exit code + mock_web_runner.assert_called_once() + + # Verify WebRunnerConfig was created with all custom values + config = mock_web_runner.call_args[0][0] # First positional argument + + # Verify web service configuration + assert config.web_service.host == "192.168.1.100" + assert config.web_service.port == 9000 + assert config.web_service.log_level == "WARNING" + + # Verify Lambda service configuration + assert config.lambda_endpoint == "http://custom-lambda:4000" + assert config.local_runner_endpoint == "http://custom-runner:9000" + assert config.local_runner_region == "eu-west-1" + assert config.local_runner_mode == "remote" + + +def test_start_server_command_uses_context_manager_properly() -> None: + """Test that start-server command uses WebRunner as context manager.""" + app = CliApp() + + with patch( + "aws_durable_execution_sdk_python_testing.cli.WebRunner" + ) as mock_web_runner: + # Mock runner context manager + mock_runner_instance = mock_web_runner.return_value + mock_runner_instance.__enter__.return_value = mock_runner_instance + mock_runner_instance.__exit__.return_value = None + mock_runner_instance.serve_forever.return_value = None + + exit_code = app.run(["start-server"]) + + assert exit_code == 0 + mock_web_runner.assert_called_once() + + # Verify context manager methods were called + mock_runner_instance.__enter__.assert_called_once() + mock_runner_instance.__exit__.assert_called_once() + mock_runner_instance.serve_forever.assert_called_once() + + +def test_start_server_command_handles_runtime_error_from_web_runner() -> None: + """Test that start-server command handles DurableFunctionsLocalRunnerError from WebRunner.""" + app = CliApp() + + with patch( + "aws_durable_execution_sdk_python_testing.cli.WebRunner" + ) as mock_web_runner: + # Mock runner context manager that raises DurableFunctionsLocalRunnerError + mock_runner_instance = mock_web_runner.return_value + mock_runner_instance.__enter__.side_effect = DurableFunctionsLocalRunnerError( + "Server already running" + ) + + exit_code = app.run(["start-server"]) + + assert exit_code == 1 + + +def test_start_server_command_logs_configuration_details() -> None: + """Test that start-server command logs configuration details.""" + app = CliApp() + + with patch( + "aws_durable_execution_sdk_python_testing.cli.WebRunner" + ) as mock_web_runner: + # Mock runner context manager + mock_runner_instance = mock_web_runner.return_value + mock_runner_instance.__enter__.return_value = mock_runner_instance + mock_runner_instance.__exit__.return_value = None + mock_runner_instance.serve_forever.side_effect = KeyboardInterrupt() + + with patch( + "aws_durable_execution_sdk_python_testing.cli.logger" + ) as mock_logger: + exit_code = app.run( + [ + "start-server", + "--host", + "test-host", + "--port", + "8888", + ] + ) + + assert exit_code == 130 + + # Verify configuration logging + mock_logger.info.assert_any_call( + "Starting Durable Functions Local Runner on %s:%s", + "test-host", + 8888, + ) + mock_logger.info.assert_any_call("Configuration:") + mock_logger.info.assert_any_call(" Host: %s", "test-host") + mock_logger.info.assert_any_call(" Port: %s", 8888) + + +def test_start_server_command_maintains_backward_compatible_logging() -> None: + """Test that start-server command maintains backward compatible logging messages.""" + app = CliApp() + + with patch( + "aws_durable_execution_sdk_python_testing.cli.WebRunner" + ) as mock_web_runner: + # Mock runner context manager + mock_runner_instance = mock_web_runner.return_value + mock_runner_instance.__enter__.return_value = mock_runner_instance + mock_runner_instance.__exit__.return_value = None + mock_runner_instance.serve_forever.side_effect = KeyboardInterrupt() + + with patch( + "aws_durable_execution_sdk_python_testing.cli.logger" + ) as mock_logger: + exit_code = app.run(["start-server"]) + + assert exit_code == 130 + + # Verify backward compatible logging messages + mock_logger.info.assert_any_call( + "Server started successfully. Press Ctrl+C to stop." + ) + mock_logger.info.assert_any_call( + "Received shutdown signal, stopping server..." + ) + + +def test_start_server_command_handles_serve_forever_exception() -> None: + """Test that start-server command handles exceptions from serve_forever.""" + app = CliApp() + + with patch( + "aws_durable_execution_sdk_python_testing.cli.WebRunner" + ) as mock_web_runner: + # Mock runner context manager + mock_runner_instance = mock_web_runner.return_value + mock_runner_instance.__enter__.return_value = mock_runner_instance + mock_runner_instance.__exit__.return_value = None + mock_runner_instance.serve_forever.side_effect = ( + DurableFunctionsLocalRunnerError("Server error during operation") + ) + + exit_code = app.run(["start-server"]) + + assert exit_code == 1 + + +def test_main_function_creates_cli_app_and_runs() -> None: + """Test the main function entry point.""" + with patch("aws_durable_execution_sdk_python_testing.cli.CliApp") as mock_cli_app: + mock_app_instance = mock_cli_app.return_value + mock_app_instance.run.return_value = 42 + + exit_code = main() + + mock_cli_app.assert_called_once() + mock_app_instance.run.assert_called_once() + assert exit_code == 42 + + +def test_main_function_works_when_called_as_script() -> None: + """Test that main function works when called as script.""" + original_argv = sys.argv[:] + try: + sys.argv = ["dex-local-runner", "--help"] + + with patch( + "aws_durable_execution_sdk_python_testing.cli.CliApp" + ) as mock_cli_app: + mock_app_instance = mock_cli_app.return_value + mock_app_instance.run.return_value = 0 + + exit_code = main() + + assert exit_code == 0 + mock_app_instance.run.assert_called_once() + finally: + sys.argv = original_argv + + +# Tests for client operation CLI commands + + +def test_invoke_command_makes_http_request_to_start_execution_endpoint() -> None: + """Test that invoke command makes HTTP request to start-durable-execution endpoint.""" + app = CliApp() + + response_body = json.dumps( + { + "ExecutionArn": "arn:aws:lambda:us-west-2:123456789012:function:test-function:execution:test-execution" + } + ).encode("utf-8") + + mock_response = Mock() + mock_response.read.return_value = response_body + mock_response.__enter__ = Mock(return_value=mock_response) + mock_response.__exit__ = Mock(return_value=False) + + with patch( + "aws_durable_execution_sdk_python_testing.cli.urlopen", + return_value=mock_response, + ) as mock_urlopen: + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + exit_code = app.invoke_command( + argparse.Namespace( + function_name="test-function", + input='{"key": "value"}', + durable_execution_name="test-execution", + ) + ) + + assert exit_code == 0 + mock_urlopen.assert_called_once() + + # Verify the request details + call_args = mock_urlopen.call_args + req = call_args[0][0] + assert req.full_url.endswith("/start-durable-execution") + assert req.get_header("Content-type") == "application/json" + assert call_args[1]["timeout"] == 10 + + # Verify payload structure + payload = json.loads(req.data.decode("utf-8")) + assert payload["FunctionName"] == "test-function" + assert payload["Input"] == '{"key": "value"}' + assert payload["ExecutionName"] == "test-execution" + + # Verify output + output = mock_stdout.getvalue() + assert "ExecutionArn" in output + + +def test_invoke_command_uses_default_execution_name_when_not_provided() -> None: + """Test that invoke command generates default execution name when not provided.""" + app = CliApp() + + response_body = json.dumps({"ExecutionArn": "test-arn"}).encode("utf-8") + mock_response = Mock() + mock_response.read.return_value = response_body + mock_response.__enter__ = Mock(return_value=mock_response) + mock_response.__exit__ = Mock(return_value=False) + + with patch( + "aws_durable_execution_sdk_python_testing.cli.urlopen", + return_value=mock_response, + ) as mock_urlopen: + app.invoke_command( + argparse.Namespace( + function_name="my-function", + input="{}", + durable_execution_name=None, + ) + ) + + # Verify default execution name is generated + req = mock_urlopen.call_args[0][0] + payload = json.loads(req.data.decode("utf-8")) + assert payload["ExecutionName"] == "my-function-execution" + + +def test_invoke_command_handles_connection_error() -> None: + """Test that invoke command handles connection errors gracefully.""" + app = CliApp() + + with patch("aws_durable_execution_sdk_python_testing.cli.urlopen") as mock_urlopen: + mock_urlopen.side_effect = URLError("Connection refused") + + exit_code = app.invoke_command( + argparse.Namespace( + function_name="test-function", + input="{}", + durable_execution_name=None, + ) + ) + + assert exit_code == 1 + + +def test_invoke_command_handles_http_error_response() -> None: + """Test that invoke command handles HTTP error responses.""" + app = CliApp() + + error_body = json.dumps( + { + "ErrorMessage": "Invalid parameter value", + "ErrorType": "InvalidParameterValueException", + } + ).encode("utf-8") + + with patch("aws_durable_execution_sdk_python_testing.cli.urlopen") as mock_urlopen: + mock_urlopen.side_effect = HTTPError( + url="http://0.0.0.0:5000/start-durable-execution", + code=400, + msg="Bad Request", + hdrs=HTTPMessage(), + fp=BytesIO(error_body), + ) + + with patch("sys.stderr", new_callable=StringIO) as mock_stderr: + exit_code = app.invoke_command( + argparse.Namespace( + function_name="test-function", + input="{}", + durable_execution_name=None, + ) + ) + + assert exit_code == 1 + assert "Invalid parameter value" in mock_stderr.getvalue() + + +def test_invoke_command_handles_non_json_error_response() -> None: + """Test that invoke command handles non-JSON error responses.""" + app = CliApp() + + with patch("aws_durable_execution_sdk_python_testing.cli.urlopen") as mock_urlopen: + mock_urlopen.side_effect = HTTPError( + url="http://0.0.0.0:5000/start-durable-execution", + code=500, + msg="Internal Server Error", + hdrs=HTTPMessage(), + fp=BytesIO(b"Internal Server Error"), + ) + + exit_code = app.invoke_command( + argparse.Namespace( + function_name="test-function", + input="{}", + durable_execution_name=None, + ) + ) + + assert exit_code == 1 + + +def test_get_durable_execution_command_uses_boto3_client() -> None: + """Test that get-durable-execution command uses boto3 client.""" + app = CliApp() + + with patch.object(app, "_create_boto3_client") as mock_create_client: + mock_client = mock_create_client.return_value + mock_client.get_durable_execution.return_value = { + "DurableExecutionArn": "test-arn", + "Status": "SUCCEEDED", + "Result": {"output": "success"}, + } + + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + exit_code = app.get_durable_execution_command( + argparse.Namespace(durable_execution_arn="test-arn") + ) + + assert exit_code == 0 + mock_create_client.assert_called_once() + mock_client.get_durable_execution.assert_called_once_with( + DurableExecutionArn="test-arn" + ) + + # Verify JSON output + output = mock_stdout.getvalue() + assert "test-arn" in output + assert "SUCCEEDED" in output + + +def test_get_durable_execution_command_handles_resource_not_found() -> None: + """Test that get-durable-execution command handles ResourceNotFoundException.""" + app = CliApp() + + with patch.object(app, "_create_boto3_client") as mock_create_client: + mock_client = mock_create_client.return_value + + mock_client.exceptions.ResourceNotFoundException = ResourceNotFoundException + mock_client.exceptions.InvalidParameterValueException = ( + InvalidParameterValueException + ) + mock_client.exceptions.TooManyRequestsException = TooManyRequestsException + mock_client.exceptions.ServiceException = ServiceException + + mock_client.get_durable_execution.side_effect = ResourceNotFoundException( + "Resource not found" + ) + + with patch("sys.stderr", new_callable=StringIO) as mock_stderr: + exit_code = app.get_durable_execution_command( + argparse.Namespace(durable_execution_arn="nonexistent-arn") + ) + + assert exit_code == 1 + assert "Error: Execution not found" in mock_stderr.getvalue() + + +def test_get_durable_execution_command_handles_invalid_parameter() -> None: + """Test that get-durable-execution command handles InvalidParameterValueException.""" + app = CliApp() + + with patch.object(app, "_create_boto3_client") as mock_create_client: + mock_client = mock_create_client.return_value + + mock_client.exceptions.ResourceNotFoundException = ResourceNotFoundException + mock_client.exceptions.InvalidParameterValueException = ( + InvalidParameterValueException + ) + mock_client.exceptions.TooManyRequestsException = TooManyRequestsException + mock_client.exceptions.ServiceException = ServiceException + + mock_client.get_durable_execution.side_effect = InvalidParameterValueException( + "Invalid parameters" + ) + + with patch("sys.stderr", new_callable=StringIO) as mock_stderr: + exit_code = app.get_durable_execution_command( + argparse.Namespace(durable_execution_arn="invalid-arn") + ) + + assert exit_code == 1 + assert "Error: Invalid parameter" in mock_stderr.getvalue() + + +def test_get_durable_execution_command_handles_too_many_requests() -> None: + """Test that get-durable-execution command handles InvalidParameterValueException.""" + app = CliApp() + + with patch.object(app, "_create_boto3_client") as mock_create_client: + mock_client = mock_create_client.return_value + + mock_client.exceptions.ResourceNotFoundException = ResourceNotFoundException + mock_client.exceptions.InvalidParameterValueException = ( + InvalidParameterValueException + ) + mock_client.exceptions.TooManyRequestsException = TooManyRequestsException + mock_client.exceptions.ServiceException = ServiceException + + mock_client.get_durable_execution.side_effect = TooManyRequestsException( + "Too many requests" + ) + + with patch("sys.stderr", new_callable=StringIO) as mock_stderr: + exit_code = app.get_durable_execution_command( + argparse.Namespace(durable_execution_arn="my-arn") + ) + + assert exit_code == 1 + assert "Error: Too many requests" in mock_stderr.getvalue() + + +def test_get_durable_execution_command_handles_service_exception() -> None: + """Test that get-durable-execution command handles InvalidParameterValueException.""" + app = CliApp() + + with patch.object(app, "_create_boto3_client") as mock_create_client: + mock_client = mock_create_client.return_value + + mock_client.exceptions.ResourceNotFoundException = ResourceNotFoundException + mock_client.exceptions.InvalidParameterValueException = ( + InvalidParameterValueException + ) + mock_client.exceptions.TooManyRequestsException = TooManyRequestsException + mock_client.exceptions.ServiceException = ServiceException + + mock_client.get_durable_execution.side_effect = ServiceException( + "Service exception" + ) + + with patch("sys.stderr", new_callable=StringIO) as mock_stderr: + exit_code = app.get_durable_execution_command( + argparse.Namespace(durable_execution_arn="my-arn") + ) + + assert exit_code == 1 + assert "Error: Service error" in mock_stderr.getvalue() + + +def test_get_durable_execution_command_handles_connection_error() -> None: + """Test that get-durable-execution command handles connection errors.""" + app = CliApp() + + with patch.object(app, "_create_boto3_client") as mock_create_client: + mock_client = mock_create_client.return_value + + mock_client.exceptions.ResourceNotFoundException = ResourceNotFoundException + mock_client.exceptions.InvalidParameterValueException = ( + InvalidParameterValueException + ) + mock_client.exceptions.TooManyRequestsException = TooManyRequestsException + mock_client.exceptions.ServiceException = ServiceException + + mock_client.get_durable_execution.side_effect = ConnectionError( + error="Mocked connection error" + ) + + with patch( + "aws_durable_execution_sdk_python_testing.cli.logger" + ) as mock_logger: + exit_code = app.get_durable_execution_command( + argparse.Namespace(durable_execution_arn="my-arn") + ) + + assert exit_code == 1 + mock_logger.exception.assert_called_once_with( + "Error: Could not connect to the local runner server. Is it running?" + ) + + +def test_get_durable_execution_history_command_uses_boto3_client() -> None: + """Test that get-durable-execution-history command uses boto3 client.""" + app = CliApp() + + with patch.object(app, "_create_boto3_client") as mock_create_client: + mock_client = mock_create_client.return_value + mock_client.get_durable_execution_history.return_value = { + "Events": [ + { + "EventType": "ExecutionStarted", + "EventTimestamp": "2024-01-01T00:00:00Z", + }, + { + "EventType": "ExecutionSucceeded", + "EventTimestamp": "2024-01-01T00:01:00Z", + }, + ] + } + + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + exit_code = app.get_durable_execution_history_command( + argparse.Namespace(durable_execution_arn="test-arn") + ) + + assert exit_code == 0 + mock_create_client.assert_called_once() + mock_client.get_durable_execution_history.assert_called_once_with( + DurableExecutionArn="test-arn" + ) + + # Verify JSON output + output = mock_stdout.getvalue() + assert "ExecutionStarted" in output + assert "ExecutionSucceeded" in output + + +def test_get_durable_execution_history_command_handles_resource_not_found() -> None: + """Test that get-durable-execution-history command handles ResourceNotFoundException.""" + app = CliApp() + + with patch.object(app, "_create_boto3_client") as mock_create_client: + mock_client = mock_create_client.return_value + mock_client.get_durable_execution_history.side_effect = Exception( + "ResourceNotFoundException: Execution not found" + ) + + exit_code = app.get_durable_execution_history_command( + argparse.Namespace(durable_execution_arn="nonexistent-arn") + ) + + assert exit_code == 1 + + +def test_get_durable_execution_history_command_handles_connection_error() -> None: + """Test that get-durable-execution-history command handles connection errors.""" + app = CliApp() + + with patch.object(app, "_create_boto3_client") as mock_create_client: + mock_client = mock_create_client.return_value + mock_client.get_durable_execution_history.side_effect = Exception( + "Connection refused" + ) + + exit_code = app.get_durable_execution_history_command( + argparse.Namespace(durable_execution_arn="test-arn") + ) + + assert exit_code == 1 + + +def test_create_boto3_client_creates_client_correctly() -> None: + """Test that _create_boto3_client creates boto3 client correctly.""" + app = CliApp() + + with patch("boto3.client") as mock_boto3_client: + app._create_boto3_client() # noqa: SLF001 + + # Verify boto3 client is created with correct parameters + mock_boto3_client.assert_called_once_with( + "lambda", + endpoint_url=app.config.local_runner_endpoint, + region_name=app.config.local_runner_region, + ) + + +def test_create_boto3_client_handles_creation_failure() -> None: + """Test that _create_boto3_client handles client creation failures.""" + app = CliApp() + + with patch("boto3.client") as mock_boto3_client: + mock_boto3_client.side_effect = Exception("Client creation failed") + + with pytest.raises(DurableFunctionsLocalRunnerError) as exc_info: + app._create_boto3_client() # noqa: SLF001 + + assert "Failed to create boto3 client" in str(exc_info.value) + assert "Client creation failed" in str(exc_info.value) + + +def test_cli_app_handles_durable_functions_test_error() -> None: + """Test that DurableFunctionsTestError is handled gracefully.""" + app = CliApp() + + with patch.object(app, "_create_parsers") as mock_setup: + from aws_durable_execution_sdk_python_testing.exceptions import ( + DurableFunctionsTestError, + ) + + mock_setup.side_effect = DurableFunctionsTestError("Test error") + + exit_code = app.run(["start-server"]) + + assert exit_code == 1 + + +def test_cli_app_handles_unexpected_exception() -> None: + """Test that unexpected exceptions are handled gracefully.""" + app = CliApp() + + with patch.object(app, "_create_parsers") as mock_setup: + mock_setup.side_effect = RuntimeError("Unexpected error") + + with patch( + "aws_durable_execution_sdk_python_testing.cli.logger" + ) as mock_logger: + exit_code = app.run(["start-server"]) + + assert exit_code == 1 + mock_logger.exception.assert_called_once_with("Unexpected error.") + + +def test_invoke_command_handles_general_exception() -> None: + """Test that invoke command handles general exceptions.""" + app = CliApp() + + with patch("aws_durable_execution_sdk_python_testing.cli.urlopen") as mock_urlopen: + mock_urlopen.side_effect = ValueError("Some unexpected error") + + exit_code = app.invoke_command( + argparse.Namespace( + function_name="test-function", + input="{}", + durable_execution_name=None, + ) + ) + + assert exit_code == 1 + + +def test_get_durable_execution_command_handles_general_exception() -> None: + """Test that get-durable-execution command handles general exceptions.""" + app = CliApp() + + with patch.object(app, "_create_boto3_client") as mock_create_client: + mock_client = mock_create_client.return_value + mock_client.exceptions.ResourceNotFoundException = ResourceNotFoundException + mock_client.exceptions.InvalidParameterValueException = ( + InvalidParameterValueException + ) + mock_client.exceptions.TooManyRequestsException = TooManyRequestsException + mock_client.exceptions.ServiceException = ServiceException + mock_client.get_durable_execution.side_effect = ValueError( + "Some unexpected error" + ) + + with patch( + "aws_durable_execution_sdk_python_testing.cli.logger" + ) as mock_logger: + exit_code = app.get_durable_execution_command( + argparse.Namespace(durable_execution_arn="my-arn") + ) + + assert exit_code == 1 + mock_logger.exception.assert_called_once_with( + "Unexpected error in get-durable-execution command" + ) + + +def test_get_durable_execution_history_command_handles_general_exception() -> None: + """Test that get-durable-execution-history command handles general exceptions.""" + app = CliApp() + + with patch.object(app, "_create_boto3_client") as mock_create_client: + mock_client = mock_create_client.return_value + mock_client.get_durable_execution_history.side_effect = ValueError( + "Some unexpected error" + ) + + exit_code = app.get_durable_execution_history_command( + argparse.Namespace(durable_execution_arn="test-arn") + ) + + assert exit_code == 1 diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/client_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/client_test.py new file mode 100644 index 0000000..15aa366 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/client_test.py @@ -0,0 +1,103 @@ +"""Unit tests for InMemoryServiceClient.""" + +import datetime +from unittest.mock import Mock + +from aws_durable_execution_sdk_python.lambda_service import ( + CheckpointOutput, + OperationAction, + OperationType, + OperationUpdate, + StateOutput, +) + +from aws_durable_execution_sdk_python_testing.client import InMemoryServiceClient + + +def test_checkpoint(): + """Test checkpoint method delegates to processor.""" + processor = Mock() + expected_output = CheckpointOutput( + checkpoint_token="new-token", # noqa: S106 + new_execution_state=Mock(), + ) + processor.process_checkpoint.return_value = expected_output + + client = InMemoryServiceClient(processor) + + updates = [ + OperationUpdate( + operation_id="test-id", + operation_type=OperationType.STEP, + action=OperationAction.START, + ) + ] + + result = client.checkpoint( + "arn:aws:lambda:us-east-1:123456789012:function:test", + "token", + updates, + "client-token", + ) + + assert result == expected_output + processor.process_checkpoint.assert_called_once_with( + "token", updates, "client-token" + ) + + +def test_get_execution_state(): + """Test get_execution_state method delegates to processor.""" + processor = Mock() + expected_output = StateOutput(operations=[], next_marker="marker") + processor.get_execution_state.return_value = expected_output + + client = InMemoryServiceClient(processor) + + result = client.get_execution_state( + "arn:aws:lambda:us-east-1:123456789012:function:test", "token", "marker", 500 + ) + + assert result == expected_output + processor.get_execution_state.assert_called_once_with("token", "marker", 500) + + +def test_get_execution_state_default_max_items(): + """Test get_execution_state with default max_items.""" + processor = Mock() + expected_output = StateOutput(operations=[], next_marker="marker") + processor.get_execution_state.return_value = expected_output + + client = InMemoryServiceClient(processor) + + result = client.get_execution_state( + "arn:aws:lambda:us-east-1:123456789012:function:test", "token", "marker" + ) + + assert result == expected_output + processor.get_execution_state.assert_called_once_with("token", "marker", 1000) + + +def test_stop(): + """Test stop method returns current datetime.""" + processor = Mock() + client = InMemoryServiceClient(processor) + + before = datetime.datetime.now(tz=datetime.UTC) + result = client.stop( + "arn:aws:states:us-east-1:123456789012:execution:test", b"payload" + ) + after = datetime.datetime.now(tz=datetime.UTC) + + assert isinstance(result, datetime.datetime) + assert before <= result <= after + + +def test_stop_with_none_payload(): + """Test stop method with None payload.""" + processor = Mock() + client = InMemoryServiceClient(processor) + + result = client.stop("arn:aws:states:us-east-1:123456789012:execution:test", None) + + assert isinstance(result, datetime.datetime) diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/durable_executions_python_testing_library_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/durable_executions_python_testing_library_test.py new file mode 100644 index 0000000..d827bbf --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/durable_executions_python_testing_library_test.py @@ -0,0 +1,7 @@ +"""Tests for DurableExecutionsPythonTestingLibrary module.""" + +import aws_durable_execution_sdk_python_testing # noqa: F401 + + +def test_aws_durable_execution_sdk_python_testing_importable(): + """Test aws_durable_execution_sdk_python_testing is importable.""" diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/e2e/__init__.py b/packages/aws-durable-execution-sdk-python-testing/tests/e2e/__init__.py new file mode 100644 index 0000000..78d8de9 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/e2e/__init__.py @@ -0,0 +1 @@ +"""Test package""" diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/e2e/basic_success_path_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/e2e/basic_success_path_test.py new file mode 100644 index 0000000..3e93bcf --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/e2e/basic_success_path_test.py @@ -0,0 +1,92 @@ +"""Functional tests, covering end-to-end DurableTestRunner.""" + +import json +from typing import Any + +from aws_durable_execution_sdk_python.context import ( + DurableContext, + durable_step, + durable_with_child_context, +) +from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, + durable_execution, +) +from aws_durable_execution_sdk_python.types import StepContext + +from aws_durable_execution_sdk_python_testing.runner import ( + ContextOperation, + DurableFunctionTestResult, + DurableFunctionTestRunner, + StepOperation, +) +from aws_durable_execution_sdk_python.config import Duration + + +# brazil-test-exec pytest test/runner_int_test.py +def test_basic_durable_function() -> None: + @durable_step + def one(step_context: StepContext, a: int, b: int) -> str: + # print("[DEBUG] one called") + return f"{a} {b}" + + @durable_step + def two_1(step_context: StepContext, a: int, b: int) -> str: + # print("[DEBUG] two_1 called") + return f"{a} {b}" + + @durable_step + def two_2(step_context: StepContext, a: int, b: int) -> str: + # print("[DEBUG] two_2 called") + return f"{b} {a}" + + @durable_with_child_context + def two(ctx: DurableContext, a: int, b: int) -> str: + # print("[DEBUG] two called") + two_1_result: str = ctx.step(two_1(a, b)) + two_2_result: str = ctx.step(two_2(a, b)) + return f"{two_1_result} {two_2_result}" + + @durable_step + def three(step_context: StepContext, a: int, b: int) -> str: + # print("[DEBUG] three called") + return f"{a} {b}" + + @durable_execution + def function_under_test(event: Any, context: DurableContext) -> list[str]: + results: list[str] = [] + + result_one: str = context.step(one(1, 2)) + results.append(result_one) + + context.wait(Duration.from_seconds(1)) + + result_two: str = context.run_in_child_context(two(3, 4)) + results.append(result_two) + + result_three: str = context.step(three(5, 6)) + results.append(result_three) + + return results + + with DurableFunctionTestRunner(handler=function_under_test) as runner: + result: DurableFunctionTestResult = runner.run(input="input str", timeout=10) + + assert result.status is InvocationStatus.SUCCEEDED + assert result.result == json.dumps(["1 2", "3 4 4 3", "5 6"]) + + one_result: StepOperation = result.get_step("one") + assert one_result.result == json.dumps("1 2") + + two_result: ContextOperation = result.get_context("two") + assert two_result.result == json.dumps("3 4 4 3") + + three_result: StepOperation = result.get_step("three") + assert three_result.result == json.dumps("5 6") + + # currently has the optimization where it's not saving child checkpoints after parent done + # prob should unpick that for test + # two_one_op = cast(StepOperation, two_result_op.get_operation_by_name("two_1")) + # assert two_one_op.result == '"3 4"' + + # print("done") diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/event_factory_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/event_factory_test.py new file mode 100644 index 0000000..1f4238e --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/event_factory_test.py @@ -0,0 +1,2002 @@ +"""Tests for Event factory methods. + +This module tests all the event creation factory methods in the Event class. +""" + +from datetime import UTC, datetime +from unittest.mock import Mock + +import pytest +from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + OperationStatus, + OperationType, + StepDetails, + OperationUpdate, + OperationSubType, + OperationAction, + StepOptions, +) + +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) +from aws_durable_execution_sdk_python_testing.model import ( + CheckpointDurableExecutionRequest, + ErrorResponse, + Event, + EventCreationContext, + EventError, + EventInput, + EventResult, + Execution, + ExecutionStartedDetails, + LambdaContext, + StartDurableExecutionInput, +) + + +# Helper function to create mock operations +def create_mock_operation( + operation_id: str = "op-1", + name: str = "test_op", + parent_id=None, + status: OperationStatus = OperationStatus.STARTED, +): + from unittest.mock import Mock + + op = Mock() + op.operation_id = operation_id + op.name = name + op.parent_id = parent_id + op.status = status + return op + + +# region execution-tests +def test_create_execution_started(): + from unittest.mock import Mock + from aws_durable_execution_sdk_python.lambda_service import ExecutionDetails + + operation = Mock() + operation.operation_id = "op-1" + operation.name = "test_execution" + operation.parent_id = None + operation.status = OperationStatus.STARTED + operation.start_timestamp = datetime.now(UTC) + operation.operation_type = OperationType.EXECUTION + operation.sub_type = None + operation.execution_details = ExecutionDetails(input_payload='{"test": "data"}') + + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + include_execution_data=True, + ) + event = Event.create_execution_event(context) + + assert event.event_type == "ExecutionStarted" + assert event.operation_id == "op-1" + assert event.name == "test_execution" + assert event.execution_started_details.input.payload == '{"test": "data"}' + assert event.execution_started_details.execution_timeout == 300 + + +def test_create_execution_succeeded(): + from aws_durable_execution_sdk_python.execution import ( + DurableExecutionInvocationOutput, + InvocationStatus, + ) + + operation = create_mock_operation("op-1", status=OperationStatus.SUCCEEDED) + operation.end_timestamp = datetime.now(UTC) + + result = DurableExecutionInvocationOutput( + status=InvocationStatus.SUCCEEDED, result='{"result": "success"}' + ) + context = EventCreationContext.create( + operation=operation, + event_id=2, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + result=result, + include_execution_data=True, + ) + event = Event.create_execution_event(context) + + assert event.event_type == "ExecutionSucceeded" + assert event.execution_succeeded_details.result.payload == '{"result": "success"}' + + +def test_create_execution_failed(): + from aws_durable_execution_sdk_python.execution import ( + DurableExecutionInvocationOutput, + InvocationStatus, + ) + + operation = create_mock_operation("op-1", status=OperationStatus.FAILED) + operation.end_timestamp = datetime.now(UTC) + + error_result = DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, + error=ErrorObject.from_message("Execution failed"), + ) + context = EventCreationContext.create( + operation=operation, + event_id=3, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + result=error_result, + include_execution_data=True, + ) + event = Event.create_execution_event(context) + + assert event.event_type == "ExecutionFailed" + assert event.execution_failed_details.error.payload.message == "Execution failed" + + +def test_create_execution_timed_out(): + from aws_durable_execution_sdk_python.execution import ( + DurableExecutionInvocationOutput, + InvocationStatus, + ) + + operation = create_mock_operation("op-1", status=OperationStatus.TIMED_OUT) + operation.end_timestamp = datetime.now(UTC) + + error_result = DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, + error=ErrorObject.from_message("Execution timed out"), + ) + context = EventCreationContext.create( + operation=operation, + event_id=4, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + result=error_result, + include_execution_data=True, + ) + event = Event.create_execution_event(context) + + assert event.event_type == "ExecutionTimedOut" + assert ( + event.execution_timed_out_details.error.payload.message == "Execution timed out" + ) + + +def test_create_execution_stopped(): + from aws_durable_execution_sdk_python.execution import ( + DurableExecutionInvocationOutput, + InvocationStatus, + ) + + operation = create_mock_operation("op-1", status=OperationStatus.STOPPED) + operation.end_timestamp = datetime.now(UTC) + + error_result = DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, + error=ErrorObject.from_message("Execution stopped"), + ) + context = EventCreationContext.create( + operation=operation, + event_id=5, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + result=error_result, + include_execution_data=True, + ) + event = Event.create_execution_event(context) + + assert event.event_type == "ExecutionStopped" + assert event.execution_stopped_details.error.payload.message == "Execution stopped" + + +def test_create_execution_invalid_status(): + operation = create_mock_operation("op-1", status=OperationStatus.CANCELLED) + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + with pytest.raises( + InvalidParameterValueException, + match="Operation status .* is not valid for execution operations", + ): + Event.create_execution_event(context) + + +# endregion execution-tests + + +# region context-tests +def test_create_context_started(): + operation = create_mock_operation( + "ctx-1", "test_context", status=OperationStatus.STARTED + ) + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_context_event(context) + + assert event.event_type == "ContextStarted" + assert event.operation_id == "ctx-1" + assert event.name == "test_context" + assert event.context_started_details is not None + + +def test_create_context_succeeded(): + operation = create_mock_operation("ctx-1", status=OperationStatus.SUCCEEDED) + operation.context_details = type( + "MockDetails", (), {"result": '{"context": "result"}', "error": None} + )() + context = EventCreationContext.create( + operation=operation, + event_id=2, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + include_execution_data=True, + ) + event = Event.create_context_event(context) + + assert event.event_type == "ContextSucceeded" + assert event.context_succeeded_details.result.payload == '{"context": "result"}' + + +def test_create_context_failed(): + operation = create_mock_operation("ctx-1", status=OperationStatus.FAILED) + error_obj = ErrorObject.from_message("Context failed") + operation.context_details = type( + "MockDetails", (), {"result": None, "error": error_obj} + )() + context = EventCreationContext.create( + operation=operation, + event_id=3, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_context_event(context) + + assert event.event_type == "ContextFailed" + assert event.context_failed_details.error.payload.message == "Context failed" + + +def test_create_context_invalid_status(): + operation = create_mock_operation("ctx-1", status=OperationStatus.TIMED_OUT) + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + with pytest.raises( + InvalidParameterValueException, + match="Operation status .* is not valid for context operations", + ): + Event.create_context_event(context) + + +# endregion context-tests + + +# region wait-tests +def test_create_wait_started(): + operation = create_mock_operation("wait-1", status=OperationStatus.STARTED) + operation.start_timestamp = datetime.fromisoformat("2024-01-01T12:00:00Z") + operation.wait_details = type( + "MockDetails", + (), + {"scheduled_end_timestamp": datetime.fromisoformat("2024-01-01T12:05:00Z")}, + )() + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_wait_event(context) + + assert event.event_type == "WaitStarted" + assert event.wait_started_details.duration == 300 + assert event.wait_started_details.scheduled_end_timestamp == datetime.fromisoformat( + "2024-01-01T12:05:00Z" + ) + + +def test_create_wait_succeeded(): + operation = create_mock_operation("wait-1", status=OperationStatus.SUCCEEDED) + operation.start_timestamp = datetime.fromisoformat("2024-01-01T12:00:00Z") + operation.wait_details = type( + "MockDetails", + (), + {"scheduled_end_timestamp": datetime.fromisoformat("2024-01-01T12:05:00Z")}, + )() + context = EventCreationContext.create( + operation=operation, + event_id=2, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_wait_event(context) + + assert event.event_type == "WaitSucceeded" + assert event.wait_succeeded_details.duration == 300 + + +def test_create_wait_cancelled(): + operation = create_mock_operation("wait-1", status=OperationStatus.CANCELLED) + operation.wait_details = None + mock_operation_update = Mock() + mock_operation_update.operation_type = OperationType.WAIT + mock_operation_update.operation_update.action = OperationAction.CANCEL + context = EventCreationContext.create( + operation=operation, + event_id=3, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + operation_update=mock_operation_update, + ) + event = Event.create_wait_event(context) + + assert event.event_type == "WaitCancelled" + assert event.wait_cancelled_details is not None + + +def test_create_wait_invalid_status(): + operation = create_mock_operation("wait-1", status=OperationStatus.FAILED) + operation.wait_details.scheduled_end_timestamp = operation.start_timestamp = ( + datetime.fromisoformat("2024-01-01T12:00:00Z") + ) + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + with pytest.raises( + InvalidParameterValueException, + match="Operation status .* is not valid for wait operations", + ): + Event.create_wait_event(context) + + +# endregion wait-tests + + +# region step-tests +def test_create_step_started(): + operation = create_mock_operation( + "step-1", "test_step", status=OperationStatus.STARTED + ) + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_step_event(context) + + assert event.event_type == "StepStarted" + assert event.operation_id == "step-1" + assert event.name == "test_step" + assert event.step_started_details is not None + + +def test_create_step_succeeded(): + operation = create_mock_operation("step-1", status=OperationStatus.SUCCEEDED) + operation.step_details = type( + "MockDetails", (), {"result": '{"step": "result"}', "error": None} + )() + context = EventCreationContext.create( + operation=operation, + event_id=2, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + include_execution_data=True, + ) + event = Event.create_step_event(context) + + assert event.event_type == "StepSucceeded" + assert event.step_succeeded_details.result.payload == '{"step": "result"}' + + +def test_create_step_failed(): + operation = create_mock_operation("step-1", status=OperationStatus.FAILED) + error_obj = ErrorObject.from_message("Step failed") + operation.step_details = type( + "MockDetails", (), {"result": None, "error": error_obj} + )() + context = EventCreationContext.create( + operation=operation, + event_id=3, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_step_event(context) + + assert event.event_type == "StepFailed" + assert event.step_failed_details.error.payload.message == "Step failed" + + +def test_create_step_invalid_status(): + operation = create_mock_operation("step-1", status=OperationStatus.TIMED_OUT) + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + with pytest.raises( + InvalidParameterValueException, + match="Operation status .* is not valid for step operations", + ): + Event.create_step_event(context) + + +# endregion step-tests + + +# region chained_invoke +def test_create_chained_invoke_started(): + operation = create_mock_operation( + "invoke-1", "test_invoke", status=OperationStatus.STARTED + ) + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_chained_invoke_event(context) + + assert event.event_type == "ChainedInvokeStarted" + assert event.operation_id == "invoke-1" + assert event.name == "test_invoke" + assert event.chained_invoke_started_details is not None + + +# endregion callback + + +# endregion helpers-test + + +def test_create_chained_invoke_succeeded(): + operation = create_mock_operation("invoke-1", status=OperationStatus.SUCCEEDED) + operation.chained_invoke_details = type( + "MockDetails", (), {"result": '{"invoke": "result"}', "error": None} + )() + context = EventCreationContext.create( + operation=operation, + event_id=2, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + include_execution_data=True, + ) + event = Event.create_chained_invoke_event(context) + + assert event.event_type == "ChainedInvokeSucceeded" + assert ( + event.chained_invoke_succeeded_details.result.payload == '{"invoke": "result"}' + ) + + +def test_create_chained_invoke_failed(): + operation = create_mock_operation("invoke-1", status=OperationStatus.FAILED) + error_obj = ErrorObject.from_message("Invoke failed") + operation.chained_invoke_details = type( + "MockDetails", (), {"result": None, "error": error_obj} + )() + context = EventCreationContext.create( + operation=operation, + event_id=3, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_chained_invoke_event(context) + + assert event.event_type == "ChainedInvokeFailed" + assert event.chained_invoke_failed_details.error.payload.message == "Invoke failed" + + +def test_create_chained_invoke_timed_out(): + operation = create_mock_operation("invoke-1", status=OperationStatus.TIMED_OUT) + error_obj = ErrorObject.from_message("Invoke timed out") + operation.chained_invoke_details = type( + "MockDetails", (), {"result": None, "error": error_obj} + )() + context = EventCreationContext.create( + operation=operation, + event_id=4, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_chained_invoke_event(context) + + assert event.event_type == "ChainedInvokeTimedOut" + assert ( + event.chained_invoke_timed_out_details.error.payload.message + == "Invoke timed out" + ) + + +def test_create_chained_invoke_stopped(): + operation = create_mock_operation("invoke-1", status=OperationStatus.STOPPED) + error_obj = ErrorObject.from_message("Invoke stopped") + operation.chained_invoke_details = type( + "MockDetails", (), {"result": None, "error": error_obj} + )() + context = EventCreationContext.create( + operation=operation, + event_id=5, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_chained_invoke_event(context) + + assert event.event_type == "ChainedInvokeStopped" + assert ( + event.chained_invoke_stopped_details.error.payload.message == "Invoke stopped" + ) + + +def test_create_chained_invoke_invalid_status(): + operation = create_mock_operation("invoke-1", status=OperationStatus.CANCELLED) + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + with pytest.raises( + InvalidParameterValueException, + match="Operation status .* is not valid for chained invoke operations", + ): + Event.create_chained_invoke_event(context) + + +# endregion chained_invoke + + +# region callback-tests +def test_create_callback_started(): + operation = create_mock_operation( + "callback-1", "test_callback", status=OperationStatus.STARTED + ) + operation.callback_details = type( + "MockDetails", (), {"callback_id": "cb-123", "result": None, "error": None} + )() + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_callback_event(context) + + assert event.event_type == "CallbackStarted" + assert event.operation_id == "callback-1" + assert event.name == "test_callback" + assert event.callback_started_details.callback_id == "cb-123" + + +def test_create_callback_succeeded(): + operation = create_mock_operation("callback-1", status=OperationStatus.SUCCEEDED) + operation.callback_details = type( + "MockDetails", + (), + {"callback_id": None, "result": '{"callback": "result"}', "error": None}, + )() + context = EventCreationContext.create( + operation=operation, + event_id=2, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + include_execution_data=True, + ) + event = Event.create_callback_event(context) + + assert event.event_type == "CallbackSucceeded" + assert event.callback_succeeded_details.result.payload == '{"callback": "result"}' + + +def test_create_callback_failed(): + operation = create_mock_operation("callback-1", status=OperationStatus.FAILED) + error_obj = ErrorObject.from_message("Callback failed") + operation.callback_details = type( + "MockDetails", (), {"callback_id": None, "result": None, "error": error_obj} + )() + context = EventCreationContext.create( + operation=operation, + event_id=3, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_callback_event(context) + + assert event.event_type == "CallbackFailed" + assert event.callback_failed_details.error.payload.message == "Callback failed" + + +def test_create_callback_timed_out(): + operation = create_mock_operation("callback-1", status=OperationStatus.TIMED_OUT) + error_obj = ErrorObject.from_message("Callback timed out") + operation.callback_details = type( + "MockDetails", (), {"callback_id": None, "result": None, "error": error_obj} + )() + context = EventCreationContext.create( + operation=operation, + event_id=4, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_callback_event(context) + + assert event.event_type == "CallbackTimedOut" + assert ( + event.callback_timed_out_details.error.payload.message == "Callback timed out" + ) + + +def test_create_callback_invalid_status(): + operation = create_mock_operation("callback-1", status=OperationStatus.STOPPED) + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + with pytest.raises( + InvalidParameterValueException, + match="Operation status .* is not valid for callback operations", + ): + Event.create_callback_event(context) + + +# endregion callback-tests + + +# region model-tests +def test_lambda_context(): + context = LambdaContext(aws_request_id="test-123") + assert context.aws_request_id == "test-123" + assert context.get_remaining_time_in_millis() == 900000 + context.log("test message") # Should not raise + + +def test_start_durable_execution_input_missing_field(): + with pytest.raises( + InvalidParameterValueException, match="Missing required field: AccountId" + ): + StartDurableExecutionInput.from_dict({}) + + +def test_start_durable_execution_input_to_dict_with_optionals(): + input_obj = StartDurableExecutionInput( + account_id="123456789", + function_name="test-func", + function_qualifier="$LATEST", + execution_name="test-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="inv-123", + trace_fields={"key": "value"}, + tenant_id="tenant-123", + input='{"test": "data"}', + ) + result = input_obj.to_dict() + assert result["InvocationId"] == "inv-123" + assert result["TraceFields"] == {"key": "value"} + assert result["TenantId"] == "tenant-123" + assert result["Input"] == '{"test": "data"}' + + +def test_execution_from_dict_empty_function_arn(): + data = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789:function:test", + "DurableExecutionName": "test-exec", + "Status": "SUCCEEDED", + "StartTimestamp": 1640995200.0, + } + execution = Execution.from_dict(data) + assert execution.function_arn == "" + + +def test_execution_to_dict_with_function_arn(): + execution = Execution( + durable_execution_arn="arn:aws:lambda:us-east-1:123456789:function:test", + durable_execution_name="test-exec", + function_arn="arn:aws:lambda:us-east-1:123456789:function:test", + status="SUCCEEDED", + start_timestamp=1640995200.0, + ) + result = execution.to_dict() + assert "FunctionArn" in result + + +def test_event_input_from_details(): + from aws_durable_execution_sdk_python.lambda_service import ExecutionDetails + + details = ExecutionDetails(input_payload='{"test": "data"}') + event_input = EventInput.from_details(details, include=True) + assert event_input.payload == '{"test": "data"}' + assert not event_input.truncated + + event_input_truncated = EventInput.from_details(details, include=False) + assert event_input_truncated.payload is None + assert event_input_truncated.truncated + + +def test_event_result_from_details(): + from aws_durable_execution_sdk_python.lambda_service import StepDetails + + details = StepDetails(result='{"result": "success"}') + event_result = EventResult.from_details(details, include=True) + assert event_result.payload == '{"result": "success"}' + assert not event_result.truncated + + +def test_event_error_from_details(): + from aws_durable_execution_sdk_python.lambda_service import StepDetails + + error_obj = ErrorObject.from_message("Test error") + details = StepDetails(error=error_obj) + event_error = EventError.from_details(details) + assert event_error.payload.message == "Test error" + + +def test_event_from_dict_with_all_details(): + data = { + "EventType": "ExecutionStarted", + "EventTimestamp": datetime.fromisoformat("2024-01-01T12:00:00Z"), + "EventId": 1, + "Id": "op-1", + "Name": "test", + "ParentId": "parent-1", + "SubType": "test-subtype", + "ExecutionStartedDetails": { + "Input": {"Payload": '{"test": "data"}', "Truncated": False}, + "ExecutionTimeout": 300, + }, + } + event = Event.from_dict(data) + assert event.sub_type == "test-subtype" + assert event.parent_id == "parent-1" + + +def test_event_to_dict_with_all_details(): + event = Event( + event_type="ExecutionStarted", + event_timestamp=datetime.fromisoformat("2024-01-01T12:00:00Z"), + event_id=1, + operation_id="op-1", + name="test", + parent_id="parent-1", + sub_type="test-subtype", + execution_started_details=ExecutionStartedDetails( + input=EventInput(payload='{"test": "data"}', truncated=False), + execution_timeout=300, + ), + ) + result = event.to_dict() + assert result["SubType"] == "test-subtype" + assert result["ParentId"] == "parent-1" + assert result["ExecutionStartedDetails"]["ExecutionTimeout"] == 300 + + +def test_error_response_from_dict_nested(): + data = { + "error": { + "type": "ValidationError", + "message": "Invalid input", + "code": "400", + "requestId": "req-123", + } + } + error_response = ErrorResponse.from_dict(data) + assert error_response.error_type == "ValidationError" + assert error_response.error_message == "Invalid input" + assert error_response.error_code == "400" + assert error_response.request_id == "req-123" + + +def test_error_response_from_dict_flat(): + data = {"type": "ValidationError", "message": "Invalid input"} + error_response = ErrorResponse.from_dict(data) + assert error_response.error_type == "ValidationError" + assert error_response.error_message == "Invalid input" + + +def test_checkpoint_durable_execution_request_from_dict(): + token: str = "token-123" + data = { + "CheckpointToken": token, + "Updates": [ + {"Id": "op-1", "Type": "STEP", "Action": "START", "SubType": "Step"} + ], + } + request = CheckpointDurableExecutionRequest.from_dict(data, "arn:test") + assert request.checkpoint_token == token + assert len(request.updates) == 1 + assert request.updates[0].operation_id == "op-1" + + +# endregion model-tests + + +# region from_operation_started_tests +class TestFromOperationStarted: + """Tests for Event.from_operation_started method.""" + + def test_from_operation_started_execution(self): + """Test converting execution operation to started event.""" + operation = Mock() + operation.operation_id = "exec-123" + operation.name = "test_execution" + operation.parent_id = "parent-123" + operation.operation_type = OperationType.EXECUTION + operation.start_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC) + + execution_details = Mock() + execution_details.input_payload = '{"test": "data"}' + operation.execution_details = execution_details + + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + include_execution_data=True, + ) + event = Event.create_event_started(context) + + assert event.event_type == "ExecutionStarted" + assert event.operation_id == "exec-123" + assert event.name == "test_execution" + assert event.parent_id == "parent-123" + assert event.execution_started_details.input.payload == '{"test": "data"}' + assert not event.execution_started_details.input.truncated + + def test_from_operation_started_execution_no_data(self): + """Test execution operation with include_execution_data=False.""" + operation = Mock() + operation.operation_id = "exec-123" + operation.name = "test_execution" + operation.parent_id = "parent-123" + operation.operation_type = OperationType.EXECUTION + operation.start_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC) + + execution_details = Mock() + execution_details.input_payload = '{"test": "data"}' + operation.execution_details = execution_details + + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + include_execution_data=False, + ) + event = Event.create_event_started(context) + + assert event.event_type == "ExecutionStarted" + assert event.execution_started_details.input.payload is None + assert event.execution_started_details.input.truncated + + def test_from_operation_started_step(self): + """Test converting step operation to started event.""" + operation = Mock() + operation.operation_id = "step-123" + operation.name = "test_step" + operation.parent_id = "parent-123" + operation.operation_type = OperationType.STEP + operation.start_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC) + + context = EventCreationContext.create( + operation=operation, + event_id=2, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_event_started(context) + + assert event.event_type == "StepStarted" + assert event.operation_id == "step-123" + assert event.name == "test_step" + assert event.parent_id == "parent-123" + assert event.step_started_details is not None + + def test_from_operation_started_wait(self): + """Test converting wait operation to started event.""" + operation = Mock() + operation.operation_id = "wait-123" + operation.name = "test_wait" + operation.parent_id = "parent-123" + operation.operation_type = OperationType.WAIT + operation.start_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC) + + wait_details = Mock() + wait_details.scheduled_end_timestamp = datetime( + 2024, 1, 1, 12, 5, 0, tzinfo=UTC + ) + operation.wait_details = wait_details + + context = EventCreationContext.create( + operation=operation, + event_id=3, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_event_started(context) + + assert event.event_type == "WaitStarted" + assert event.operation_id == "wait-123" + assert event.name == "test_wait" + assert event.parent_id == "parent-123" + assert event.wait_started_details.duration == 300 + assert ( + event.wait_started_details.scheduled_end_timestamp + == datetime.fromisoformat("2024-01-01T12:05:00+00:00") + ) + + def test_from_operation_started_callback(self): + """Test converting callback operation to started event.""" + operation = Mock() + operation.operation_id = "callback-123" + operation.name = "test_callback" + operation.parent_id = "parent-123" + operation.operation_type = OperationType.CALLBACK + operation.start_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC) + + callback_details = Mock() + callback_details.callback_id = "cb-456" + operation.callback_details = callback_details + + context = EventCreationContext.create( + operation=operation, + event_id=4, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_event_started(context) + + assert event.event_type == "CallbackStarted" + assert event.operation_id == "callback-123" + assert event.name == "test_callback" + assert event.parent_id == "parent-123" + assert event.callback_started_details.callback_id == "cb-456" + + def test_from_operation_started_chained_invoke(self): + """Test converting chained invoke operation to started event.""" + operation = Mock() + operation.operation_id = "invoke-123" + operation.name = "test_invoke" + operation.parent_id = "parent-123" + operation.operation_type = OperationType.CHAINED_INVOKE + operation.start_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC) + + context = EventCreationContext.create( + operation=operation, + event_id=5, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_event_started(context) + + assert event.event_type == "ChainedInvokeStarted" + assert event.operation_id == "invoke-123" + assert event.name == "test_invoke" + assert event.parent_id == "parent-123" + assert event.chained_invoke_started_details is not None + + def test_from_operation_started_context(self): + """Test converting context operation to started event.""" + operation = Mock() + operation.operation_id = "context-123" + operation.name = "test_context" + operation.parent_id = "parent-123" + operation.operation_type = OperationType.CONTEXT + operation.start_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC) + + context = EventCreationContext.create( + operation=operation, + event_id=6, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_event_started(context) + + assert event.event_type == "ContextStarted" + assert event.operation_id == "context-123" + assert event.name == "test_context" + assert event.parent_id == "parent-123" + assert event.context_started_details is not None + + def test_from_operation_started_no_timestamp(self): + """Test error when operation has no start timestamp.""" + operation = Mock() + operation.start_timestamp = None + + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + with pytest.raises( + InvalidParameterValueException, + match="Operation start timestamp cannot be None", + ): + Event.create_event_started(context) + + def test_from_operation_started_unknown_type(self): + """Test error with unknown operation type.""" + operation = Mock() + operation.operation_type = "UNKNOWN_TYPE" + operation.start_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC) + + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + with pytest.raises( + InvalidParameterValueException, match="Unknown operation type: UNKNOWN_TYPE" + ): + Event.create_event_started(context) + + +# endregion from_operation_started_tests + + +# region from_operation_finished_tests +class TestFromOperationFinished: + """Tests for Event.from_operation_finished method.""" + + def test_from_operation_finished_execution_succeeded(self): + """Test converting succeeded execution operation to finished event.""" + operation = Mock() + operation.operation_id = "exec-123" + operation.name = "test_execution" + operation.parent_id = "parent-123" + operation.operation_type = OperationType.EXECUTION + operation.status = OperationStatus.SUCCEEDED + operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC) + + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_event_terminated(context) + + assert event.event_type == "ExecutionSucceeded" + assert event.operation_id == "exec-123" + assert event.name == "test_execution" + assert event.parent_id == "parent-123" + + def test_from_operation_finished_execution_failed(self): + """Test converting failed execution operation to finished event.""" + operation = Mock() + operation.operation_id = "exec-123" + operation.name = "test_execution" + operation.parent_id = "parent-123" + operation.operation_type = OperationType.EXECUTION + operation.status = OperationStatus.FAILED + operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC) + + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_event_terminated(context) + + assert event.event_type == "ExecutionFailed" + assert event.operation_id == "exec-123" + + def test_from_operation_finished_step_with_result(self): + """Test converting succeeded step operation with result.""" + operation = Mock() + operation.operation_id = "step-123" + operation.name = "test_step" + operation.parent_id = "parent-123" + operation.operation_type = OperationType.STEP + operation.status = OperationStatus.SUCCEEDED + operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC) + + step_details = Mock() + step_details.result = '{"result": "success"}' + step_details.error = None + operation.step_details = step_details + + context = EventCreationContext.create( + operation=operation, + event_id=2, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + include_execution_data=True, + ) + event = Event.create_event_terminated(context) + + assert event.event_type == "StepSucceeded" + assert event.operation_id == "step-123" + assert event.step_succeeded_details.result.payload == '{"result": "success"}' + + def test_from_operation_finished_step_with_error(self): + """Test converting failed step operation with error.""" + operation = Mock() + operation.operation_id = "step-123" + operation.name = "test_step" + operation.parent_id = "parent-123" + operation.operation_type = OperationType.STEP + operation.status = OperationStatus.FAILED + operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC) + + step_details = Mock() + step_details.result = None + step_details.error = ErrorObject.from_message("Step failed") + operation.step_details = step_details + + context = EventCreationContext.create( + operation=operation, + event_id=2, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_event_terminated(context) + + assert event.event_type == "StepFailed" + assert event.step_failed_details.error.payload.message == "Step failed" + + def test_from_operation_finished_wait_succeeded(self): + """Test converting succeeded wait operation.""" + operation = Mock() + operation.operation_id = "wait-123" + operation.name = "test_wait" + operation.parent_id = "parent-123" + operation.operation_type = OperationType.WAIT + operation.status = OperationStatus.SUCCEEDED + operation.start_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC) + operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC) + + wait_details = Mock() + wait_details.scheduled_end_timestamp = datetime( + 2024, 1, 1, 12, 5, 0, tzinfo=UTC + ) + operation.wait_details = wait_details + + context = EventCreationContext.create( + operation=operation, + event_id=3, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_event_terminated(context) + + assert event.event_type == "WaitSucceeded" + assert event.wait_succeeded_details.duration == 300 + + def test_from_operation_finished_wait_cancelled(self): + """Test converting cancelled wait operation.""" + operation = Mock() + operation.operation_id = "wait-123" + operation.name = "test_wait" + operation.parent_id = "parent-123" + operation.operation_type = OperationType.WAIT + operation.status = OperationStatus.CANCELLED + operation.end_timestamp = datetime(2024, 1, 1, 12, 3, 0, tzinfo=UTC) + operation.wait_details = None + + context = EventCreationContext.create( + operation=operation, + event_id=3, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_event_terminated(context) + + assert event.event_type == "WaitCancelled" + assert event.wait_cancelled_details is not None + + def test_from_operation_finished_callback_succeeded(self): + """Test converting succeeded callback operation.""" + operation = Mock() + operation.operation_id = "callback-123" + operation.name = "test_callback" + operation.parent_id = "parent-123" + operation.operation_type = OperationType.CALLBACK + operation.status = OperationStatus.SUCCEEDED + operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC) + + callback_details = Mock() + callback_details.result = '{"callback": "result"}' + callback_details.error = None + operation.callback_details = callback_details + + context = EventCreationContext.create( + operation=operation, + event_id=4, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + include_execution_data=True, + ) + event = Event.create_event_terminated(context) + + assert event.event_type == "CallbackSucceeded" + assert ( + event.callback_succeeded_details.result.payload == '{"callback": "result"}' + ) + + def test_from_operation_finished_callback_timed_out(self): + """Test converting timed out callback operation.""" + operation = Mock() + operation.operation_id = "callback-123" + operation.name = "test_callback" + operation.parent_id = "parent-123" + operation.operation_type = OperationType.CALLBACK + operation.status = OperationStatus.TIMED_OUT + operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC) + + callback_details = Mock() + callback_details.result = None + callback_details.error = ErrorObject.from_message("Callback timed out") + operation.callback_details = callback_details + + context = EventCreationContext.create( + operation=operation, + event_id=4, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_event_terminated(context) + + assert event.event_type == "CallbackTimedOut" + assert ( + event.callback_timed_out_details.error.payload.message + == "Callback timed out" + ) + + def test_from_operation_finished_chained_invoke_succeeded(self): + """Test converting succeeded chained invoke operation.""" + operation = Mock() + operation.operation_id = "invoke-123" + operation.name = "test_invoke" + operation.parent_id = "parent-123" + operation.operation_type = OperationType.CHAINED_INVOKE + operation.status = OperationStatus.SUCCEEDED + operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC) + + chained_invoke_details = Mock() + chained_invoke_details.result = '{"invoke": "result"}' + chained_invoke_details.error = None + operation.chained_invoke_details = chained_invoke_details + + context = EventCreationContext.create( + operation=operation, + event_id=5, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + include_execution_data=True, + ) + event = Event.create_event_terminated(context) + + assert event.event_type == "ChainedInvokeSucceeded" + assert ( + event.chained_invoke_succeeded_details.result.payload + == '{"invoke": "result"}' + ) + + def test_from_operation_finished_chained_invoke_stopped(self): + """Test converting stopped chained invoke operation.""" + operation = Mock() + operation.operation_id = "invoke-123" + operation.name = "test_invoke" + operation.parent_id = "parent-123" + operation.operation_type = OperationType.CHAINED_INVOKE + operation.status = OperationStatus.STOPPED + operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC) + + chained_invoke_details = Mock() + chained_invoke_details.result = None + chained_invoke_details.error = ErrorObject.from_message("Invoke stopped") + operation.chained_invoke_details = chained_invoke_details + + context = EventCreationContext.create( + operation=operation, + event_id=5, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_event_terminated(context) + + assert event.event_type == "ChainedInvokeStopped" + assert ( + event.chained_invoke_stopped_details.error.payload.message + == "Invoke stopped" + ) + + def test_from_operation_finished_context_succeeded(self): + """Test converting succeeded context operation.""" + operation = Mock() + operation.operation_id = "context-123" + operation.name = "test_context" + operation.parent_id = "parent-123" + operation.operation_type = OperationType.CONTEXT + operation.status = OperationStatus.SUCCEEDED + operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC) + + context_details = Mock() + context_details.result = '{"context": "result"}' + context_details.error = None + operation.context_details = context_details + operation.result = None + operation.error = None + + context = EventCreationContext.create( + operation=operation, + event_id=6, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + include_execution_data=True, + ) + event = Event.create_event_terminated(context) + + assert event.event_type == "ContextSucceeded" + assert event.context_succeeded_details.result.payload == '{"context": "result"}' + + def test_from_operation_finished_context_failed(self): + """Test converting failed context operation.""" + operation = Mock() + operation.operation_id = "context-123" + operation.name = "test_context" + operation.parent_id = "parent-123" + operation.operation_type = OperationType.CONTEXT + operation.status = OperationStatus.FAILED + operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC) + + context_details = Mock() + context_details.result = None + context_details.error = ErrorObject.from_message("Context failed") + operation.context_details = context_details + operation.result = None + operation.error = None + + context = EventCreationContext.create( + operation=operation, + event_id=6, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_event_terminated(context) + + assert event.event_type == "ContextFailed" + assert event.context_failed_details.error.payload.message == "Context failed" + + def test_from_operation_finished_no_end_timestamp(self): + """Test error when operation has no end timestamp.""" + operation = Mock() + operation.end_timestamp = None + + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + with pytest.raises( + InvalidParameterValueException, + match="Operation end timestamp cannot be None", + ): + Event.create_event_terminated(context) + + def test_from_operation_finished_invalid_status(self): + """Test error with invalid operation status.""" + operation = Mock() + operation.status = OperationStatus.STARTED + operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC) + + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + with pytest.raises( + InvalidParameterValueException, + match="Operation status must be one of SUCCEEDED, FAILED, TIMED_OUT, STOPPED, or CANCELLED", + ): + Event.create_event_terminated(context) + + def test_from_operation_finished_unknown_type(self): + """Test error with unknown operation type.""" + operation = Mock() + operation.operation_type = "UNKNOWN_TYPE" + operation.status = OperationStatus.SUCCEEDED + operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC) + + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + with pytest.raises( + InvalidParameterValueException, match="Unknown operation type: UNKNOWN_TYPE" + ): + Event.create_event_terminated(context) + + def test_from_operation_finished_no_details(self): + """Test operations with no detail objects.""" + operation = Mock() + operation.operation_id = "step-123" + operation.name = "test_step" + operation.parent_id = "parent-123" + operation.operation_type = OperationType.STEP + operation.status = OperationStatus.SUCCEEDED + operation.end_timestamp = datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC) + operation.step_details = None + + context = EventCreationContext.create( + operation=operation, + event_id=2, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + event = Event.create_event_terminated(context) + + assert event.event_type == "StepSucceeded" + assert event.step_succeeded_details.result is None + + +# endregion from_operation_finished_tests + + +def test_chained_invoke_pending_details_from_dict(): + """Test ChainedInvokePendingDetails parsing in Event.from_dict.""" + data = { + "EventType": "ChainedInvokeStarted", + "EventTimestamp": datetime.now(UTC), + "ChainedInvokePendingDetails": { + "Input": {"Payload": "test-input", "Truncated": False}, + "FunctionName": "test-function", + }, + } + + event = Event.from_dict(data) + assert event.chained_invoke_pending_details is not None + assert event.chained_invoke_pending_details.input.payload == "test-input" + assert event.chained_invoke_pending_details.function_name == "test-function" + + +def test_event_creation_context_sub_type_property(): + """Test EventCreationContext.sub_type property with and without sub_type.""" + # Test with sub_type + operation = Mock() + operation.sub_type = OperationSubType.STEP + + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + + assert context.sub_type == "Step" + + # Test without sub_type + operation.sub_type = None + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + + assert context.sub_type is None + + +def test_event_creation_context_get_retry_details(): + """Test EventCreationContext.get_retry_details method.""" + operation = Mock() + operation.step_details = StepDetails(attempt=2) + + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.SUCCEED, + step_options=StepOptions(next_attempt_delay_seconds=30), + ) + + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + operation_update=operation_update, + ) + + retry_details = context.get_retry_details() + assert retry_details is not None + assert retry_details.current_attempt == 2 + assert retry_details.next_attempt_delay_seconds == 30 + + # Test with no step_details + operation.step_details = None + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + operation_update=operation_update, + ) + + retry_details = context.get_retry_details() + assert retry_details is None + + # Test with no operation_update + operation.step_details = StepDetails(attempt=2) + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + ) + + retry_details = context.get_retry_details() + assert retry_details is None + + +def test_create_chained_invoke_event_pending(): + """Test Event.create_chained_invoke_event_pending method.""" + operation = Mock() + operation.operation_id = "invoke-1" + operation.name = "test_invoke" + operation.parent_id = None + operation.status = OperationStatus.PENDING + operation.start_timestamp = datetime.now(UTC) + operation.sub_type = None + + context = EventCreationContext.create( + operation=operation, + event_id=1, + durable_execution_arn="arn:test", + start_input=StartDurableExecutionInput( + account_id="123", + function_name="test", + function_qualifier="$LATEST", + execution_name="test", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ), + include_execution_data=True, + ) + + event = Event.create_chained_invoke_event_pending(context) + + assert event.event_type == "ChainedInvokeStarted" + assert event.operation_id == "invoke-1" + assert event.name == "test_invoke" + assert event.chained_invoke_pending_details is not None + assert event.chained_invoke_pending_details.function_name == "test" diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/exceptions_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/exceptions_test.py new file mode 100644 index 0000000..1f1ea55 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/exceptions_test.py @@ -0,0 +1,953 @@ +"""Tests for AWS-compliant exceptions and their boto3 compatibility. + +This module contains comprehensive tests for all exception types used in the +AWS Durable Execution SDK Python Testing framework, including validation +of boto3 compatibility for proper AWS service integration. +""" + +import json + +import pytest + +from aws_durable_execution_sdk_python_testing import exceptions + + +# ============================================================================= +# Base Exception Tests +# ============================================================================= + + +def test_durable_functions_test_error_base_exception() -> None: + """Test DurableFunctionsTestError base exception.""" + error = exceptions.DurableFunctionsTestError("Base error message") + + assert str(error) == "Base error message" + assert isinstance(error, Exception) + + +def test_durable_functions_local_runner_error_base_exception() -> None: + """Test DurableFunctionsLocalRunnerError base exception.""" + error = exceptions.DurableFunctionsLocalRunnerError("Local runner error") + + assert str(error) == "Local runner error" + assert isinstance(error, Exception) + + +def test_serialization_error() -> None: + """Test SerializationError for serialization failures.""" + error = exceptions.SerializationError("Failed to serialize data") + + assert str(error) == "Failed to serialize data" + assert isinstance(error, exceptions.DurableFunctionsLocalRunnerError) + + +def test_unknown_route_error() -> None: + """Test UnknownRouteError for unknown HTTP routes.""" + error = exceptions.UnknownRouteError("POST", "/unknown/path") + + assert str(error) == "Unknown path pattern: POST /unknown/path" + assert error.method == "POST" + assert error.path == "/unknown/path" + assert isinstance(error, exceptions.DurableFunctionsLocalRunnerError) + + +def test_aws_api_exception_base() -> None: + """Test AwsApiException base class.""" + # AwsApiException is abstract, so we test with a concrete implementation + error = exceptions.ServiceException("Test service error") + + assert isinstance(error, exceptions.AwsApiException) + assert isinstance(error, exceptions.DurableFunctionsLocalRunnerError) + assert error.http_status_code == 500 + + +def test_exception_hierarchy() -> None: + """Test that all custom exceptions inherit from appropriate base exceptions.""" + # Test AWS API exceptions + aws_exceptions = [ + exceptions.IllegalStateException("test"), + exceptions.InvalidParameterValueException("test"), + exceptions.ResourceNotFoundException("test"), + exceptions.ServiceException("test"), + exceptions.CallbackTimeoutException("test"), + ] + + for aws_exception in aws_exceptions: + assert isinstance(aws_exception, exceptions.AwsApiException) + assert isinstance(aws_exception, exceptions.DurableFunctionsLocalRunnerError) + assert isinstance(aws_exception, Exception) + + # Test local runner exceptions + local_exceptions = [ + exceptions.SerializationError("test"), + exceptions.UnknownRouteError("GET", "/test"), + ] + + for local_exception in local_exceptions: + assert isinstance(local_exception, exceptions.DurableFunctionsLocalRunnerError) + assert isinstance(local_exception, Exception) + + # Test testing exceptions + test_error = exceptions.DurableFunctionsTestError("test") + assert isinstance(test_error, Exception) + + +def test_illegal_argument_exception() -> None: + """Test IllegalArgumentException maps to InvalidParameterValueException.""" + error = exceptions.IllegalArgumentException("Invalid argument provided") + + assert str(error) == "Invalid argument provided" + assert isinstance(error, exceptions.AwsApiException) + assert error.http_status_code == 400 + + # Test serialization maps to InvalidParameterValueException + json_dict = error.to_dict() + assert json_dict == { + "Type": "InvalidParameterValueException", + "message": "Invalid argument provided", + } + + +def test_runtime_exception() -> None: + """Test RuntimeException maps to ServiceException.""" + error = exceptions.RuntimeException("Runtime error occurred") + + assert str(error) == "Runtime error occurred" + assert isinstance(error, exceptions.AwsApiException) + assert error.http_status_code == 500 + + # Test serialization maps to ServiceException + json_dict = error.to_dict() + assert json_dict == { + "Type": "ServiceException", + "Message": "Runtime error occurred", + } + + +def test_illegal_state_exception() -> None: + """Test IllegalStateException for invalid state transitions.""" + error = exceptions.IllegalStateException( + "Cannot transition from RUNNING to PENDING" + ) + + assert str(error) == "Cannot transition from RUNNING to PENDING" + assert isinstance(error, exceptions.AwsApiException) + assert error.http_status_code == 500 + + # Test serialization maps to ServiceException + json_dict = error.to_dict() + assert json_dict == { + "Type": "ServiceException", + "Message": "Cannot transition from RUNNING to PENDING", + } + + +# ============================================================================= +# Boto3 Compatibility Tests +# ============================================================================= + + +def test_invalid_parameter_value_exception_boto3_format() -> None: + """Test InvalidParameterValueException produces correct boto3 format.""" + exception = exceptions.InvalidParameterValueException("Invalid parameter value") + + # Test serialization + json_dict = exception.to_dict() + + # Validate structure matches boto3 expectations + assert json_dict == { + "Type": "InvalidParameterValueException", + "message": "Invalid parameter value", + } + + # Test that it can be serialized to JSON and back + json_str = json.dumps(json_dict) + parsed_back = json.loads(json_str) + assert parsed_back == json_dict + + # Validate HTTP status code + assert exception.http_status_code == 400 + + +def test_resource_not_found_exception_boto3_format() -> None: + """Test ResourceNotFoundException produces correct boto3 format.""" + exception = exceptions.ResourceNotFoundException("Resource not found") + + json_dict = exception.to_dict() + + assert json_dict == { + "Type": "ResourceNotFoundException", + "Message": "Resource not found", # Capital M per Smithy definition + } + + # Test JSON serialization + json_str = json.dumps(json_dict) + parsed_back = json.loads(json_str) + assert parsed_back == json_dict + + assert exception.http_status_code == 404 + + +def test_service_exception_boto3_format() -> None: + """Test ServiceException produces correct boto3 format.""" + exception = exceptions.ServiceException("Internal service error") + + json_dict = exception.to_dict() + + assert json_dict == { + "Type": "ServiceException", + "Message": "Internal service error", # Capital M per Smithy definition + } + + # Test JSON serialization + json_str = json.dumps(json_dict) + parsed_back = json.loads(json_str) + assert parsed_back == json_dict + + assert exception.http_status_code == 500 + + +def test_callback_timeout_exception_boto3_format() -> None: + """Test CallbackTimeoutException produces correct boto3 format.""" + exception = exceptions.CallbackTimeoutException("Callback timed out") + + json_dict = exception.to_dict() + + assert json_dict == { + "Type": "CallbackTimeoutException", + "message": "Callback timed out", + } + + # Test JSON serialization + json_str = json.dumps(json_dict) + parsed_back = json.loads(json_str) + assert parsed_back == json_dict + + assert exception.http_status_code == 408 + + +def test_execution_already_started_exception_special_format() -> None: + """Test ExecutionAlreadyStartedException has no Type field (special case).""" + exception = exceptions.ExecutionAlreadyStartedException( + "Execution already started", + "arn:aws:states:us-east-1:123456789012:execution:test", + ) + + json_dict = exception.to_dict() + + # Special case: no Type field for this exception, includes DurableExecutionArn + assert json_dict == { + "message": "Execution already started", + "DurableExecutionArn": "arn:aws:states:us-east-1:123456789012:execution:test", + } + + # Ensure Type field is not present + assert "Type" not in json_dict + + # Test JSON serialization + json_str = json.dumps(json_dict) + parsed_back = json.loads(json_str) + assert parsed_back == json_dict + + assert exception.http_status_code == 409 + + +def test_all_exceptions_have_correct_type_field_values() -> None: + """Test that Type field values match what boto3 expects for exception names.""" + test_cases = [ + ( + exceptions.InvalidParameterValueException("test"), + "InvalidParameterValueException", + ), + (exceptions.ResourceNotFoundException("test"), "ResourceNotFoundException"), + (exceptions.ServiceException("test"), "ServiceException"), + (exceptions.CallbackTimeoutException("test"), "CallbackTimeoutException"), + ] + + for exception, expected_type in test_cases: + json_dict = exception.to_dict() + assert json_dict["Type"] == expected_type + + +def test_message_field_casing_compatibility() -> None: + """Test message field casing matches boto3 deserialization expectations.""" + # InvalidParameterValueException uses lowercase 'message' + exception1 = exceptions.InvalidParameterValueException("Test message") + json_dict1 = exception1.to_dict() + + assert "message" in json_dict1 + assert "Message" not in json_dict1 + assert json_dict1["message"] == "Test message" + + # ResourceNotFoundException uses capital 'Message' + exception2 = exceptions.ResourceNotFoundException("Test message") + json_dict2 = exception2.to_dict() + + assert "Message" in json_dict2 + assert "message" not in json_dict2 + assert json_dict2["Message"] == "Test message" + + +def test_json_serialization_with_special_characters() -> None: + """Test that exceptions with special characters serialize correctly.""" + special_message = 'Error with "quotes", newlines\n, and unicode: 🚀' + exception = exceptions.InvalidParameterValueException(special_message) + + json_dict = exception.to_dict() + + # Test that it can be serialized to JSON + json_str = json.dumps(json_dict) + parsed_back = json.loads(json_str) + + assert parsed_back["message"] == special_message + assert parsed_back["Type"] == "InvalidParameterValueException" + + +def test_empty_message_handling() -> None: + """Test that empty messages are handled correctly.""" + exception = exceptions.InvalidParameterValueException("") + json_dict = exception.to_dict() + + assert json_dict == {"Type": "InvalidParameterValueException", "message": ""} + + +def test_none_message_handling() -> None: + """Test that None messages are converted to empty strings.""" + # This tests the edge case where message might be None + exception = exceptions.InvalidParameterValueException(None) # type: ignore + json_dict = exception.to_dict() + + # Should convert None to string "None" for JSON compatibility + assert json_dict["message"] is None or json_dict["message"] == "None" + + +def test_http_status_codes_match_aws_standards() -> None: + """Test that HTTP status codes match AWS service standards.""" + status_code_tests = [ + (exceptions.InvalidParameterValueException("test"), 400), # Bad Request + (exceptions.ResourceNotFoundException("test"), 404), # Not Found + ( + exceptions.ExecutionAlreadyStartedException( + "test", "arn:aws:states:us-east-1:123456789012:execution:test" + ), + 409, + ), # Conflict + (exceptions.CallbackTimeoutException("test"), 408), # Request Timeout + (exceptions.ServiceException("test"), 500), # Internal Server Error + ] + + for exception, expected_status in status_code_tests: + assert exception.http_status_code == expected_status + + +def test_json_structure_has_no_extra_fields() -> None: + """Test that JSON structure only contains expected fields.""" + exception = exceptions.InvalidParameterValueException("test") + json_dict = exception.to_dict() + + # Should only have Type and message fields + expected_fields = {"Type", "message"} + actual_fields = set(json_dict.keys()) + + assert actual_fields == expected_fields + + +def test_execution_already_started_has_only_message_field() -> None: + """Test that ExecutionAlreadyStartedException only has message field.""" + exception = exceptions.ExecutionAlreadyStartedException( + "test", "arn:aws:states:us-east-1:123456789012:execution:test" + ) + json_dict = exception.to_dict() + + # Should only have message and DurableExecutionArn fields (no Type) + expected_fields = {"message", "DurableExecutionArn"} + actual_fields = set(json_dict.keys()) + + assert actual_fields == expected_fields + + +def test_large_message_serialization() -> None: + """Test that large messages can be serialized correctly.""" + # Create a large message (but not too large to avoid memory issues in tests) + large_message = "Error: " + "x" * 1000 + exception = exceptions.ServiceException(large_message) + + json_dict = exception.to_dict() + json_str = json.dumps(json_dict) + parsed_back = json.loads(json_str) + + assert parsed_back["Message"] == large_message # ServiceException uses capital M + assert len(parsed_back["Message"]) == len(large_message) + + +def test_all_aws_exceptions_are_json_serializable() -> None: + """Test that all AWS exception types can be JSON serialized.""" + test_exceptions = [ + exceptions.InvalidParameterValueException("test"), + exceptions.ResourceNotFoundException("test"), + exceptions.ServiceException("test"), + exceptions.CallbackTimeoutException("test"), + exceptions.ExecutionAlreadyStartedException( + "test", "arn:aws:states:us-east-1:123456789012:execution:test" + ), + ] + + for exception in test_exceptions: + json_dict = exception.to_dict() + + # Should be able to serialize to JSON without errors + json_str = json.dumps(json_dict) + + # Should be able to parse back from JSON + parsed_back = json.loads(json_str) + + # Should match original structure + assert parsed_back == json_dict + + +def test_too_many_requests_exception() -> None: + """Test TooManyRequestsException for rate limiting.""" + exception = exceptions.TooManyRequestsException("Rate limit exceeded") + + assert str(exception) == "Rate limit exceeded" + assert isinstance(exception, exceptions.AwsApiException) + assert exception.http_status_code == 429 + + json_dict = exception.to_dict() + assert json_dict == { + "Type": "TooManyRequestsException", + "message": "Rate limit exceeded", + } + + +def test_execution_conflict_exception() -> None: + """Test ExecutionConflictException for execution conflicts.""" + exception = exceptions.ExecutionConflictException("Execution conflict detected") + + assert str(exception) == "Execution conflict detected" + assert isinstance(exception, exceptions.AwsApiException) + assert exception.http_status_code == 409 + + json_dict = exception.to_dict() + assert json_dict == { + "Type": "ExecutionConflictException", + "message": "Execution conflict detected", + } + + +# ============================================================================= +# AWS Compliant Exception Tests (Comprehensive) +# ============================================================================= + + +def test_base_exception_hierarchy(): + """Test that all AWS exceptions inherit from the correct base classes.""" + # Test base hierarchy + assert issubclass( + exceptions.AwsApiException, exceptions.DurableFunctionsLocalRunnerError + ) + assert issubclass(exceptions.DurableFunctionsLocalRunnerError, Exception) + + # Test all AWS exceptions inherit from AwsApiException + aws_exceptions = [ + exceptions.InvalidParameterValueException, + exceptions.ResourceNotFoundException, + exceptions.ServiceException, + exceptions.ExecutionAlreadyStartedException, + exceptions.ExecutionConflictException, + exceptions.CallbackTimeoutException, + exceptions.TooManyRequestsException, + exceptions.IllegalStateException, + exceptions.RuntimeException, + exceptions.IllegalArgumentException, + ] + + for exception_class in aws_exceptions: + assert issubclass(exception_class, exceptions.AwsApiException) + + +def test_aws_api_exception_abstract_to_dict(): + """Test that AwsApiException.to_dict() raises NotImplementedError.""" + exception = exceptions.AwsApiException("test message") + + with pytest.raises(NotImplementedError): + exception.to_dict() + + +class TestSmithyMappedExceptions: + """Test Smithy-mapped exceptions (defined in Smithy models).""" + + def test_invalid_parameter_value_exception(self): + """Test InvalidParameterValueException serialization and properties.""" + message = "Invalid parameter" + exception = exceptions.InvalidParameterValueException(message) + + # Test properties + assert exception.http_status_code == 400 + assert exception.message == message + assert str(exception) == message + + # Test serialization + expected_json = {"Type": "InvalidParameterValueException", "message": message} + assert exception.to_dict() == expected_json + + def test_resource_not_found_exception(self): + """Test ResourceNotFoundException serialization and properties.""" + message = "Resource not found" + exception = exceptions.ResourceNotFoundException(message) + + # Test properties + assert exception.http_status_code == 404 + assert exception.Message == message # Capital M per Smithy + assert str(exception) == message + + # Test serialization + expected_json = {"Type": "ResourceNotFoundException", "Message": message} + assert exception.to_dict() == expected_json + + def test_service_exception(self): + """Test ServiceException serialization and properties.""" + message = "Service error" + exception = exceptions.ServiceException(message) + + # Test properties + assert exception.http_status_code == 500 + assert exception.Message == message # Capital M per Smithy + assert str(exception) == message + + # Test serialization + expected_json = {"Type": "ServiceException", "Message": message} + assert exception.to_dict() == expected_json + + def test_execution_already_started_exception(self): + """Test ExecutionAlreadyStartedException serialization and properties.""" + message = "Execution already started" + arn = "arn:aws:lambda:us-east-1:123456789012:function:test" + exception = exceptions.ExecutionAlreadyStartedException(message, arn) + + # Test properties + assert exception.http_status_code == 409 + assert exception.message == message + assert exception.DurableExecutionArn == arn + assert str(exception) == message + + # Test serialization (no Type field per Smithy definition) + expected_json = {"message": message, "DurableExecutionArn": arn} + assert exception.to_dict() == expected_json + + def test_callback_timeout_exception(self): + """Test CallbackTimeoutException serialization and properties.""" + message = "Callback timed out" + exception = exceptions.CallbackTimeoutException(message) + + # Test properties + assert exception.http_status_code == 408 + assert exception.message == message + assert str(exception) == message + + # Test serialization + expected_json = {"Type": "CallbackTimeoutException", "message": message} + assert exception.to_dict() == expected_json + + def test_too_many_requests_exception(self): + """Test TooManyRequestsException serialization and properties.""" + message = "Too many requests" + exception = exceptions.TooManyRequestsException(message) + + # Test properties + assert exception.http_status_code == 429 + assert exception.message == message + assert str(exception) == message + + # Test serialization + expected_json = {"Type": "TooManyRequestsException", "message": message} + assert exception.to_dict() == expected_json + + def test_execution_conflict_exception(self): + """Test ExecutionConflictException serialization and properties.""" + message = "Execution conflict" + exception = exceptions.ExecutionConflictException(message) + + # Test properties + assert exception.http_status_code == 409 + assert exception.message == message + assert str(exception) == message + + # Test serialization + expected_json = {"Type": "ExecutionConflictException", "message": message} + assert exception.to_dict() == expected_json + + +class TestUnmappedExceptions: + """Test unmapped exceptions (thrown by services but not in Smithy).""" + + def test_illegal_state_exception(self): + """Test IllegalStateException maps to ServiceException when serialized.""" + message = "Invalid state" + exception = exceptions.IllegalStateException(message) + + # Test properties + assert exception.http_status_code == 500 + assert exception.message == message + assert str(exception) == message + + # Test serialization (maps to ServiceException) + expected_json = {"Type": "ServiceException", "Message": message} + assert exception.to_dict() == expected_json + + def test_runtime_exception(self): + """Test RuntimeException maps to ServiceException when serialized.""" + message = "Runtime error" + exception = exceptions.RuntimeException(message) + + # Test properties + assert exception.http_status_code == 500 + assert exception.message == message + assert str(exception) == message + + # Test serialization (maps to ServiceException) + expected_json = {"Type": "ServiceException", "Message": message} + assert exception.to_dict() == expected_json + + def test_illegal_argument_exception(self): + """Test IllegalArgumentException maps to InvalidParameterValueException when serialized.""" + message = "Invalid argument" + exception = exceptions.IllegalArgumentException(message) + + # Test properties + assert exception.http_status_code == 400 + assert exception.message == message + assert str(exception) == message + + # Test serialization (maps to InvalidParameterValueException) + expected_json = {"Type": "InvalidParameterValueException", "message": message} + assert exception.to_dict() == expected_json + + +class TestHttpStatusCodes: + """Test HTTP status codes match Smithy @httpError annotations.""" + + def test_client_error_status_codes(self): + """Test client error (4xx) status codes.""" + assert exceptions.InvalidParameterValueException("test").http_status_code == 400 + assert exceptions.ResourceNotFoundException("test").http_status_code == 404 + assert exceptions.CallbackTimeoutException("test").http_status_code == 408 + assert ( + exceptions.ExecutionAlreadyStartedException("test", "arn").http_status_code + == 409 + ) + assert exceptions.ExecutionConflictException("test").http_status_code == 409 + assert exceptions.TooManyRequestsException("test").http_status_code == 429 + assert exceptions.IllegalArgumentException("test").http_status_code == 400 + + def test_server_error_status_codes(self): + """Test server error (5xx) status codes.""" + assert exceptions.ServiceException("test").http_status_code == 500 + assert exceptions.IllegalStateException("test").http_status_code == 500 + assert exceptions.RuntimeException("test").http_status_code == 500 + + +class TestFieldNameCasing: + """Test field name casing matches Smithy definitions.""" + + def test_lowercase_message_fields(self): + """Test exceptions that use lowercase 'message' field.""" + # These use lowercase 'message' per Smithy definitions + exceptions_with_lowercase_message = [ + exceptions.InvalidParameterValueException("test"), + exceptions.ExecutionAlreadyStartedException("test", "arn"), + exceptions.ExecutionConflictException("test"), + exceptions.CallbackTimeoutException("test"), + exceptions.TooManyRequestsException("test"), + exceptions.IllegalStateException("test"), + exceptions.RuntimeException("test"), + exceptions.IllegalArgumentException("test"), + ] + + for exception in exceptions_with_lowercase_message: + if hasattr(exception, "message"): + assert exception.message == "test" + + def test_uppercase_message_fields(self): + """Test exceptions that use uppercase 'Message' field.""" + # These use uppercase 'Message' per Smithy definitions + exceptions_with_uppercase_message = [ + exceptions.ResourceNotFoundException("test"), + exceptions.ServiceException("test"), + ] + + for exception in exceptions_with_uppercase_message: + assert exception.Message == "test" + + +class TestBoto3Compatibility: + """Test boto3 compatibility and JSON structure validation.""" + + def test_json_structure_matches_boto3_expectations(self): + """Test that JSON output matches what boto3 error factory expects.""" + # Test that all exceptions produce valid JSON structures + test_cases = [ + ( + exceptions.InvalidParameterValueException("test"), + {"Type": "InvalidParameterValueException", "message": "test"}, + ), + ( + exceptions.ResourceNotFoundException("test"), + {"Type": "ResourceNotFoundException", "Message": "test"}, + ), + ( + exceptions.ServiceException("test"), + {"Type": "ServiceException", "Message": "test"}, + ), + ( + exceptions.ExecutionAlreadyStartedException("test", "arn"), + {"message": "test", "DurableExecutionArn": "arn"}, + ), + ( + exceptions.ExecutionConflictException("test"), + {"Type": "ExecutionConflictException", "message": "test"}, + ), + ( + exceptions.CallbackTimeoutException("test"), + {"Type": "CallbackTimeoutException", "message": "test"}, + ), + ( + exceptions.TooManyRequestsException("test"), + {"Type": "TooManyRequestsException", "message": "test"}, + ), + ] + + for exception, expected_json in test_cases: + actual_json = exception.to_dict() + assert actual_json == expected_json + + # Verify JSON is serializable (no complex objects) + json_str = json.dumps(actual_json) + assert json.loads(json_str) == actual_json + + def test_type_field_values_match_exception_names(self): + """Test that Type field values match what boto3 expects for exception names.""" + type_field_mappings = [ + ( + exceptions.InvalidParameterValueException("test"), + "InvalidParameterValueException", + ), + (exceptions.ResourceNotFoundException("test"), "ResourceNotFoundException"), + (exceptions.ServiceException("test"), "ServiceException"), + ( + exceptions.ExecutionConflictException("test"), + "ExecutionConflictException", + ), + (exceptions.CallbackTimeoutException("test"), "CallbackTimeoutException"), + (exceptions.TooManyRequestsException("test"), "TooManyRequestsException"), + # Unmapped exceptions map to different types + (exceptions.IllegalStateException("test"), "ServiceException"), + (exceptions.RuntimeException("test"), "ServiceException"), + ( + exceptions.IllegalArgumentException("test"), + "InvalidParameterValueException", + ), + ] + + for exception, expected_type in type_field_mappings: + json_output = exception.to_dict() + if ( + "Type" in json_output + ): # ExecutionAlreadyStartedException doesn't have Type field + assert json_output["Type"] == expected_type + + def test_execution_already_started_exception_special_case(self): + """Test ExecutionAlreadyStartedException special case (no Type field).""" + exception = exceptions.ExecutionAlreadyStartedException( + "test message", "test-arn" + ) + json_output = exception.to_dict() + + # Should not have Type field + assert "Type" not in json_output + + # Should have required fields + assert "message" in json_output + assert "DurableExecutionArn" in json_output + assert json_output["message"] == "test message" + assert json_output["DurableExecutionArn"] == "test-arn" + + def test_message_field_casing_compatibility(self): + """Test message field casing compatibility with boto3 deserialization.""" + # Test lowercase 'message' field exceptions + lowercase_exceptions = [ + exceptions.InvalidParameterValueException("test"), + exceptions.ExecutionAlreadyStartedException("test", "arn"), + exceptions.ExecutionConflictException("test"), + exceptions.CallbackTimeoutException("test"), + exceptions.TooManyRequestsException("test"), + ] + + for exception in lowercase_exceptions: + json_output = exception.to_dict() + if "message" in json_output: + assert json_output["message"] == "test" + # Should not have uppercase Message + assert "Message" not in json_output + + # Test uppercase 'Message' field exceptions + uppercase_exceptions = [ + exceptions.ResourceNotFoundException("test"), + exceptions.ServiceException("test"), + ] + + for exception in uppercase_exceptions: + json_output = exception.to_dict() + assert json_output["Message"] == "test" + # Should not have lowercase message + assert "message" not in json_output + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_empty_message_handling(self): + """Test handling of empty messages.""" + exceptions_list = [ + exceptions.InvalidParameterValueException(""), + exceptions.ResourceNotFoundException(""), + exceptions.ServiceException(""), + exceptions.ExecutionConflictException(""), + exceptions.CallbackTimeoutException(""), + exceptions.TooManyRequestsException(""), + exceptions.IllegalStateException(""), + exceptions.RuntimeException(""), + exceptions.IllegalArgumentException(""), + ] + + for exception in exceptions_list: + # Should not raise exception during serialization + json_output = exception.to_dict() + assert isinstance(json_output, dict) + + def test_special_characters_in_messages(self): + """Test handling of special characters in messages.""" + special_message = 'Test with "quotes", newlines\n, and unicode: 🚀' + + exceptions_list = [ + exceptions.InvalidParameterValueException(special_message), + exceptions.ResourceNotFoundException(special_message), + exceptions.ServiceException(special_message), + ] + + for exception in exceptions_list: + json_output = exception.to_dict() + # Message should be preserved exactly + message_field = "Message" if hasattr(exception, "Message") else "message" + assert json_output[message_field] == special_message + + def test_execution_already_started_with_empty_arn(self): + """Test ExecutionAlreadyStartedException with empty ARN.""" + exception = exceptions.ExecutionAlreadyStartedException("test", "") + json_output = exception.to_dict() + + assert json_output["DurableExecutionArn"] == "" + assert json_output["message"] == "test" + + +def test_exception_test_cases_data_structure(): + """Test that we can create a comprehensive test data structure for all exceptions.""" + # This validates the test data structure mentioned in the design document + exception_test_cases = [ + # Smithy-mapped exceptions + { + "exception_class": exceptions.InvalidParameterValueException, + "args": ["Invalid parameter"], + "expected_json": { + "Type": "InvalidParameterValueException", + "message": "Invalid parameter", + }, + "expected_status": 400, + }, + { + "exception_class": exceptions.ResourceNotFoundException, + "args": ["Resource not found"], + "expected_json": { + "Type": "ResourceNotFoundException", + "Message": "Resource not found", + }, + "expected_status": 404, + }, + { + "exception_class": exceptions.ServiceException, + "args": ["Service error"], + "expected_json": {"Type": "ServiceException", "Message": "Service error"}, + "expected_status": 500, + }, + { + "exception_class": exceptions.ExecutionAlreadyStartedException, + "args": [ + "Already started", + "arn:aws:lambda:us-east-1:123456789012:function:test", + ], + "expected_json": { + "message": "Already started", + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test", + }, + "expected_status": 409, + }, + { + "exception_class": exceptions.ExecutionConflictException, + "args": ["Execution conflict"], + "expected_json": { + "Type": "ExecutionConflictException", + "message": "Execution conflict", + }, + "expected_status": 409, + }, + { + "exception_class": exceptions.CallbackTimeoutException, + "args": ["Callback timeout"], + "expected_json": { + "Type": "CallbackTimeoutException", + "message": "Callback timeout", + }, + "expected_status": 408, + }, + { + "exception_class": exceptions.TooManyRequestsException, + "args": ["Too many requests"], + "expected_json": { + "Type": "TooManyRequestsException", + "message": "Too many requests", + }, + "expected_status": 429, + }, + # Unmapped exceptions + { + "exception_class": exceptions.IllegalStateException, + "args": ["Invalid state"], + "expected_json": {"Type": "ServiceException", "Message": "Invalid state"}, + "expected_status": 500, + }, + { + "exception_class": exceptions.RuntimeException, + "args": ["Runtime error"], + "expected_json": {"Type": "ServiceException", "Message": "Runtime error"}, + "expected_status": 500, + }, + { + "exception_class": exceptions.IllegalArgumentException, + "args": ["Invalid argument"], + "expected_json": { + "Type": "InvalidParameterValueException", + "message": "Invalid argument", + }, + "expected_status": 400, + }, + ] + + # Test each case + for case in exception_test_cases: + exception = case["exception_class"](*case["args"]) + + # Test status code + assert exception.http_status_code == case["expected_status"] + + # Test serialization + assert exception.to_dict() == case["expected_json"] diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/execution_concurrent_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/execution_concurrent_test.py new file mode 100644 index 0000000..6ea2bef --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/execution_concurrent_test.py @@ -0,0 +1,83 @@ +"""Concurrent access tests for Execution class.""" + +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed + +from aws_durable_execution_sdk_python_testing.execution import Execution +from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput + + +def test_concurrent_token_generation(): + """Test concurrent checkpoint token generation.""" + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-inv-id", + input='{"test": "data"}', + ) + execution = Execution.new(input_data) + tokens = [] + tokens_lock = threading.Lock() + + def generate_token(): + token = execution.get_new_checkpoint_token() + with tokens_lock: + tokens.append(token) + + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(generate_token) for _ in range(20)] + + for future in as_completed(futures): + future.result() + + # All tokens should be unique and sequential + assert len(tokens) == 20 + assert len(set(tokens)) == 20 # All unique + assert execution.token_sequence == 20 + + +def test_concurrent_operations_modification(): + """Test concurrent operations list modifications.""" + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-inv-id", + input='{"test": "data"}', + ) + execution = Execution.new(input_data) + results = [] + results_lock = threading.Lock() + + def start_execution(): + execution.start() + with results_lock: + results.append("started") + + def get_operations(): + ops = execution.get_navigable_operations() + with results_lock: + results.append(f"ops-{len(ops)}") + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [] + # One start operation + futures.append(executor.submit(start_execution)) + # Multiple read operations + futures.extend([executor.submit(get_operations) for _ in range(4)]) + + for future in as_completed(futures): + future.result() + + assert len(results) == 5 + assert "started" in results + # Should have at least one operation after start + final_ops = execution.get_navigable_operations() + assert len(final_ops) >= 1 diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/execution_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/execution_test.py new file mode 100644 index 0000000..5bde747 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/execution_test.py @@ -0,0 +1,1014 @@ +"""Unit tests for execution module.""" + +from datetime import datetime, timezone +from unittest.mock import patch, Mock + +import pytest +from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, +) +from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + Operation, + OperationStatus, + OperationType, + StepDetails, + CallbackDetails, +) + +from aws_durable_execution_sdk_python_testing.exceptions import ( + IllegalStateException, +) +from aws_durable_execution_sdk_python_testing.execution import Execution +from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput + + +def test_execution_init(): + """Test Execution initialization.""" + arn = "test-arn" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + operations = [] + + execution = Execution(arn, start_input, operations) + + assert execution.durable_execution_arn == arn + assert execution.start_input == start_input + assert execution.operations == operations + assert execution.updates == [] + assert execution.used_tokens == set() + assert execution.token_sequence == 0 + assert execution.is_complete is False + assert execution.consecutive_failed_invocation_attempts == 0 + + +@patch("aws_durable_execution_sdk_python_testing.execution.uuid4") +def test_execution_new(mock_uuid4): + """Test Execution.new static method.""" + mock_uuid = "test-uuid-123" + mock_uuid4.return_value = mock_uuid + + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id-1234", + ) + + execution = Execution.new(start_input) + + assert ( + execution.durable_execution_arn == str(mock_uuid) + "/test-invocation-id-1234" + ) + assert execution.start_input == start_input + assert execution.operations == [] + + +@patch("aws_durable_execution_sdk_python_testing.execution.datetime") +def test_execution_start(mock_datetime): + """Test Execution.start method.""" + mock_now = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + mock_datetime.now.return_value = mock_now + + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + input='{"key": "value"}', + ) + execution = Execution("test-arn", start_input, []) + + execution.start() + + assert len(execution.operations) == 1 + operation = execution.operations[0] + assert operation.operation_id == "test-invocation-id" + assert operation.parent_id is None + assert operation.name == "test-execution" + assert operation.start_timestamp == mock_now + assert operation.operation_type == OperationType.EXECUTION + assert operation.status == OperationStatus.STARTED + assert operation.execution_details.input_payload == '{"key": "value"}' + + +def test_get_operation_execution_started(): + """Test get_operation_execution_started method.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution("test-arn", start_input, []) + execution.start() + + result = execution.get_operation_execution_started() + + assert result == execution.operations[0] + assert result.operation_type == OperationType.EXECUTION + + +def test_get_operation_execution_started_not_started(): + """Test get_operation_execution_started raises error when not started.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution("test-arn", start_input, []) + + with pytest.raises(IllegalStateException, match="execution not started"): + execution.get_operation_execution_started() + + +def test_get_new_checkpoint_token(): + """Test get_new_checkpoint_token method.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-id", + ) + execution = Execution("test-arn", start_input, []) + + token1 = execution.get_new_checkpoint_token() + token2 = execution.get_new_checkpoint_token() + + assert execution.token_sequence == 2 + assert token1 in execution.used_tokens + assert token2 in execution.used_tokens + assert token1 != token2 + + +def test_get_navigable_operations(): + """Test get_navigable_operations method.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + operations = [ + Operation( + operation_id="op1", + parent_id=None, + name="test", + start_timestamp=datetime.now(timezone.utc), + operation_type=OperationType.EXECUTION, + status=OperationStatus.STARTED, + ) + ] + execution = Execution("test-arn", start_input, operations) + + result = execution.get_navigable_operations() + + assert result == operations + + +def test_get_assertable_operations(): + """Test get_assertable_operations method.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution_op = Operation( + operation_id="exec-op", + parent_id=None, + name="execution", + start_timestamp=datetime.now(timezone.utc), + operation_type=OperationType.EXECUTION, + status=OperationStatus.STARTED, + ) + step_op = Operation( + operation_id="step-op", + parent_id=None, + name="step", + start_timestamp=datetime.now(timezone.utc), + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + ) + operations = [execution_op, step_op] + execution = Execution("test-arn", start_input, operations) + + result = execution.get_assertable_operations() + + assert len(result) == 1 + assert result[0] == step_op + + +def test_has_pending_operations_with_pending_step(): + """Test has_pending_operations returns True for pending STEP operations.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + operations = [ + Operation( + operation_id="op1", + parent_id=None, + name="test", + start_timestamp=datetime.now(timezone.utc), + operation_type=OperationType.STEP, + status=OperationStatus.PENDING, + ) + ] + execution = Execution("test-arn", start_input, operations) + + result = execution.has_pending_operations(execution) + + assert result is True + + +def test_has_pending_operations_with_started_wait(): + """Test has_pending_operations returns True for started WAIT operations.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + operations = [ + Operation( + operation_id="op1", + parent_id=None, + name="test", + start_timestamp=datetime.now(timezone.utc), + operation_type=OperationType.WAIT, + status=OperationStatus.STARTED, + ) + ] + execution = Execution("test-arn", start_input, operations) + + result = execution.has_pending_operations(execution) + + assert result is True + + +def test_has_pending_operations_with_started_callback(): + """Test has_pending_operations returns True for started CALLBACK operations.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + operations = [ + Operation( + operation_id="op1", + parent_id=None, + name="test", + start_timestamp=datetime.now(timezone.utc), + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + ) + ] + execution = Execution("test-arn", start_input, operations) + + result = execution.has_pending_operations(execution) + + assert result is True + + +def test_has_pending_operations_with_started_invoke(): + """Test has_pending_operations returns True for started INVOKE operations.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + operations = [ + Operation( + operation_id="op1", + parent_id=None, + name="test", + start_timestamp=datetime.now(timezone.utc), + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.STARTED, + ) + ] + execution = Execution("test-arn", start_input, operations) + + result = execution.has_pending_operations(execution) + + assert result is True + + +def test_has_pending_operations_no_pending(): + """Test has_pending_operations returns False when no pending operations.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + operations = [ + Operation( + operation_id="op1", + parent_id=None, + name="test", + start_timestamp=datetime.now(timezone.utc), + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + ) + ] + execution = Execution("test-arn", start_input, operations) + + result = execution.has_pending_operations(execution) + + assert result is False + + +def test_complete_success_with_string_result(): + """Test complete_success method with string result.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution("test-arn", start_input, [Mock()]) + + execution.complete_success("success result") + + assert execution.is_complete is True + assert execution.result.status == InvocationStatus.SUCCEEDED + assert execution.result.result == "success result" + + +def test_complete_success_with_none_result(): + """Test complete_success method with None result.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution("test-arn", start_input, [Mock()]) + + execution.complete_success(None) + + assert execution.is_complete is True + assert execution.result.status == InvocationStatus.SUCCEEDED + assert execution.result.result is None + + +def test_complete_fail(): + """Test complete_fail method.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution("test-arn", start_input, [Mock()]) + error = ErrorObject.from_message("Test error message") + + execution.complete_fail(error) + + assert execution.is_complete is True + assert execution.result.status == InvocationStatus.FAILED + assert execution.result.error == error + + +def test_find_operation_exists(): + """Test find_operation method when operation exists.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + operation = Operation( + operation_id="test-op-id", + parent_id=None, + name="test", + start_timestamp=datetime.now(timezone.utc), + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + ) + execution = Execution("test-arn", start_input, [operation]) + + index, found_operation = execution.find_operation("test-op-id") + + assert index == 0 + assert found_operation == operation + + +def test_find_operation_not_exists(): + """Test find_operation method when operation doesn't exist.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution("test-arn", start_input, []) + + with pytest.raises( + IllegalStateException, match="Attempting to update state of an Operation" + ): + execution.find_operation("non-existent-id") + + +@patch("aws_durable_execution_sdk_python_testing.execution.datetime") +def test_complete_wait_success(mock_datetime): + """Test complete_wait method successful completion.""" + mock_now = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + mock_datetime.now.return_value = mock_now + + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + operation = Operation( + operation_id="wait-op-id", + parent_id=None, + name="test-wait", + start_timestamp=datetime.now(timezone.utc), + operation_type=OperationType.WAIT, + status=OperationStatus.STARTED, + ) + execution = Execution("test-arn", start_input, [operation]) + + result = execution.complete_wait("wait-op-id") + + assert result.status == OperationStatus.SUCCEEDED + assert result.end_timestamp == mock_now + assert execution.token_sequence == 1 + assert execution.operations[0] == result + + +def test_complete_wait_wrong_status(): + """Test complete_wait method with wrong operation status.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + operation = Operation( + operation_id="wait-op-id", + parent_id=None, + name="test-wait", + start_timestamp=datetime.now(timezone.utc), + operation_type=OperationType.WAIT, + status=OperationStatus.SUCCEEDED, + ) + execution = Execution("test-arn", start_input, [operation]) + + with pytest.raises( + IllegalStateException, match="Attempting to transition a Wait Operation" + ): + execution.complete_wait("wait-op-id") + + +def test_complete_wait_wrong_type(): + """Test complete_wait method with wrong operation type.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + operation = Operation( + operation_id="step-op-id", + parent_id=None, + name="test-step", + start_timestamp=datetime.now(timezone.utc), + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + ) + execution = Execution("test-arn", start_input, [operation]) + + with pytest.raises(IllegalStateException, match="Expected WAIT operation"): + execution.complete_wait("step-op-id") + + +def test_complete_retry_success(): + """Test complete_retry method successful completion.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + step_details = StepDetails( + next_attempt_timestamp=str(datetime.now(timezone.utc)), + attempt=1, + ) + operation = Operation( + operation_id="step-op-id", + parent_id=None, + name="test-step", + start_timestamp=datetime.now(timezone.utc), + operation_type=OperationType.STEP, + status=OperationStatus.PENDING, + step_details=step_details, + ) + execution = Execution("test-arn", start_input, [operation]) + + result = execution.complete_retry("step-op-id") + + assert result.status == OperationStatus.READY + assert result.step_details.next_attempt_timestamp is None + assert execution.token_sequence == 1 + assert execution.operations[0] == result + + +def test_complete_retry_no_step_details(): + """Test complete_retry method with no step_details.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + operation = Operation( + operation_id="step-op-id", + parent_id=None, + name="test-step", + start_timestamp=datetime.now(timezone.utc), + operation_type=OperationType.STEP, + status=OperationStatus.PENDING, + ) + execution = Execution("test-arn", start_input, [operation]) + + result = execution.complete_retry("step-op-id") + + assert result.status == OperationStatus.READY + assert result.step_details is None + assert execution.token_sequence == 1 + + +def test_complete_retry_wrong_status(): + """Test complete_retry method with wrong operation status.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + operation = Operation( + operation_id="step-op-id", + parent_id=None, + name="test-step", + start_timestamp=datetime.now(timezone.utc), + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + ) + execution = Execution("test-arn", start_input, [operation]) + + with pytest.raises( + IllegalStateException, match="Attempting to transition a Step Operation" + ): + execution.complete_retry("step-op-id") + + +def test_complete_retry_wrong_type(): + """Test complete_retry method with wrong operation type.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + operation = Operation( + operation_id="wait-op-id", + parent_id=None, + name="test-wait", + start_timestamp=datetime.now(timezone.utc), + operation_type=OperationType.WAIT, + status=OperationStatus.PENDING, + ) + execution = Execution("test-arn", start_input, [operation]) + + with pytest.raises(IllegalStateException, match="Expected STEP operation"): + execution.complete_retry("wait-op-id") + + +def test_status_running(): + """Test status property returns RUNNING for incomplete execution.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution("test-arn", start_input, []) + + assert execution.current_status().value == "RUNNING" + + +def test_status_succeeded(): + """Test status property returns SUCCEEDED for successful execution.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution("test-arn", start_input, [Mock()]) + execution.complete_success("success result") + + assert execution.current_status().value == "SUCCEEDED" + + +def test_status_failed(): + """Test status property returns FAILED for failed execution.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution("test-arn", start_input, [Mock()]) + error = ErrorObject.from_message("Test error") + execution.complete_fail(error) + + assert execution.current_status().value == "FAILED" + + +def test_status_timed_out(): + """Test status property returns TIMED_OUT for timeout errors.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-id", + ) + execution = Execution("test-arn", start_input, [Mock()]) + error = ErrorObject( + message="Execution timed out", type="TimeoutError", data=None, stack_trace=None + ) + execution.complete_timeout(error) + + assert execution.current_status().value == "TIMED_OUT" + + +def test_status_stopped(): + """Test status property returns STOPPED for stop errors.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-id", + ) + execution = Execution("test-arn", start_input, [Mock()]) + error = ErrorObject( + message="Execution stopped", type="StopError", data=None, stack_trace=None + ) + execution.complete_stopped(error) + + assert execution.current_status().value == "STOPPED" + + +def test_status_no_result(): + """Test status property returns FAILED for completed execution with no result.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-id", + ) + execution = Execution("test-arn", start_input, []) + execution.is_complete = True + execution.result = None + with pytest.raises( + IllegalStateException, + match="close_status must be set when execution is complete", + ): + execution.current_status() + + +def test_complete_retry_with_step_details(): + """Test complete_retry with operation that has step_details.""" + step_details = StepDetails( + attempt=1, next_attempt_timestamp=datetime.now(timezone.utc) + ) + step_op = Operation( + operation_id="op-1", + operation_type=OperationType.STEP, + status=OperationStatus.PENDING, + step_details=step_details, + ) + + execution = Execution("test-arn", Mock(), [step_op]) + + result = execution.complete_retry("op-1") + assert result.status == OperationStatus.READY + assert result.step_details.next_attempt_timestamp is None + + +def test_complete_retry_without_step_details(): + """Test complete_retry with operation that has no step_details.""" + step_op = Operation( + operation_id="op-1", + operation_type=OperationType.STEP, + status=OperationStatus.PENDING, + step_details=None, # No step details + ) + + execution = Execution("test-arn", Mock(), [step_op]) + + result = execution.complete_retry("op-1") + assert result.status == OperationStatus.READY + assert result.step_details is None + + +# endregion retry + + +def test_from_dict_with_none_result(): + """Test from_dict with None result.""" + data = { + "DurableExecutionArn": "test-arn", + "StartInput": {"function_name": "test"}, + "Operations": [], + "Updates": [], + "UsedTokens": [], + "TokenSequence": 0, + "IsComplete": False, + "Result": None, # None result + "ConsecutiveFailedInvocationAttempts": 0, + "CloseStatus": None, + } + + with patch( + "aws_durable_execution_sdk_python_testing.model.StartDurableExecutionInput.from_dict" + ) as mock_from_dict: + mock_from_dict.return_value = Mock() + execution = Execution.from_json_dict(data) + assert execution.result is None + + +# region callback +def test_find_callback_operation_not_found(): + """Test find_callback_operation raises exception when callback not found.""" + execution = Execution("test-arn", Mock(), []) + + with pytest.raises( + IllegalStateException, + match="Callback operation with callback_id \\[nonexistent\\] not found", + ): + execution.find_callback_operation("nonexistent") + + +def test_complete_callback_success_not_started(): + """Test complete_callback_success raises exception when callback not in STARTED state.""" + # Create callback operation in wrong state + callback_op = Operation( + operation_id="op-1", + operation_type=OperationType.CALLBACK, + status=OperationStatus.SUCCEEDED, # Wrong state + callback_details=CallbackDetails(callback_id="test-id"), + ) + + execution = Execution("test-arn", Mock(), [callback_op]) + + with pytest.raises( + IllegalStateException, + match="Callback operation \\[test-id\\] is not in STARTED state", + ): + execution.complete_callback_success("test-id") + + +def test_complete_callback_failure_not_started(): + """Test complete_callback_failure raises exception when callback not in STARTED state.""" + # Create callback operation in wrong state + callback_op = Operation( + operation_id="op-1", + operation_type=OperationType.CALLBACK, + status=OperationStatus.FAILED, # Wrong state + callback_details=CallbackDetails(callback_id="test-id"), + ) + + execution = Execution("test-arn", Mock(), [callback_op]) + error = ErrorObject.from_message("test error") + + with pytest.raises( + IllegalStateException, + match="Callback operation \\[test-id\\] is not in STARTED state", + ): + execution.complete_callback_failure("test-id", error) + + +def test_complete_callback_success_no_callback_details(): + """Test complete_callback_success with operation that has no callback_details.""" + callback_details = CallbackDetails(callback_id="test-id") + callback_op = Operation( + operation_id="op-1", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=callback_details, + ) + + execution = Execution("test-arn", Mock(), [callback_op]) + + # Test with None result + result = execution.complete_callback_success("test-id", None) + assert result.status == OperationStatus.SUCCEEDED + + +def test_complete_callback_failure_no_callback_details(): + """Test complete_callback_failure with operation that has no callback_details.""" + callback_details = CallbackDetails(callback_id="test-id") + callback_op = Operation( + operation_id="op-1", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=callback_details, + ) + + execution = Execution("test-arn", Mock(), [callback_op]) + error = ErrorObject.from_message("test error") + + # Test with actual callback details + result = execution.complete_callback_failure("test-id", error) + assert result.status == OperationStatus.FAILED + + +# region callback - details + + +def test_complete_callback_success_with_none_callback_details(): + """Test complete_callback_success when operation has None callback_details.""" + callback_op = Operation( + operation_id="op-1", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=None, # None callback details + ) + + execution = Execution("test-arn", Mock(), [callback_op]) + + # Mock find_callback_operation to return this operation + execution.find_callback_operation = Mock(return_value=(0, callback_op)) + + result = execution.complete_callback_success("test-id", b"result") + assert result.status == OperationStatus.SUCCEEDED + assert result.callback_details is None + + +def test_complete_callback_failure_with_none_callback_details(): + """Test complete_callback_failure when operation has None callback_details.""" + callback_op = Operation( + operation_id="op-1", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=None, # None callback details + ) + + execution = Execution("test-arn", Mock(), [callback_op]) + error = ErrorObject.from_message("test error") + + # Mock find_callback_operation to return this operation + execution.find_callback_operation = Mock(return_value=(0, callback_op)) + + result = execution.complete_callback_failure("test-id", error) + assert result.status == OperationStatus.FAILED + assert result.callback_details is None + + +def test_complete_callback_success_with_bytes_result(): + """Test complete_callback_success with bytes result that gets decoded.""" + callback_details = CallbackDetails(callback_id="test-id") + callback_op = Operation( + operation_id="op-1", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=callback_details, + ) + + execution = Execution("test-arn", Mock(), [callback_op]) + + result = execution.complete_callback_success("test-id", b"test result") + assert result.status == OperationStatus.SUCCEEDED + assert result.callback_details.result == "test result" + + +def test_complete_callback_success_with_none_result(): + """Test complete_callback_success with None result.""" + callback_details = CallbackDetails(callback_id="test-id") + callback_op = Operation( + operation_id="op-1", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=callback_details, + ) + + execution = Execution("test-arn", Mock(), [callback_op]) + + result = execution.complete_callback_success("test-id", None) + assert result.status == OperationStatus.SUCCEEDED + assert result.callback_details.result is None + + +# endregion callback -details + +# endregion callback diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/execution_wait_retry_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/execution_wait_retry_test.py new file mode 100644 index 0000000..b0c9db3 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/execution_wait_retry_test.py @@ -0,0 +1,80 @@ +"""Additional concurrent tests for wait and retry operations.""" + +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import UTC, datetime + +from aws_durable_execution_sdk_python.lambda_service import ( + Operation, + OperationStatus, + OperationType, + StepDetails, +) + +from aws_durable_execution_sdk_python_testing.execution import Execution +from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput + + +def test_concurrent_wait_and_retry_completion(): + """Test concurrent complete_wait and complete_retry operations.""" + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-inv-id", + input='{"test": "data"}', + ) + execution = Execution.new(input_data) + + # Add WAIT and STEP operations + wait_op = Operation( + operation_id="wait-1", + parent_id=None, + name="test-wait", + start_timestamp=datetime.now(UTC), + operation_type=OperationType.WAIT, + status=OperationStatus.STARTED, + ) + + step_op = Operation( + operation_id="step-1", + parent_id=None, + name="test-step", + start_timestamp=datetime.now(UTC), + operation_type=OperationType.STEP, + status=OperationStatus.PENDING, + step_details=StepDetails(), + ) + + execution.operations.extend([wait_op, step_op]) + + results = [] + results_lock = threading.Lock() + + def complete_wait(): + result = execution.complete_wait("wait-1") + with results_lock: + results.append(f"wait-completed-{result.status.value}") + + def complete_retry(): + result = execution.complete_retry("step-1") + with results_lock: + results.append(f"retry-completed-{result.status.value}") + + with ThreadPoolExecutor(max_workers=2) as executor: + futures = [] + futures.append(executor.submit(complete_wait)) + futures.append(executor.submit(complete_retry)) + + for future in as_completed(futures): + future.result() + + assert len(results) == 2 + assert "wait-completed-SUCCEEDED" in results + assert "retry-completed-READY" in results + + # Verify token sequence was incremented twice + assert execution.token_sequence == 2 diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/executor_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/executor_test.py new file mode 100644 index 0000000..e228a0d --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/executor_test.py @@ -0,0 +1,2881 @@ +"""Unit tests for executor module.""" + +import asyncio +from datetime import UTC, datetime +from unittest.mock import Mock, patch + +import pytest + +from aws_durable_execution_sdk_python.execution import ( + DurableExecutionInvocationOutput, + InvocationStatus, +) +from aws_durable_execution_sdk_python.lambda_service import ( + CallbackOptions, + OperationUpdate, + OperationAction, + OperationType, + Operation, + OperationStatus, + CallbackDetails, +) +from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + ExecutionDetails, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + ExecutionAlreadyStartedException, + IllegalStateException, + InvalidParameterValueException, + ResourceNotFoundException, +) +from aws_durable_execution_sdk_python_testing.execution import ( + ExecutionStatus, + Execution, +) +from aws_durable_execution_sdk_python_testing.executor import Executor +from aws_durable_execution_sdk_python_testing.invoker import InvokeResponse +from aws_durable_execution_sdk_python_testing.model import ( + ListDurableExecutionsResponse, + SendDurableExecutionCallbackFailureResponse, + SendDurableExecutionCallbackHeartbeatResponse, + SendDurableExecutionCallbackSuccessResponse, + StartDurableExecutionInput, + StopDurableExecutionResponse, +) +from aws_durable_execution_sdk_python_testing.observer import ( + ExecutionNotifier, + ExecutionObserver, +) +from aws_durable_execution_sdk_python_testing.token import ( + CallbackToken, +) + + +class MockExecutionObserver(ExecutionObserver): + """Mock observer to capture execution events through public callbacks.""" + + def __init__(self): + self.completed_executions = {} + self.failed_executions = {} + self.wait_timers = {} + self.retry_schedules = {} + self.callback_creations = {} + + def on_completed(self, execution_arn: str, result: str | None = None) -> None: + """Capture completion events.""" + self.completed_executions[execution_arn] = result + + def on_failed(self, execution_arn: str, error: ErrorObject) -> None: + """Capture failure events.""" + self.failed_executions[execution_arn] = error + + def on_wait_timer_scheduled( + self, execution_arn: str, operation_id: str, delay: float + ) -> None: + """Capture wait timer scheduling events.""" + self.wait_timers[execution_arn] = {"operation_id": operation_id, "delay": delay} + + def on_step_retry_scheduled( + self, execution_arn: str, operation_id: str, delay: float + ) -> None: + """Capture retry scheduling events.""" + self.retry_schedules[execution_arn] = { + "operation_id": operation_id, + "delay": delay, + } + + def on_callback_created( + self, + execution_arn: str, + operation_id: str, + callback_options: CallbackOptions | None, + callback_token: CallbackToken, + ) -> None: + """Capture callback creation events.""" + self.callback_creations[execution_arn] = { + "operation_id": operation_id, + "callback_id": callback_token.to_str(), + } + + def on_callback_completed( + self, execution_arn: str, operation_id: str, callback_id: str + ) -> None: + """Capture callback completion events.""" + pass # Not needed for current tests + + def on_timed_out(self, execution_arn: str, error: ErrorObject) -> None: + """Capture timeout events.""" + pass # Not needed for current tests + + def on_stopped(self, execution_arn: str, error: ErrorObject) -> None: + """Capture stop events.""" + pass # Not needed for current tests + + +@pytest.fixture +def test_observer(): + return MockExecutionObserver() + + +@pytest.fixture +def mock_store(): + return Mock() + + +@pytest.fixture +def mock_scheduler(): + return Mock() + + +@pytest.fixture +def mock_invoker(): + return Mock() + + +@pytest.fixture +def mock_checkpoint_processor(): + return Mock() + + +@pytest.fixture +def executor(mock_store, mock_scheduler, mock_invoker, mock_checkpoint_processor): + return Executor(mock_store, mock_scheduler, mock_invoker, mock_checkpoint_processor) + + +@pytest.fixture +def start_input(): + return StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ) + + +@pytest.fixture +def mock_execution(): + execution = Mock(spec=Execution) + execution.durable_execution_arn = "arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:test-execution" + execution.is_complete = False + execution.consecutive_failed_invocation_attempts = 0 + execution.start_input = Mock() + execution.start_input.function_name = "test-function" + return execution + + +def test_init(mock_store, mock_scheduler, mock_invoker, mock_checkpoint_processor): + # Test that Executor can be constructed with dependencies + # Dependency injection is implementation detail - test behavior instead + executor = Executor( + mock_store, mock_scheduler, mock_invoker, mock_checkpoint_processor + ) + + # Verify executor is properly initialized by testing it can perform basic operations + assert executor is not None + + # Test that the executor uses the injected dependencies by verifying behavior + # This will be covered by other tests that exercise the executor's functionality + + +@patch("aws_durable_execution_sdk_python_testing.executor.Execution") +def test_start_execution( + mock_execution_class, executor, start_input, mock_store, mock_scheduler +): + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution_class.new.return_value = mock_execution + mock_event = Mock() + mock_scheduler.create_event.return_value = mock_event + + with patch.object(executor, "_invoke_execution") as mock_invoke: + result = executor.start_execution(start_input) + + # Test observable behavior through public API + # The executor should generate an invocation_id if not provided + call_args = mock_execution_class.new.call_args + actual_input = call_args.kwargs["input"] + + # Verify all fields match except invocation_id should be generated + assert actual_input.account_id == start_input.account_id + assert actual_input.function_name == start_input.function_name + assert actual_input.function_qualifier == start_input.function_qualifier + assert actual_input.execution_name == start_input.execution_name + assert ( + actual_input.execution_timeout_seconds == start_input.execution_timeout_seconds + ) + assert ( + actual_input.execution_retention_period_days + == start_input.execution_retention_period_days + ) + assert actual_input.invocation_id is not None # Should be generated + assert actual_input.trace_fields == start_input.trace_fields + assert actual_input.tenant_id == start_input.tenant_id + assert actual_input.input == start_input.input + mock_execution.start.assert_called_once() + mock_store.save.assert_called_once_with(mock_execution) + mock_scheduler.create_event.assert_called_once() + + # Verify execution timeout was scheduled + assert mock_scheduler.call_later.called + timeout_call = mock_scheduler.call_later.call_args + assert timeout_call.kwargs["delay"] == start_input.execution_timeout_seconds + assert timeout_call.kwargs["completion_event"] == mock_event + + mock_invoke.assert_called_once_with("test-arn") + assert result.execution_arn == "test-arn" + + # Test that completion event was created by verifying wait_until_complete works + # This tests the same functionality without accessing private members + mock_event.wait.return_value = True + wait_result = executor.wait_until_complete("test-arn", timeout=1) + assert wait_result is True + mock_event.wait.assert_called_once_with(1) + + +@patch("aws_durable_execution_sdk_python_testing.executor.Execution") +def test_start_execution_with_provided_invocation_id( + mock_execution_class, executor, mock_store, mock_scheduler +): + # Create input with invocation_id already provided + provided_invocation_id = "user-provided-id-123" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id=provided_invocation_id, + ) + + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution_class.new.return_value = mock_execution + mock_event = Mock() + mock_scheduler.create_event.return_value = mock_event + + with patch.object(executor, "_invoke_execution") as mock_invoke: + result = executor.start_execution(start_input) + + # Should use the provided invocation_id unchanged + mock_execution_class.new.assert_called_once_with(input=start_input) + mock_execution.start.assert_called_once() + mock_store.save.assert_called_once_with(mock_execution) + mock_scheduler.create_event.assert_called_once() + mock_invoke.assert_called_once_with("test-arn") + assert result.execution_arn == "test-arn" + + mock_execution = Mock() + mock_store.load.return_value = mock_execution + + result = executor.get_execution("test-arn") + + mock_store.load.assert_called_once_with("test-arn") + assert result == mock_execution + + +def test_should_complete_workflow_with_error_when_invocation_fails( + executor, mock_store, mock_scheduler, mock_invoker, start_input +): + """Test that failed invocation responses trigger workflow completion with error.""" + # Arrange + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution.is_complete = False + mock_execution.start_input = start_input + mock_execution.consecutive_failed_invocation_attempts = 0 + + # Mock invoker to return failed response + mock_invocation_input = Mock() + mock_invoker.create_invocation_input.return_value = mock_invocation_input + failed_response = DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, error=ErrorObject.from_message("Test error") + ) + mock_invoker.invoke.return_value = InvokeResponse( + invocation_output=failed_response, request_id="test-request-id" + ) + + # Mock execution creation and store behavior + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution_class.new.return_value = mock_execution + mock_store.load.return_value = mock_execution + + # Mock the workflow completion methods + with patch.object(executor, "fail_execution") as mock_fail: + # Act - trigger invocation through public start_execution method + executor.start_execution(start_input) + + # Get the handler that was passed to the scheduler and execute it manually + assert mock_scheduler.call_later.call_count >= 1 + handler = mock_scheduler.call_later.call_args_list[-1][0][0] + + # Execute the handler to trigger the invocation logic + import asyncio + + asyncio.run(handler()) + + # Assert - verify workflow was completed with error + mock_fail.assert_called_once_with("test-arn", failed_response.error) + + +def test_should_complete_workflow_with_result_when_invocation_succeeds( + executor, mock_store, mock_scheduler, mock_invoker, start_input +): + """Test that successful invocation responses trigger workflow completion with result.""" + # Arrange + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution.is_complete = False + mock_execution.start_input = start_input + mock_execution.consecutive_failed_invocation_attempts = 0 + + # Mock invoker to return successful response + mock_invocation_input = Mock() + mock_invoker.create_invocation_input.return_value = mock_invocation_input + success_response = DurableExecutionInvocationOutput( + status=InvocationStatus.SUCCEEDED, result="success result" + ) + mock_invoker.invoke.return_value = InvokeResponse( + invocation_output=success_response, request_id="test-request-id" + ) + + # Mock execution creation and store behavior + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution_class.new.return_value = mock_execution + mock_store.load.return_value = mock_execution + + # Mock the workflow completion methods + with patch.object(executor, "complete_execution") as mock_complete: + # Act - trigger invocation through public start_execution method + executor.start_execution(start_input) + + # Get the handler that was passed to the scheduler and execute it manually + assert mock_scheduler.call_later.call_count >= 1 + handler = mock_scheduler.call_later.call_args_list[-1][0][0] + + # Execute the handler to trigger the invocation logic + import asyncio + + asyncio.run(handler()) + + # Assert - verify workflow was completed with result + mock_complete.assert_called_once_with("test-arn", "success result") + + +def test_should_handle_pending_status_when_operations_exist( + executor, mock_store, mock_scheduler, mock_invoker, start_input +): + """Test that pending invocation responses are handled when operations exist.""" + # Arrange + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution.is_complete = False + mock_execution.start_input = start_input + mock_execution.consecutive_failed_invocation_attempts = 0 + mock_execution.has_pending_operations.return_value = True + + # Mock invoker to return pending response + mock_invocation_input = Mock() + mock_invoker.create_invocation_input.return_value = mock_invocation_input + pending_response = DurableExecutionInvocationOutput(status=InvocationStatus.PENDING) + mock_invoker.invoke.return_value = InvokeResponse( + invocation_output=pending_response, request_id="test-request-id" + ) + + # Mock execution creation and store behavior + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution_class.new.return_value = mock_execution + mock_store.load.return_value = mock_execution + + # Act - trigger invocation through public start_execution method + executor.start_execution(start_input) + + # Get the handler that was passed to the scheduler and execute it manually + assert mock_scheduler.call_later.call_count >= 1 + handler = mock_scheduler.call_later.call_args_list[-1][0][0] + + # Execute the handler to trigger the invocation logic + import asyncio + + asyncio.run(handler()) + + # Assert - verify pending operations were checked + mock_execution.has_pending_operations.assert_called_once_with(mock_execution) + + +def test_should_ignore_response_when_execution_already_complete( + executor, mock_store, mock_scheduler, mock_invoker, start_input +): + """Test that responses are ignored when execution is already complete.""" + # Arrange + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution.is_complete = True # Already complete + mock_execution.start_input = start_input + + # Mock invoker - this shouldn't be called since execution is complete + mock_invoker.create_invocation_input.return_value = Mock() + mock_invoker.invoke.return_value = ( + DurableExecutionInvocationOutput(status=InvocationStatus.SUCCEEDED), + "test-request-id", + ) + + # Mock execution creation and store behavior + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution_class.new.return_value = mock_execution + mock_store.load.return_value = mock_execution + + # Act - trigger invocation through public start_execution method + executor.start_execution(start_input) + + # Get the handler that was passed to the scheduler and execute it manually + assert mock_scheduler.call_later.call_count >= 1 + handler = mock_scheduler.call_later.call_args_list[-1][0][0] + + # Execute the handler to trigger the invocation logic + import asyncio + + asyncio.run(handler()) + + # Assert - verify invoker was not called since execution was already complete + mock_invoker.create_invocation_input.assert_not_called() + mock_invoker.invoke.assert_not_called() + + +def test_should_retry_when_response_has_no_status( + executor, mock_store, mock_scheduler, mock_invoker, start_input +): + """Test that invocation responses without status trigger retry logic.""" + # Arrange + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution.is_complete = False + mock_execution.start_input = start_input + mock_execution.consecutive_failed_invocation_attempts = 0 + + # Mock invoker to return response without status + mock_invocation_input = Mock() + mock_invoker.create_invocation_input.return_value = mock_invocation_input + no_status_response = DurableExecutionInvocationOutput(status=None) + mock_invoker.invoke.return_value = InvokeResponse( + invocation_output=no_status_response, request_id="test-request-id" + ) + + # Mock execution creation and store behavior + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution_class.new.return_value = mock_execution + mock_store.load.return_value = mock_execution + + # Act - trigger invocation through public start_execution method + executor.start_execution(start_input) + + # Get the handler that was passed to the scheduler and execute it manually + assert mock_scheduler.call_later.call_count >= 1 + handler = mock_scheduler.call_later.call_args_list[-1][0][0] + + # Execute the handler to trigger the invocation logic + asyncio.run(handler()) + + # Assert - verify retry was triggered due to validation error + assert mock_execution.consecutive_failed_invocation_attempts == 1 + mock_store.save.assert_called_with(mock_execution) + # Verify retry was scheduled (call_later should be called 3 times: timeout + initial + retry) + assert mock_scheduler.call_later.call_count == 3 + + +def test_should_retry_when_failed_response_has_result( + executor, mock_store, mock_scheduler, mock_invoker, start_input +): + """Test that failed responses with result trigger retry logic.""" + # Arrange + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution.is_complete = False + mock_execution.start_input = start_input + mock_execution.consecutive_failed_invocation_attempts = 0 + + # Mock invoker to return invalid failed response (with result) + mock_invocation_input = Mock() + mock_invoker.create_invocation_input.return_value = mock_invocation_input + invalid_response = DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, result="should not have result" + ) + mock_invoker.invoke.return_value = InvokeResponse( + invocation_output=invalid_response, request_id="test-request-id" + ) + + # Mock execution creation and store behavior + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution_class.new.return_value = mock_execution + mock_store.load.return_value = mock_execution + + # Act - trigger invocation through public start_execution method + executor.start_execution(start_input) + + # Get the handler that was passed to the scheduler and execute it manually + assert mock_scheduler.call_later.call_count >= 1 + handler = mock_scheduler.call_later.call_args_list[-1][0][0] + + # Execute the handler to trigger the invocation logic + asyncio.run(handler()) + + # Assert - verify retry was triggered due to validation error + assert mock_execution.consecutive_failed_invocation_attempts == 1 + mock_store.save.assert_called_with(mock_execution) + # Verify retry was scheduled (call_later should be called 3 times: timeout + initial + retry) + assert mock_scheduler.call_later.call_count == 3 + + +def test_should_retry_when_success_response_has_error( + executor, mock_store, mock_scheduler, mock_invoker, start_input +): + """Test that successful responses with error trigger retry logic.""" + # Arrange + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution.is_complete = False + mock_execution.start_input = start_input + mock_execution.consecutive_failed_invocation_attempts = 0 + + # Mock invoker to return invalid success response (with error) + mock_invocation_input = Mock() + mock_invoker.create_invocation_input.return_value = mock_invocation_input + invalid_response = DurableExecutionInvocationOutput( + status=InvocationStatus.SUCCEEDED, + error=ErrorObject.from_message("should not have error"), + ) + mock_invoker.invoke.return_value = InvokeResponse( + invocation_output=invalid_response, request_id="test-request-id" + ) + + # Mock execution creation and store behavior + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution_class.new.return_value = mock_execution + mock_store.load.return_value = mock_execution + + # Act - trigger invocation through public start_execution method + executor.start_execution(start_input) + + # Get the handler that was passed to the scheduler and execute it manually + assert mock_scheduler.call_later.call_count >= 1 + handler = mock_scheduler.call_later.call_args_list[-1][0][0] + + # Execute the handler to trigger the invocation logic + asyncio.run(handler()) + + # Assert - verify retry was triggered due to validation error + assert mock_execution.consecutive_failed_invocation_attempts == 1 + mock_store.save.assert_called_with(mock_execution) + # Verify retry was scheduled (call_later should be called 3 times: timeout + initial + retry) + assert mock_scheduler.call_later.call_count == 3 + + +def test_should_retry_when_pending_response_has_no_operations( + executor, mock_store, mock_scheduler, mock_invoker, start_input +): + """Test that pending responses without operations trigger retry logic.""" + # Arrange + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution.is_complete = False + mock_execution.start_input = start_input + mock_execution.consecutive_failed_invocation_attempts = 0 + mock_execution.has_pending_operations.return_value = False # No pending operations + + # Mock invoker to return pending response + mock_invocation_input = Mock() + mock_invoker.create_invocation_input.return_value = mock_invocation_input + pending_response = DurableExecutionInvocationOutput(status=InvocationStatus.PENDING) + mock_invoker.invoke.return_value = InvokeResponse( + invocation_output=pending_response, request_id="test-request-id" + ) + + # Mock execution creation and store behavior + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution_class.new.return_value = mock_execution + mock_store.load.return_value = mock_execution + + # Act - trigger invocation through public start_execution method + executor.start_execution(start_input) + + # Get the handler that was passed to the scheduler and execute it manually + assert mock_scheduler.call_later.call_count >= 1 + handler = mock_scheduler.call_later.call_args_list[-1][0][0] + + # Execute the handler to trigger the invocation logic + asyncio.run(handler()) + + # Assert - verify retry was triggered due to validation error + assert mock_execution.consecutive_failed_invocation_attempts == 1 + mock_store.save.assert_called_with(mock_execution) + # Verify retry was scheduled (call_later should be called 3 times: timeout + initial + retry) + assert mock_scheduler.call_later.call_count == 3 + + +def test_invoke_handler_success( + executor, mock_store, mock_scheduler, mock_invoker, start_input +): + """Test successful invocation through public API.""" + # Arrange + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution.is_complete = False + mock_execution.start_input = start_input + + mock_invocation_input = Mock() + mock_invoker.create_invocation_input.return_value = mock_invocation_input + mock_response = DurableExecutionInvocationOutput( + status=InvocationStatus.SUCCEEDED, result="test" + ) + mock_invoker.invoke.return_value = InvokeResponse( + invocation_output=mock_response, request_id="test-request-id" + ) + + # Mock execution creation and store behavior + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution_class.new.return_value = mock_execution + mock_store.load.return_value = mock_execution + + # Act - trigger invocation through public start_execution method + executor.start_execution(start_input) + + # Get the handler that was passed to the scheduler and execute it manually + assert mock_scheduler.call_later.call_count >= 1 + handler = mock_scheduler.call_later.call_args_list[-1][0][0] + + # Execute the handler to trigger the invocation logic + asyncio.run(handler()) + + # Verify the invocation process was executed + mock_invoker.create_invocation_input.assert_called_once_with( + execution=mock_execution + ) + mock_invoker.invoke.assert_called_once_with( + "test-function", mock_invocation_input, None + ) + + +def test_invoke_handler_execution_already_complete( + executor, mock_store, mock_scheduler, mock_invoker, start_input +): + """Test that completed executions are handled properly through public API.""" + # Arrange + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution.is_complete = True + mock_execution.start_input = start_input + + # Mock execution creation and store behavior + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution_class.new.return_value = mock_execution + mock_store.load.return_value = mock_execution + + # Act - trigger invocation through public start_execution method + executor.start_execution(start_input) + + # Get the handler that was passed to the scheduler and execute it manually + assert mock_scheduler.call_later.call_count >= 1 + handler = mock_scheduler.call_later.call_args_list[-1][0][0] + + # Execute the handler to trigger the invocation logic + asyncio.run(handler()) + + # Verify store was accessed to check execution status + mock_store.load.assert_called_with("test-arn") + + +def test_invoke_handler_execution_completed_during_invocation( + executor, mock_store, mock_scheduler, mock_invoker, start_input +): + """Test execution completing during invocation through public API.""" + # Arrange + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution.is_complete = False + mock_execution.start_input = start_input + + mock_invocation_input = Mock() + mock_invoker.create_invocation_input.return_value = mock_invocation_input + mock_response = Mock() + mock_invoker.invoke.return_value = InvokeResponse( + invocation_output=mock_response, request_id="test-request-id" + ) + + # Create a completed execution mock + completed_execution = Mock() + completed_execution.durable_execution_arn = "test-arn" + completed_execution.is_complete = True + completed_execution.start_input = start_input + + # First call returns incomplete execution, second call returns completed execution + mock_store.load.side_effect = [mock_execution, completed_execution] + + # Mock execution creation + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution_class.new.return_value = mock_execution + + # Act - trigger invocation through public start_execution method + executor.start_execution(start_input) + + # Get the handler that was passed to the scheduler and execute it manually + assert mock_scheduler.call_later.call_count >= 1 + handler = mock_scheduler.call_later.call_args_list[-1][0][0] + + # Execute the handler to trigger the invocation logic + asyncio.run(handler()) + + # Verify the execution was checked for completion + assert mock_store.load.call_count >= 2 + + +def test_invoke_handler_resource_not_found( + executor, mock_store, mock_scheduler, mock_invoker, start_input +): + """Test resource not found handling causes workflow failure through public API.""" + # Arrange + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution.is_complete = False + mock_execution.start_input = start_input + + mock_invoker.create_invocation_input.side_effect = ResourceNotFoundException( + "Function not found" + ) + + # Mock execution creation and store behavior + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution_class.new.return_value = mock_execution + mock_store.load.return_value = mock_execution + + # Mock the public fail_execution method to verify it gets called + with patch.object(executor, "fail_execution") as mock_fail: + # Act - trigger invocation through public start_execution method + executor.start_execution(start_input) + + # Get the handler that was passed to the scheduler and execute it manually + assert mock_scheduler.call_later.call_count >= 1 + handler = mock_scheduler.call_later.call_args_list[-1][0][0] + + # Execute the handler to trigger the invocation logic + asyncio.run(handler()) + + # Assert - verify workflow failure was triggered through public API + mock_fail.assert_called_once() + # Verify the error contains the expected message + call_args = mock_fail.call_args + assert call_args[0][0] == "test-arn" # execution_arn is first positional arg + assert "Function not found" in str( + call_args[0][1] + ) # error is second positional arg + + +def test_invoke_handler_general_exception( + executor, mock_store, mock_scheduler, mock_invoker, start_input +): + """Test general exception handling triggers retry through public API.""" + # Arrange + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution.is_complete = False + mock_execution.start_input = start_input + mock_execution.consecutive_failed_invocation_attempts = 0 + + # Configure invoker to fail + mock_invoker.create_invocation_input.side_effect = Exception("General error") + + # Mock execution creation and store behavior + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution_class.new.return_value = mock_execution + mock_store.load.return_value = mock_execution + + # Act - trigger invocation through public start_execution method + executor.start_execution(start_input) + + # Get the handler that was passed to the scheduler and execute it manually + assert mock_scheduler.call_later.call_count >= 1 + handler = mock_scheduler.call_later.call_args_list[-1][0][0] + + # Execute the handler to trigger the invocation logic + asyncio.run(handler()) + + # Assert - verify retry was scheduled through observable behavior + assert mock_execution.consecutive_failed_invocation_attempts == 1 + mock_store.save.assert_called_with(mock_execution) + # Verify retry was scheduled (call_later should be called 3 times: timeout + initial + retry) + assert mock_scheduler.call_later.call_count == 3 + + +def test_invoke_execution_through_start_execution( + executor, mock_scheduler, start_input +): + """Test execution invocation behavior through public start_execution method.""" + mock_event = Mock() + mock_scheduler.create_event.return_value = mock_event + + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution_class.new.return_value = mock_execution + + # Start execution which internally calls _invoke_execution + executor.start_execution(start_input) + + # Verify scheduler was called with the completion event + mock_scheduler.call_later.assert_called() + args = mock_scheduler.call_later.call_args + assert args[1]["delay"] == 0 # Initial invocation has no delay + assert args[1]["completion_event"] == mock_event + + +def test_should_complete_workflow_successfully_through_public_api( + executor, mock_store, mock_execution +): + """Test workflow completion through public complete_execution method.""" + # Arrange + mock_execution.result = "test result" # Mock result after completion + mock_store.load.return_value = mock_execution + + with patch.object(executor, "_complete_events") as mock_complete_events: + # Act - Use public API to complete workflow + executor.complete_execution("test-arn", "result") + + # Assert - Verify final execution status and stored results + mock_store.load.assert_called_once_with(execution_arn="test-arn") + mock_execution.complete_success.assert_called_once_with(result="result") + mock_store.update.assert_called_once_with(mock_execution) + mock_complete_events.assert_called_once_with(execution_arn="test-arn") + + +def test_should_complete_workflow_with_failure_through_public_api( + executor, mock_store, mock_execution +): + """Test workflow failure completion through public fail_execution method.""" + # Arrange + error = ErrorObject.from_message("test error") + mock_execution.result = "error result" # Mock result after failure + mock_store.load.return_value = mock_execution + + with patch.object(executor, "_complete_events") as mock_complete_events: + # Act - Use public API to fail workflow + executor.fail_execution("test-arn", error) + + # Assert - Verify final execution status and stored error + mock_store.load.assert_called_once_with(execution_arn="test-arn") + mock_execution.complete_fail.assert_called_once_with(error=error) + mock_store.update.assert_called_once_with(mock_execution) + mock_complete_events.assert_called_once_with(execution_arn="test-arn") + + +def test_should_handle_workflow_completion_state_through_public_api( + executor, mock_store, mock_execution +): + """Test workflow completion behavior and state management through public API.""" + # Arrange + mock_execution.result = "final result" # Mock result after completion + mock_store.load.return_value = mock_execution + + with patch.object(executor, "_complete_events") as mock_complete_events: + # Act - Complete workflow through public API + executor.complete_execution("test-arn", "result") + + # Assert - Verify completion was processed and observer notifications sent + mock_store.load.assert_called_once_with(execution_arn="test-arn") + mock_execution.complete_success.assert_called_once_with(result="result") + mock_store.update.assert_called_once_with(mock_execution) + mock_complete_events.assert_called_once_with(execution_arn="test-arn") + + +def test_should_fail_execution_when_function_not_found( + executor, mock_store, mock_scheduler, mock_invoker, start_input +): + """Test that workflow fails when function is not found during invocation.""" + # Arrange + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution.is_complete = False + mock_execution.start_input = start_input + mock_execution.consecutive_failed_invocation_attempts = 0 + + # Mock invoker to raise function not found error + mock_invoker.create_invocation_input.side_effect = ResourceNotFoundException( + "Function not found: test_function" + ) + + # Mock execution creation and store behavior + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution_class.new.return_value = mock_execution + mock_store.load.return_value = mock_execution + + with patch.object(executor, "fail_execution") as mock_fail: + # Act - trigger invocation through public start_execution method + executor.start_execution(start_input) + + # Get the handler that was passed to the scheduler and execute it manually + assert mock_scheduler.call_later.call_count >= 1 + handler = mock_scheduler.call_later.call_args_list[-1][0][0] + + # Execute the handler to trigger the invocation logic + import asyncio + + asyncio.run(handler()) + + # Assert - verify failure was triggered with correct error + mock_fail.assert_called_once() + call_args = mock_fail.call_args + assert call_args[0][0] == "test-arn" # execution_arn + assert "Function not found" in call_args[0][1].message # error message + + +def test_should_fail_execution_when_retries_exhausted( + executor, mock_store, mock_scheduler, mock_invoker, start_input +): + """Test that workflow fails when maximum retry attempts are exhausted.""" + # Arrange + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution.is_complete = False + mock_execution.start_input = start_input + mock_execution.consecutive_failed_invocation_attempts = ( + executor.MAX_CONSECUTIVE_FAILED_ATTEMPTS + 1 + ) + + # Mock invoker to raise exception (simulating network/invocation failure) + mock_invoker.create_invocation_input.side_effect = Exception("Network error") + + # Mock execution creation and store behavior + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution_class.new.return_value = mock_execution + mock_store.load.return_value = mock_execution + + with patch.object(executor, "fail_execution") as mock_fail: + # Act - trigger invocation through public start_execution method + # This will cause an exception during invocation, which triggers retry logic + executor.start_execution(start_input) + + # Get the handler that was passed to the scheduler and execute it manually + assert mock_scheduler.call_later.call_count >= 1 + handler = mock_scheduler.call_later.call_args_list[-1][0][0] + + # Execute the handler to trigger the invocation logic + import asyncio + + asyncio.run(handler()) + + # Assert - verify failure was triggered when retries exhausted + mock_fail.assert_called_once() + call_args = mock_fail.call_args + assert call_args[0][0] == "test-arn" # execution_arn + + +def test_should_prevent_multiple_workflow_failures_on_complete_execution( + executor, mock_store, mock_scheduler, mock_invoker, start_input +): + """Test that attempting to fail an already completed execution raises an exception.""" + # Arrange - execution starts incomplete but becomes complete during processing + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution.is_complete = False # Initially incomplete + mock_execution.start_input = start_input + mock_execution.consecutive_failed_invocation_attempts = 0 + + # Create a completed execution for the _fail_workflow call + completed_execution = Mock() + completed_execution.is_complete = True + + # Mock invoker to raise ResourceNotFoundException (triggers _fail_workflow) + mock_invoker.create_invocation_input.side_effect = ResourceNotFoundException( + "Function not found" + ) + + # Mock execution creation and store behavior + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution_class.new.return_value = mock_execution + # First load returns incomplete, second load (in _fail_workflow) returns complete + mock_store.load.side_effect = [mock_execution, completed_execution] + + # Act & Assert - triggering workflow failure on completed execution should raise exception + executor.start_execution(start_input) + handler = mock_scheduler.call_later.call_args_list[-1][0][0] + + # Execute the handler to trigger the invocation logic - this should raise the exception + with pytest.raises( + IllegalStateException, match="Cannot make multiple close workflow decisions" + ): + asyncio.run(handler()) + + +def test_should_retry_invocation_when_under_limit_through_public_api( + executor, mock_store, mock_scheduler, mock_invoker, start_input +): + """Test that invocation retries when under limit through public API with final outcome verification.""" + # Arrange - Set up execution that will trigger retry logic + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution.is_complete = False + mock_execution.start_input = start_input + mock_execution.consecutive_failed_invocation_attempts = 3 # Under limit (5 is max) + + # Configure invoker to fail initially with validation error, then succeed on retry + mock_invocation_input = Mock() + mock_invoker.create_invocation_input.return_value = mock_invocation_input + + # First invocation: invalid response triggers retry + invalid_response = DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, + result="should not have result", # Invalid: failed response with result + ) + # Second invocation: valid success response + success_response = DurableExecutionInvocationOutput( + status=InvocationStatus.SUCCEEDED, result="final success" + ) + mock_invoker.invoke.side_effect = [ + InvokeResponse( + invocation_output=invalid_response, request_id="test-request-id-1" + ), + InvokeResponse( + invocation_output=success_response, request_id="test-request-id-2" + ), + ] + + # Mock execution creation and store behavior + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution_class.new.return_value = mock_execution + mock_store.load.return_value = mock_execution + + # Act - trigger the retry scenario through public API + executor.start_execution(start_input) + + # Simulate scheduler executing the initial invocation handler + initial_handler = mock_scheduler.call_later.call_args_list[-1][0][0] + import asyncio + + asyncio.run(initial_handler()) + + # Verify retry was scheduled due to validation error + assert mock_scheduler.call_later.call_count == 3 # timeout + initial + retry + retry_call = mock_scheduler.call_later.call_args_list[ + 2 + ] # Third call is the retry + retry_handler = retry_call[0][0] + retry_delay = retry_call[1]["delay"] + + # Execute the retry handler to complete the scenario + asyncio.run(retry_handler()) + + # Assert - verify final outcome after retry sequence + assert ( + mock_execution.consecutive_failed_invocation_attempts == 4 + ) # Incremented from 3 to 4 + assert retry_delay == Executor.RETRY_BACKOFF_SECONDS # Correct backoff delay used + mock_store.save.assert_called_with(mock_execution) # Execution state saved + assert mock_invoker.invoke.call_count == 2 # Initial + retry invocation + + +def test_should_fail_workflow_when_retry_limit_exceeded( + executor, mock_store, mock_scheduler, mock_invoker, start_input +): + """Test that workflow fails when retry limit is exceeded through public API.""" + # Arrange + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution.is_complete = False + mock_execution.start_input = start_input + mock_execution.consecutive_failed_invocation_attempts = 6 # Over limit + + # Mock invoker to consistently fail + mock_invoker.create_invocation_input.side_effect = Exception("Persistent error") + mock_store.load.return_value = mock_execution + + # Mock execution creation + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution_class.new.return_value = mock_execution + + # Mock the public fail_execution method to verify it gets called + with patch.object(executor, "fail_execution") as mock_fail: + # Act - trigger execution that will exceed retry limit + executor.start_execution(start_input) + + # Get the handler that was passed to the scheduler and execute it manually + assert mock_scheduler.call_later.call_count >= 1 + handler = mock_scheduler.call_later.call_args_list[-1][0][0] + + # Execute the handler to trigger the invocation logic + asyncio.run(handler()) + + # Assert - verify workflow failed due to retry limit exceeded + mock_fail.assert_called_once() + # Verify the error contains the expected message + call_args = mock_fail.call_args + assert call_args[0][0] == "test-arn" # execution_arn is first positional arg + assert "Persistent error" in str( + call_args[0][1] + ) # error is second positional arg + + +def test_complete_events_through_complete_execution( + executor, mock_store, mock_scheduler +): + """Test completion event behavior through public complete_execution method.""" + mock_execution = Mock() + mock_execution.result = "test result" + mock_store.load.return_value = mock_execution + + # Set up completion event through start_execution + mock_event = Mock() + mock_scheduler.create_event.return_value = mock_event + + # Mock the timeout future that will be created + mock_timeout_future = Mock() + mock_scheduler.call_later.return_value = mock_timeout_future + + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_exec = Mock() + mock_exec.durable_execution_arn = "test-arn" + mock_execution_class.new.return_value = mock_exec + + start_input = Mock() + start_input.execution_timeout_seconds = 300 + executor.start_execution(start_input) + + # Now complete the execution - this should trigger event.set() and cancel timeout + executor.complete_execution("test-arn", "result") + + # Verify the event was set and timeout was cancelled + mock_event.set.assert_called_once() + mock_timeout_future.cancel.assert_called_once() + + +def test_complete_events_no_event_through_public_api(executor, mock_store): + """Test that completing non-existent execution handles missing events gracefully.""" + mock_execution = Mock() + mock_execution.result = "test result" + mock_store.load.return_value = mock_execution + + # Complete execution without setting up completion event first + # Should not raise exception when event doesn't exist + executor.complete_execution("nonexistent-arn", "result") + + +def test_wait_until_complete_success(executor, mock_scheduler): + """Test wait until complete success through public API.""" + mock_event = Mock() + mock_event.wait.return_value = True + mock_scheduler.create_event.return_value = mock_event + + # Set up completion event through start_execution + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution_class.new.return_value = mock_execution + + start_input = Mock() + start_input.execution_timeout_seconds = 0 + executor.start_execution(start_input) + + result = executor.wait_until_complete("test-arn", timeout=10) + + assert result is True + mock_event.wait.assert_called_once_with(10) + + +def test_wait_until_complete_timeout(executor, mock_scheduler): + """Test wait until complete timeout through public API.""" + mock_event = Mock() + mock_event.wait.return_value = False + mock_scheduler.create_event.return_value = mock_event + + # Set up completion event through start_execution + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution_class.new.return_value = mock_execution + + start_input = Mock() + start_input.execution_timeout_seconds = 0 + executor.start_execution(start_input) + + result = executor.wait_until_complete("test-arn", timeout=10) + + assert result is False + + +def test_wait_until_complete_no_event(executor): + with pytest.raises(ResourceNotFoundException, match="execution does not exist"): + executor.wait_until_complete("nonexistent-arn") + + +def test_complete_execution(executor, mock_store, mock_execution): + mock_execution.result = "test result" + mock_store.load.return_value = mock_execution + + with patch.object(executor, "_complete_events") as mock_complete_events: + executor.complete_execution("test-arn", "result") + + mock_store.load.assert_called_once_with(execution_arn="test-arn") + mock_execution.complete_success.assert_called_once_with(result="result") + mock_store.update.assert_called_once_with(mock_execution) + mock_complete_events.assert_called_once_with(execution_arn="test-arn") + + +def test_fail_execution(executor, mock_store, mock_execution): + error = ErrorObject.from_message("test error") + mock_execution.result = "error result" + mock_store.load.return_value = mock_execution + + with patch.object(executor, "_complete_events") as mock_complete_events: + executor.fail_execution("test-arn", error) + + mock_store.load.assert_called_once_with(execution_arn="test-arn") + mock_execution.complete_fail.assert_called_once_with(error=error) + mock_store.update.assert_called_once_with(mock_execution) + mock_complete_events.assert_called_once_with(execution_arn="test-arn") + + +def test_should_schedule_wait_timer_correctly(executor, mock_scheduler): + """Test that wait timer is scheduled correctly through public method.""" + # Arrange + mock_event = Mock() + mock_scheduler.create_event.return_value = mock_event + + # Set up completion event through start_execution + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution_class.new.return_value = mock_execution + + start_input = Mock() + start_input.execution_timeout_seconds = 0 + executor.start_execution(start_input) + + # Act - schedule wait timer through public method + executor.on_wait_timer_scheduled("test-arn", "op-123", delay=5.0) + + # Assert - verify scheduler was called correctly + assert mock_scheduler.call_later.call_count == 2 # start_execution + wait timer + wait_call = mock_scheduler.call_later.call_args_list[1] # Second call is wait timer + assert wait_call[1]["delay"] == 5.0 + assert wait_call[1]["completion_event"] == mock_event + + +def test_should_ignore_wait_completion_for_completed_execution( + executor, mock_store, mock_execution +): + """Test that wait completion logic correctly handles completed executions.""" + # Arrange + mock_execution.is_complete = True + mock_store.load.return_value = mock_execution + + # Act - simulate the wait completion logic for a completed execution + execution = mock_store.load("test-arn") + + # The logic should check if execution is complete before attempting to complete wait + if not execution.is_complete: + execution.complete_wait(operation_id="op-123") + mock_store.update(execution) + + # Assert - verify that complete_wait was not called for completed execution + mock_execution.complete_wait.assert_not_called() + mock_store.update.assert_not_called() + + +def test_should_handle_wait_completion_exception_gracefully( + executor, mock_store, mock_execution +): + """Test that wait completion exceptions are handled through error handling.""" + # Arrange + mock_store.load.return_value = mock_execution + mock_execution.is_complete = False + mock_execution.complete_wait.side_effect = Exception("test error") + + # Act & Assert - test that exception handling works correctly + # This tests the error handling logic without scheduler timing dependencies + execution = mock_store.load("test-arn") + + with pytest.raises(Exception, match="test error"): + execution.complete_wait(operation_id="op-123") + + +def test_should_complete_retry_when_retry_scheduled( + executor, mock_store, mock_scheduler, mock_execution +): + """Test retry completion through public scheduler callback API.""" + # Arrange + mock_store.load.return_value = mock_execution + + # Configure scheduler to immediately execute the callback + def immediate_callback(func, delay=0, count=1, completion_event=None): + func() # Execute the retry handler immediately + return Mock() + + mock_scheduler.call_later.side_effect = immediate_callback + + # Mock _invoke_execution to prevent async warnings + with patch.object(executor, "_invoke_execution"): + # Act - trigger retry through public API + executor.on_step_retry_scheduled("test-arn", "op-123", 10.0) + + # Assert - verify observable behavior + mock_store.load.assert_called_with("test-arn") + mock_execution.complete_retry.assert_called_once_with(operation_id="op-123") + mock_store.update.assert_called_with(mock_execution) + + +def test_should_ignore_retry_when_execution_complete( + executor, mock_store, mock_scheduler, mock_execution +): + """Test that completed executions ignore retry events through public API.""" + # Arrange + mock_execution.is_complete = True + mock_store.load.return_value = mock_execution + + # Configure scheduler to immediately execute the callback + def immediate_callback(func, delay=0, count=1, completion_event=None): + func() # Execute the retry handler immediately + return Mock() + + mock_scheduler.call_later.side_effect = immediate_callback + + # Mock _invoke_execution to prevent async warnings + with patch.object(executor, "_invoke_execution"): + # Act - trigger retry through public API + executor.on_step_retry_scheduled("test-arn", "op-123", 10.0) + + # Assert - verify no retry processing occurs + mock_execution.complete_retry.assert_not_called() + mock_store.update.assert_not_called() + + +def test_should_handle_retry_exception_gracefully( + executor, mock_store, mock_scheduler, mock_execution +): + """Test that retry exceptions are handled gracefully through public API.""" + # Arrange + mock_store.load.return_value = mock_execution + mock_execution.complete_retry.side_effect = Exception("test error") + + # Configure scheduler to immediately execute the callback + def immediate_callback(func, delay=0, count=1, completion_event=None): + func() # Execute the retry handler immediately + return Mock() + + mock_scheduler.call_later.side_effect = immediate_callback + + # Mock _invoke_execution to prevent async warnings + with patch.object(executor, "_invoke_execution"): + # Act - should not raise exception + executor.on_step_retry_scheduled("test-arn", "op-123", 10.0) + + # Assert - verify the retry was attempted but exception was caught + mock_execution.complete_retry.assert_called_once_with(operation_id="op-123") + + +def test_on_completed(executor): + with patch.object(executor, "complete_execution") as mock_complete: + executor.on_completed("test-arn", "result") + + mock_complete.assert_called_once_with("test-arn", "result") + + +def test_on_failed(executor): + error = ErrorObject.from_message("test error") + + with patch.object(executor, "fail_execution") as mock_fail: + executor.on_failed("test-arn", error) + + mock_fail.assert_called_once_with("test-arn", error) + + +def test_on_wait_timer_scheduled(executor, mock_scheduler): + """Test wait timer scheduling through public observer method.""" + mock_event = Mock() + mock_scheduler.create_event.return_value = mock_event + + # Set up completion event through start_execution + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution_class.new.return_value = mock_execution + + start_input = Mock() + start_input.execution_timeout_seconds = 0 + executor.start_execution(start_input) + + with patch.object(executor, "_on_wait_succeeded"): + with patch.object(executor, "_invoke_execution"): + executor.on_wait_timer_scheduled("test-arn", "op-123", 10.0) + + # Verify scheduler was called with correct parameters + assert ( + mock_scheduler.call_later.call_count == 2 + ) # Once for start_execution, once for wait timer + wait_timer_call = mock_scheduler.call_later.call_args_list[ + 1 + ] # Second call is for wait timer + assert wait_timer_call[1]["delay"] == 10.0 + assert wait_timer_call[1]["completion_event"] == mock_event + + +def test_should_retry_when_response_has_unexpected_status( + executor, mock_store, mock_scheduler, mock_invoker, start_input +): + """Test that responses with unexpected status trigger retry logic.""" + # Arrange + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution.is_complete = False + mock_execution.start_input = start_input + mock_execution.consecutive_failed_invocation_attempts = 0 + + # Mock invoker to return response with unexpected status + mock_invocation_input = Mock() + mock_invoker.create_invocation_input.return_value = mock_invocation_input + unexpected_response = Mock() + unexpected_response.status = "UNKNOWN_STATUS" + mock_invoker.invoke.return_value = InvokeResponse( + invocation_output=unexpected_response, request_id="test-request-id" + ) + + # Mock execution creation and store behavior + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution_class.new.return_value = mock_execution + mock_store.load.return_value = mock_execution + + # Act - trigger invocation through public start_execution method + executor.start_execution(start_input) + + # Get the handler that was passed to the scheduler and execute it manually + assert mock_scheduler.call_later.call_count >= 1 + handler = mock_scheduler.call_later.call_args_list[-1][0][0] + + # Execute the handler to trigger the invocation logic + asyncio.run(handler()) + + # Assert - verify retry was triggered due to validation error + assert mock_execution.consecutive_failed_invocation_attempts == 1 + mock_store.save.assert_called_with(mock_execution) + # Verify retry was scheduled (call_later should be called 3 times: timeout + initial + retry) + assert mock_scheduler.call_later.call_count == 3 + + +def test_invoke_handler_execution_completed_during_invocation_async( + executor, mock_store, mock_scheduler, mock_invoker, start_input +): + """Test execution completing during invocation through public API.""" + # First call returns incomplete execution, second call returns completed execution + incomplete_execution = Mock(spec=Execution) + incomplete_execution.is_complete = False + incomplete_execution.start_input = start_input + incomplete_execution.consecutive_failed_invocation_attempts = 0 + incomplete_execution.durable_execution_arn = "test-arn" + + completed_execution = Mock(spec=Execution) + completed_execution.is_complete = True + + mock_store.load.side_effect = [incomplete_execution, completed_execution] + + mock_invocation_input = Mock() + mock_invoker.create_invocation_input.return_value = mock_invocation_input + mock_response = Mock() + mock_invoker.invoke.return_value = InvokeResponse( + invocation_output=mock_response, request_id="test-request-id" + ) + + # Mock execution creation + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution_class.new.return_value = incomplete_execution + + # Act - trigger invocation through public start_execution method + executor.start_execution(start_input) + + # Get the handler that was passed to the scheduler and execute it manually + assert mock_scheduler.call_later.call_count >= 1 + handler = mock_scheduler.call_later.call_args_list[-1][0][0] + + # Execute the handler to trigger the invocation logic + asyncio.run(handler()) + + # Verify the execution was loaded multiple times (before and after invocation) + assert mock_store.load.call_count >= 2 + + +def test_invoke_handler_resource_not_found_async( + executor, mock_store, mock_scheduler, mock_invoker, start_input +): + """Test resource not found handling causes workflow failure through public API (async version).""" + # Arrange + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution.is_complete = False + mock_execution.start_input = start_input + + mock_invoker.create_invocation_input.side_effect = ResourceNotFoundException( + "Function not found" + ) + + # Mock execution creation and store behavior + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution_class.new.return_value = mock_execution + mock_store.load.return_value = mock_execution + + # Mock the public fail_execution method to verify it gets called + with patch.object(executor, "fail_execution") as mock_fail: + # Act - trigger invocation through public start_execution method + executor.start_execution(start_input) + + # Get the handler that was passed to the scheduler and execute it manually + assert mock_scheduler.call_later.call_count >= 1 + handler = mock_scheduler.call_later.call_args_list[-1][0][0] + + # Execute the handler to trigger the invocation logic + asyncio.run(handler()) + + # Assert - verify workflow failure was triggered through public API + mock_fail.assert_called_once() + # Verify the error contains the expected message + call_args = mock_fail.call_args + assert call_args[0][0] == "test-arn" # execution_arn is first positional arg + assert "Function not found" in str( + call_args[0][1] + ) # error is second positional arg + + +def test_invoke_handler_general_exception_async( + executor, mock_store, mock_scheduler, mock_invoker, start_input +): + """Test general exception handling triggers retry through public API (async version).""" + # Arrange + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution.is_complete = False + mock_execution.start_input = start_input + mock_execution.consecutive_failed_invocation_attempts = 0 + + # Configure invoker to fail initially, then succeed on retry + mock_invoker.create_invocation_input.side_effect = [ + Exception("General error"), # First call fails + Mock(), # Second call succeeds (returns invocation input) + ] + + # Mock successful response for retry + success_response = DurableExecutionInvocationOutput( + status=InvocationStatus.SUCCEEDED, result="success" + ) + mock_invoker.invoke.return_value = InvokeResponse( + invocation_output=success_response, request_id="test-request-id" + ) + + # Mock execution creation and store behavior + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution_class.new.return_value = mock_execution + mock_store.load.return_value = mock_execution + + # Act - trigger invocation through public start_execution method + executor.start_execution(start_input) + + # Get the handler that was passed to the scheduler and execute it manually + assert mock_scheduler.call_later.call_count >= 1 + handler = mock_scheduler.call_later.call_args_list[-1][0][0] + + # Execute the handler to trigger the invocation logic + asyncio.run(handler()) + + # Assert - verify retry was scheduled through observable behavior + assert mock_execution.consecutive_failed_invocation_attempts == 1 + mock_store.save.assert_called_with(mock_execution) + # Verify retry was scheduled (call_later should be called 3 times: timeout + initial + retry) + assert mock_scheduler.call_later.call_count == 3 + + +def test_invoke_execution_with_delay_through_wait_timer(executor, mock_scheduler): + """Test execution invocation with delay through wait timer scheduling.""" + mock_event = Mock() + mock_scheduler.create_event.return_value = mock_event + + # Set up completion event through start_execution + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution_class.new.return_value = mock_execution + + start_input = Mock() + start_input.execution_timeout_seconds = 0 + executor.start_execution(start_input) + + # Test delay behavior through wait timer scheduling + with patch.object(executor, "_on_wait_succeeded"): + executor.on_wait_timer_scheduled("test-arn", "op-123", 10.0) + + # Verify scheduler was called with delay for wait timer + wait_timer_call = mock_scheduler.call_later.call_args_list[ + 1 + ] # Second call is for wait timer + assert wait_timer_call[1]["delay"] == 10.0 + + +def test_invoke_execution_no_delay_through_start_execution(executor, mock_scheduler): + """Test execution invocation with no delay through start_execution.""" + mock_event = Mock() + mock_scheduler.create_event.return_value = mock_event + + # Test no delay behavior through start_execution + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution_class.new.return_value = mock_execution + + start_input = Mock() + start_input.execution_timeout_seconds = 0 + executor.start_execution(start_input) + + # Verify scheduler was called with no delay for initial execution + initial_call = mock_scheduler.call_later.call_args_list[ + 0 + ] # First call is for initial execution + assert initial_call[1]["delay"] == 0 + + +def test_on_step_retry_scheduled(executor, mock_scheduler): + """Test step retry scheduling through public observer method.""" + mock_event = Mock() + mock_scheduler.create_event.return_value = mock_event + + # Set up completion event through start_execution + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution_class.new.return_value = mock_execution + + start_input = Mock() + start_input.execution_timeout_seconds = 0 + executor.start_execution(start_input) + + with patch.object(executor, "_on_retry_ready"): + with patch.object(executor, "_invoke_execution"): + executor.on_step_retry_scheduled("test-arn", "op-123", 10.0) + + # Verify scheduler was called with correct parameters + assert ( + mock_scheduler.call_later.call_count == 2 + ) # Once for start_execution, once for retry + retry_call = mock_scheduler.call_later.call_args_list[1] # Second call is for retry + assert retry_call[1]["delay"] == 10.0 + assert retry_call[1]["completion_event"] == mock_event + + +def test_wait_handler_execution(executor, mock_scheduler): + """Test wait handler execution through public observer method.""" + mock_event = Mock() + mock_scheduler.create_event.return_value = mock_event + + # Set up completion event through start_execution + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution_class.new.return_value = mock_execution + + start_input = Mock() + start_input.execution_timeout_seconds = 0 + executor.start_execution(start_input) + + with patch.object(executor, "_on_wait_succeeded") as mock_wait: + with patch.object(executor, "_invoke_execution") as mock_invoke: + executor.on_wait_timer_scheduled("test-arn", "op-123", 10.0) + + # Get the handler that was passed to call_later (second call for wait timer) + wait_timer_call = mock_scheduler.call_later.call_args_list[1] + wait_handler = wait_timer_call[0][0] + + # Execute the handler to test the inner function + wait_handler() + + mock_wait.assert_called_once_with("test-arn", "op-123") + mock_invoke.assert_called_once_with("test-arn", delay=0) + + +def test_retry_handler_execution(executor, mock_scheduler): + """Test retry handler execution through public observer method.""" + mock_event = Mock() + mock_scheduler.create_event.return_value = mock_event + + # Set up completion event through start_execution + with patch( + "aws_durable_execution_sdk_python_testing.executor.Execution" + ) as mock_execution_class: + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + mock_execution_class.new.return_value = mock_execution + + start_input = Mock() + start_input.execution_timeout_seconds = 0 + executor.start_execution(start_input) + + with patch.object(executor, "_on_retry_ready") as mock_retry: + with patch.object(executor, "_invoke_execution") as mock_invoke: + executor.on_step_retry_scheduled("test-arn", "op-123", 10.0) + + # Get the handler that was passed to call_later (second call for retry) + retry_call = mock_scheduler.call_later.call_args_list[1] + retry_handler = retry_call[0][0] + + # Execute the handler to test the inner function + retry_handler() + + mock_retry.assert_called_once_with("test-arn", "op-123") + mock_invoke.assert_called_once_with("test-arn", delay=0) + + +# Tests for new web handler methods + + +def test_get_execution_details(executor, mock_store): + """Test get_execution_details method.""" + + # Create real execution instance with mocked start_input + mock_start_input = Mock() + mock_start_input.execution_name = "test-execution" + mock_start_input.function_name = "test-function" + + execution = Execution( + durable_execution_arn="test-arn", start_input=mock_start_input, operations=[] + ) + execution.is_complete = True + + # Create mock result + mock_result = DurableExecutionInvocationOutput( + status=InvocationStatus.SUCCEEDED, result="test-result" + ) + execution.result = mock_result + execution.close_status = ExecutionStatus.SUCCEEDED + + # Create mock operation and add to execution + mock_operation = Operation( + operation_id="op-1", + parent_id=None, + name="test-execution", + start_timestamp=datetime.now(UTC), + end_timestamp=datetime.now(UTC), + operation_type=OperationType.EXECUTION, + status=OperationStatus.SUCCEEDED, + execution_details=ExecutionDetails(input_payload='{"test": "data"}'), + ) + execution.operations = [mock_operation] + + mock_store.load.return_value = execution + + result = executor.get_execution_details("test-arn") + + assert result.durable_execution_arn == "test-arn" + assert result.durable_execution_name == "test-execution" + assert result.status == "SUCCEEDED" + assert result.result == "test-result" + assert result.error is None + mock_store.load.assert_called_once_with("test-arn") + + +def test_get_execution_details_not_found(executor, mock_store): + """Test get_execution_details with non-existent execution.""" + mock_store.load.side_effect = KeyError("Execution not found") + + with pytest.raises(ResourceNotFoundException, match="Execution test-arn not found"): + executor.get_execution_details("test-arn") + + +def test_get_execution_details_failed_execution(executor, mock_store): + """Test get_execution_details with failed execution.""" + + # Create real execution instance with mocked start_input + mock_start_input = Mock() + mock_start_input.execution_name = "test-execution" + mock_start_input.function_name = "test-function" + + execution = Execution( + durable_execution_arn="test-arn", start_input=mock_start_input, operations=[] + ) + execution.is_complete = True + + error = ErrorObject.from_message("Test error") + mock_result = DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, error=error + ) + execution.result = mock_result + + # Create mock operation and add to execution + mock_operation = Operation( + operation_id="op-1", + parent_id=None, + name="test-execution", + start_timestamp=datetime.now(UTC), + operation_type=OperationType.EXECUTION, + status=OperationStatus.FAILED, + execution_details=ExecutionDetails(input_payload='{"test": "data"}'), + ) + execution.operations = [mock_operation] + + mock_store.load.return_value = execution + with pytest.raises( + IllegalStateException, + match="close_status must be set when execution is complete", + ): + executor.get_execution_details("test-arn") + execution.close_status = ExecutionStatus.FAILED + result = executor.get_execution_details("test-arn") + assert result.status == "FAILED" + assert result.result is None + assert result.error == error + + +def test_list_executions_empty(executor, mock_store): + """Test list_executions with no executions.""" + query_result = ([], None) + mock_store.query.return_value = query_result + + result = executor.list_executions() + + assert result.durable_executions == [] + assert result.next_marker is None + mock_store.query.assert_called_once() + + +def test_list_executions_with_filtering(executor, mock_store): + """Test list_executions with function name filtering.""" + # Create real execution instance + mock_start_input = Mock() + mock_start_input.execution_name = "exec1" + mock_start_input.function_name = "function1" + + execution1 = Execution( + durable_execution_arn="arn1", start_input=mock_start_input, operations=[] + ) + execution1.is_complete = False + execution1.result = None + + # Create mock operations + op1 = Operation( + operation_id="op-1", + parent_id=None, + name="exec1", + start_timestamp=datetime.now(UTC), + operation_type=OperationType.EXECUTION, + status=OperationStatus.STARTED, + execution_details=ExecutionDetails(input_payload="{}"), + ) + execution1.operations = [op1] + + # Mock the query method to return filtered results + query_result = ([execution1], "1") + mock_store.query.return_value = query_result + + # Test filtering by function name + result = executor.list_executions(function_name="function1") + + assert len(result.durable_executions) == 1 + assert result.durable_executions[0].durable_execution_arn == "arn1" + assert result.durable_executions[0].status == "RUNNING" + + +def test_list_executions_with_pagination(executor, mock_store): + """Test list_executions with pagination.""" + # Create multiple mock executions for first page + executions_page1 = [] + for i in range(2): + execution = Mock() + execution.durable_execution_arn = f"arn{i}" + execution.start_input.execution_name = f"exec{i}" + execution.start_input.function_name = "test-function" + execution.is_complete = False + execution.result = None + + op = Operation( + operation_id=f"op-{i}", + parent_id=None, + name=f"exec{i}", + start_timestamp=datetime.now(UTC), + operation_type=OperationType.EXECUTION, + status=OperationStatus.STARTED, + execution_details=ExecutionDetails(input_payload="{}"), + ) + execution.get_operation_execution_started.return_value = op + executions_page1.append(execution) + + # Create executions for second page + executions_page2 = [] + for i in range(2, 4): + execution = Mock() + execution.durable_execution_arn = f"arn{i}" + execution.start_input.execution_name = f"exec{i}" + execution.start_input.function_name = "test-function" + execution.is_complete = False + execution.result = None + + op = Operation( + operation_id=f"op-{i}", + parent_id=None, + name=f"exec{i}", + start_timestamp=datetime.now(UTC), + operation_type=OperationType.EXECUTION, + status=OperationStatus.STARTED, + execution_details=ExecutionDetails(input_payload="{}"), + ) + execution.get_operation_execution_started.return_value = op + executions_page2.append(execution) + + # Mock query responses for pagination + query_result1 = (executions_page1, "2") + + query_result2 = (executions_page2, "4") + + mock_store.query.side_effect = [query_result1, query_result2] + + # Test pagination with max_items=2 + result = executor.list_executions(max_items=2) + + assert len(result.durable_executions) == 2 + assert result.next_marker == "2" + + # Test second page + result2 = executor.list_executions(max_items=2, marker="2") + + assert len(result2.durable_executions) == 2 + assert result2.next_marker == "4" + + +def test_list_executions_by_function(executor): + """Test list_executions_by_function delegates to list_executions.""" + with patch.object(executor, "list_executions") as mock_list: + mock_response = ListDurableExecutionsResponse( + durable_executions=[], next_marker=None + ) + mock_list.return_value = mock_response + + result = executor.list_executions_by_function( + "test-function", status_filter="RUNNING" + ) + + mock_list.assert_called_once_with( + function_name="test-function", + execution_name=None, + status_filter="RUNNING", + started_after=None, + started_before=None, + marker=None, + max_items=None, + reverse_order=False, + ) + assert result.durable_executions == [] + assert result.next_marker is None + + +def test_stop_execution(executor, mock_store): + """Test stop_execution method.""" + # Create real execution instance with mocked start_input + mock_start_input = Mock() + mock_start_input.execution_name = "test-execution" + mock_start_input.function_name = "test-function" + + execution = Execution( + durable_execution_arn="test-arn", + start_input=mock_start_input, + operations=[Mock()], + ) + execution.is_complete = False + mock_store.load.return_value = execution + + result = executor.stop_execution("test-arn") + + mock_store.load.assert_called_once_with("test-arn") + mock_store.update.assert_called_once_with(execution) + assert result.stop_timestamp is not None + assert execution.is_complete is True + assert execution.close_status == ExecutionStatus.STOPPED + + +def test_stop_execution_already_complete(executor, mock_store): + """Test stop_execution with already completed execution returns idempotent response.""" + mock_execution = Mock() + mock_execution.is_complete = True + mock_execution.durable_execution_arn = "test-arn" + + # Mock the execution operation with end_timestamp + mock_execution_op = Mock() + mock_execution_op.end_timestamp = datetime(2023, 1, 1, 0, 1, 0, tzinfo=UTC) + mock_execution.get_operation_execution_started.return_value = mock_execution_op + + mock_store.load.return_value = mock_execution + + result = executor.stop_execution("test-arn") + + assert isinstance(result, StopDurableExecutionResponse) + assert result.stop_timestamp == datetime(2023, 1, 1, 0, 1, 0, tzinfo=UTC) + + +def test_stop_execution_with_custom_error(executor, mock_store): + """Test stop_execution with custom error.""" + # Create real execution instance with mocked start_input + mock_start_input = Mock() + mock_start_input.execution_name = "test-execution" + mock_start_input.function_name = "test-function" + + execution = Execution( + durable_execution_arn="test-arn", + start_input=mock_start_input, + operations=[Mock()], + ) + execution.is_complete = False + mock_store.load.return_value = execution + + custom_error = ErrorObject.from_message("Custom stop error") + + executor.stop_execution("test-arn", error=custom_error) + + mock_store.load.assert_called_once_with("test-arn") + mock_store.update.assert_called_once_with(execution) + assert execution.is_complete is True + assert execution.close_status == ExecutionStatus.STOPPED + assert execution.result.error == custom_error + + +def test_get_execution_not_found(executor, mock_store): + mock_store.load.side_effect = KeyError("not found") + + with pytest.raises(ResourceNotFoundException): + executor.get_execution("test-arn") + + +def test_get_execution_state(executor, mock_store): + """Test get_execution_state method.""" + + mock_execution = Mock() + mock_execution.used_tokens = {"token1", "token2"} + + # Create mock operations + operations = [ + Operation( + operation_id="op-1", + parent_id=None, + name="step1", + start_timestamp=datetime.now(UTC), + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + ), + Operation( + operation_id="op-2", + parent_id=None, + name="step2", + start_timestamp=datetime.now(UTC), + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + ), + ] + mock_execution.get_assertable_operations.return_value = operations + + mock_store.load.return_value = mock_execution + + result = executor.get_execution_state("test-arn", checkpoint_token="token1") # noqa: S106 + + assert len(result.operations) == 2 + assert result.next_marker is None + mock_store.load.assert_called_once_with("test-arn") + + +def test_get_execution_state_invalid_token(executor, mock_store): + """Test get_execution_state with invalid checkpoint token.""" + mock_execution = Mock() + mock_execution.used_tokens = {"token1", "token2"} + mock_store.load.return_value = mock_execution + + with pytest.raises( + InvalidParameterValueException, match="Invalid checkpoint token" + ): + executor.get_execution_state("test-arn", checkpoint_token="invalid-token") # noqa: S106 + + +def test_get_execution_history(executor, mock_store): + """Test get_execution_history method.""" + mock_execution = Mock() + mock_execution.operations = [] # Empty operations list + mock_execution.updates = [] + mock_execution.invocation_completions = [] + mock_execution.durable_execution_arn = "" + mock_execution.start_input = Mock() + mock_execution.result = Mock() + + mock_store.load.return_value = mock_execution + + result = executor.get_execution_history("test-arn") + + assert result.events == [] + assert result.next_marker is None + mock_store.load.assert_called_once_with("test-arn") + + +def test_get_execution_history_with_events(executor, mock_store): + """Test get_execution_history with actual events.""" + from aws_durable_execution_sdk_python.lambda_service import StepDetails + + # Create operations that will generate events + op1 = Operation( + operation_id="op-1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + start_timestamp=datetime.now(UTC), + end_timestamp=datetime.now(UTC), + step_details=StepDetails(result="test_result"), + ) + mock_execution = Mock() + mock_execution.operations = [op1] + mock_execution.updates = [] + mock_execution.invocation_completions = [] + mock_execution.durable_execution_arn = "" + mock_execution.start_input = Mock() + mock_execution.result = Mock() + mock_store.load.return_value = mock_execution + + result = executor.get_execution_history("test-arn", include_execution_data=True) + + assert len(result.events) == 2 # Started + Succeeded events + assert result.events[0].event_type == "StepStarted" + assert result.events[1].event_type == "StepSucceeded" + + +def test_get_execution_history_reverse_order(executor, mock_store): + """Test get_execution_history with reverse order.""" + op1 = Operation( + operation_id="op-1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + start_timestamp=datetime.now(UTC), + end_timestamp=datetime.now(UTC), + ) + + mock_execution = Mock() + mock_execution.operations = [op1] + mock_execution.updates = [] + mock_execution.invocation_completions = [] + mock_execution.durable_execution_arn = "" + mock_execution.start_input = Mock() + mock_execution.result = Mock() + mock_store.load.return_value = mock_execution + + result = executor.get_execution_history("test-arn", reverse_order=True) + + assert len(result.events) == 2 + # In reverse order, succeeded event should come first + assert result.events[0].event_type == "StepSucceeded" + assert result.events[1].event_type == "StepStarted" + + +def test_get_execution_history_pagination(executor, mock_store): + """Test get_execution_history with pagination.""" + # Create multiple operations to generate many events + operations = [] + for i in range(3): + op = Operation( + operation_id=f"op-{i}", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + start_timestamp=datetime.now(UTC), + end_timestamp=datetime.now(UTC), + ) + operations.append(op) + + mock_execution = Mock() + mock_execution.operations = operations + mock_execution.updates = [] + mock_execution.invocation_completions = [] + mock_execution.durable_execution_arn = "" + mock_execution.start_input = Mock() + mock_execution.result = Mock() + mock_store.load.return_value = mock_execution + + # Test with max_items=2 + result = executor.get_execution_history("test-arn", max_items=2) + + assert len(result.events) == 2 + assert result.next_marker == "3" # Next event_id + + +def test_get_execution_history_pagination_with_marker(executor, mock_store): + """Test get_execution_history pagination with marker.""" + operations = [] + for i in range(3): + op = Operation( + operation_id=f"op-{i}", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + start_timestamp=datetime.now(UTC), + end_timestamp=datetime.now(UTC), + ) + operations.append(op) + + mock_execution = Mock() + mock_execution.operations = operations + mock_execution.updates = [] + mock_execution.invocation_completions = [] + mock_execution.durable_execution_arn = "" + mock_execution.start_input = Mock() + mock_execution.result = Mock() + mock_store.load.return_value = mock_execution + + # Test with marker (start from event_id 3) + result = executor.get_execution_history("test-arn", marker="3", max_items=2) + + assert len(result.events) == 2 + # Should get events with event_id >= 3 + + +def test_get_execution_history_invalid_marker(executor, mock_store): + """Test get_execution_history with invalid marker.""" + mock_execution = Mock() + mock_execution.operations = [] + mock_execution.updates = [] + mock_execution.invocation_completions = [] + mock_execution.durable_execution_arn = "" + mock_execution.start_input = Mock() + mock_execution.result = Mock() + mock_store.load.return_value = mock_execution + + # Invalid marker should default to 1 + result = executor.get_execution_history("test-arn", marker="invalid") + + assert result.events == [] + assert result.next_marker is None + + +def test_checkpoint_execution(executor, mock_store): + """Test checkpoint_execution method.""" + mock_execution = Mock() + mock_execution.used_tokens = {"token1", "token2"} + mock_execution.get_new_checkpoint_token.return_value = "new-token" + mock_store.load.return_value = mock_execution + + result = executor.checkpoint_execution("test-arn", "token1") + + assert result.checkpoint_token == "new-token" # noqa: S105 + assert result.new_execution_state is None + mock_store.load.assert_called_once_with("test-arn") + mock_execution.get_new_checkpoint_token.assert_called_once() + + +def test_checkpoint_execution_invalid_token(executor, mock_store): + """Test checkpoint_execution with invalid checkpoint token.""" + mock_execution = Mock() + mock_execution.used_tokens = {"token1", "token2"} + mock_store.load.return_value = mock_execution + + with pytest.raises( + InvalidParameterValueException, match="Invalid checkpoint token" + ): + executor.checkpoint_execution("test-arn", "invalid-token") + + +# Callback method tests + + +def test_send_callback_success(executor, mock_store): + """Test send_callback_success method.""" + from aws_durable_execution_sdk_python_testing.token import CallbackToken + + # Create valid callback token + callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123") + callback_id = callback_token.to_str() + + # Create mock execution with callback operation + mock_execution = Mock() + mock_execution.find_callback_operation.return_value = (0, Mock()) + mock_execution.complete_callback_success.return_value = Mock() + mock_store.load.return_value = mock_execution + + with patch.object(executor, "_invoke_execution") as mock_invoke: + result = executor.send_callback_success(callback_id, b"success-result") + + assert isinstance(result, SendDurableExecutionCallbackSuccessResponse) + mock_store.load.assert_called_once_with("test-arn") + mock_execution.complete_callback_success.assert_called_once_with( + callback_id, b"success-result" + ) + mock_store.update.assert_called_once_with(mock_execution) + # Verify execution is invoked after callback success + mock_invoke.assert_called_once_with("test-arn") + + +def test_send_callback_success_empty_callback_id(executor): + """Test send_callback_success with empty callback_id.""" + with pytest.raises(InvalidParameterValueException, match="callback_id is required"): + executor.send_callback_success("") + + +def test_send_callback_success_none_callback_id(executor): + """Test send_callback_success with None callback_id.""" + with pytest.raises(InvalidParameterValueException, match="callback_id is required"): + executor.send_callback_success(None) + + +def test_send_callback_success_with_result(executor, mock_store): + """Test send_callback_success with result data.""" + from aws_durable_execution_sdk_python_testing.token import CallbackToken + + # Create valid callback token + callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123") + callback_id = callback_token.to_str() + + # Create mock execution with callback operation + mock_execution = Mock() + mock_execution.find_callback_operation.return_value = (0, Mock()) + mock_execution.complete_callback_success.return_value = Mock() + mock_store.load.return_value = mock_execution + + with patch.object(executor, "_invoke_execution") as mock_invoke: + result = executor.send_callback_success(callback_id, b"test-result") + + assert isinstance(result, SendDurableExecutionCallbackSuccessResponse) + mock_execution.complete_callback_success.assert_called_once_with( + callback_id, b"test-result" + ) + # Verify execution is invoked after callback success + mock_invoke.assert_called_once_with("test-arn") + + +def test_send_callback_failure(executor, mock_store): + """Test send_callback_failure method.""" + from aws_durable_execution_sdk_python_testing.token import CallbackToken + + # Create valid callback token + callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123") + callback_id = callback_token.to_str() + + # Create mock execution with callback operation + mock_execution = Mock() + mock_execution.find_callback_operation.return_value = (0, Mock()) + mock_execution.complete_callback_failure.return_value = Mock() + mock_store.load.return_value = mock_execution + + with patch.object(executor, "_invoke_execution") as mock_invoke: + result = executor.send_callback_failure(callback_id) + + assert isinstance(result, SendDurableExecutionCallbackFailureResponse) + mock_store.load.assert_called_once_with("test-arn") + mock_store.update.assert_called_once_with(mock_execution) + # Verify execution is invoked after callback failure + mock_invoke.assert_called_once_with("test-arn") + + +def test_send_callback_failure_empty_callback_id(executor): + """Test send_callback_failure with empty callback_id.""" + with pytest.raises(InvalidParameterValueException, match="callback_id is required"): + executor.send_callback_failure("") + + +def test_send_callback_failure_none_callback_id(executor): + """Test send_callback_failure with None callback_id.""" + with pytest.raises(InvalidParameterValueException, match="callback_id is required"): + executor.send_callback_failure(None) + + +def test_send_callback_failure_with_error(executor, mock_store): + """Test send_callback_failure with error object.""" + # Create valid callback token + callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123") + callback_id = callback_token.to_str() + + # Create mock execution with callback operation + mock_execution = Mock() + mock_execution.find_callback_operation.return_value = (0, Mock()) + mock_execution.complete_callback_failure.return_value = Mock() + mock_store.load.return_value = mock_execution + + error = ErrorObject.from_message("Test callback error") + with patch.object(executor, "_invoke_execution") as mock_invoke: + result = executor.send_callback_failure(callback_id, error) + + assert isinstance(result, SendDurableExecutionCallbackFailureResponse) + mock_execution.complete_callback_failure.assert_called_once_with(callback_id, error) + # Verify execution is invoked after callback failure + mock_invoke.assert_called_once_with("test-arn") + + +def test_send_callback_heartbeat(executor, mock_store): + """Test send_callback_heartbeat method.""" + # Create valid callback token + callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123") + callback_id = callback_token.to_str() + + # Create mock execution with callback operation + mock_execution = Mock() + mock_operation = Mock() + mock_operation.status = OperationStatus.STARTED + mock_execution.find_callback_operation.return_value = (0, mock_operation) + mock_execution.updates = [] # No callback options to reset timeout + mock_execution.invocation_completions = [] + mock_store.load.return_value = mock_execution + + result = executor.send_callback_heartbeat(callback_id) + + assert isinstance(result, SendDurableExecutionCallbackHeartbeatResponse) + # Called twice: once in get_execution, once in _reset_callback_heartbeat_timeout + assert mock_store.load.call_count == 2 + mock_execution.find_callback_operation.assert_called_once_with(callback_id) + + +def test_send_callback_heartbeat_empty_callback_id(executor): + """Test send_callback_heartbeat with empty callback_id.""" + with pytest.raises(InvalidParameterValueException, match="callback_id is required"): + executor.send_callback_heartbeat("") + + +def test_send_callback_heartbeat_none_callback_id(executor): + """Test send_callback_heartbeat with None callback_id.""" + with pytest.raises(InvalidParameterValueException, match="callback_id is required"): + executor.send_callback_heartbeat(None) + + +def test_complete_execution_no_result(mock_store, executor): + """Test complete_execution when execution has no result after completion.""" + mock_execution = Mock() + mock_execution.result = None # No result after completion + mock_store.load.return_value = mock_execution + + with patch.object(executor, "_complete_events"): + with pytest.raises(IllegalStateException, match="Execution result is required"): + executor.complete_execution("test-arn", "result") + + +def test_fail_execution_no_result(mock_store, executor): + """Test fail_execution when execution has no result after failure.""" + mock_execution = Mock() + mock_execution.result = None # No result after failure + mock_store.load.return_value = mock_execution + error = ErrorObject.from_message("test error") + + with patch.object(executor, "_complete_events"): + with pytest.raises(IllegalStateException, match="Execution result is required"): + executor.fail_execution("test-arn", error) + + +def test_send_callback_heartbeat_inactive_callback(mock_store, executor): + """Test send_callback_heartbeat with inactive callback.""" + + # Create valid callback token + callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123") + callback_id = callback_token.to_str() + + # Create mock execution with inactive callback operation + mock_execution = Mock() + mock_operation = Mock() + mock_operation.status = OperationStatus.SUCCEEDED # Not STARTED + mock_execution.find_callback_operation.return_value = (0, mock_operation) + mock_store.load.return_value = mock_execution + + with pytest.raises(ResourceNotFoundException, match="Callback .* is not active"): + executor.send_callback_heartbeat(callback_id) + + +def test_send_callback_success_invalid_token(executor): + """Test send_callback_success with invalid token format.""" + with pytest.raises( + ResourceNotFoundException, match="Failed to process callback success" + ): + executor.send_callback_success("invalid-token") + + +def test_send_callback_failure_invalid_token(executor): + """Test send_callback_failure with invalid token format.""" + with pytest.raises( + ResourceNotFoundException, match="Failed to process callback failure" + ): + executor.send_callback_failure("invalid-token") + + +def test_send_callback_heartbeat_invalid_token(executor): + """Test send_callback_heartbeat with invalid token format.""" + with pytest.raises( + ResourceNotFoundException, match="Failed to process callback heartbeat" + ): + executor.send_callback_heartbeat("invalid-token") + + +def test_complete_events_no_event(executor): + """Test _complete_events when no event exists.""" + # Should not raise exception when event doesn't exist + executor._complete_events("nonexistent-arn") # Should handle gracefully + + +# Tests for callback timeout functionality + + +def test_callback_timeout_scheduling(executor, mock_store, mock_scheduler): + """Test that callback timeouts are scheduled when callback is created.""" + # Create callback options with both timeouts + callback_options = CallbackOptions(timeout_seconds=60, heartbeat_timeout_seconds=30) + + # Set up completion event + executor._completion_events["test-arn"] = Mock() + + # Test the timeout scheduling directly with correct parameters + executor._schedule_callback_timeouts("test-arn", callback_options, "callback-id") + + # Verify scheduler was called for both timeouts + assert mock_scheduler.call_later.call_count == 2 # main timeout + heartbeat timeout + + +def test_callback_timeout_cleanup(executor, mock_store): + """Test that callback timeouts are cleaned up when callback completes.""" + # Create mock timeout events + timeout_event = Mock() + heartbeat_event = Mock() + + executor._callback_timeouts["callback-id"] = timeout_event + executor._callback_heartbeats["callback-id"] = heartbeat_event + + # Trigger cleanup + executor._cleanup_callback_timeouts("callback-id") + + # Verify events were cancelled and removed + timeout_event.cancel.assert_called_once() + heartbeat_event.cancel.assert_called_once() + assert "callback-id" not in executor._callback_timeouts + assert "callback-id" not in executor._callback_heartbeats + + +def test_callback_heartbeat_timeout_reset(executor, mock_store, mock_scheduler): + """Test that heartbeat timeout is reset when heartbeat is received.""" + + # Create callback token + callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123") + callback_id = callback_token.to_str() + + # Create mock execution with callback options + mock_execution = Mock() + callback_options = CallbackOptions(heartbeat_timeout_seconds=30) + update = OperationUpdate( + operation_id="op-123", + operation_type=OperationType.CALLBACK, + action=OperationAction.START, + callback_options=callback_options, + ) + mock_execution.updates = [update] + + mock_store.load.return_value = mock_execution + mock_scheduler.create_event.return_value = Mock() + + # Set up existing heartbeat event + old_event = Mock() + executor._callback_heartbeats[callback_id] = old_event + + # Reset heartbeat timeout + executor._reset_callback_heartbeat_timeout(callback_id, "test-arn") + + # Verify old event was cancelled and new one scheduled + old_event.cancel.assert_called_once() + mock_scheduler.call_later.assert_called() + + +def test_callback_timeout_handlers(executor, mock_store): + """Test callback timeout and heartbeat timeout handlers.""" + # Create callback token + callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123") + callback_id = callback_token.to_str() + + # Create mock execution + mock_execution = Mock() + mock_execution.is_complete = False + mock_store.load.return_value = mock_execution + + # Test main timeout handler + executor._on_callback_timeout("test-arn", callback_id) + + # Verify callback was failed with timeout error + mock_execution.complete_callback_timeout.assert_called() + timeout_error = mock_execution.complete_callback_timeout.call_args[0][1] + assert "Callback timed out" in str(timeout_error.message) + + # Reset mocks for heartbeat test + mock_execution.reset_mock() + + # Test heartbeat timeout handler + executor._on_callback_heartbeat_timeout("test-arn", callback_id) + + # Verify callback was failed with heartbeat timeout error + mock_execution.complete_callback_timeout.assert_called() + heartbeat_error = mock_execution.complete_callback_timeout.call_args[0][1] + assert "Callback heartbeat timed out" in str(heartbeat_error.message) + + +def test_callback_timeout_completed_execution(executor, mock_store): + """Test that timeout handlers ignore completed executions.""" + + # Create callback token + callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123") + callback_id = callback_token.to_str() + + # Create completed execution + mock_execution = Mock() + mock_execution.is_complete = True + mock_store.load.return_value = mock_execution + + # Test timeout handlers with completed execution + executor._on_callback_timeout("test-arn", callback_id) + executor._on_callback_heartbeat_timeout("test-arn", callback_id) + + # Verify no callback operations were performed + mock_execution.complete_callback_timeout.assert_not_called() + mock_store.update.assert_not_called() + + +def test_schedule_callback_timeouts_no_callback_details(executor, mock_store): + """Test _schedule_callback_timeouts when operation has no callback details.""" + + # Create operation without callback details + operation = Operation( + operation_id="op-123", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=None, + ) + + mock_execution = Mock() + mock_execution.find_operation.return_value = (0, operation) + mock_store.load.return_value = mock_execution + + # Should return early without scheduling + executor._schedule_callback_timeouts("test-arn", "op-123", "callback-id") + + # No scheduler calls should be made + assert len(executor._callback_timeouts) == 0 + assert len(executor._callback_heartbeats) == 0 + + +def test_schedule_callback_timeouts_no_callback_options(executor, mock_store): + """Test _schedule_callback_timeouts when no callback options are found.""" + + # Create operation with callback details but no matching updates + operation = Operation( + operation_id="op-123", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=CallbackDetails(callback_id="callback-id"), + ) + + mock_execution = Mock() + mock_execution.find_operation.return_value = (0, operation) + mock_execution.updates = [] # No updates with callback options + mock_execution.invocation_completions = [] + mock_store.load.return_value = mock_execution + + # Should return early without scheduling + executor._schedule_callback_timeouts("test-arn", "op-123", "callback-id") + + # No scheduler calls should be made + assert len(executor._callback_timeouts) == 0 + assert len(executor._callback_heartbeats) == 0 + + +def test_schedule_callback_timeouts_zero_timeouts(executor, mock_store, mock_scheduler): + """Test _schedule_callback_timeouts with zero timeout values.""" + # Create operation with callback details + operation = Operation( + operation_id="op-123", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=CallbackDetails(callback_id="callback-id"), + ) + + mock_execution = Mock() + mock_execution.find_operation.return_value = (0, operation) + + # Create update with zero timeouts (disabled) + callback_options = CallbackOptions(timeout_seconds=0, heartbeat_timeout_seconds=0) + update = OperationUpdate( + operation_id="op-123", + operation_type=OperationType.CALLBACK, + action=OperationAction.START, + callback_options=callback_options, + ) + mock_execution.updates = [update] + + mock_store.load.return_value = mock_execution + executor._completion_events["test-arn"] = Mock() + + # Should not schedule any timeouts + executor._schedule_callback_timeouts("test-arn", "op-123", "callback-id") + + # No scheduler calls should be made + mock_scheduler.call_later.assert_not_called() + assert len(executor._callback_timeouts) == 0 + assert len(executor._callback_heartbeats) == 0 + + +def test_schedule_callback_timeouts_only_main_timeout( + executor, mock_store, mock_scheduler +): + """Test _schedule_callback_timeouts with only main timeout configured.""" + + # Create callback options with only main timeout + callback_options = CallbackOptions(timeout_seconds=60, heartbeat_timeout_seconds=0) + + executor._completion_events["test-arn"] = Mock() + + executor._schedule_callback_timeouts("test-arn", callback_options, "callback-id") + + # Only main timeout should be scheduled + assert mock_scheduler.call_later.call_count == 1 + assert len(executor._callback_timeouts) == 1 + assert len(executor._callback_heartbeats) == 0 + + +def test_schedule_callback_timeouts_only_heartbeat_timeout( + executor, mock_store, mock_scheduler +): + """Test _schedule_callback_timeouts with only heartbeat timeout configured.""" + # Create callback options with only heartbeat timeout + callback_options = CallbackOptions(timeout_seconds=0, heartbeat_timeout_seconds=30) + + executor._completion_events["test-arn"] = Mock() + + executor._schedule_callback_timeouts("test-arn", callback_options, "callback-id") + + # Only heartbeat timeout should be scheduled + assert mock_scheduler.call_later.call_count == 1 + assert len(executor._callback_timeouts) == 0 + assert len(executor._callback_heartbeats) == 1 + + +def test_schedule_callback_timeouts_exception_handling(executor, mock_store): + """Test _schedule_callback_timeouts handles exceptions gracefully.""" + # Make get_execution raise an exception + mock_store.load.side_effect = Exception("Test error") + + # Should not raise exception + executor._schedule_callback_timeouts("test-arn", "op-123", "callback-id") + + # No timeouts should be scheduled + assert len(executor._callback_timeouts) == 0 + assert len(executor._callback_heartbeats) == 0 + + +def test_on_timed_out(executor, mock_store): + """Test on_timed_out method.""" + # Create real execution instance + mock_start_input = Mock() + mock_start_input.execution_name = "test-execution" + mock_start_input.function_name = "test-function" + + execution = Execution( + durable_execution_arn="test-arn", + start_input=mock_start_input, + operations=[Mock()], + ) + execution.is_complete = False + mock_store.load.return_value = execution + + error = ErrorObject.from_message("Execution timeout") + + with patch.object(executor, "_complete_events") as mock_complete_events: + executor.on_timed_out("test-arn", error) + + mock_store.load.assert_called_once_with(execution_arn="test-arn") + mock_store.update.assert_called_once_with(execution) + mock_complete_events.assert_called_once_with(execution_arn="test-arn") + assert execution.is_complete is True + assert execution.close_status == ExecutionStatus.TIMED_OUT + assert execution.result.error == error + + +def test_on_stopped(executor): + """Test on_stopped method.""" + error = ErrorObject.from_message("Execution stopped") + + with patch.object(executor, "fail_execution") as mock_fail: + executor.on_stopped("test-arn", error) + + mock_fail.assert_called_once_with("test-arn", error) + + +def test_notify_timed_out(): + """Test notify_timed_out method.""" + notifier = ExecutionNotifier() + observer = Mock() + notifier.add_observer(observer) + + error = ErrorObject.from_message("Timeout error") + notifier.notify_timed_out("test-arn", error) + + observer.on_timed_out.assert_called_once_with(execution_arn="test-arn", error=error) + + +def test_notify_stopped(): + """Test notify_stopped method.""" + notifier = ExecutionNotifier() + observer = Mock() + notifier.add_observer(observer) + + error = ErrorObject.from_message("Stop error") + notifier.notify_stopped("test-arn", error) + + observer.on_stopped.assert_called_once_with(execution_arn="test-arn", error=error) diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/invoker_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/invoker_test.py new file mode 100644 index 0000000..1270f50 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/invoker_test.py @@ -0,0 +1,652 @@ +"""Tests for invoker module.""" + +import json +from unittest.mock import Mock, patch + +import pytest +from aws_durable_execution_sdk_python.execution import ( + DurableExecutionInvocationInput, + DurableExecutionInvocationInputWithClient, + DurableExecutionInvocationOutput, + InitialExecutionState, + InvocationStatus, +) + +from aws_durable_execution_sdk_python.lambda_service import ( + ExecutionDetails, + Operation, + OperationStatus, + OperationType, +) + +from datetime import datetime, UTC + +from aws_durable_execution_sdk_python_testing.execution import Execution +from aws_durable_execution_sdk_python_testing.invoker import ( + InProcessInvoker, + LambdaInvoker, + _LAMBDA_CLIENT_CONFIG, + create_lambda_client, + create_test_lambda_context, +) +from aws_durable_execution_sdk_python_testing.model import ( + LambdaContext, + StartDurableExecutionInput, +) + + +def test_create_test_lambda_context(): + """Test creating a test lambda context.""" + context = create_test_lambda_context() + + assert ( + context.invoked_function_arn + == "arn:aws:lambda:us-west-2:123456789012:function:test-function" + ) + assert context.tenant_id == "test-tenant-789" + assert context.client_context is not None + + +def test_in_process_invoker_init(): + """Test InProcessInvoker initialization.""" + handler = Mock() + service_client = Mock() + + invoker = InProcessInvoker(handler, service_client) + + assert invoker.handler is handler + assert invoker.service_client is service_client + + +def test_in_process_invoker_create_invocation_input(): + """Test creating invocation input for in-process invoker.""" + handler = Mock() + service_client = Mock() + invoker = InProcessInvoker(handler, service_client) + + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution.new(input_data) + + invocation_input = invoker.create_invocation_input(execution) + + assert isinstance(invocation_input, DurableExecutionInvocationInputWithClient) + assert invocation_input.durable_execution_arn == execution.durable_execution_arn + assert invocation_input.checkpoint_token is not None + assert isinstance(invocation_input.initial_execution_state, InitialExecutionState) + assert invocation_input.service_client is service_client + + +def test_in_process_invoker_invoke(): + """Test invoking function with in-process invoker.""" + # Mock handler that returns a valid response + handler = Mock() + handler.return_value = {"Status": "SUCCEEDED", "Result": "test-result"} + + service_client = Mock() + invoker = InProcessInvoker(handler, service_client) + + input_data = DurableExecutionInvocationInput( + durable_execution_arn="test-arn", + checkpoint_token="test-token", # noqa: S106 + initial_execution_state=InitialExecutionState(operations=[], next_marker=""), + ) + + response = invoker.invoke("test-function", input_data) + + assert isinstance(response.invocation_output, DurableExecutionInvocationOutput) + assert response.invocation_output.status == InvocationStatus.SUCCEEDED + assert response.invocation_output.result == "test-result" + assert isinstance(response.request_id, str) + + # Verify handler was called with correct arguments + handler.assert_called_once() + call_args = handler.call_args[0] + assert isinstance(call_args[0], DurableExecutionInvocationInputWithClient) + assert isinstance(call_args[1], LambdaContext) + + +def test_lambda_invoker_init(): + """Test LambdaInvoker initialization.""" + lambda_client = Mock() + + invoker = LambdaInvoker(lambda_client) + + assert invoker.lambda_client is lambda_client + + +def test_lambda_invoker_create(): + """Test creating LambdaInvoker with boto3 client.""" + with patch("aws_durable_execution_sdk_python_testing.invoker.boto3") as mock_boto3: + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + invoker = LambdaInvoker.create("http://localhost:3001", "us-west-2") + + assert isinstance(invoker, LambdaInvoker) + assert invoker.lambda_client is mock_client + mock_boto3.client.assert_called_once_with( + "lambda", + endpoint_url="http://localhost:3001", + region_name="us-west-2", + config=_LAMBDA_CLIENT_CONFIG, + ) + + +def test_lambda_invoker_create_invocation_input(): + """Test creating invocation input for lambda invoker.""" + lambda_client = Mock() + invoker = LambdaInvoker(lambda_client) + + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation", + ) + execution = Execution.new(input_data) + + invocation_input = invoker.create_invocation_input(execution) + + assert isinstance(invocation_input, DurableExecutionInvocationInput) + assert invocation_input.durable_execution_arn == execution.durable_execution_arn + assert invocation_input.checkpoint_token is not None + assert isinstance(invocation_input.initial_execution_state, InitialExecutionState) + + +def test_lambda_invoker_invoke_success(): + """Test successful lambda invocation.""" + lambda_client = Mock() + + # Mock successful response + mock_payload = Mock() + mock_payload.read.return_value = json.dumps( + {"Status": "SUCCEEDED", "Result": "lambda-result"} + ).encode("utf-8") + + lambda_client.invoke.return_value = { + "StatusCode": 200, + "Payload": mock_payload, + "ResponseMetadata": {"HTTPHeaders": {"x-amzn-RequestId": "test-request-id"}}, + } + + invoker = LambdaInvoker(lambda_client) + + mock_operation = Operation( + operation_id="op-1", + parent_id=None, + name="test-execution", + start_timestamp=datetime.now(UTC), + end_timestamp=datetime.now(UTC), + operation_type=OperationType.EXECUTION, + status=OperationStatus.SUCCEEDED, + execution_details=ExecutionDetails(input_payload='{"test": "data"}'), + ) + + input_data = DurableExecutionInvocationInput( + durable_execution_arn="test-arn", + checkpoint_token="test-token", # noqa: S106 + initial_execution_state=InitialExecutionState( + operations=[mock_operation], next_marker="" + ), + ) + + response = invoker.invoke("test-function", input_data) + + assert isinstance(response.invocation_output, DurableExecutionInvocationOutput) + assert response.invocation_output.status == InvocationStatus.SUCCEEDED + assert response.invocation_output.result == "lambda-result" + assert response.request_id == "test-request-id" + + # Verify lambda client was called correctly + lambda_client.invoke.assert_called_once_with( + FunctionName="test-function", + InvocationType="RequestResponse", + Payload=json.dumps(input_data.to_json_dict()), + ) + + +def test_lambda_invoker_invoke_failure(): + """Test lambda invocation failure.""" + from aws_durable_execution_sdk_python_testing.exceptions import ( + DurableFunctionsTestError, + ) + + lambda_client, _ = _create_mock_lambda_client_with_exceptions() + + # Mock failed response + mock_payload = Mock() + lambda_client.invoke.return_value = { + "StatusCode": 500, + "Payload": mock_payload, + } + + invoker = LambdaInvoker(lambda_client) + + input_data = DurableExecutionInvocationInput( + durable_execution_arn="test-arn", + checkpoint_token="test-token", # noqa: S106 + initial_execution_state=InitialExecutionState(operations=[], next_marker=""), + ) + + with pytest.raises( + DurableFunctionsTestError, + match="Lambda invocation failed with status code: 500", + ): + invoker.invoke("test-function", input_data) + + +def test_in_process_invoker_invoke_with_execution_operations(): + """Test in-process invoker with execution that has operations.""" + handler = Mock() + handler.return_value = {"Status": "SUCCEEDED", "Result": None} + + service_client = Mock() + invoker = InProcessInvoker(handler, service_client) + + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation", + ) + execution = Execution.new(input_data) + execution.start() # This adds operations + + invocation_input = invoker.create_invocation_input(execution) + response = invoker.invoke("test-function", invocation_input) + + assert isinstance(response.invocation_output, DurableExecutionInvocationOutput) + assert isinstance(response.request_id, str) + assert response.invocation_output.status == InvocationStatus.SUCCEEDED + assert len(invocation_input.initial_execution_state.operations) > 0 + + +def test_lambda_invoker_create_invocation_input_with_operations(): + """Test lambda invoker creating input with execution operations.""" + lambda_client = Mock() + invoker = LambdaInvoker(lambda_client) + + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation", + ) + execution = Execution.new(input_data) + execution.start() # This adds operations + + invocation_input = invoker.create_invocation_input(execution) + + assert isinstance(invocation_input, DurableExecutionInvocationInput) + assert len(invocation_input.initial_execution_state.operations) > 0 + assert invocation_input.initial_execution_state.next_marker == "" + + +def test_lambda_invoker_invoke_empty_function_name(): + """Test lambda invocation with empty function name.""" + from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, + ) + + lambda_client = Mock() + invoker = LambdaInvoker(lambda_client) + + input_data = DurableExecutionInvocationInput( + durable_execution_arn="test-arn", + checkpoint_token="test-token", + initial_execution_state=InitialExecutionState(operations=[], next_marker=""), + ) + + with pytest.raises( + InvalidParameterValueException, match="Function name is required" + ): + invoker.invoke("", input_data) + + +def test_lambda_invoker_invoke_whitespace_function_name(): + """Test lambda invocation with whitespace-only function name.""" + from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, + ) + + lambda_client = Mock() + invoker = LambdaInvoker(lambda_client) + + input_data = DurableExecutionInvocationInput( + durable_execution_arn="test-arn", + checkpoint_token="test-token", + initial_execution_state=InitialExecutionState(operations=[], next_marker=""), + ) + + with pytest.raises( + InvalidParameterValueException, match="Function name is required" + ): + invoker.invoke(" ", input_data) + + +def test_lambda_invoker_invoke_status_202(): + """Test lambda invocation with status code 202.""" + lambda_client = Mock() + + mock_payload = Mock() + mock_payload.read.return_value = json.dumps( + {"Status": "SUCCEEDED", "Result": "async-result"} + ).encode("utf-8") + + lambda_client.invoke.return_value = { + "StatusCode": 202, + "Payload": mock_payload, + "ResponseMetadata": { + "HTTPHeaders": {"x-amzn-RequestId": "test-request-id-202"} + }, + } + + invoker = LambdaInvoker(lambda_client) + + input_data = DurableExecutionInvocationInput( + durable_execution_arn="test-arn", + checkpoint_token="test-token", + initial_execution_state=InitialExecutionState(operations=[], next_marker=""), + ) + + response = invoker.invoke("test-function", input_data) + assert isinstance(response.invocation_output, DurableExecutionInvocationOutput) + assert response.request_id == "test-request-id-202" + + +def test_lambda_invoker_invoke_function_error(): + """Test lambda invocation with function error.""" + from aws_durable_execution_sdk_python_testing.exceptions import ( + DurableFunctionsTestError, + ) + + lambda_client, _ = _create_mock_lambda_client_with_exceptions() + + mock_payload = Mock() + mock_payload.read.return_value = b'{"errorMessage": "Function failed"}' + + lambda_client.invoke.return_value = { + "StatusCode": 200, + "FunctionError": "Unhandled", + "Payload": mock_payload, + } + + invoker = LambdaInvoker(lambda_client) + + input_data = DurableExecutionInvocationInput( + durable_execution_arn="test-arn", + checkpoint_token="test-token", + initial_execution_state=InitialExecutionState(operations=[], next_marker=""), + ) + + with pytest.raises( + DurableFunctionsTestError, match="Lambda invocation failed with status 200" + ): + invoker.invoke("test-function", input_data) + + +def _create_mock_lambda_client_with_exceptions(): + """Helper to create mock lambda client with all exception types.""" + lambda_client = Mock() + + class MockException(Exception): + pass + + exceptions_mock = Mock() + for exc_name in [ + "ResourceNotFoundException", + "InvalidParameterValueException", + "TooManyRequestsException", + "ServiceException", + "ResourceConflictException", + "InvalidRequestContentException", + "RequestTooLargeException", + "UnsupportedMediaTypeException", + "InvalidRuntimeException", + "InvalidZipFileException", + "ResourceNotReadyException", + "SnapStartTimeoutException", + "SnapStartNotReadyException", + "SnapStartException", + "RecursiveInvocationException", + "InvalidSecurityGroupIDException", + "EC2ThrottledException", + "EFSMountConnectivityException", + "SubnetIPAddressLimitReachedException", + "EC2UnexpectedException", + "InvalidSubnetIDException", + "EC2AccessDeniedException", + "EFSIOException", + "ENILimitReachedException", + "EFSMountTimeoutException", + "EFSMountFailureException", + "KMSAccessDeniedException", + "KMSDisabledException", + "KMSNotFoundException", + "KMSInvalidStateException", + ]: + setattr(exceptions_mock, exc_name, MockException) + + lambda_client.exceptions = exceptions_mock + return lambda_client, MockException + + +def test_lambda_invoker_invoke_resource_not_found(): + """Test lambda invocation with ResourceNotFoundException.""" + from aws_durable_execution_sdk_python_testing.exceptions import ( + ResourceNotFoundException, + ) + + lambda_client, _ = _create_mock_lambda_client_with_exceptions() + + # Create specific exception for ResourceNotFoundException + class MockResourceNotFoundException(Exception): + pass + + lambda_client.exceptions.ResourceNotFoundException = MockResourceNotFoundException + + lambda_client.invoke.side_effect = MockResourceNotFoundException( + "Function not found" + ) + + invoker = LambdaInvoker(lambda_client) + + input_data = DurableExecutionInvocationInput( + durable_execution_arn="test-arn", + checkpoint_token="test-token", + initial_execution_state=InitialExecutionState(operations=[], next_marker=""), + ) + + with pytest.raises( + ResourceNotFoundException, match="Function not found: test-function" + ): + invoker.invoke("test-function", input_data) + + +def test_lambda_invoker_invoke_invalid_parameter(): + """Test lambda invocation with InvalidParameterValueException.""" + from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, + ) + + lambda_client, MockException = _create_mock_lambda_client_with_exceptions() + + # Override specific exception for this test + class MockInvalidParameterValueException(Exception): + pass + + lambda_client.exceptions.InvalidParameterValueException = ( + MockInvalidParameterValueException + ) + + lambda_client.invoke.side_effect = MockInvalidParameterValueException( + "Invalid param" + ) + + invoker = LambdaInvoker(lambda_client) + + input_data = DurableExecutionInvocationInput( + durable_execution_arn="test-arn", + checkpoint_token="test-token", + initial_execution_state=InitialExecutionState(operations=[], next_marker=""), + ) + + with pytest.raises(InvalidParameterValueException, match="Invalid parameter"): + invoker.invoke("test-function", input_data) + + +def test_lambda_invoker_invoke_service_exception(): + """Test lambda invocation with ServiceException.""" + from aws_durable_execution_sdk_python_testing.exceptions import ( + DurableFunctionsTestError, + ) + + lambda_client, _ = _create_mock_lambda_client_with_exceptions() + + # Create specific exception for ServiceException + class MockServiceException(Exception): + pass + + lambda_client.exceptions.ServiceException = MockServiceException + + lambda_client.invoke.side_effect = MockServiceException("Service error") + + invoker = LambdaInvoker(lambda_client) + + input_data = DurableExecutionInvocationInput( + durable_execution_arn="test-arn", + checkpoint_token="test-token", + initial_execution_state=InitialExecutionState(operations=[], next_marker=""), + ) + + with pytest.raises(DurableFunctionsTestError, match="Lambda invocation failed"): + invoker.invoke("test-function", input_data) + + +def test_lambda_invoker_invoke_ec2_exception(): + """Test lambda invocation with EC2 exception.""" + from aws_durable_execution_sdk_python_testing.exceptions import ( + DurableFunctionsTestError, + ) + + lambda_client, _ = _create_mock_lambda_client_with_exceptions() + + # Create specific exception for EC2AccessDeniedException + class MockEC2Exception(Exception): + pass + + lambda_client.exceptions.EC2AccessDeniedException = MockEC2Exception + + lambda_client.invoke.side_effect = MockEC2Exception("Access denied") + + invoker = LambdaInvoker(lambda_client) + + input_data = DurableExecutionInvocationInput( + durable_execution_arn="test-arn", + checkpoint_token="test-token", + initial_execution_state=InitialExecutionState(operations=[], next_marker=""), + ) + + with pytest.raises(DurableFunctionsTestError, match="Lambda infrastructure error"): + invoker.invoke("test-function", input_data) + + +def test_lambda_invoker_invoke_kms_exception(): + """Test lambda invocation with KMS exception.""" + from aws_durable_execution_sdk_python_testing.exceptions import ( + DurableFunctionsTestError, + ) + + lambda_client, _ = _create_mock_lambda_client_with_exceptions() + + # Create specific exception for KMSAccessDeniedException + class MockKMSException(Exception): + pass + + lambda_client.exceptions.KMSAccessDeniedException = MockKMSException + + lambda_client.invoke.side_effect = MockKMSException("KMS access denied") + + invoker = LambdaInvoker(lambda_client) + + input_data = DurableExecutionInvocationInput( + durable_execution_arn="test-arn", + checkpoint_token="test-token", + initial_execution_state=InitialExecutionState(operations=[], next_marker=""), + ) + + with pytest.raises(DurableFunctionsTestError, match="Lambda KMS error"): + invoker.invoke("test-function", input_data) + + +def test_lambda_invoker_invoke_durable_execution_already_started(): + """Test lambda invocation with DurableExecutionAlreadyStartedException.""" + from aws_durable_execution_sdk_python_testing.exceptions import ( + DurableFunctionsTestError, + ) + + lambda_client, _ = _create_mock_lambda_client_with_exceptions() + + class MockDurableExecutionAlreadyStartedException(Exception): + pass + + MockDurableExecutionAlreadyStartedException.__name__ = ( + "DurableExecutionAlreadyStartedException" + ) + + lambda_client.invoke.side_effect = MockDurableExecutionAlreadyStartedException( + "Already started" + ) + + invoker = LambdaInvoker(lambda_client) + + input_data = DurableExecutionInvocationInput( + durable_execution_arn="test-arn", + checkpoint_token="test-token", + initial_execution_state=InitialExecutionState(operations=[], next_marker=""), + ) + + with pytest.raises( + DurableFunctionsTestError, match="Durable execution already started" + ): + invoker.invoke("test-function", input_data) + + +def test_lambda_invoker_invoke_unexpected_exception(): + """Test lambda invocation with unexpected exception.""" + from aws_durable_execution_sdk_python_testing.exceptions import ( + DurableFunctionsTestError, + ) + + lambda_client, _ = _create_mock_lambda_client_with_exceptions() + lambda_client.invoke.side_effect = RuntimeError("Unexpected error") + + invoker = LambdaInvoker(lambda_client) + + input_data = DurableExecutionInvocationInput( + durable_execution_arn="test-arn", + checkpoint_token="test-token", + initial_execution_state=InitialExecutionState(operations=[], next_marker=""), + ) + + with pytest.raises( + DurableFunctionsTestError, match="Unexpected error during Lambda invocation" + ): + invoker.invoke("test-function", input_data) diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/model_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/model_test.py new file mode 100644 index 0000000..10076c0 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/model_test.py @@ -0,0 +1,3667 @@ +"""Tests for model serialization dataclasses.""" + +from __future__ import annotations + +import datetime +import json + +import pytest +from aws_durable_execution_sdk_python.lambda_service import ( + OperationStatus, + OperationType, +) + +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) +from aws_durable_execution_sdk_python_testing.model import ( + CallbackFailedDetails, + CallbackStartedDetails, + CallbackSucceededDetails, + CallbackTimedOutDetails, + ChainedInvokeFailedDetails, + ChainedInvokeStartedDetails, + ChainedInvokeStoppedDetails, + ChainedInvokeSucceededDetails, + ChainedInvokeTimedOutDetails, + CheckpointDurableExecutionRequest, + CheckpointDurableExecutionResponse, + CheckpointUpdatedExecutionState, + ContextFailedDetails, + ContextStartedDetails, + ContextSucceededDetails, + ErrorResponse, + Event, + EventError, + EventInput, + EventResult, + Execution, + ExecutionFailedDetails, + ExecutionStartedDetails, + ExecutionStoppedDetails, + ExecutionSucceededDetails, + ExecutionTimedOutDetails, + GetDurableExecutionHistoryRequest, + GetDurableExecutionHistoryResponse, + GetDurableExecutionRequest, + GetDurableExecutionResponse, + GetDurableExecutionStateRequest, + GetDurableExecutionStateResponse, + InvocationCompletedDetails, + ListDurableExecutionsByFunctionRequest, + ListDurableExecutionsByFunctionResponse, + ListDurableExecutionsRequest, + ListDurableExecutionsResponse, + RetryDetails, + SendDurableExecutionCallbackFailureRequest, + SendDurableExecutionCallbackFailureResponse, + SendDurableExecutionCallbackHeartbeatRequest, + SendDurableExecutionCallbackHeartbeatResponse, + SendDurableExecutionCallbackSuccessRequest, + SendDurableExecutionCallbackSuccessResponse, + StartDurableExecutionInput, + StartDurableExecutionOutput, + StepFailedDetails, + StepStartedDetails, + StepSucceededDetails, + StopDurableExecutionRequest, + StopDurableExecutionResponse, + WaitCancelledDetails, + WaitStartedDetails, + WaitSucceededDetails, + events_to_operations, +) + + +# Test timestamp constants +TIMESTAMP_2023_01_01_00_00 = datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC) +TIMESTAMP_2023_01_01_00_01 = datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC) +TIMESTAMP_2023_01_01_00_02 = datetime.datetime(2023, 1, 1, 0, 2, 0, tzinfo=datetime.UTC) +TIMESTAMP_2023_01_02_00_00 = datetime.datetime(2023, 1, 2, 0, 0, 0, tzinfo=datetime.UTC) + +DEFAULT_START_DURABLE_EXECUTION_INPUT_DATA = { + "AccountId": "123456789012", + "FunctionName": "my-function", + "FunctionQualifier": "$LATEST", + "ExecutionName": "test-execution", + "ExecutionTimeoutSeconds": 300, + "ExecutionRetentionPeriodDays": 7, + "InvocationId": "invocation-123", + "TraceFields": {"key": "value"}, + "TenantId": "tenant-123", + "Input": "test-input", +} + + +def test_start_durable_execution_input_serialization(): + """Test StartDurableExecutionInput from_dict/to_dict round-trip.""" + data = DEFAULT_START_DURABLE_EXECUTION_INPUT_DATA + + # Test from_dict + input_obj = StartDurableExecutionInput.from_dict(data) + assert input_obj.account_id == "123456789012" + assert input_obj.function_name == "my-function" + assert input_obj.function_qualifier == "$LATEST" + assert input_obj.execution_name == "test-execution" + assert input_obj.execution_timeout_seconds == 300 + assert input_obj.execution_retention_period_days == 7 + assert input_obj.invocation_id == "invocation-123" + assert input_obj.trace_fields == {"key": "value"} + assert input_obj.tenant_id == "tenant-123" + assert input_obj.input == "test-input" + + # Test to_dict + result_data = input_obj.to_dict() + assert result_data == data + + # Test round-trip + round_trip = StartDurableExecutionInput.from_dict(result_data) + assert round_trip == input_obj + + +def test_start_durable_execution_input_get_input_json_input(): + """Test StartDurableExecutionInput from_dict/to_dict round-trip.""" + data = DEFAULT_START_DURABLE_EXECUTION_INPUT_DATA + data["Input"] = '{"message": "hello"}' + + input_obj = StartDurableExecutionInput.from_dict(data) + assert '{"message": "hello"}' == input_obj.get_normalized_input() + + +def test_start_durable_execution_input_get_input_str_non_json_input(): + """Test StartDurableExecutionInput from_dict/to_dict round-trip.""" + data = DEFAULT_START_DURABLE_EXECUTION_INPUT_DATA + data["Input"] = "hello" + + input_obj = StartDurableExecutionInput.from_dict(data) + assert '"hello"' == input_obj.get_normalized_input() + + +def test_start_durable_execution_input_get_input_str_json_input(): + """Test StartDurableExecutionInput from_dict/to_dict round-trip.""" + data = DEFAULT_START_DURABLE_EXECUTION_INPUT_DATA + data["Input"] = '"hello"' + + input_obj = StartDurableExecutionInput.from_dict(data) + assert '"hello"' == input_obj.get_normalized_input() + + +def test_start_durable_execution_input_get_input_list_json_input(): + """Test StartDurableExecutionInput from_dict/to_dict round-trip.""" + data = DEFAULT_START_DURABLE_EXECUTION_INPUT_DATA + data["Input"] = "[1,2,3]" + + input_obj = StartDurableExecutionInput.from_dict(data) + assert "[1,2,3]" == input_obj.get_normalized_input() + + +def test_start_durable_execution_input_minimal(): + """Test StartDurableExecutionInput with only required fields.""" + data = { + "AccountId": "123456789012", + "FunctionName": "my-function", + "FunctionQualifier": "$LATEST", + "ExecutionName": "test-execution", + "ExecutionTimeoutSeconds": 300, + "ExecutionRetentionPeriodDays": 7, + } + + input_obj = StartDurableExecutionInput.from_dict(data) + assert input_obj.invocation_id is None + assert input_obj.trace_fields is None + assert input_obj.tenant_id is None + assert input_obj.input is None + + result_data = input_obj.to_dict() + assert result_data == data + + +def test_start_durable_execution_output_serialization(): + """Test StartDurableExecutionOutput from_dict/to_dict round-trip.""" + data = { + "ExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test" + } + + output_obj = StartDurableExecutionOutput.from_dict(data) + assert ( + output_obj.execution_arn + == "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test" + ) + + result_data = output_obj.to_dict() + assert result_data == data + + # Test round-trip + round_trip = StartDurableExecutionOutput.from_dict(result_data) + assert round_trip == output_obj + + +def test_start_durable_execution_output_empty(): + """Test StartDurableExecutionOutput with empty data.""" + data = {} + + output_obj = StartDurableExecutionOutput.from_dict(data) + assert output_obj.execution_arn is None + + result_data = output_obj.to_dict() + assert result_data == {} + + +def test_get_durable_execution_request_serialization(): + """Test GetDurableExecutionRequest from_dict/to_dict round-trip.""" + data = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test" + } + + request_obj = GetDurableExecutionRequest.from_dict(data) + assert ( + request_obj.durable_execution_arn + == "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test" + ) + + result_data = request_obj.to_dict() + assert result_data == data + + # Test round-trip + round_trip = GetDurableExecutionRequest.from_dict(result_data) + assert round_trip == request_obj + + +def test_get_durable_execution_response_serialization(): + """Test GetDurableExecutionResponse from_dict/to_dict round-trip.""" + data = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test", + "DurableExecutionName": "test-execution", + "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function", + "Status": "SUCCEEDED", + "StartTimestamp": TIMESTAMP_2023_01_01_00_00, + "InputPayload": "test-input", + "Result": "test-result", + "Error": {"ErrorMessage": "test error"}, + "EndTimestamp": TIMESTAMP_2023_01_01_00_01, + "Version": "1.0", + } + + response_obj = GetDurableExecutionResponse.from_dict(data) + assert ( + response_obj.durable_execution_arn + == "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test" + ) + assert response_obj.durable_execution_name == "test-execution" + assert ( + response_obj.function_arn + == "arn:aws:lambda:us-east-1:123456789012:function:my-function" + ) + assert response_obj.status == "SUCCEEDED" + assert response_obj.start_timestamp == TIMESTAMP_2023_01_01_00_00 + assert response_obj.input_payload == "test-input" + assert response_obj.result == "test-result" + assert response_obj.error.message == "test error" + assert response_obj.end_timestamp == TIMESTAMP_2023_01_01_00_01 + assert response_obj.version == "1.0" + + result_data = response_obj.to_dict() + assert result_data == data + + # Test round-trip + round_trip = GetDurableExecutionResponse.from_dict(result_data) + assert round_trip == response_obj + + +def test_get_durable_execution_response_minimal(): + """Test GetDurableExecutionResponse with only required fields.""" + data = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test", + "DurableExecutionName": "test-execution", + "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function", + "Status": "RUNNING", + "StartTimestamp": TIMESTAMP_2023_01_01_00_00, + } + + response_obj = GetDurableExecutionResponse.from_dict(data) + assert response_obj.input_payload is None + assert response_obj.result is None + assert response_obj.error is None + assert response_obj.end_timestamp is None + assert response_obj.version is None + + result_data = response_obj.to_dict() + assert result_data == data + + +def test_list_durable_executions_request_serialization(): + """Test ListDurableExecutionsRequest from_dict/to_dict round-trip.""" + data = { + "FunctionName": "my-function", + "FunctionVersion": "$LATEST", + "DurableExecutionName": "test-execution", + "StatusFilter": ["RUNNING", "SUCCEEDED"], + "StartedAfter": TIMESTAMP_2023_01_01_00_00, + "StartedBefore": TIMESTAMP_2023_01_02_00_00, + "Marker": "marker-123", + "MaxItems": 10, + "ReverseOrder": True, + } + + request_obj = ListDurableExecutionsRequest.from_dict(data) + assert request_obj.function_name == "my-function" + assert request_obj.function_version == "$LATEST" + assert request_obj.durable_execution_name == "test-execution" + assert request_obj.status_filter == ["RUNNING", "SUCCEEDED"] + assert request_obj.started_after == TIMESTAMP_2023_01_01_00_00 + assert request_obj.started_before == TIMESTAMP_2023_01_02_00_00 + assert request_obj.marker == "marker-123" + assert request_obj.max_items == 10 + assert request_obj.reverse_order is True + + result_data = request_obj.to_dict() + assert result_data == data + + # Test round-trip + round_trip = ListDurableExecutionsRequest.from_dict(result_data) + assert round_trip == request_obj + + +def test_list_durable_executions_request_empty(): + """Test ListDurableExecutionsRequest with empty data.""" + data = {} + + request_obj = ListDurableExecutionsRequest.from_dict(data) + assert request_obj.function_name is None + assert request_obj.function_version is None + assert request_obj.durable_execution_name is None + assert request_obj.status_filter is None + assert request_obj.started_after is None + assert request_obj.started_before is None + assert request_obj.marker is None + assert request_obj.max_items == 0 # Default value from Smithy + assert request_obj.reverse_order is None + + result_data = request_obj.to_dict() + # The result should include the default MaxItems + expected_data = {"MaxItems": 0} + assert result_data == expected_data + + +def test_durable_execution_summary_serialization(): + """Test Execution from_dict/to_dict round-trip.""" + data = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test", + "DurableExecutionName": "test-execution", + "Status": "SUCCEEDED", + "StartTimestamp": TIMESTAMP_2023_01_01_00_00, + "EndTimestamp": TIMESTAMP_2023_01_01_00_01, + } + + summary_obj = Execution.from_dict(data) + assert ( + summary_obj.durable_execution_arn + == "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test" + ) + assert summary_obj.durable_execution_name == "test-execution" + assert summary_obj.status == "SUCCEEDED" + assert summary_obj.start_timestamp == TIMESTAMP_2023_01_01_00_00 + assert summary_obj.end_timestamp == TIMESTAMP_2023_01_01_00_01 + + result_data = summary_obj.to_dict() + assert result_data == data + + # Test round-trip + round_trip = Execution.from_dict(result_data) + assert round_trip == summary_obj + + +def test_durable_execution_summary_no_end_timestamp(): + """Test Execution without end timestamp.""" + data = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test", + "DurableExecutionName": "test-execution", + "Status": "RUNNING", + "StartTimestamp": TIMESTAMP_2023_01_01_00_00, + } + + summary_obj = Execution.from_dict(data) + assert summary_obj.end_timestamp is None + + result_data = summary_obj.to_dict() + assert result_data == data + + +def test_list_durable_executions_response_serialization(): + """Test ListDurableExecutionsResponse from_dict/to_dict round-trip.""" + data = { + "DurableExecutions": [ + { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test1", + "DurableExecutionName": "test-execution-1", + "Status": "SUCCEEDED", + "StartTimestamp": TIMESTAMP_2023_01_01_00_00, + "EndTimestamp": TIMESTAMP_2023_01_01_00_01, + }, + { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test2", + "DurableExecutionName": "test-execution-2", + "Status": "RUNNING", + "StartTimestamp": TIMESTAMP_2023_01_01_00_02, + }, + ], + "NextMarker": "next-marker-123", + } + + response_obj = ListDurableExecutionsResponse.from_dict(data) + assert len(response_obj.durable_executions) == 2 + assert ( + response_obj.durable_executions[0].durable_execution_name == "test-execution-1" + ) + assert ( + response_obj.durable_executions[1].durable_execution_name == "test-execution-2" + ) + assert response_obj.next_marker == "next-marker-123" + + result_data = response_obj.to_dict() + assert result_data == data + + # Test round-trip + round_trip = ListDurableExecutionsResponse.from_dict(result_data) + assert round_trip == response_obj + + +def test_list_durable_executions_response_empty(): + """Test ListDurableExecutionsResponse with empty executions.""" + data = {"DurableExecutions": []} + + response_obj = ListDurableExecutionsResponse.from_dict(data) + assert len(response_obj.durable_executions) == 0 + assert response_obj.next_marker is None + + result_data = response_obj.to_dict() + assert result_data == {"DurableExecutions": []} + + +def test_stop_durable_execution_request_serialization(): + """Test StopDurableExecutionRequest from_dict/to_dict round-trip.""" + data = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test", + "Error": {"ErrorMessage": "Stopped by user"}, + } + + request_obj = StopDurableExecutionRequest.from_dict(data) + assert ( + request_obj.durable_execution_arn + == "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test" + ) + assert request_obj.error.message == "Stopped by user" + + result_data = request_obj.to_dict() + assert result_data == data + + # Test round-trip + round_trip = StopDurableExecutionRequest.from_dict(result_data) + assert round_trip == request_obj + + +def test_stop_durable_execution_request_minimal(): + """Test StopDurableExecutionRequest with only required fields.""" + data = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test" + } + + request_obj = StopDurableExecutionRequest.from_dict(data) + assert request_obj.error is None + + result_data = request_obj.to_dict() + assert result_data == data + + +def test_stop_durable_execution_response_serialization(): + """Test StopDurableExecutionResponse from_dict/to_dict round-trip.""" + data = {"StopTimestamp": "2023-01-01T00:01:00Z"} + + response_obj = StopDurableExecutionResponse.from_dict(data) + assert response_obj.stop_timestamp == "2023-01-01T00:01:00Z" + + result_data = response_obj.to_dict() + assert result_data == data + + # Test round-trip + round_trip = StopDurableExecutionResponse.from_dict(result_data) + assert round_trip == response_obj + + +def test_get_durable_execution_state_request_serialization(): + """Test GetDurableExecutionStateRequest from_dict/to_dict round-trip.""" + data = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test", + "CheckpointToken": "checkpoint-123", + "Marker": "marker-123", + "MaxItems": 10, + } + + request_obj = GetDurableExecutionStateRequest.from_dict(data) + assert ( + request_obj.durable_execution_arn + == "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test" + ) + assert request_obj.checkpoint_token == "checkpoint-123" # noqa: S105 + assert request_obj.marker == "marker-123" + assert request_obj.max_items == 10 + + result_data = request_obj.to_dict() + assert result_data == data + + # Test round-trip + round_trip = GetDurableExecutionStateRequest.from_dict(result_data) + assert round_trip == request_obj + + +def test_get_durable_execution_state_request_minimal(): + """Test GetDurableExecutionStateRequest with only required fields.""" + data = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test", + "CheckpointToken": "checkpoint-123", + } + + request_obj = GetDurableExecutionStateRequest.from_dict(data) + assert request_obj.marker is None + assert request_obj.max_items == 0 # Default value from Smithy + + result_data = request_obj.to_dict() + # The result should include the default MaxItems + expected_data = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test", + "CheckpointToken": "checkpoint-123", + "MaxItems": 0, + } + assert result_data == expected_data + + +def test_get_durable_execution_state_response_serialization(): + """Test GetDurableExecutionStateResponse from_dict/to_dict round-trip.""" + data = { + "Operations": [ + {"Id": "op-1", "Type": "STEP", "Status": "SUCCEEDED"}, + {"Id": "op-2", "Type": "CONTEXT", "Status": "STARTED"}, + ], + "NextMarker": "next-marker-123", + } + + response_obj = GetDurableExecutionStateResponse.from_dict(data) + assert len(response_obj.operations) == 2 + assert response_obj.operations[0].operation_id == "op-1" + assert response_obj.operations[0].operation_type.value == "STEP" + assert response_obj.operations[1].operation_id == "op-2" + assert response_obj.operations[1].operation_type.value == "CONTEXT" + assert response_obj.next_marker == "next-marker-123" + + result_data = response_obj.to_dict() + assert result_data == data + + # Test round-trip + round_trip = GetDurableExecutionStateResponse.from_dict(result_data) + assert round_trip == response_obj + + +def test_get_durable_execution_state_response_empty(): + """Test GetDurableExecutionStateResponse with empty operations.""" + data = {"Operations": []} + + response_obj = GetDurableExecutionStateResponse.from_dict(data) + assert len(response_obj.operations) == 0 + assert response_obj.next_marker is None + + result_data = response_obj.to_dict() + assert result_data == {"Operations": []} + + +def test_get_durable_execution_history_request_serialization(): + """Test GetDurableExecutionHistoryRequest from_dict/to_dict round-trip.""" + data = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test", + "IncludeExecutionData": True, + "ReverseOrder": False, + "Marker": "marker-123", + "MaxItems": 20, + } + + request_obj = GetDurableExecutionHistoryRequest.from_dict(data) + assert ( + request_obj.durable_execution_arn + == "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test" + ) + assert request_obj.include_execution_data is True + assert request_obj.reverse_order is False + assert request_obj.marker == "marker-123" + assert request_obj.max_items == 20 + + result_data = request_obj.to_dict() + assert result_data == data + + # Test round-trip + round_trip = GetDurableExecutionHistoryRequest.from_dict(result_data) + assert round_trip == request_obj + + +def test_get_durable_execution_history_request_minimal(): + """Test GetDurableExecutionHistoryRequest with only required fields.""" + data = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test" + } + + request_obj = GetDurableExecutionHistoryRequest.from_dict(data) + assert request_obj.include_execution_data is None + assert request_obj.reverse_order is None + assert request_obj.marker is None + assert request_obj.max_items == 0 # Default value from Smithy + + result_data = request_obj.to_dict() + # The result should include the default MaxItems + expected_data = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test", + "MaxItems": 0, + } + assert result_data == expected_data + + +def test_execution_event_serialization(): + """Test Event from_dict/to_dict round-trip.""" + data = { + "EventType": "ExecutionStarted", + "EventId": 123, + "EventTimestamp": TIMESTAMP_2023_01_01_00_00, + "SubType": "UserInitiated", + "Id": "op-123", + "Name": "test-operation", + "ParentId": "parent-op-123", + "ExecutionStartedDetails": { + "Input": {"Payload": "test-input", "Truncated": False}, + "ExecutionTimeout": 300, + }, + } + + event_obj = Event.from_dict(data) + assert event_obj.event_type == "ExecutionStarted" + assert event_obj.event_id == 123 + assert event_obj.event_timestamp == TIMESTAMP_2023_01_01_00_00 + assert event_obj.sub_type == "UserInitiated" + assert event_obj.operation_id == "op-123" + assert event_obj.name == "test-operation" + assert event_obj.parent_id == "parent-op-123" + assert event_obj.execution_started_details is not None + assert event_obj.execution_started_details.input.payload == "test-input" + assert event_obj.execution_started_details.execution_timeout == 300 + + result_data = event_obj.to_dict() + assert result_data == data + + # Test round-trip + round_trip = Event.from_dict(result_data) + assert round_trip == event_obj + + +def test_execution_event_minimal(): + """Test Event with only required fields.""" + data = { + "EventType": "ExecutionStarted", + "EventTimestamp": TIMESTAMP_2023_01_01_00_00, + } + + event_obj = Event.from_dict(data) + assert event_obj.event_id == 1 # Default value from Smithy + assert event_obj.sub_type is None + assert event_obj.operation_id is None + assert event_obj.name is None + assert event_obj.parent_id is None + assert event_obj.execution_started_details is None + + result_data = event_obj.to_dict() + # The result should include the default EventId + expected_data = { + "EventType": "ExecutionStarted", + "EventTimestamp": TIMESTAMP_2023_01_01_00_00, + "EventId": 1, + } + assert result_data == expected_data + + +def test_get_durable_execution_history_response_serialization(): + """Test GetDurableExecutionHistoryResponse from_dict/to_dict round-trip.""" + data = { + "Events": [ + { + "EventType": "ExecutionStarted", + "EventId": 1, + "EventTimestamp": TIMESTAMP_2023_01_01_00_00, + }, + { + "EventType": "ExecutionSucceeded", + "EventId": 2, + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "ExecutionSucceededDetails": { + "Result": {"Payload": "success", "Truncated": False} + }, + }, + ], + "NextMarker": "next-marker-123", + } + + response_obj = GetDurableExecutionHistoryResponse.from_dict(data) + assert len(response_obj.events) == 2 + assert response_obj.events[0].event_type == "ExecutionStarted" + assert response_obj.events[1].event_type == "ExecutionSucceeded" + assert response_obj.events[1].execution_succeeded_details is not None + assert ( + response_obj.events[1].execution_succeeded_details.result.payload == "success" + ) + assert response_obj.next_marker == "next-marker-123" + + result_data = response_obj.to_dict() + assert result_data == data + + # Test round-trip + round_trip = GetDurableExecutionHistoryResponse.from_dict(result_data) + assert round_trip == response_obj + + +def test_get_durable_execution_history_response_empty(): + """Test GetDurableExecutionHistoryResponse with empty events.""" + data = {"Events": []} + + response_obj = GetDurableExecutionHistoryResponse.from_dict(data) + assert len(response_obj.events) == 0 + assert response_obj.next_marker is None + + result_data = response_obj.to_dict() + assert result_data == {"Events": []} + + +def test_list_durable_executions_by_function_request_serialization(): + """Test ListDurableExecutionsByFunctionRequest from_dict/to_dict round-trip.""" + data = { + "FunctionName": "my-function", + "Qualifier": "$LATEST", + "StatusFilter": ["RUNNING", "SUCCEEDED"], + "StartedAfter": TIMESTAMP_2023_01_01_00_00, + "StartedBefore": TIMESTAMP_2023_01_02_00_00, + "Marker": "marker-123", + "MaxItems": 10, + "ReverseOrder": True, + } + + request_obj = ListDurableExecutionsByFunctionRequest.from_dict(data) + assert request_obj.function_name == "my-function" + assert request_obj.qualifier == "$LATEST" + assert request_obj.status_filter == ["RUNNING", "SUCCEEDED"] + assert request_obj.started_after == TIMESTAMP_2023_01_01_00_00 + assert request_obj.started_before == TIMESTAMP_2023_01_02_00_00 + assert request_obj.marker == "marker-123" + assert request_obj.max_items == 10 + assert request_obj.reverse_order is True + + result_data = request_obj.to_dict() + assert result_data == data + + # Test round-trip + round_trip = ListDurableExecutionsByFunctionRequest.from_dict(result_data) + assert round_trip == request_obj + + +def test_list_durable_executions_by_function_request_minimal(): + """Test ListDurableExecutionsByFunctionRequest with only required fields.""" + data = {"FunctionName": "my-function"} + + request_obj = ListDurableExecutionsByFunctionRequest.from_dict(data) + assert request_obj.qualifier is None + assert request_obj.status_filter is None + assert request_obj.started_after is None + assert request_obj.started_before is None + assert request_obj.marker is None + assert request_obj.max_items == 0 # Default value from Smithy + assert request_obj.reverse_order is None + + result_data = request_obj.to_dict() + # The result should include the default MaxItems + expected_data = {"FunctionName": "my-function", "MaxItems": 0} + assert result_data == expected_data + + +def test_list_durable_executions_by_function_response_serialization(): + """Test ListDurableExecutionsByFunctionResponse from_dict/to_dict round-trip.""" + data = { + "DurableExecutions": [ + { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test1", + "DurableExecutionName": "test-execution-1", + "Status": "SUCCEEDED", + "StartTimestamp": TIMESTAMP_2023_01_01_00_00, + "EndTimestamp": TIMESTAMP_2023_01_01_00_01, + } + ], + "NextMarker": "next-marker-123", + } + + response_obj = ListDurableExecutionsByFunctionResponse.from_dict(data) + assert len(response_obj.durable_executions) == 1 + assert ( + response_obj.durable_executions[0].durable_execution_name == "test-execution-1" + ) + assert response_obj.next_marker == "next-marker-123" + + result_data = response_obj.to_dict() + assert result_data == data + + # Test round-trip + round_trip = ListDurableExecutionsByFunctionResponse.from_dict(result_data) + assert round_trip == response_obj + + +def test_send_durable_execution_callback_success_request_serialization(): + """Test SendDurableExecutionCallbackSuccessRequest from_dict/to_dict round-trip.""" + data = { + "CallbackId": "callback-123", + "Result": "success-result", + } + + request_obj = SendDurableExecutionCallbackSuccessRequest.from_dict(data) + assert request_obj.callback_id == "callback-123" + assert request_obj.result == "success-result" + + result_data = request_obj.to_dict() + assert result_data == data + + # Test round-trip + round_trip = SendDurableExecutionCallbackSuccessRequest.from_dict(result_data) + assert round_trip == request_obj + + +def test_send_durable_execution_callback_success_request_minimal(): + """Test SendDurableExecutionCallbackSuccessRequest with only required fields.""" + data = {"CallbackId": "callback-123"} + + request_obj = SendDurableExecutionCallbackSuccessRequest.from_dict(data) + assert request_obj.result is None + + result_data = request_obj.to_dict() + assert result_data == data + + +def test_send_durable_execution_callback_success_response_creation(): + """Test SendDurableExecutionCallbackSuccessResponse creation.""" + response_obj = SendDurableExecutionCallbackSuccessResponse() + assert isinstance(response_obj, SendDurableExecutionCallbackSuccessResponse) + + +def test_send_durable_execution_callback_failure_request_serialization(): + """Test SendDurableExecutionCallbackFailureRequest from_dict/to_dict round-trip.""" + data = {"ErrorMessage": "callback failed"} + + request_obj = SendDurableExecutionCallbackFailureRequest.from_dict( + data, "callback-123" + ) + assert request_obj.callback_id == "callback-123" + assert request_obj.error.message == "callback failed" + + result_data = request_obj.to_dict() + expected_data = { + "CallbackId": "callback-123", + "Error": {"ErrorMessage": "callback failed"}, + } + assert result_data == expected_data + + # Test round-trip + round_trip = SendDurableExecutionCallbackFailureRequest.from_dict( + result_data.get("Error", {}), result_data["CallbackId"] + ) + assert round_trip == request_obj + + +def test_send_durable_execution_callback_failure_request_minimal(): + """Test SendDurableExecutionCallbackFailureRequest with only required fields.""" + + request_obj = SendDurableExecutionCallbackFailureRequest.from_dict( + {}, "callback-123" + ) + assert request_obj.error is None + + result_data = request_obj.to_dict() + assert result_data == {"CallbackId": "callback-123"} + + +def test_send_durable_execution_callback_failure_response_creation(): + """Test SendDurableExecutionCallbackFailureResponse creation.""" + response_obj = SendDurableExecutionCallbackFailureResponse() + assert isinstance(response_obj, SendDurableExecutionCallbackFailureResponse) + + +def test_send_durable_execution_callback_heartbeat_request_serialization(): + """Test SendDurableExecutionCallbackHeartbeatRequest from_dict/to_dict round-trip.""" + data = {"CallbackId": "callback-123"} + + request_obj = SendDurableExecutionCallbackHeartbeatRequest.from_dict(data) + assert request_obj.callback_id == "callback-123" + + result_data = request_obj.to_dict() + assert result_data == data + + # Test round-trip + round_trip = SendDurableExecutionCallbackHeartbeatRequest.from_dict(result_data) + assert round_trip == request_obj + + +def test_send_durable_execution_callback_heartbeat_response_creation(): + """Test SendDurableExecutionCallbackHeartbeatResponse creation.""" + response_obj = SendDurableExecutionCallbackHeartbeatResponse() + assert isinstance(response_obj, SendDurableExecutionCallbackHeartbeatResponse) + + +def test_checkpoint_durable_execution_request_serialization(): + """Test CheckpointDurableExecutionRequest from_dict/to_dict round-trip.""" + execution_arn = ( + "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test" + ) + data = { + "CheckpointToken": "checkpoint-123", + "Updates": [ + {"Id": "op-1", "Type": "STEP", "Action": "SUCCEED"}, + {"Id": "op-2", "Type": "CONTEXT", "Action": "START"}, + ], + "ClientToken": "client-token-123", + } + + request_obj = CheckpointDurableExecutionRequest.from_dict(data, execution_arn) + assert request_obj.durable_execution_arn == execution_arn + assert request_obj.checkpoint_token == "checkpoint-123" # noqa: S105 + assert len(request_obj.updates) == 2 + assert request_obj.updates[0].operation_id == "op-1" + assert request_obj.updates[0].operation_type.value == "STEP" + assert request_obj.updates[0].action.value == "SUCCEED" + assert request_obj.updates[1].operation_id == "op-2" + assert request_obj.updates[1].operation_type.value == "CONTEXT" + assert request_obj.updates[1].action.value == "START" + assert request_obj.client_token == "client-token-123" # noqa: S105 + + result_data = request_obj.to_dict() + expected_data = {"DurableExecutionArn": execution_arn, **data} + assert result_data == expected_data + + # Test round-trip + round_trip = CheckpointDurableExecutionRequest.from_dict(result_data, execution_arn) + assert round_trip == request_obj + + +def test_checkpoint_durable_execution_request_minimal(): + """Test CheckpointDurableExecutionRequest with only required fields.""" + execution_arn = ( + "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test" + ) + data = { + "CheckpointToken": "checkpoint-123", + } + + request_obj = CheckpointDurableExecutionRequest.from_dict(data, execution_arn) + assert request_obj.updates is None + assert request_obj.client_token is None + + result_data = request_obj.to_dict() + expected_data = {"DurableExecutionArn": execution_arn, **data} + assert result_data == expected_data + + +def test_checkpoint_durable_execution_response_serialization(): + """Test CheckpointDurableExecutionResponse from_dict/to_dict round-trip.""" + data = { + "CheckpointToken": "new-checkpoint-123", + "NewExecutionState": { + "Operations": [{"Id": "op-1", "Type": "STEP", "Status": "SUCCEEDED"}], + "NextMarker": "marker-123", + }, + } + + response_obj = CheckpointDurableExecutionResponse.from_dict(data) + assert response_obj.checkpoint_token == "new-checkpoint-123" # noqa: S105 + assert response_obj.new_execution_state is not None + assert len(response_obj.new_execution_state.operations) == 1 + assert response_obj.new_execution_state.operations[0].operation_id == "op-1" + assert response_obj.new_execution_state.next_marker == "marker-123" + + result_data = response_obj.to_dict() + assert result_data == data + + # Test round-trip + round_trip = CheckpointDurableExecutionResponse.from_dict(result_data) + assert round_trip == response_obj + + +def test_checkpoint_durable_execution_response_minimal(): + """Test CheckpointDurableExecutionResponse with only required fields.""" + data = {"CheckpointToken": "new-checkpoint-123"} + + response_obj = CheckpointDurableExecutionResponse.from_dict(data) + assert response_obj.new_execution_state is None + + result_data = response_obj.to_dict() + assert result_data == data + + +def test_error_response_creation(): + """Test ErrorResponse creation with all fields.""" + error_response = ErrorResponse( + error_type="InvalidParameterValueException", + error_message="Invalid parameter value", + error_code="INVALID_PARAMETER", + request_id="req-123", + ) + + assert error_response.error_type == "InvalidParameterValueException" + assert error_response.error_message == "Invalid parameter value" + assert error_response.error_code == "INVALID_PARAMETER" + assert error_response.request_id == "req-123" + + +def test_error_response_creation_minimal(): + """Test ErrorResponse creation with minimal fields.""" + error_response = ErrorResponse( + error_type="ServiceException", + error_message="Internal server error", + ) + + assert error_response.error_type == "ServiceException" + assert error_response.error_message == "Internal server error" + assert error_response.error_code is None + assert error_response.request_id is None + + +def test_error_response_to_dict_complete(): + """Test ErrorResponse.to_dict() with all fields.""" + error_response = ErrorResponse( + error_type="ResourceNotFoundException", + error_message="Resource not found", + error_code="RESOURCE_NOT_FOUND", + request_id="req-456", + ) + + result = error_response.to_dict() + + expected = { + "error": { + "type": "ResourceNotFoundException", + "message": "Resource not found", + "code": "RESOURCE_NOT_FOUND", + "requestId": "req-456", + } + } + + assert result == expected + + +def test_error_response_to_dict_minimal(): + """Test ErrorResponse.to_dict() with minimal fields.""" + error_response = ErrorResponse( + error_type="ConflictException", + error_message="Resource conflict", + ) + + result = error_response.to_dict() + + expected = { + "error": { + "type": "ConflictException", + "message": "Resource conflict", + } + } + + assert result == expected + + +def test_error_response_from_dict_nested(): + """Test ErrorResponse.from_dict() with nested error structure.""" + data = { + "error": { + "type": "InvalidParameterValueException", + "message": "Invalid input", + "code": "INVALID_INPUT", + "requestId": "req-789", + } + } + + error_response = ErrorResponse.from_dict(data) + + assert error_response.error_type == "InvalidParameterValueException" + assert error_response.error_message == "Invalid input" + assert error_response.error_code == "INVALID_INPUT" + assert error_response.request_id == "req-789" + + +def test_error_response_from_dict_flat(): + """Test ErrorResponse.from_dict() with flat error structure.""" + data = { + "type": "ServiceException", + "message": "Internal error", + "code": "INTERNAL_ERROR", + } + + error_response = ErrorResponse.from_dict(data) + + assert error_response.error_type == "ServiceException" + assert error_response.error_message == "Internal error" + assert error_response.error_code == "INTERNAL_ERROR" + assert error_response.request_id is None + + +def test_error_response_from_dict_minimal(): + """Test ErrorResponse.from_dict() with minimal fields.""" + data = { + "error": { + "type": "TooManyRequestsException", + "message": "Rate limit exceeded", + } + } + + error_response = ErrorResponse.from_dict(data) + + assert error_response.error_type == "TooManyRequestsException" + assert error_response.error_message == "Rate limit exceeded" + assert error_response.error_code is None + assert error_response.request_id is None + + +def test_error_response_round_trip(): + """Test ErrorResponse round-trip serialization.""" + original = ErrorResponse( + error_type="ExecutionAlreadyStartedException", + error_message="Execution already exists", + error_code="EXECUTION_ALREADY_STARTED", + request_id="req-round-trip", + ) + + # Convert to dict and back + data = original.to_dict() + restored = ErrorResponse.from_dict(data) + + assert restored.error_type == original.error_type + assert restored.error_message == original.error_message + assert restored.error_code == original.error_code + assert restored.request_id == original.request_id + + +def test_error_response_immutable(): + """Test that ErrorResponse is immutable (frozen dataclass).""" + error_response = ErrorResponse( + error_type="TestException", + error_message="Test message", + ) + + with pytest.raises(AttributeError): + error_response.error_type = "ModifiedException" # type: ignore + + +# Tests for missing coverage in StartDurableExecutionInput +def test_start_durable_execution_input_missing_required_fields(): + """Test StartDurableExecutionInput validation with missing required fields.""" + # Test missing AccountId + data = { + "FunctionName": "my-function", + "FunctionQualifier": "$LATEST", + "ExecutionName": "test-execution", + "ExecutionTimeoutSeconds": 300, + "ExecutionRetentionPeriodDays": 7, + } + + with pytest.raises(InvalidParameterValueException) as exc_info: + StartDurableExecutionInput.from_dict(data) + assert "Missing required field: AccountId" in str(exc_info.value) + + # Test missing FunctionName + data = { + "AccountId": "123456789012", + "FunctionQualifier": "$LATEST", + "ExecutionName": "test-execution", + "ExecutionTimeoutSeconds": 300, + "ExecutionRetentionPeriodDays": 7, + } + + with pytest.raises(InvalidParameterValueException) as exc_info: + StartDurableExecutionInput.from_dict(data) + assert "Missing required field: FunctionName" in str(exc_info.value) + + # Test missing FunctionQualifier + data = { + "AccountId": "123456789012", + "FunctionName": "my-function", + "ExecutionName": "test-execution", + "ExecutionTimeoutSeconds": 300, + "ExecutionRetentionPeriodDays": 7, + } + + with pytest.raises(InvalidParameterValueException) as exc_info: + StartDurableExecutionInput.from_dict(data) + assert "Missing required field: FunctionQualifier" in str(exc_info.value) + + # Test missing ExecutionName + data = { + "AccountId": "123456789012", + "FunctionName": "my-function", + "FunctionQualifier": "$LATEST", + "ExecutionTimeoutSeconds": 300, + "ExecutionRetentionPeriodDays": 7, + } + + with pytest.raises(InvalidParameterValueException) as exc_info: + StartDurableExecutionInput.from_dict(data) + assert "Missing required field: ExecutionName" in str(exc_info.value) + + # Test missing ExecutionTimeoutSeconds + data = { + "AccountId": "123456789012", + "FunctionName": "my-function", + "FunctionQualifier": "$LATEST", + "ExecutionName": "test-execution", + "ExecutionRetentionPeriodDays": 7, + } + + with pytest.raises(InvalidParameterValueException) as exc_info: + StartDurableExecutionInput.from_dict(data) + assert "Missing required field: ExecutionTimeoutSeconds" in str(exc_info.value) + + # Test missing ExecutionRetentionPeriodDays + data = { + "AccountId": "123456789012", + "FunctionName": "my-function", + "FunctionQualifier": "$LATEST", + "ExecutionName": "test-execution", + "ExecutionTimeoutSeconds": 300, + } + + with pytest.raises(InvalidParameterValueException) as exc_info: + StartDurableExecutionInput.from_dict(data) + assert "Missing required field: ExecutionRetentionPeriodDays" in str(exc_info.value) + + +# Tests for Execution backward compatibility +def test_execution_backward_compatibility_empty_function_arn(): + """Test Execution with empty FunctionArn for backward compatibility.""" + data = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test", + "DurableExecutionName": "test-execution", + "Status": "SUCCEEDED", + "StartTimestamp": TIMESTAMP_2023_01_01_00_00, + "EndTimestamp": TIMESTAMP_2023_01_01_00_01, + } + + execution_obj = Execution.from_dict(data) + assert ( + execution_obj.function_arn == "" + ) # Default empty string for backward compatibility + + result_data = execution_obj.to_dict() + # Empty function_arn should not be included in output + expected_data = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test", + "DurableExecutionName": "test-execution", + "Status": "SUCCEEDED", + "StartTimestamp": TIMESTAMP_2023_01_01_00_00, + "EndTimestamp": TIMESTAMP_2023_01_01_00_01, + } + assert result_data == expected_data + + +def test_execution_with_function_arn(): + """Test Execution with non-empty FunctionArn.""" + data = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test", + "DurableExecutionName": "test-execution", + "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function", + "Status": "SUCCEEDED", + "StartTimestamp": TIMESTAMP_2023_01_01_00_00, + "EndTimestamp": TIMESTAMP_2023_01_01_00_01, + } + + execution_obj = Execution.from_dict(data) + assert ( + execution_obj.function_arn + == "arn:aws:lambda:us-east-1:123456789012:function:my-function" + ) + + result_data = execution_obj.to_dict() + assert result_data == data + + +# Tests for ListDurableExecutionsRequest with all optional fields +def test_list_durable_executions_request_all_optional_fields(): + """Test ListDurableExecutionsRequest to_dict with all optional fields as None.""" + request_obj = ListDurableExecutionsRequest( + function_name=None, + function_version=None, + durable_execution_name=None, + status_filter=None, + started_after=None, + started_before=None, + marker=None, + max_items=None, + reverse_order=None, + ) + + result_data = request_obj.to_dict() + # Only non-None fields should be included + expected_data = {} + assert result_data == expected_data + + +def test_list_durable_executions_request_partial_fields(): + """Test ListDurableExecutionsRequest to_dict with some optional fields.""" + request_obj = ListDurableExecutionsRequest( + function_name="my-function", + function_version=None, + durable_execution_name="test-execution", + status_filter=None, + started_after=TIMESTAMP_2023_01_01_00_00, + started_before=None, + marker="marker-123", + max_items=10, + reverse_order=None, + ) + + result_data = request_obj.to_dict() + expected_data = { + "FunctionName": "my-function", + "DurableExecutionName": "test-execution", + "StartedAfter": TIMESTAMP_2023_01_01_00_00, + "Marker": "marker-123", + "MaxItems": 10, + } + assert result_data == expected_data + + +# Tests for GetDurableExecutionStateRequest with all optional fields +def test_get_durable_execution_state_request_all_optional_fields(): + """Test GetDurableExecutionStateRequest to_dict with all optional fields as None.""" + request_obj = GetDurableExecutionStateRequest( + durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test", + checkpoint_token="checkpoint-123", # noqa: S106 + marker=None, + max_items=None, + ) + + result_data = request_obj.to_dict() + expected_data = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test", + "CheckpointToken": "checkpoint-123", + } + assert result_data == expected_data + + +# Tests for EventInput +def test_event_input_serialization(): + """Test EventInput from_dict/to_dict round-trip.""" + data = { + "Payload": "test-payload", + "Truncated": True, + } + + event_input = EventInput.from_dict(data) + assert event_input.payload == "test-payload" + assert event_input.truncated is True + + result_data = event_input.to_dict() + assert result_data == data + + +def test_event_input_minimal(): + """Test EventInput with minimal data.""" + data = {} + + event_input = EventInput.from_dict(data) + assert event_input.payload is None + assert event_input.truncated is False + + result_data = event_input.to_dict() + assert result_data == {"Truncated": False} + + +def test_event_input_with_payload_only(): + """Test EventInput with payload but default truncated.""" + data = {"Payload": "test-payload"} + + event_input = EventInput.from_dict(data) + assert event_input.payload == "test-payload" + assert event_input.truncated is False + + result_data = event_input.to_dict() + assert result_data == {"Payload": "test-payload", "Truncated": False} + + +# Tests for EventResult +def test_event_result_serialization(): + """Test EventResult from_dict/to_dict round-trip.""" + data = { + "Payload": "test-result", + "Truncated": True, + } + + event_result = EventResult.from_dict(data) + assert event_result.payload == "test-result" + assert event_result.truncated is True + + result_data = event_result.to_dict() + assert result_data == data + + +def test_event_result_minimal(): + """Test EventResult with minimal data.""" + data = {} + + event_result = EventResult.from_dict(data) + assert event_result.payload is None + assert event_result.truncated is False + + result_data = event_result.to_dict() + assert result_data == {"Truncated": False} + + +# Tests for EventError +def test_event_error_serialization(): + """Test EventError from_dict/to_dict round-trip.""" + data = { + "Payload": {"ErrorMessage": "test error"}, + "Truncated": True, + } + + event_error = EventError.from_dict(data) + assert event_error.payload.message == "test error" + assert event_error.truncated is True + + result_data = event_error.to_dict() + assert result_data == data + + +def test_event_error_minimal(): + """Test EventError with minimal data.""" + data = {} + + event_error = EventError.from_dict(data) + assert event_error.payload is None + assert event_error.truncated is False + + result_data = event_error.to_dict() + assert result_data == {"Truncated": False} + + +def test_event_error_with_payload_only(): + """Test EventError with payload but default truncated.""" + data = {"Payload": {"ErrorMessage": "test error"}} + + event_error = EventError.from_dict(data) + assert event_error.payload.message == "test error" + assert event_error.truncated is False + + result_data = event_error.to_dict() + assert result_data == { + "Payload": {"ErrorMessage": "test error"}, + "Truncated": False, + } + + +# Tests for RetryDetails +def test_retry_details_serialization(): + """Test RetryDetails from_dict/to_dict round-trip.""" + data = { + "CurrentAttempt": 3, + "NextAttemptDelaySeconds": 60, + } + + retry_details = RetryDetails.from_dict(data) + assert retry_details.current_attempt == 3 + assert retry_details.next_attempt_delay_seconds == 60 + + result_data = retry_details.to_dict() + assert result_data == data + + +def test_retry_details_minimal(): + """Test RetryDetails with minimal data.""" + data = {} + + retry_details = RetryDetails.from_dict(data) + assert retry_details.current_attempt == 0 + assert retry_details.next_attempt_delay_seconds is None + + result_data = retry_details.to_dict() + assert result_data == {"CurrentAttempt": 0} + + +def test_retry_details_with_current_attempt_only(): + """Test RetryDetails with current attempt but no delay.""" + data = {"CurrentAttempt": 2} + + retry_details = RetryDetails.from_dict(data) + assert retry_details.current_attempt == 2 + assert retry_details.next_attempt_delay_seconds is None + + result_data = retry_details.to_dict() + assert result_data == {"CurrentAttempt": 2} + + +# Tests for ExecutionStartedDetails +def test_execution_started_details_serialization(): + """Test ExecutionStartedDetails from_dict/to_dict round-trip.""" + data = { + "Input": {"Payload": "test-input", "Truncated": False}, + "ExecutionTimeout": 300, + } + + details = ExecutionStartedDetails.from_dict(data) + assert details.input.payload == "test-input" + assert details.execution_timeout == 300 + + result_data = details.to_dict() + assert result_data == data + + +def test_execution_started_details_minimal(): + """Test ExecutionStartedDetails with minimal data.""" + data = {} + + details = ExecutionStartedDetails.from_dict(data) + assert details.input is None + assert details.execution_timeout is None + + result_data = details.to_dict() + assert result_data == {} + + +def test_execution_started_details_with_input_only(): + """Test ExecutionStartedDetails with input but no timeout.""" + data = {"Input": {"Payload": "test-input", "Truncated": False}} + + details = ExecutionStartedDetails.from_dict(data) + assert details.input.payload == "test-input" + assert details.execution_timeout is None + + result_data = details.to_dict() + assert result_data == {"Input": {"Payload": "test-input", "Truncated": False}} + + +# Tests for ExecutionSucceededDetails +def test_execution_succeeded_details_serialization(): + """Test ExecutionSucceededDetails from_dict/to_dict round-trip.""" + data = { + "Result": {"Payload": "success-result", "Truncated": False}, + } + + details = ExecutionSucceededDetails.from_dict(data) + assert details.result.payload == "success-result" + + result_data = details.to_dict() + assert result_data == data + + +def test_execution_succeeded_details_minimal(): + """Test ExecutionSucceededDetails with minimal data.""" + data = {} + + details = ExecutionSucceededDetails.from_dict(data) + assert details.result is None + + result_data = details.to_dict() + assert result_data == {} + + +# Tests for ExecutionFailedDetails +def test_execution_failed_details_serialization(): + """Test ExecutionFailedDetails from_dict/to_dict round-trip.""" + data = { + "Error": {"Payload": {"ErrorMessage": "execution failed"}, "Truncated": False}, + } + + details = ExecutionFailedDetails.from_dict(data) + assert details.error.payload.message == "execution failed" + + result_data = details.to_dict() + assert result_data == data + + +def test_execution_failed_details_minimal(): + """Test ExecutionFailedDetails with minimal data.""" + data = {} + + details = ExecutionFailedDetails.from_dict(data) + assert details.error is None + + result_data = details.to_dict() + assert result_data == {} + + +# Tests for ExecutionTimedOutDetails +def test_execution_timed_out_details_serialization(): + """Test ExecutionTimedOutDetails from_dict/to_dict round-trip.""" + data = { + "Error": { + "Payload": {"ErrorMessage": "execution timed out"}, + "Truncated": False, + }, + } + + details = ExecutionTimedOutDetails.from_dict(data) + assert details.error.payload.message == "execution timed out" + + result_data = details.to_dict() + assert result_data == data + + +def test_execution_timed_out_details_minimal(): + """Test ExecutionTimedOutDetails with minimal data.""" + data = {} + + details = ExecutionTimedOutDetails.from_dict(data) + assert details.error is None + + result_data = details.to_dict() + assert result_data == {} + + +# Tests for ExecutionStoppedDetails +def test_execution_stopped_details_serialization(): + """Test ExecutionStoppedDetails from_dict/to_dict round-trip.""" + data = { + "Error": {"Payload": {"ErrorMessage": "execution stopped"}, "Truncated": False}, + } + + details = ExecutionStoppedDetails.from_dict(data) + assert details.error.payload.message == "execution stopped" + + result_data = details.to_dict() + assert result_data == data + + +def test_execution_stopped_details_minimal(): + """Test ExecutionStoppedDetails with minimal data.""" + data = {} + + details = ExecutionStoppedDetails.from_dict(data) + assert details.error is None + + result_data = details.to_dict() + assert result_data == {} + + +# Tests for ContextStartedDetails +def test_context_started_details_serialization(): + """Test ContextStartedDetails from_dict/to_dict round-trip.""" + # ContextStartedDetails ignores input data and always returns empty dict + data = {"dummy": "value"} # Can provide any data + + details = ContextStartedDetails.from_dict(data) + assert isinstance(details, ContextStartedDetails) + + result_data = details.to_dict() + assert result_data == {} + + +# Tests for ContextSucceededDetails +def test_context_succeeded_details_serialization(): + """Test ContextSucceededDetails from_dict/to_dict round-trip.""" + data = { + "Result": {"Payload": "context-result", "Truncated": False}, + } + + details = ContextSucceededDetails.from_dict(data) + assert details.result.payload == "context-result" + + result_data = details.to_dict() + assert result_data == data + + +def test_context_succeeded_details_minimal(): + """Test ContextSucceededDetails with minimal data.""" + data = {} + + details = ContextSucceededDetails.from_dict(data) + assert details.result is None + + result_data = details.to_dict() + assert result_data == {} + + +# Tests for ContextFailedDetails +def test_context_failed_details_serialization(): + """Test ContextFailedDetails from_dict/to_dict round-trip.""" + data = { + "Error": {"Payload": {"ErrorMessage": "context failed"}, "Truncated": False}, + } + + details = ContextFailedDetails.from_dict(data) + assert details.error.payload.message == "context failed" + + result_data = details.to_dict() + assert result_data == data + + +def test_context_failed_details_minimal(): + """Test ContextFailedDetails with minimal data.""" + data = {} + + details = ContextFailedDetails.from_dict(data) + assert details.error is None + + result_data = details.to_dict() + assert result_data == {} + + +# Tests for WaitStartedDetails +def test_wait_started_details_serialization(): + """Test WaitStartedDetails from_dict/to_dict round-trip.""" + data = { + "Duration": 60, + "ScheduledEndTimestamp": TIMESTAMP_2023_01_01_00_01, + } + + details = WaitStartedDetails.from_dict(data) + assert details.duration == 60 + assert details.scheduled_end_timestamp == TIMESTAMP_2023_01_01_00_01 + + result_data = details.to_dict() + assert result_data == data + + +def test_wait_started_details_minimal(): + """Test WaitStartedDetails with minimal data.""" + data = {} + + details = WaitStartedDetails.from_dict(data) + assert details.duration is None + assert details.scheduled_end_timestamp is None + + result_data = details.to_dict() + assert result_data == {} + + +def test_wait_started_details_with_duration_only(): + """Test WaitStartedDetails with duration but no timestamp.""" + data = {"Duration": 30} + + details = WaitStartedDetails.from_dict(data) + assert details.duration == 30 + assert details.scheduled_end_timestamp is None + + result_data = details.to_dict() + assert result_data == {"Duration": 30} + + +# Tests for WaitSucceededDetails +def test_wait_succeeded_details_serialization(): + """Test WaitSucceededDetails from_dict/to_dict round-trip.""" + data = {"Duration": 60} + + details = WaitSucceededDetails.from_dict(data) + assert details.duration == 60 + + result_data = details.to_dict() + assert result_data == data + + +def test_wait_succeeded_details_minimal(): + """Test WaitSucceededDetails with minimal data.""" + data = {} + + details = WaitSucceededDetails.from_dict(data) + assert details.duration is None + + result_data = details.to_dict() + assert result_data == {} + + +# Tests for WaitCancelledDetails +def test_wait_cancelled_details_serialization(): + """Test WaitCancelledDetails from_dict/to_dict round-trip.""" + data = { + "Error": {"Payload": {"ErrorMessage": "wait cancelled"}, "Truncated": False}, + } + + details = WaitCancelledDetails.from_dict(data) + assert details.error.payload.message == "wait cancelled" + + result_data = details.to_dict() + assert result_data == data + + +def test_wait_cancelled_details_minimal(): + """Test WaitCancelledDetails with minimal data.""" + data = {} + + details = WaitCancelledDetails.from_dict(data) + assert details.error is None + + result_data = details.to_dict() + assert result_data == {} + + +# Tests for StepStartedDetails +def test_step_started_details_serialization(): + """Test StepStartedDetails from_dict/to_dict round-trip.""" + # StepStartedDetails ignores input data and always returns empty dict + data = {"dummy": "value"} # Can provide any data + + details = StepStartedDetails.from_dict(data) + assert isinstance(details, StepStartedDetails) + + result_data = details.to_dict() + assert result_data == {} + + +# Tests for StepSucceededDetails +def test_step_succeeded_details_serialization(): + """Test StepSucceededDetails from_dict/to_dict round-trip.""" + data = { + "Result": {"Payload": "step-result", "Truncated": False}, + "RetryDetails": {"CurrentAttempt": 2, "NextAttemptDelaySeconds": 30}, + } + + details = StepSucceededDetails.from_dict(data) + assert details.result.payload == "step-result" + assert details.retry_details.current_attempt == 2 + + result_data = details.to_dict() + assert result_data == data + + +def test_step_succeeded_details_minimal(): + """Test StepSucceededDetails with minimal data.""" + data = {} + + details = StepSucceededDetails.from_dict(data) + assert details.result is None + assert details.retry_details is None + + result_data = details.to_dict() + assert result_data == {} + + +def test_step_succeeded_details_with_result_only(): + """Test StepSucceededDetails with result but no retry details.""" + data = {"Result": {"Payload": "step-result", "Truncated": False}} + + details = StepSucceededDetails.from_dict(data) + assert details.result.payload == "step-result" + assert details.retry_details is None + + result_data = details.to_dict() + assert result_data == {"Result": {"Payload": "step-result", "Truncated": False}} + + +# Tests for StepFailedDetails +def test_step_failed_details_serialization(): + """Test StepFailedDetails from_dict/to_dict round-trip.""" + data = { + "Error": {"Payload": {"ErrorMessage": "step failed"}, "Truncated": False}, + "RetryDetails": {"CurrentAttempt": 1, "NextAttemptDelaySeconds": 15}, + } + + details = StepFailedDetails.from_dict(data) + assert details.error.payload.message == "step failed" + assert details.retry_details.current_attempt == 1 + + result_data = details.to_dict() + assert result_data == data + + +def test_step_failed_details_minimal(): + """Test StepFailedDetails with minimal data.""" + data = {} + + details = StepFailedDetails.from_dict(data) + assert details.error is None + assert details.retry_details is None + + result_data = details.to_dict() + assert result_data == {} + + +def test_step_failed_details_with_error_only(): + """Test StepFailedDetails with error but no retry details.""" + data = {"Error": {"Payload": {"ErrorMessage": "step failed"}, "Truncated": False}} + + details = StepFailedDetails.from_dict(data) + assert details.error.payload.message == "step failed" + assert details.retry_details is None + + result_data = details.to_dict() + assert result_data == { + "Error": {"Payload": {"ErrorMessage": "step failed"}, "Truncated": False} + } + + +# Tests for ChainedInvokeStartedDetails +def test_invoke_started_details_serialization(): + """Test ChainedInvokeStartedDetails from_dict/to_dict round-trip.""" + data = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test", + } + + details = ChainedInvokeStartedDetails.from_dict(data) + assert ( + details.durable_execution_arn + == "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test" + ) + + result_data = details.to_dict() + assert result_data == data + + +def test_invoke_started_details_minimal(): + """Test ChainedInvokeStartedDetails with minimal data.""" + data = {} + + details = ChainedInvokeStartedDetails.from_dict(data) + assert details.durable_execution_arn is None + + result_data = details.to_dict() + assert result_data == {} + + +def test_invoke_started_details_partial(): + """Test ChainedInvokeStartedDetails with partial data.""" + data = {} + details = ChainedInvokeStartedDetails.from_dict(data) + assert details.durable_execution_arn is None + + +# Tests for ChainedInvokeSucceededDetails +def test_invoke_succeeded_details_serialization(): + """Test ChainedInvokeSucceededDetails from_dict/to_dict round-trip.""" + data = { + "Result": {"Payload": "invoke-result", "Truncated": False}, + } + + details = ChainedInvokeSucceededDetails.from_dict(data) + assert details.result.payload == "invoke-result" + + result_data = details.to_dict() + assert result_data == data + + +def test_invoke_succeeded_details_minimal(): + """Test ChainedInvokeSucceededDetails with minimal data.""" + data = {} + + details = ChainedInvokeSucceededDetails.from_dict(data) + assert details.result is None + + result_data = details.to_dict() + assert result_data == {} + + +# Tests for ChainedInvokeFailedDetails +def test_invoke_failed_details_serialization(): + """Test ChainedInvokeFailedDetails from_dict/to_dict round-trip.""" + data = { + "Error": {"Payload": {"ErrorMessage": "invoke failed"}, "Truncated": False}, + } + + details = ChainedInvokeFailedDetails.from_dict(data) + assert details.error.payload.message == "invoke failed" + + result_data = details.to_dict() + assert result_data == data + + +def test_invoke_failed_details_minimal(): + """Test ChainedInvokeFailedDetails with minimal data.""" + data = {} + + details = ChainedInvokeFailedDetails.from_dict(data) + assert details.error is None + + result_data = details.to_dict() + assert result_data == {} + + +# Tests for ChainedInvokeTimedOutDetails +def test_invoke_timed_out_details_serialization(): + """Test ChainedInvokeTimedOutDetails from_dict/to_dict round-trip.""" + data = { + "Error": {"Payload": {"ErrorMessage": "invoke timed out"}, "Truncated": False}, + } + + details = ChainedInvokeTimedOutDetails.from_dict(data) + assert details.error.payload.message == "invoke timed out" + + result_data = details.to_dict() + assert result_data == data + + +def test_invoke_timed_out_details_minimal(): + """Test ChainedInvokeTimedOutDetails with minimal data.""" + data = {} + + details = ChainedInvokeTimedOutDetails.from_dict(data) + assert details.error is None + + result_data = details.to_dict() + assert result_data == {} + + +# Tests for ChainedInvokeStoppedDetails +def test_invoke_stopped_details_serialization(): + """Test ChainedInvokeStoppedDetails from_dict/to_dict round-trip.""" + data = { + "Error": {"Payload": {"ErrorMessage": "invoke stopped"}, "Truncated": False}, + } + + details = ChainedInvokeStoppedDetails.from_dict(data) + assert details.error.payload.message == "invoke stopped" + + result_data = details.to_dict() + assert result_data == data + + +def test_invoke_stopped_details_minimal(): + """Test ChainedInvokeStoppedDetails with minimal data.""" + data = {} + + details = ChainedInvokeStoppedDetails.from_dict(data) + assert details.error is None + + result_data = details.to_dict() + assert result_data == {} + + +# Tests for CallbackStartedDetails +def test_callback_started_details_serialization(): + """Test CallbackStartedDetails from_dict/to_dict round-trip.""" + data = { + "CallbackId": "callback-123", + "HeartbeatTimeout": 60, + "Timeout": 300, + } + + details = CallbackStartedDetails.from_dict(data) + assert details.callback_id == "callback-123" + assert details.heartbeat_timeout == 60 + assert details.timeout == 300 + + result_data = details.to_dict() + assert result_data == data + + +def test_callback_started_details_minimal(): + """Test CallbackStartedDetails with minimal data.""" + data = {} + + details = CallbackStartedDetails.from_dict(data) + assert details.callback_id is None + assert details.heartbeat_timeout is None + assert details.timeout is None + + result_data = details.to_dict() + assert result_data == {} + + +def test_callback_started_details_partial(): + """Test CallbackStartedDetails with partial data.""" + data = { + "CallbackId": "callback-123", + "Timeout": 300, + } + + details = CallbackStartedDetails.from_dict(data) + assert details.callback_id == "callback-123" + assert details.heartbeat_timeout is None + assert details.timeout == 300 + + result_data = details.to_dict() + assert result_data == data + + +# Tests for CallbackSucceededDetails +def test_callback_succeeded_details_serialization(): + """Test CallbackSucceededDetails from_dict/to_dict round-trip.""" + data = { + "Result": {"Payload": "callback-result", "Truncated": False}, + } + + details = CallbackSucceededDetails.from_dict(data) + assert details.result.payload == "callback-result" + + result_data = details.to_dict() + assert result_data == data + + +def test_callback_succeeded_details_minimal(): + """Test CallbackSucceededDetails with minimal data.""" + data = {} + + details = CallbackSucceededDetails.from_dict(data) + assert details.result is None + + result_data = details.to_dict() + assert result_data == {} + + +# Tests for CallbackFailedDetails +def test_callback_failed_details_serialization(): + """Test CallbackFailedDetails from_dict/to_dict round-trip.""" + data = { + "Error": {"Payload": {"ErrorMessage": "callback failed"}, "Truncated": False}, + } + + details = CallbackFailedDetails.from_dict(data) + assert details.error.payload.message == "callback failed" + + result_data = details.to_dict() + assert result_data == data + + +def test_callback_failed_details_minimal(): + """Test CallbackFailedDetails with minimal data.""" + data = {} + + details = CallbackFailedDetails.from_dict(data) + assert details.error is None + + result_data = details.to_dict() + assert result_data == {} + + +# Tests for CallbackTimedOutDetails +def test_callback_timed_out_details_serialization(): + """Test CallbackTimedOutDetails from_dict/to_dict round-trip.""" + data = { + "Error": { + "Payload": {"ErrorMessage": "callback timed out"}, + "Truncated": False, + }, + } + + details = CallbackTimedOutDetails.from_dict(data) + assert details.error.payload.message == "callback timed out" + + result_data = details.to_dict() + assert result_data == data + + +def test_callback_timed_out_details_minimal(): + """Test CallbackTimedOutDetails with minimal data.""" + data = {} + + details = CallbackTimedOutDetails.from_dict(data) + assert details.error is None + + result_data = details.to_dict() + assert result_data == {} + + +# Tests for Event class with all detail types +def test_event_with_execution_succeeded_details(): + """Test Event with ExecutionSucceededDetails.""" + data = { + "EventType": "ExecutionSucceeded", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "ExecutionSucceededDetails": { + "Result": {"Payload": "success", "Truncated": False} + }, + } + + event_obj = Event.from_dict(data) + assert event_obj.event_type == "ExecutionSucceeded" + assert event_obj.execution_succeeded_details is not None + assert event_obj.execution_succeeded_details.result.payload == "success" + + result_data = event_obj.to_dict() + expected_data = { + "EventType": "ExecutionSucceeded", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "EventId": 1, # Default value + "ExecutionSucceededDetails": { + "Result": {"Payload": "success", "Truncated": False} + }, + } + assert result_data == expected_data + + +def test_event_with_execution_failed_details(): + """Test Event with ExecutionFailedDetails.""" + data = { + "EventType": "ExecutionFailed", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "ExecutionFailedDetails": { + "Error": { + "Payload": {"ErrorMessage": "execution failed"}, + "Truncated": False, + } + }, + } + + event_obj = Event.from_dict(data) + assert event_obj.event_type == "ExecutionFailed" + assert event_obj.execution_failed_details is not None + assert ( + event_obj.execution_failed_details.error.payload.message == "execution failed" + ) + + result_data = event_obj.to_dict() + expected_data = { + "EventType": "ExecutionFailed", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "EventId": 1, + "ExecutionFailedDetails": { + "Error": { + "Payload": {"ErrorMessage": "execution failed"}, + "Truncated": False, + } + }, + } + assert result_data == expected_data + + +def test_event_with_execution_timed_out_details(): + """Test Event with ExecutionTimedOutDetails.""" + data = { + "EventType": "ExecutionTimedOut", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "ExecutionTimedOutDetails": { + "Error": { + "Payload": {"ErrorMessage": "execution timed out"}, + "Truncated": False, + } + }, + } + + event_obj = Event.from_dict(data) + assert event_obj.event_type == "ExecutionTimedOut" + assert event_obj.execution_timed_out_details is not None + assert ( + event_obj.execution_timed_out_details.error.payload.message + == "execution timed out" + ) + + result_data = event_obj.to_dict() + expected_data = { + "EventType": "ExecutionTimedOut", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "EventId": 1, + "ExecutionTimedOutDetails": { + "Error": { + "Payload": {"ErrorMessage": "execution timed out"}, + "Truncated": False, + } + }, + } + assert result_data == expected_data + + +def test_event_with_execution_stopped_details(): + """Test Event with ExecutionStoppedDetails.""" + data = { + "EventType": "ExecutionStopped", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "ExecutionStoppedDetails": { + "Error": { + "Payload": {"ErrorMessage": "execution stopped"}, + "Truncated": False, + } + }, + } + + event_obj = Event.from_dict(data) + assert event_obj.event_type == "ExecutionStopped" + assert event_obj.execution_stopped_details is not None + assert ( + event_obj.execution_stopped_details.error.payload.message == "execution stopped" + ) + + result_data = event_obj.to_dict() + expected_data = { + "EventType": "ExecutionStopped", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "EventId": 1, + "ExecutionStoppedDetails": { + "Error": { + "Payload": {"ErrorMessage": "execution stopped"}, + "Truncated": False, + } + }, + } + assert result_data == expected_data + + +def test_event_with_context_started_details(): + """Test Event with ContextStartedDetails.""" + # Since ContextStartedDetails has no fields and empty dict is falsy, + # we need to provide a non-empty dict or test without the key + data = { + "EventType": "ContextStarted", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "ContextStartedDetails": {"dummy": "value"}, # Non-empty to be truthy + } + + event_obj = Event.from_dict(data) + assert event_obj.event_type == "ContextStarted" + assert event_obj.context_started_details is not None + + result_data = event_obj.to_dict() + expected_data = { + "EventType": "ContextStarted", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "EventId": 1, + "ContextStartedDetails": {}, # to_dict() returns empty dict + } + assert result_data == expected_data + + +def test_event_with_context_succeeded_details(): + """Test Event with ContextSucceededDetails.""" + data = { + "EventType": "ContextSucceeded", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "ContextSucceededDetails": { + "Result": {"Payload": "context result", "Truncated": False} + }, + } + + event_obj = Event.from_dict(data) + assert event_obj.event_type == "ContextSucceeded" + assert event_obj.context_succeeded_details is not None + assert event_obj.context_succeeded_details.result.payload == "context result" + + result_data = event_obj.to_dict() + expected_data = { + "EventType": "ContextSucceeded", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "EventId": 1, + "ContextSucceededDetails": { + "Result": {"Payload": "context result", "Truncated": False} + }, + } + assert result_data == expected_data + + +def test_event_with_context_failed_details(): + """Test Event with ContextFailedDetails.""" + data = { + "EventType": "ContextFailed", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "ContextFailedDetails": { + "Error": {"Payload": {"ErrorMessage": "context failed"}, "Truncated": False} + }, + } + + event_obj = Event.from_dict(data) + assert event_obj.event_type == "ContextFailed" + assert event_obj.context_failed_details is not None + assert event_obj.context_failed_details.error.payload.message == "context failed" + + result_data = event_obj.to_dict() + expected_data = { + "EventType": "ContextFailed", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "EventId": 1, + "ContextFailedDetails": { + "Error": {"Payload": {"ErrorMessage": "context failed"}, "Truncated": False} + }, + } + assert result_data == expected_data + + +def test_event_with_wait_started_details(): + """Test Event with WaitStartedDetails.""" + data = { + "EventType": "WaitStarted", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "WaitStartedDetails": { + "Duration": 60, + "ScheduledEndTimestamp": TIMESTAMP_2023_01_01_00_02, + }, + } + + event_obj = Event.from_dict(data) + assert event_obj.event_type == "WaitStarted" + assert event_obj.wait_started_details is not None + assert event_obj.wait_started_details.duration == 60 + + result_data = event_obj.to_dict() + expected_data = { + "EventType": "WaitStarted", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "EventId": 1, + "WaitStartedDetails": { + "Duration": 60, + "ScheduledEndTimestamp": TIMESTAMP_2023_01_01_00_02, + }, + } + assert result_data == expected_data + + +def test_event_with_wait_succeeded_details(): + """Test Event with WaitSucceededDetails.""" + data = { + "EventType": "WaitSucceeded", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "WaitSucceededDetails": {"Duration": 60}, + } + + event_obj = Event.from_dict(data) + assert event_obj.event_type == "WaitSucceeded" + assert event_obj.wait_succeeded_details is not None + assert event_obj.wait_succeeded_details.duration == 60 + + result_data = event_obj.to_dict() + expected_data = { + "EventType": "WaitSucceeded", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "EventId": 1, + "WaitSucceededDetails": {"Duration": 60}, + } + assert result_data == expected_data + + +def test_event_with_wait_cancelled_details(): + """Test Event with WaitCancelledDetails.""" + data = { + "EventType": "WaitCancelled", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "WaitCancelledDetails": { + "Error": {"Payload": {"ErrorMessage": "wait cancelled"}, "Truncated": False} + }, + } + + event_obj = Event.from_dict(data) + assert event_obj.event_type == "WaitCancelled" + assert event_obj.wait_cancelled_details is not None + assert event_obj.wait_cancelled_details.error.payload.message == "wait cancelled" + + result_data = event_obj.to_dict() + expected_data = { + "EventType": "WaitCancelled", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "EventId": 1, + "WaitCancelledDetails": { + "Error": {"Payload": {"ErrorMessage": "wait cancelled"}, "Truncated": False} + }, + } + assert result_data == expected_data + + +def test_event_with_step_started_details(): + """Test Event with StepStartedDetails.""" + # Since StepStartedDetails has no fields and empty dict is falsy, + # we need to provide a non-empty dict or test without the key + data = { + "EventType": "StepStarted", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "StepStartedDetails": {"dummy": "value"}, # Non-empty to be truthy + } + + event_obj = Event.from_dict(data) + assert event_obj.event_type == "StepStarted" + assert event_obj.step_started_details is not None + + result_data = event_obj.to_dict() + expected_data = { + "EventType": "StepStarted", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "EventId": 1, + "StepStartedDetails": {}, # to_dict() returns empty dict + } + assert result_data == expected_data + + +def test_event_with_step_succeeded_details(): + """Test Event with StepSucceededDetails.""" + data = { + "EventType": "StepSucceeded", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "StepSucceededDetails": { + "Result": {"Payload": "step result", "Truncated": False}, + "RetryDetails": {"CurrentAttempt": 1}, + }, + } + + event_obj = Event.from_dict(data) + assert event_obj.event_type == "StepSucceeded" + assert event_obj.step_succeeded_details is not None + assert event_obj.step_succeeded_details.result.payload == "step result" + + result_data = event_obj.to_dict() + expected_data = { + "EventType": "StepSucceeded", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "EventId": 1, + "StepSucceededDetails": { + "Result": {"Payload": "step result", "Truncated": False}, + "RetryDetails": {"CurrentAttempt": 1}, + }, + } + assert result_data == expected_data + + +def test_event_with_step_failed_details(): + """Test Event with StepFailedDetails.""" + data = { + "EventType": "StepFailed", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "StepFailedDetails": { + "Error": {"Payload": {"ErrorMessage": "step failed"}, "Truncated": False}, + "RetryDetails": {"CurrentAttempt": 2, "NextAttemptDelaySeconds": 30}, + }, + } + + event_obj = Event.from_dict(data) + assert event_obj.event_type == "StepFailed" + assert event_obj.step_failed_details is not None + assert event_obj.step_failed_details.error.payload.message == "step failed" + + result_data = event_obj.to_dict() + expected_data = { + "EventType": "StepFailed", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "EventId": 1, + "StepFailedDetails": { + "Error": {"Payload": {"ErrorMessage": "step failed"}, "Truncated": False}, + "RetryDetails": {"CurrentAttempt": 2, "NextAttemptDelaySeconds": 30}, + }, + } + assert result_data == expected_data + + +def test_event_with_invoke_started_details(): + """Test Event with ChainedInvokeStartedDetails.""" + data = { + "EventType": "ChainedInvokeStarted", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "ChainedInvokeStartedDetails": { + "DurableExecutionArn": "arn:aws:durable-execution:us-east-1:123456789012:execution:my-execution:1234567890", + }, + } + + event_obj = Event.from_dict(data) + assert event_obj.event_type == "ChainedInvokeStarted" + assert event_obj.chained_invoke_started_details is not None + + result_data = event_obj.to_dict() + expected_data = { + "EventType": "ChainedInvokeStarted", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "EventId": 1, + "ChainedInvokeStartedDetails": { + "DurableExecutionArn": "arn:aws:durable-execution:us-east-1:123456789012:execution:my-execution:1234567890", + }, + } + assert result_data == expected_data + + +def test_event_with_invoke_succeeded_details(): + """Test Event with ChainedInvokeSucceededDetails.""" + data = { + "EventType": "ChainedInvokeSucceeded", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "ChainedInvokeSucceededDetails": { + "Result": {"Payload": "invoke result", "Truncated": False} + }, + } + + event_obj = Event.from_dict(data) + assert event_obj.event_type == "ChainedInvokeSucceeded" + assert event_obj.chained_invoke_succeeded_details is not None + assert event_obj.chained_invoke_succeeded_details.result.payload == "invoke result" + + result_data = event_obj.to_dict() + expected_data = { + "EventType": "ChainedInvokeSucceeded", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "EventId": 1, + "ChainedInvokeSucceededDetails": { + "Result": {"Payload": "invoke result", "Truncated": False} + }, + } + assert result_data == expected_data + + +def test_event_with_invoke_failed_details(): + """Test Event with ChainedInvokeFailedDetails.""" + data = { + "EventType": "ChainedInvokeFailed", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "ChainedInvokeFailedDetails": { + "Error": {"Payload": {"ErrorMessage": "invoke failed"}, "Truncated": False} + }, + } + + event_obj = Event.from_dict(data) + assert event_obj.event_type == "ChainedInvokeFailed" + assert event_obj.chained_invoke_failed_details is not None + assert ( + event_obj.chained_invoke_failed_details.error.payload.message == "invoke failed" + ) + + result_data = event_obj.to_dict() + expected_data = { + "EventType": "ChainedInvokeFailed", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "EventId": 1, + "ChainedInvokeFailedDetails": { + "Error": {"Payload": {"ErrorMessage": "invoke failed"}, "Truncated": False} + }, + } + assert result_data == expected_data + + +def test_event_with_invoke_timed_out_details(): + """Test Event with ChainedInvokeTimedOutDetails.""" + data = { + "EventType": "ChainedInvokeTimedOut", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "ChainedInvokeTimedOutDetails": { + "Error": { + "Payload": {"ErrorMessage": "invoke timed out"}, + "Truncated": False, + } + }, + } + + event_obj = Event.from_dict(data) + assert event_obj.event_type == "ChainedInvokeTimedOut" + assert event_obj.chained_invoke_timed_out_details is not None + assert ( + event_obj.chained_invoke_timed_out_details.error.payload.message + == "invoke timed out" + ) + + result_data = event_obj.to_dict() + expected_data = { + "EventType": "ChainedInvokeTimedOut", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "EventId": 1, + "ChainedInvokeTimedOutDetails": { + "Error": { + "Payload": {"ErrorMessage": "invoke timed out"}, + "Truncated": False, + } + }, + } + assert result_data == expected_data + + +def test_event_with_invoke_stopped_details(): + """Test Event with ChainedInvokeStoppedDetails.""" + data = { + "EventType": "ChainedInvokeStopped", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "ChainedInvokeStoppedDetails": { + "Error": {"Payload": {"ErrorMessage": "invoke stopped"}, "Truncated": False} + }, + } + + event_obj = Event.from_dict(data) + assert event_obj.event_type == "ChainedInvokeStopped" + assert event_obj.chained_invoke_stopped_details is not None + assert ( + event_obj.chained_invoke_stopped_details.error.payload.message + == "invoke stopped" + ) + + result_data = event_obj.to_dict() + expected_data = { + "EventType": "ChainedInvokeStopped", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "EventId": 1, + "ChainedInvokeStoppedDetails": { + "Error": {"Payload": {"ErrorMessage": "invoke stopped"}, "Truncated": False} + }, + } + assert result_data == expected_data + + +def test_event_with_callback_started_details(): + """Test Event with CallbackStartedDetails.""" + data = { + "EventType": "CallbackStarted", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "CallbackStartedDetails": { + "CallbackId": "callback-123", + "HeartbeatTimeout": 60, + "Timeout": 300, + }, + } + + event_obj = Event.from_dict(data) + assert event_obj.event_type == "CallbackStarted" + assert event_obj.callback_started_details is not None + assert event_obj.callback_started_details.callback_id == "callback-123" + + result_data = event_obj.to_dict() + expected_data = { + "EventType": "CallbackStarted", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "EventId": 1, + "CallbackStartedDetails": { + "CallbackId": "callback-123", + "HeartbeatTimeout": 60, + "Timeout": 300, + }, + } + assert result_data == expected_data + + +def test_event_with_callback_succeeded_details(): + """Test Event with CallbackSucceededDetails.""" + data = { + "EventType": "CallbackSucceeded", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "CallbackSucceededDetails": { + "Result": {"Payload": "callback result", "Truncated": False} + }, + } + + event_obj = Event.from_dict(data) + assert event_obj.event_type == "CallbackSucceeded" + assert event_obj.callback_succeeded_details is not None + assert event_obj.callback_succeeded_details.result.payload == "callback result" + + result_data = event_obj.to_dict() + expected_data = { + "EventType": "CallbackSucceeded", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "EventId": 1, + "CallbackSucceededDetails": { + "Result": {"Payload": "callback result", "Truncated": False} + }, + } + assert result_data == expected_data + + +def test_event_with_callback_failed_details(): + """Test Event with CallbackFailedDetails.""" + data = { + "EventType": "CallbackFailed", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "CallbackFailedDetails": { + "Error": { + "Payload": {"ErrorMessage": "callback failed"}, + "Truncated": False, + } + }, + } + + event_obj = Event.from_dict(data) + assert event_obj.event_type == "CallbackFailed" + assert event_obj.callback_failed_details is not None + assert event_obj.callback_failed_details.error.payload.message == "callback failed" + + result_data = event_obj.to_dict() + expected_data = { + "EventType": "CallbackFailed", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "EventId": 1, + "CallbackFailedDetails": { + "Error": { + "Payload": {"ErrorMessage": "callback failed"}, + "Truncated": False, + } + }, + } + assert result_data == expected_data + + +def test_event_with_callback_timed_out_details(): + """Test Event with CallbackTimedOutDetails.""" + data = { + "EventType": "CallbackTimedOut", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "CallbackTimedOutDetails": { + "Error": { + "Payload": {"ErrorMessage": "callback timed out"}, + "Truncated": False, + } + }, + } + + event_obj = Event.from_dict(data) + assert event_obj.event_type == "CallbackTimedOut" + assert event_obj.callback_timed_out_details is not None + assert ( + event_obj.callback_timed_out_details.error.payload.message + == "callback timed out" + ) + + result_data = event_obj.to_dict() + expected_data = { + "EventType": "CallbackTimedOut", + "EventTimestamp": TIMESTAMP_2023_01_01_00_01, + "EventId": 1, + "CallbackTimedOutDetails": { + "Error": { + "Payload": {"ErrorMessage": "callback timed out"}, + "Truncated": False, + } + }, + } + assert result_data == expected_data + + +# Tests for GetDurableExecutionHistoryRequest with all optional fields +def test_get_durable_execution_history_request_all_optional_fields(): + """Test GetDurableExecutionHistoryRequest to_dict with all optional fields as None.""" + request_obj = GetDurableExecutionHistoryRequest( + durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test", + include_execution_data=None, + reverse_order=None, + marker=None, + max_items=None, + ) + + result_data = request_obj.to_dict() + expected_data = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test", + } + assert result_data == expected_data + + +def test_get_durable_execution_history_request_partial_fields(): + """Test GetDurableExecutionHistoryRequest to_dict with some optional fields.""" + request_obj = GetDurableExecutionHistoryRequest( + durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test", + include_execution_data=True, + reverse_order=None, + marker="marker-123", + max_items=20, + ) + + result_data = request_obj.to_dict() + expected_data = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:my-function:execution:test", + "IncludeExecutionData": True, + "Marker": "marker-123", + "MaxItems": 20, + } + assert result_data == expected_data + + +# Tests for ListDurableExecutionsByFunctionRequest with all optional fields +def test_list_durable_executions_by_function_request_all_optional_fields(): + """Test ListDurableExecutionsByFunctionRequest to_dict with all optional fields as None.""" + request_obj = ListDurableExecutionsByFunctionRequest( + function_name="my-function", + qualifier=None, + status_filter=None, + started_after=None, + started_before=None, + marker=None, + max_items=None, + reverse_order=None, + ) + + result_data = request_obj.to_dict() + expected_data = { + "FunctionName": "my-function", + } + assert result_data == expected_data + + +def test_list_durable_executions_by_function_request_partial_fields(): + """Test ListDurableExecutionsByFunctionRequest to_dict with some optional fields.""" + request_obj = ListDurableExecutionsByFunctionRequest( + function_name="my-function", + qualifier="$LATEST", + status_filter=["RUNNING"], + started_after=None, + started_before=TIMESTAMP_2023_01_02_00_00, + marker=None, + max_items=15, + reverse_order=True, + ) + + result_data = request_obj.to_dict() + expected_data = { + "FunctionName": "my-function", + "Qualifier": "$LATEST", + "StatusFilter": ["RUNNING"], + "StartedBefore": TIMESTAMP_2023_01_02_00_00, + "MaxItems": 15, + "ReverseOrder": True, + } + assert result_data == expected_data + + +# Tests for SendDurableExecutionCallbackSuccessRequest with optional result +def test_send_durable_execution_callback_success_request_with_result(): + """Test SendDurableExecutionCallbackSuccessRequest to_dict with result.""" + request_obj = SendDurableExecutionCallbackSuccessRequest( + callback_id="callback-123", + result="success-result", + ) + + result_data = request_obj.to_dict() + expected_data = { + "CallbackId": "callback-123", + "Result": "success-result", + } + assert result_data == expected_data + + +# Tests for SendDurableExecutionCallbackFailureRequest with optional error +def test_send_durable_execution_callback_failure_request_with_error(): + """Test SendDurableExecutionCallbackFailureRequest to_dict with error.""" + request_obj = SendDurableExecutionCallbackFailureRequest( + callback_id="callback-123", + error=None, + ) + + result_data = request_obj.to_dict() + expected_data = { + "CallbackId": "callback-123", + } + assert result_data == expected_data + + +# Test for missing coverage in ListDurableExecutionsByFunctionRequest +def test_list_durable_executions_by_function_request_with_durable_execution_name(): + """Test ListDurableExecutionsByFunctionRequest to_dict with durable_execution_name.""" + request_obj = ListDurableExecutionsByFunctionRequest( + function_name="my-function", + qualifier=None, + durable_execution_name="specific-execution", + status_filter=None, + started_after=None, + started_before=None, + marker=None, + max_items=None, + reverse_order=None, + ) + + result_data = request_obj.to_dict() + expected_data = { + "FunctionName": "my-function", + "DurableExecutionName": "specific-execution", + } + assert result_data == expected_data + + +# Test for missing branch coverage in CheckpointDurableExecutionResponse +def test_checkpoint_updated_execution_state_with_next_marker(): + """Test CheckpointUpdatedExecutionState to_dict with next_marker.""" + from aws_durable_execution_sdk_python.lambda_service import ( + Operation, + OperationStatus, + OperationType, + ) + + operation = Operation( + operation_id="op-1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + ) + + state_obj = CheckpointUpdatedExecutionState( + operations=[operation], + next_marker="next-marker-123", + ) + + result_data = state_obj.to_dict() + expected_data = { + "Operations": [{"Id": "op-1", "Type": "STEP", "Status": "SUCCEEDED"}], + "NextMarker": "next-marker-123", + } + assert result_data == expected_data + + +# Tests for events_to_operations function + + +def test_events_to_operations_empty_list(): + """Test events_to_operations with empty event list.""" + from aws_durable_execution_sdk_python_testing.model import events_to_operations + + operations = events_to_operations([]) + assert operations == [] + + +def test_events_to_operations_execution_started(): + """Test events_to_operations with ExecutionStarted event.""" + event = Event( + event_type="ExecutionStarted", + event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC), + operation_id="exec-1", + execution_started_details=ExecutionStartedDetails( + input=EventInput(payload="test-input", truncated=False), + execution_timeout=300, + ), + ) + + operations = events_to_operations([event]) + + assert len(operations) == 1 + assert operations[0].operation_id == "exec-1" + assert operations[0].operation_type == OperationType.EXECUTION + assert operations[0].status == OperationStatus.STARTED + assert operations[0].execution_details.input_payload == "test-input" + + +def test_events_to_operations_callback_lifecycle(): + """Test events_to_operations with complete callback lifecycle.""" + + started_event = Event( + event_type="CallbackStarted", + event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC), + operation_id="cb-1", + name="test-callback", + callback_started_details=CallbackStartedDetails(callback_id="callback-123"), + ) + + succeeded_event = Event( + event_type="CallbackSucceeded", + event_timestamp=datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC), + operation_id="cb-1", + callback_succeeded_details=CallbackSucceededDetails( + result=EventResult(payload="callback-result", truncated=False) + ), + ) + + operations = events_to_operations([started_event, succeeded_event]) + + assert len(operations) == 1 + assert operations[0].operation_id == "cb-1" + assert operations[0].operation_type == OperationType.CALLBACK + assert operations[0].status == OperationStatus.SUCCEEDED + assert operations[0].name == "test-callback" + assert operations[0].callback_details.callback_id == "callback-123" + assert operations[0].callback_details.result == "callback-result" + assert operations[0].callback_details.error is None + + +def test_events_to_operations_missing_event_type(): + """Test events_to_operations raises error for missing event_type.""" + event = Event( + event_type=None, + event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC), + ) + + with pytest.raises( + InvalidParameterValueException, match="Missing required 'event_type' field" + ): + events_to_operations([event]) + + +def test_events_to_operations_unknown_event_type(): + """Test events_to_operations raises error for unknown event type.""" + event = Event( + event_type="UnknownEventType", + event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC), + operation_id="op-1", + ) + + with pytest.raises( + InvalidParameterValueException, match="Unknown event type: UnknownEventType" + ): + events_to_operations([event]) + + +def test_events_to_operations_missing_operation_id(): + """Test events_to_operations raises error for missing operation_id.""" + event = Event( + event_type="StepStarted", + event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC), + operation_id=None, + ) + + with pytest.raises( + InvalidParameterValueException, match="Missing required 'operation_id' field" + ): + events_to_operations([event]) + + +def test_events_to_operations_step_with_retry(): + """Test events_to_operations with step retry details.""" + import datetime + + from aws_durable_execution_sdk_python.lambda_service import ( + OperationStatus, + OperationType, + ) + + from aws_durable_execution_sdk_python_testing.model import ( + Event, + EventResult, + RetryDetails, + StepSucceededDetails, + events_to_operations, + ) + + succeeded_event = Event( + event_type="StepSucceeded", + event_timestamp=datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC), + operation_id="step-1", + name="test-step", + step_succeeded_details=StepSucceededDetails( + result=EventResult(payload="step-result", truncated=False), + retry_details=RetryDetails(current_attempt=2), + ), + ) + + operations = events_to_operations([succeeded_event]) + + assert len(operations) == 1 + assert operations[0].operation_type == OperationType.STEP + assert operations[0].status == OperationStatus.SUCCEEDED + assert operations[0].step_details.result == "step-result" + assert operations[0].step_details.attempt == 2 + + +def test_events_to_operations_step_failed_with_next_attempt(): + """Test events_to_operations with failed step and next attempt timestamp.""" + import datetime + + from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + OperationStatus, + ) + + from aws_durable_execution_sdk_python_testing.model import ( + Event, + EventError, + RetryDetails, + StepFailedDetails, + events_to_operations, + ) + + event_time = datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC) + failed_event = Event( + event_type="StepFailed", + event_timestamp=event_time, + operation_id="step-1", + step_failed_details=StepFailedDetails( + error=EventError( + payload=ErrorObject( + message="step failed", type=None, data=None, stack_trace=None + ) + ), + retry_details=RetryDetails( + current_attempt=1, next_attempt_delay_seconds=10 + ), + ), + ) + + operations = events_to_operations([failed_event]) + + assert len(operations) == 1 + assert operations[0].status == OperationStatus.FAILED + assert operations[0].step_details.error.message == "step failed" + assert operations[0].step_details.attempt == 1 + expected_next_attempt = event_time + datetime.timedelta(seconds=10) + assert operations[0].step_details.next_attempt_timestamp == expected_next_attempt + + +def test_events_to_operations_context_succeeded(): + """Test events_to_operations with successful context.""" + import datetime + + from aws_durable_execution_sdk_python.lambda_service import ( + OperationStatus, + OperationType, + ) + + from aws_durable_execution_sdk_python_testing.model import ( + ContextSucceededDetails, + Event, + EventResult, + events_to_operations, + ) + + succeeded_event = Event( + event_type="ContextSucceeded", + event_timestamp=datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC), + operation_id="ctx-1", + name="test-context", + context_succeeded_details=ContextSucceededDetails( + result=EventResult(payload="context-result", truncated=False) + ), + ) + + operations = events_to_operations([succeeded_event]) + + assert len(operations) == 1 + assert operations[0].operation_type == OperationType.CONTEXT + assert operations[0].status == OperationStatus.SUCCEEDED + assert operations[0].context_details.result == "context-result" + assert operations[0].context_details.error is None + + +def test_events_to_operations_chained_invoke_succeeded(): + """Test events_to_operations with successful chained invoke.""" + import datetime + + from aws_durable_execution_sdk_python.lambda_service import ( + OperationStatus, + OperationType, + ) + + from aws_durable_execution_sdk_python_testing.model import ( + ChainedInvokeSucceededDetails, + Event, + EventResult, + events_to_operations, + ) + + succeeded_event = Event( + event_type="ChainedInvokeSucceeded", + event_timestamp=datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC), + operation_id="invoke-1", + name="test-invoke", + chained_invoke_succeeded_details=ChainedInvokeSucceededDetails( + result=EventResult(payload="invoke-result", truncated=False) + ), + ) + + operations = events_to_operations([succeeded_event]) + + assert len(operations) == 1 + assert operations[0].operation_type == OperationType.CHAINED_INVOKE + assert operations[0].status == OperationStatus.SUCCEEDED + assert operations[0].chained_invoke_details.result == "invoke-result" + assert operations[0].chained_invoke_details.error is None + + +def test_events_to_operations_skips_invocation_completed(): + """Test events_to_operations skips InvocationCompleted events.""" + invocation_event = Event( + event_type="InvocationCompleted", + event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC), + operation_id="invocation-1", + ) + + operations = events_to_operations([invocation_event]) + assert len(operations) == 0 + + +def test_events_to_operations_callback_failed(): + """Test events_to_operations with failed callback.""" + import datetime + + from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + OperationStatus, + ) + + from aws_durable_execution_sdk_python_testing.model import ( + CallbackFailedDetails, + CallbackStartedDetails, + Event, + EventError, + events_to_operations, + ) + + started_event = Event( + event_type="CallbackStarted", + event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC), + operation_id="cb-1", + callback_started_details=CallbackStartedDetails(callback_id="callback-123"), + ) + + failed_event = Event( + event_type="CallbackFailed", + event_timestamp=datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC), + operation_id="cb-1", + callback_failed_details=CallbackFailedDetails( + error=EventError( + payload=ErrorObject( + message="callback failed", type=None, data=None, stack_trace=None + ) + ) + ), + ) + + operations = events_to_operations([started_event, failed_event]) + + assert len(operations) == 1 + assert operations[0].status == OperationStatus.FAILED + assert operations[0].callback_details.error.message == "callback failed" + assert operations[0].callback_details.result is None + + +def test_events_to_operations_callback_timed_out(): + """Test events_to_operations with timed out callback.""" + import datetime + + from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + OperationStatus, + ) + + from aws_durable_execution_sdk_python_testing.model import ( + CallbackStartedDetails, + CallbackTimedOutDetails, + Event, + EventError, + events_to_operations, + ) + + started_event = Event( + event_type="CallbackStarted", + event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC), + operation_id="cb-1", + callback_started_details=CallbackStartedDetails(callback_id="callback-123"), + ) + + timed_out_event = Event( + event_type="CallbackTimedOut", + event_timestamp=datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC), + operation_id="cb-1", + callback_timed_out_details=CallbackTimedOutDetails( + error=EventError( + payload=ErrorObject( + message="callback timed out", type=None, data=None, stack_trace=None + ) + ) + ), + ) + + operations = events_to_operations([started_event, timed_out_event]) + + assert len(operations) == 1 + assert operations[0].status == OperationStatus.TIMED_OUT + assert operations[0].callback_details.error.message == "callback timed out" + + +def test_events_to_operations_wait_started(): + """Test events_to_operations with wait operation.""" + import datetime + + from aws_durable_execution_sdk_python.lambda_service import ( + OperationStatus, + OperationType, + ) + + from aws_durable_execution_sdk_python_testing.model import ( + Event, + WaitStartedDetails, + events_to_operations, + ) + + scheduled_time = datetime.datetime(2023, 1, 1, 1, 0, 0, tzinfo=datetime.UTC) + wait_event = Event( + event_type="WaitStarted", + event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC), + operation_id="wait-1", + name="test-wait", + wait_started_details=WaitStartedDetails( + duration=3600, scheduled_end_timestamp=scheduled_time + ), + ) + + operations = events_to_operations([wait_event]) + + assert len(operations) == 1 + assert operations[0].operation_type == OperationType.WAIT + assert operations[0].status == OperationStatus.STARTED + assert operations[0].wait_details.scheduled_end_timestamp == scheduled_time + + +def test_events_to_operations_context_failed(): + """Test events_to_operations with failed context.""" + import datetime + + from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + OperationStatus, + ) + + from aws_durable_execution_sdk_python_testing.model import ( + ContextFailedDetails, + Event, + EventError, + events_to_operations, + ) + + failed_event = Event( + event_type="ContextFailed", + event_timestamp=datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC), + operation_id="ctx-1", + context_failed_details=ContextFailedDetails( + error=EventError( + payload=ErrorObject( + message="context failed", type=None, data=None, stack_trace=None + ) + ) + ), + ) + + operations = events_to_operations([failed_event]) + + assert len(operations) == 1 + assert operations[0].status == OperationStatus.FAILED + assert operations[0].context_details.error.message == "context failed" + assert operations[0].context_details.result is None + + +def test_events_to_operations_chained_invoke_failed(): + """Test events_to_operations with failed chained invoke.""" + import datetime + + from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + OperationStatus, + ) + + from aws_durable_execution_sdk_python_testing.model import ( + ChainedInvokeFailedDetails, + Event, + EventError, + events_to_operations, + ) + + failed_event = Event( + event_type="ChainedInvokeFailed", + event_timestamp=datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC), + operation_id="invoke-1", + chained_invoke_failed_details=ChainedInvokeFailedDetails( + error=EventError( + payload=ErrorObject( + message="invoke failed", type=None, data=None, stack_trace=None + ) + ) + ), + ) + + operations = events_to_operations([failed_event]) + + assert len(operations) == 1 + assert operations[0].status == OperationStatus.FAILED + assert operations[0].chained_invoke_details.error.message == "invoke failed" + assert operations[0].chained_invoke_details.result is None + + +def test_events_to_operations_multiple_operations(): + """Test events_to_operations with multiple different operations.""" + import datetime + + from aws_durable_execution_sdk_python.lambda_service import ( + OperationStatus, + OperationType, + ) + + from aws_durable_execution_sdk_python_testing.model import ( + Event, + EventResult, + StepSucceededDetails, + events_to_operations, + ) + + events = [ + Event( + event_type="StepStarted", + event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC), + operation_id="step-1", + name="step-one", + ), + Event( + event_type="StepSucceeded", + event_timestamp=datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC), + operation_id="step-1", + step_succeeded_details=StepSucceededDetails( + result=EventResult(payload="result-1", truncated=False) + ), + ), + Event( + event_type="WaitStarted", + event_timestamp=datetime.datetime(2023, 1, 1, 0, 2, 0, tzinfo=datetime.UTC), + operation_id="wait-1", + name="wait-one", + ), + ] + + operations = events_to_operations(events) + + assert len(operations) == 2 + step_op = next(op for op in operations if op.operation_id == "step-1") + wait_op = next(op for op in operations if op.operation_id == "wait-1") + + assert step_op.operation_type == OperationType.STEP + assert step_op.status == OperationStatus.SUCCEEDED + assert step_op.name == "step-one" + + assert wait_op.operation_type == OperationType.WAIT + assert wait_op.status == OperationStatus.STARTED + assert wait_op.name == "wait-one" + + +def test_events_to_operations_merges_timestamps(): + """Test events_to_operations correctly merges start and end timestamps.""" + import datetime + + from aws_durable_execution_sdk_python_testing.model import ( + Event, + EventResult, + StepSucceededDetails, + events_to_operations, + ) + + start_time = datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC) + end_time = datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC) + + events = [ + Event( + event_type="StepStarted", + event_timestamp=start_time, + operation_id="step-1", + ), + Event( + event_type="StepSucceeded", + event_timestamp=end_time, + operation_id="step-1", + step_succeeded_details=StepSucceededDetails( + result=EventResult(payload="result", truncated=False) + ), + ), + ] + + operations = events_to_operations(events) + + assert len(operations) == 1 + assert operations[0].start_timestamp == start_time + assert operations[0].end_timestamp == end_time + + +def test_events_to_operations_preserves_parent_id(): + """Test events_to_operations preserves parent_id from events.""" + event = Event( + event_type="StepStarted", + event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC), + operation_id="step-1", + parent_id="parent-ctx", + name="child-step", + ) + + operations = events_to_operations([event]) + + assert len(operations) == 1 + assert operations[0].parent_id == "parent-ctx" + + +def test_events_to_operations_preserves_sub_type(): + """Test events_to_operations preserves sub_type from events.""" + event = Event( + event_type="StepStarted", + event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC), + operation_id="step-1", + sub_type="Step", + ) + + operations = events_to_operations([event]) + + assert len(operations) == 1 + assert operations[0].sub_type is not None + assert operations[0].sub_type.value == "Step" + + +def test_events_to_operations_invalid_sub_type(): + """Test events_to_operations raises InvalidParameterValueException when sub_type is invalid.""" + invalid_sub_type: str = "INVALID_SUB_TYPE" + event = Event( + event_type="StepStarted", + event_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC), + operation_id="step-1", + sub_type=invalid_sub_type, + ) + + with pytest.raises( + InvalidParameterValueException, + match=f"'{invalid_sub_type}' is not a valid OperationSubType", + ): + events_to_operations([event]) + + +def test_invocation_completed_details_to_json_dict(): + """Test InvocationCompletedDetails.to_json_dict() converts datetime to Unix milliseconds.""" + start_time = datetime.datetime(2023, 1, 1, 0, 0, 0, 123456, tzinfo=datetime.UTC) + end_time = datetime.datetime(2023, 1, 1, 0, 1, 0, 456789, tzinfo=datetime.UTC) + + details = InvocationCompletedDetails( + start_timestamp=start_time, end_timestamp=end_time, request_id="req-123" + ) + + json_dict = details.to_json_dict() + + # Verify timestamps are converted to Unix milliseconds (integers) + assert json_dict["StartTimestamp"] == 1672531200123 + assert json_dict["EndTimestamp"] == 1672531260456 + assert json_dict["RequestId"] == "req-123" + + # Verify all values are JSON-serializable + json_str = json.dumps(json_dict) + assert json_str is not None + + +def test_invocation_completed_details_from_json_dict(): + """Test InvocationCompletedDetails.from_json_dict() converts Unix milliseconds to datetime.""" + json_dict = { + "StartTimestamp": 1672531200123, + "EndTimestamp": 1672531260456, + "RequestId": "req-456", + } + + details = InvocationCompletedDetails.from_json_dict(json_dict) + + # Verify timestamps are converted to datetime objects + assert details.start_timestamp == datetime.datetime( + 2023, 1, 1, 0, 0, 0, 123000, tzinfo=datetime.UTC + ) + assert details.end_timestamp == datetime.datetime( + 2023, 1, 1, 0, 1, 0, 456000, tzinfo=datetime.UTC + ) + assert details.request_id == "req-456" + + +def test_invocation_completed_details_json_round_trip(): + """Test InvocationCompletedDetails to_json_dict/from_json_dict round-trip.""" + original = InvocationCompletedDetails( + start_timestamp=datetime.datetime( + 2023, 6, 15, 12, 30, 45, 678000, tzinfo=datetime.UTC + ), + end_timestamp=datetime.datetime( + 2023, 6, 15, 12, 31, 50, 123000, tzinfo=datetime.UTC + ), + request_id="round-trip-test", + ) + + # Serialize to JSON dict + json_dict = original.to_json_dict() + + # Deserialize back + restored = InvocationCompletedDetails.from_json_dict(json_dict) + + # Verify round-trip preserves data + assert restored.start_timestamp == original.start_timestamp + assert restored.end_timestamp == original.end_timestamp + assert restored.request_id == original.request_id + + +def test_invocation_completed_details_to_dict_preserves_datetime(): + """Test InvocationCompletedDetails.to_dict() preserves datetime objects (not converted).""" + start_time = datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC) + end_time = datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC) + + details = InvocationCompletedDetails( + start_timestamp=start_time, end_timestamp=end_time, request_id="req-789" + ) + + regular_dict = details.to_dict() + + # Verify to_dict() preserves datetime objects (not converted to Unix milliseconds) + assert regular_dict["StartTimestamp"] == start_time + assert regular_dict["EndTimestamp"] == end_time + assert isinstance(regular_dict["StartTimestamp"], datetime.datetime) + assert isinstance(regular_dict["EndTimestamp"], datetime.datetime) + + +def test_invocation_completed_details_from_json_dict_invalid_timestamp(): + """Test InvocationCompletedDetails.from_json_dict() raises error for invalid timestamps.""" + # Test with invalid timestamp that would return None + json_dict = { + "StartTimestamp": None, + "EndTimestamp": 1672531260456, + "RequestId": "req-error", + } + + with pytest.raises( + InvalidParameterValueException, + match="StartTimestamp and EndTimestamp cannot be null", + ): + InvocationCompletedDetails.from_json_dict(json_dict) diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/observer_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/observer_test.py new file mode 100644 index 0000000..193f395 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/observer_test.py @@ -0,0 +1,355 @@ +"""Tests for observer module.""" + +import inspect +import threading +from unittest.mock import Mock + +import pytest +from aws_durable_execution_sdk_python.lambda_service import ErrorObject, CallbackOptions + +from aws_durable_execution_sdk_python_testing.observer import ( + ExecutionNotifier, + ExecutionObserver, +) +from aws_durable_execution_sdk_python_testing.token import CallbackToken + + +class MockExecutionObserver(ExecutionObserver): + """Mock implementation of ExecutionObserver for testing.""" + + def __init__(self): + self.on_completed_calls = [] + self.on_failed_calls = [] + self.on_timed_out_calls = [] + self.on_stopped_calls = [] + self.on_wait_timer_scheduled_calls = [] + self.on_step_retry_scheduled_calls = [] + self.on_callback_created_calls = [] + + def on_completed(self, execution_arn: str, result: str | None = None) -> None: + self.on_completed_calls.append((execution_arn, result)) + + def on_failed(self, execution_arn: str, error: ErrorObject) -> None: + self.on_failed_calls.append((execution_arn, error)) + + def on_timed_out(self, execution_arn: str, error: ErrorObject) -> None: + self.on_timed_out_calls.append((execution_arn, error)) + + def on_stopped(self, execution_arn: str, error: ErrorObject) -> None: + self.on_stopped_calls.append((execution_arn, error)) + + def on_wait_timer_scheduled( + self, execution_arn: str, operation_id: str, delay: float + ) -> None: + self.on_wait_timer_scheduled_calls.append((execution_arn, operation_id, delay)) + + def on_step_retry_scheduled( + self, execution_arn: str, operation_id: str, delay: float + ) -> None: + self.on_step_retry_scheduled_calls.append((execution_arn, operation_id, delay)) + + def on_callback_created( + self, + execution_arn: str, + operation_id: str, + callback_options: CallbackOptions | None, + callback_token: CallbackToken, + ) -> None: + self.on_callback_created_calls.append( + (execution_arn, operation_id, callback_options, callback_token) + ) + + +def test_execution_notifier_init(): + """Test ExecutionNotifier initialization.""" + notifier = ExecutionNotifier() + + assert notifier._observers == [] # noqa: SLF001 + assert notifier._lock is not None # noqa: SLF001 + + +def test_execution_notifier_add_observer(): + """Test adding an observer to ExecutionNotifier.""" + notifier = ExecutionNotifier() + observer = MockExecutionObserver() + + notifier.add_observer(observer) + + assert len(notifier._observers) == 1 # noqa: SLF001 + assert notifier._observers[0] is observer # noqa: SLF001 + + +def test_execution_notifier_add_multiple_observers(): + """Test adding multiple observers to ExecutionNotifier.""" + notifier = ExecutionNotifier() + observer1 = MockExecutionObserver() + observer2 = MockExecutionObserver() + + notifier.add_observer(observer1) + notifier.add_observer(observer2) + + assert len(notifier._observers) == 2 # noqa: SLF001 + assert observer1 in notifier._observers # noqa: SLF001 + assert observer2 in notifier._observers # noqa: SLF001 + + +def test_execution_notifier_notify_completed(): + """Test notifying observers about execution completion.""" + notifier = ExecutionNotifier() + observer = MockExecutionObserver() + notifier.add_observer(observer) + + execution_arn = "test-arn" + result = "test-result" + + notifier.notify_completed(execution_arn, result) + + assert len(observer.on_completed_calls) == 1 + assert observer.on_completed_calls[0] == (execution_arn, result) + + +def test_execution_notifier_notify_completed_no_result(): + """Test notifying observers about execution completion with no result.""" + notifier = ExecutionNotifier() + observer = MockExecutionObserver() + notifier.add_observer(observer) + + execution_arn = "test-arn" + + notifier.notify_completed(execution_arn) + + assert len(observer.on_completed_calls) == 1 + assert observer.on_completed_calls[0] == (execution_arn, None) + + +def test_execution_notifier_notify_failed(): + """Test notifying observers about execution failure.""" + notifier = ExecutionNotifier() + observer = MockExecutionObserver() + notifier.add_observer(observer) + + execution_arn = "test-arn" + error = ErrorObject( + "TestError", "Test error message", "test-data", ["stack", "trace"] + ) + + notifier.notify_failed(execution_arn, error) + + assert len(observer.on_failed_calls) == 1 + assert observer.on_failed_calls[0] == (execution_arn, error) + + +def test_execution_notifier_notify_wait_timer_scheduled(): + """Test notifying observers about wait timer scheduling.""" + notifier = ExecutionNotifier() + observer = MockExecutionObserver() + notifier.add_observer(observer) + + execution_arn = "test-arn" + operation_id = "test-operation" + delay = 5.0 + + notifier.notify_wait_timer_scheduled(execution_arn, operation_id, delay) + + assert len(observer.on_wait_timer_scheduled_calls) == 1 + assert observer.on_wait_timer_scheduled_calls[0] == ( + execution_arn, + operation_id, + delay, + ) + + +def test_execution_notifier_notify_step_retry_scheduled(): + """Test notifying observers about step retry scheduling.""" + notifier = ExecutionNotifier() + observer = MockExecutionObserver() + notifier.add_observer(observer) + + execution_arn = "test-arn" + operation_id = "test-operation" + delay = 10.0 + + notifier.notify_step_retry_scheduled(execution_arn, operation_id, delay) + + assert len(observer.on_step_retry_scheduled_calls) == 1 + assert observer.on_step_retry_scheduled_calls[0] == ( + execution_arn, + operation_id, + delay, + ) + + +def test_execution_notifier_multiple_observers_all_notified(): + """Test that all observers are notified when multiple are registered.""" + notifier = ExecutionNotifier() + observer1 = MockExecutionObserver() + observer2 = MockExecutionObserver() + + notifier.add_observer(observer1) + notifier.add_observer(observer2) + + execution_arn = "test-arn" + result = "test-result" + + notifier.notify_completed(execution_arn, result) + + # Both observers should be notified + assert len(observer1.on_completed_calls) == 1 + assert observer1.on_completed_calls[0] == (execution_arn, result) + assert len(observer2.on_completed_calls) == 1 + assert observer2.on_completed_calls[0] == (execution_arn, result) + + +def test_execution_notifier_no_observers(): + """Test that notifications work even with no observers.""" + notifier = ExecutionNotifier() + + # Should not raise any exceptions + notifier.notify_completed("test-arn", "result") + notifier.notify_failed( + "test-arn", ErrorObject("Error", "Message", "data", ["trace"]) + ) + notifier.notify_wait_timer_scheduled("test-arn", "op-id", 1.0) + notifier.notify_step_retry_scheduled("test-arn", "op-id", 2.0) + + +def test_execution_notifier_thread_safety(): + """Test that ExecutionNotifier is thread-safe.""" + notifier = ExecutionNotifier() + observer = MockExecutionObserver() + notifier.add_observer(observer) + + # Test concurrent access + def add_observer_thread(): + new_observer = MockExecutionObserver() + notifier.add_observer(new_observer) + + def notify_thread(): + notifier.notify_completed("test-arn", "result") + + threads = [] + for _ in range(5): + threads.append(threading.Thread(target=add_observer_thread)) + threads.append(threading.Thread(target=notify_thread)) + + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + # Should have original observer plus 5 more + assert len(notifier._observers) == 6 # noqa: SLF001 + # Original observer should have been notified multiple times + assert len(observer.on_completed_calls) >= 1 + + +def test_execution_observer_abstract_methods(): + """Test that ExecutionObserver is abstract and cannot be instantiated.""" + with pytest.raises(TypeError): + ExecutionObserver() + + +def test_mock_execution_observer_implementation(): + """Test that MockExecutionObserver properly implements all abstract methods.""" + observer = MockExecutionObserver() + + # Test all methods can be called + error = ErrorObject("Error", "Message", "data", ["trace"]) + observer.on_completed("arn", "result") + observer.on_failed("arn", error) + observer.on_timed_out("arn", error) + observer.on_stopped("arn", error) + observer.on_wait_timer_scheduled("arn", "op", 1.0) + observer.on_step_retry_scheduled("arn", "op", 2.0) + + # Verify calls were recorded + assert len(observer.on_completed_calls) == 1 + assert len(observer.on_failed_calls) == 1 + assert len(observer.on_timed_out_calls) == 1 + assert len(observer.on_stopped_calls) == 1 + assert len(observer.on_wait_timer_scheduled_calls) == 1 + assert len(observer.on_step_retry_scheduled_calls) == 1 + + +def test_execution_notifier_notify_observers_with_exception(): + """Test that exceptions in one observer don't affect others.""" + notifier = ExecutionNotifier() + + # Create a mock observer that raises an exception + failing_observer = Mock(spec=ExecutionObserver) + failing_observer.on_completed.side_effect = ValueError("Test exception") + + # Create a normal observer + normal_observer = MockExecutionObserver() + + notifier.add_observer(failing_observer) + notifier.add_observer(normal_observer) + + # This should raise an exception from the failing observer + with pytest.raises(ValueError, match="Test exception"): + notifier.notify_completed("test-arn", "result") + + # The normal observer should still have been called before the exception + failing_observer.on_completed.assert_called_once_with( + execution_arn="test-arn", result="result" + ) + + +def test_execution_observer_abstract_method_coverage(): + """Test coverage of abstract methods in ExecutionObserver.""" + # This test ensures we cover the abstract method definitions + # by checking they exist and have the correct signatures + + methods = inspect.getmembers(ExecutionObserver, predicate=inspect.isfunction) + method_names = [name for name, _ in methods] + + assert "on_completed" in method_names + assert "on_failed" in method_names + assert "on_timed_out" in method_names + assert "on_stopped" in method_names + assert "on_wait_timer_scheduled" in method_names + assert "on_step_retry_scheduled" in method_names + + +def test_execution_notifier_notify_observers_internal(): + """Test the internal _notify_observers method behavior.""" + notifier = ExecutionNotifier() + observer = MockExecutionObserver() + notifier.add_observer(observer) + + # Test that _notify_observers correctly calls the method on observers + notifier._notify_observers( # noqa: SLF001 + ExecutionObserver.on_completed, execution_arn="test", result="success" + ) + + assert len(observer.on_completed_calls) == 1 + assert observer.on_completed_calls[0] == ("test", "success") + + +def test_execution_notifier_all_notification_methods(): + """Test all notification methods with various parameter combinations.""" + notifier = ExecutionNotifier() + observer = MockExecutionObserver() + notifier.add_observer(observer) + + # Test notify_completed with positional args + notifier.notify_completed("arn1", "result1") + assert observer.on_completed_calls[-1] == ("arn1", "result1") + + # Test notify_completed with keyword args + notifier.notify_completed(execution_arn="arn2", result="result2") + assert observer.on_completed_calls[-1] == ("arn2", "result2") + + # Test notify_failed + error = ErrorObject("TestError", "Message", "data", ["trace"]) + notifier.notify_failed("arn3", error) + assert observer.on_failed_calls[-1] == ("arn3", error) + + # Test notify_wait_timer_scheduled + notifier.notify_wait_timer_scheduled("arn4", "op1", 5.5) + assert observer.on_wait_timer_scheduled_calls[-1] == ("arn4", "op1", 5.5) + + # Test notify_step_retry_scheduled + notifier.notify_step_retry_scheduled("arn5", "op2", 10.5) + assert observer.on_step_retry_scheduled_calls[-1] == ("arn5", "op2", 10.5) diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/pending_operation_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/pending_operation_test.py new file mode 100644 index 0000000..d508fbc --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/pending_operation_test.py @@ -0,0 +1,129 @@ +# """Test for pending operation handling in get_execution_history.""" +# +# from datetime import UTC, datetime +# from unittest.mock import Mock +# +# from aws_durable_execution_sdk_python.lambda_service import ( +# OperationStatus, +# OperationType, +# ) +# +# from aws_durable_execution_sdk_python_testing.executor import Executor +# from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput +# +# +# def test_get_execution_history_with_pending_chained_invoke(): +# """Test get_execution_history handles pending CHAINED_INVOKE operations correctly.""" +# # Create mocks +# mock_store = Mock() +# mock_scheduler = Mock() +# mock_invoker = Mock() +# mock_checkpoint_processor = Mock() +# +# executor = Executor(mock_store, mock_scheduler, mock_invoker, mock_checkpoint_processor) +# +# # Create mock execution +# mock_execution = Mock() +# mock_execution.durable_execution_arn = "test-arn" +# mock_execution.start_input = StartDurableExecutionInput( +# account_id="123", +# function_name="test", +# function_qualifier="$LATEST", +# execution_name="test", +# execution_timeout_seconds=300, +# execution_retention_period_days=7, +# ) +# mock_execution.result = None +# mock_execution.updates = [] +# +# # Create a pending CHAINED_INVOKE operation with start_timestamp +# pending_op = Mock() +# pending_op.operation_id = "invoke-1" +# pending_op.operation_type = OperationType.CHAINED_INVOKE +# pending_op.status = OperationStatus.PENDING +# pending_op.start_timestamp = datetime.now(UTC) +# pending_op.end_timestamp = None +# +# # Create a non-CHAINED_INVOKE pending operation (should be skipped) +# pending_step = Mock() +# pending_step.operation_id = "step-1" +# pending_step.operation_type = OperationType.STEP +# pending_step.status = OperationStatus.PENDING +# pending_step.start_timestamp = datetime.now(UTC) +# pending_step.end_timestamp = None +# +# # Create a CHAINED_INVOKE pending operation without start_timestamp (should be skipped) +# pending_invoke_no_timestamp = Mock() +# pending_invoke_no_timestamp.operation_id = "invoke-2" +# pending_invoke_no_timestamp.operation_type = OperationType.CHAINED_INVOKE +# pending_invoke_no_timestamp.status = OperationStatus.PENDING +# pending_invoke_no_timestamp.start_timestamp = None +# pending_invoke_no_timestamp.end_timestamp = None +# +# mock_execution.operations = [pending_op, pending_step, pending_invoke_no_timestamp] +# mock_store.load.return_value = mock_execution +# +# # Call get_execution_history +# result = executor.get_execution_history("test-arn", include_execution_data=True) +# +# # Should have 2 events: 1 pending event + 1 started event for the valid pending CHAINED_INVOKE +# assert len(result.events) == 2 +# +# # First event should be the pending event +# assert result.events[0].event_type == "ChainedInvokeStarted" +# assert result.events[0].operation_id == "invoke-1" +# assert result.events[0].chained_invoke_pending_details is not None +# +# # Second event should be the started event +# assert result.events[1].event_type == "ChainedInvokeStarted" +# assert result.events[1].operation_id == "invoke-1" +# assert result.events[1].chained_invoke_started_details is not None +# +# +# def test_get_execution_history_skips_invalid_pending_operations(): +# """Test that invalid pending operations are skipped.""" +# # Create mocks +# mock_store = Mock() +# mock_scheduler = Mock() +# mock_invoker = Mock() +# mock_checkpoint_processor = Mock() +# +# executor = Executor(mock_store, mock_scheduler, mock_invoker, mock_checkpoint_processor) +# +# # Create mock execution +# mock_execution = Mock() +# mock_execution.durable_execution_arn = "test-arn" +# mock_execution.start_input = StartDurableExecutionInput( +# account_id="123", +# function_name="test", +# function_qualifier="$LATEST", +# execution_name="test", +# execution_timeout_seconds=300, +# execution_retention_period_days=7, +# ) +# mock_execution.result = None +# mock_execution.updates = [] +# +# # Create operations that should be skipped +# # 1. Non-CHAINED_INVOKE pending operation +# pending_step = Mock() +# pending_step.operation_id = "step-1" +# pending_step.operation_type = OperationType.STEP +# pending_step.status = OperationStatus.PENDING +# pending_step.start_timestamp = datetime.now(UTC) +# +# # 2. CHAINED_INVOKE pending operation without start_timestamp +# pending_invoke_no_timestamp = Mock() +# pending_invoke_no_timestamp.operation_id = "invoke-1" +# pending_invoke_no_timestamp.operation_type = OperationType.CHAINED_INVOKE +# pending_invoke_no_timestamp.status = OperationStatus.PENDING +# pending_invoke_no_timestamp.start_timestamp = None +# +# mock_execution.operations = [pending_step, pending_invoke_no_timestamp] +# mock_store.load.return_value = mock_execution +# +# # Call get_execution_history +# result = executor.get_execution_history("test-arn") +# +# # Should have no events since all pending operations are invalid +# assert len(result.events) == 0 diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/runner_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/runner_test.py new file mode 100644 index 0000000..3b81269 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/runner_test.py @@ -0,0 +1,2180 @@ +"""Unit tests for runner module.""" + +import datetime +import json +from unittest.mock import Mock, patch + +import pytest +from aws_durable_execution_sdk_python.execution import InvocationStatus +from aws_durable_execution_sdk_python.lambda_service import ( + CallbackDetails, + ChainedInvokeDetails, + ContextDetails, + ExecutionDetails, + OperationStatus, + OperationType, + StepDetails, + WaitDetails, +) +from aws_durable_execution_sdk_python.lambda_service import Operation as SvcOperation + +from aws_durable_execution_sdk_python_testing.exceptions import ( + DurableFunctionsTestError, + InvalidParameterValueException, + ResourceNotFoundException, +) +from aws_durable_execution_sdk_python_testing.execution import Execution +from aws_durable_execution_sdk_python_testing.model import ( + StartDurableExecutionInput, + StartDurableExecutionOutput, + GetDurableExecutionHistoryResponse, +) +from aws_durable_execution_sdk_python_testing.runner import ( + OPERATION_FACTORIES, + CallbackOperation, + ContextOperation, + DurableChildContextTestRunner, + DurableFunctionTestResult, + DurableFunctionTestRunner, + ExecutionOperation, + InvokeOperation, + Operation, + StepOperation, + WaitOperation, + create_operation, +) + + +def test_operation_creation(): + """Test basic Operation creation.""" + op = Operation( + operation_id="test-id", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + parent_id="parent-id", + name="test-name", + sub_type="test-subtype", + start_timestamp=datetime.datetime.now(tz=datetime.UTC), + end_timestamp=datetime.datetime.now(tz=datetime.UTC), + ) + + assert op.operation_id == "test-id" + assert op.operation_type is OperationType.STEP + assert op.status is OperationStatus.SUCCEEDED + assert op.parent_id == "parent-id" + assert op.name == "test-name" + assert op.sub_type == "test-subtype" + + +def test_execution_operation_from_svc_operation(): + """Test ExecutionOperation creation from service operation.""" + execution_details = ExecutionDetails(input_payload="test-input") + svc_op = SvcOperation( + operation_id="exec-id", + operation_type=OperationType.EXECUTION, + status=OperationStatus.SUCCEEDED, + execution_details=execution_details, + ) + + exec_op = ExecutionOperation.from_svc_operation(svc_op) + + assert exec_op.operation_id == "exec-id" + assert exec_op.operation_type is OperationType.EXECUTION + assert exec_op.input_payload == "test-input" + + +def test_execution_operation_wrong_type(): + """Test ExecutionOperation raises error for wrong operation type.""" + svc_op = SvcOperation( + operation_id="test-id", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + ) + + with pytest.raises( + InvalidParameterValueException, + match="Expected EXECUTION operation, got OperationType.STEP", + ): + ExecutionOperation.from_svc_operation(svc_op) + + +def test_context_operation_from_svc_operation(): + """Test ContextOperation creation from service operation.""" + context_details = ContextDetails(result=json.dumps("test-result"), error=None) + svc_op = SvcOperation( + operation_id="ctx-id", + operation_type=OperationType.CONTEXT, + status=OperationStatus.SUCCEEDED, + context_details=context_details, + ) + + ctx_op = ContextOperation.from_svc_operation(svc_op) + + assert ctx_op.operation_id == "ctx-id" + assert ctx_op.operation_type is OperationType.CONTEXT + assert ctx_op.result == json.dumps("test-result") + assert ctx_op.child_operations == [] + + +def test_context_operation_with_children(): + """Test ContextOperation with child operations.""" + parent_op = SvcOperation( + operation_id="parent-id", + operation_type=OperationType.CONTEXT, + status=OperationStatus.SUCCEEDED, + context_details=ContextDetails(result=json.dumps("parent-result")), + ) + + child_op = SvcOperation( + operation_id="child-id", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + parent_id="parent-id", + name="child-step", + step_details=StepDetails(result=json.dumps("child-result")), + ) + + all_ops = [parent_op, child_op] + ctx_op = ContextOperation.from_svc_operation(parent_op, all_ops) + + assert len(ctx_op.child_operations) == 1 + assert ctx_op.child_operations[0].name == "child-step" + + +def test_context_operation_get_operation_by_name(): + """Test ContextOperation get_operation_by_name method.""" + child_op = Operation( + operation_id="child-id", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + name="test-child", + ) + + ctx_op = ContextOperation( + operation_id="ctx-id", + operation_type=OperationType.CONTEXT, + status=OperationStatus.SUCCEEDED, + child_operations=[child_op], + ) + + found_op = ctx_op.get_operation_by_name("test-child") + assert found_op == child_op + + +def test_context_operation_get_operation_by_name_not_found(): + """Test ContextOperation get_operation_by_name raises error when not found.""" + ctx_op = ContextOperation( + operation_id="ctx-id", + operation_type=OperationType.CONTEXT, + status=OperationStatus.SUCCEEDED, + child_operations=[], + ) + + with pytest.raises( + DurableFunctionsTestError, match="Child Operation with name 'missing' not found" + ): + ctx_op.get_operation_by_name("missing") + + +def test_context_operation_get_step(): + """Test ContextOperation get_step method.""" + step_op = StepOperation( + operation_id="step-id", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + name="test-step", + child_operations=[], + ) + + ctx_op = ContextOperation( + operation_id="ctx-id", + operation_type=OperationType.CONTEXT, + status=OperationStatus.SUCCEEDED, + child_operations=[step_op], + ) + + found_step = ctx_op.get_step("test-step") + assert isinstance(found_step, StepOperation) + assert found_step.name == "test-step" + + +def test_context_operation_get_wait(): + """Test ContextOperation get_wait method.""" + wait_op = WaitOperation( + operation_id="wait-id", + operation_type=OperationType.WAIT, + status=OperationStatus.SUCCEEDED, + name="test-wait", + ) + + ctx_op = ContextOperation( + operation_id="ctx-id", + operation_type=OperationType.CONTEXT, + status=OperationStatus.SUCCEEDED, + child_operations=[wait_op], + ) + + found_wait = ctx_op.get_wait("test-wait") + assert isinstance(found_wait, WaitOperation) + assert found_wait.name == "test-wait" + + +def test_context_operation_get_context(): + """Test ContextOperation get_context method.""" + nested_ctx_op = ContextOperation( + operation_id="nested-ctx-id", + operation_type=OperationType.CONTEXT, + status=OperationStatus.SUCCEEDED, + name="nested-context", + child_operations=[], + ) + + ctx_op = ContextOperation( + operation_id="ctx-id", + operation_type=OperationType.CONTEXT, + status=OperationStatus.SUCCEEDED, + child_operations=[nested_ctx_op], + ) + + found_ctx = ctx_op.get_context("nested-context") + assert isinstance(found_ctx, ContextOperation) + assert found_ctx.name == "nested-context" + + +def test_context_operation_get_callback(): + """Test ContextOperation get_callback method.""" + callback_op = CallbackOperation( + operation_id="callback-id", + operation_type=OperationType.CALLBACK, + status=OperationStatus.SUCCEEDED, + name="test-callback", + child_operations=[], + ) + + ctx_op = ContextOperation( + operation_id="ctx-id", + operation_type=OperationType.CONTEXT, + status=OperationStatus.SUCCEEDED, + child_operations=[callback_op], + ) + + found_callback = ctx_op.get_callback("test-callback") + assert isinstance(found_callback, CallbackOperation) + assert found_callback.name == "test-callback" + + +def test_context_operation_get_invoke(): + """Test ContextOperation get_invoke method.""" + invoke_op = InvokeOperation( + operation_id="invoke-id", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.SUCCEEDED, + name="test-invoke", + ) + + ctx_op = ContextOperation( + operation_id="ctx-id", + operation_type=OperationType.CONTEXT, + status=OperationStatus.SUCCEEDED, + child_operations=[invoke_op], + ) + + found_invoke = ctx_op.get_invoke("test-invoke") + assert isinstance(found_invoke, InvokeOperation) + assert found_invoke.name == "test-invoke" + + +def test_context_operation_get_execution(): + """Test ContextOperation get_execution method.""" + exec_op = ExecutionOperation( + operation_id="exec-id", + operation_type=OperationType.EXECUTION, + status=OperationStatus.SUCCEEDED, + name="test-execution", + ) + + ctx_op = ContextOperation( + operation_id="ctx-id", + operation_type=OperationType.CONTEXT, + status=OperationStatus.SUCCEEDED, + child_operations=[exec_op], + ) + + found_exec = ctx_op.get_execution("test-execution") + assert isinstance(found_exec, ExecutionOperation) + assert found_exec.name == "test-execution" + + +def test_step_operation_from_svc_operation(): + """Test StepOperation creation from service operation.""" + step_details = StepDetails(attempt=2, result=json.dumps("step-result"), error=None) + svc_op = SvcOperation( + operation_id="step-id", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=step_details, + ) + + step_op = StepOperation.from_svc_operation(svc_op) + + assert step_op.operation_id == "step-id" + assert step_op.operation_type is OperationType.STEP + assert step_op.attempt == 2 + assert step_op.result == json.dumps("step-result") + + +def test_step_operation_wrong_type(): + """Test StepOperation raises error for wrong operation type.""" + svc_op = SvcOperation( + operation_id="test-id", + operation_type=OperationType.CONTEXT, + status=OperationStatus.SUCCEEDED, + ) + + with pytest.raises( + InvalidParameterValueException, + match="Expected STEP operation, got OperationType.CONTEXT", + ): + StepOperation.from_svc_operation(svc_op) + + +def test_wait_operation_from_svc_operation(): + """Test WaitOperation creation from service operation.""" + scheduled_time = datetime.datetime.now(tz=datetime.UTC) + wait_details = WaitDetails(scheduled_end_timestamp=scheduled_time) + svc_op = SvcOperation( + operation_id="wait-id", + operation_type=OperationType.WAIT, + status=OperationStatus.SUCCEEDED, + wait_details=wait_details, + ) + + wait_op = WaitOperation.from_svc_operation(svc_op) + + assert wait_op.operation_id == "wait-id" + assert wait_op.operation_type is OperationType.WAIT + assert wait_op.scheduled_end_timestamp == scheduled_time + + +def test_wait_operation_wrong_type(): + """Test WaitOperation raises error for wrong operation type.""" + svc_op = SvcOperation( + operation_id="test-id", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + ) + + with pytest.raises( + InvalidParameterValueException, + match="Expected WAIT operation, got OperationType.STEP", + ): + WaitOperation.from_svc_operation(svc_op) + + +def test_callback_operation_from_svc_operation(): + """Test CallbackOperation creation from service operation.""" + callback_details = CallbackDetails( + callback_id="cb-123", result=json.dumps("callback-result") + ) + svc_op = SvcOperation( + operation_id="callback-id", + operation_type=OperationType.CALLBACK, + status=OperationStatus.SUCCEEDED, + callback_details=callback_details, + ) + + callback_op = CallbackOperation.from_svc_operation(svc_op) + + assert callback_op.operation_id == "callback-id" + assert callback_op.operation_type is OperationType.CALLBACK + assert callback_op.callback_id == "cb-123" + assert callback_op.result == json.dumps("callback-result") + + +def test_callback_operation_wrong_type(): + """Test CallbackOperation raises error for wrong operation type.""" + svc_op = SvcOperation( + operation_id="test-id", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + ) + + with pytest.raises( + InvalidParameterValueException, + match="Expected CALLBACK operation, got OperationType.STEP", + ): + CallbackOperation.from_svc_operation(svc_op) + + +def test_invoke_operation_from_svc_operation(): + """Test InvokeOperation creation from service operation.""" + invoke_details = ChainedInvokeDetails( + result=json.dumps("invoke-result"), + ) + svc_op = SvcOperation( + operation_id="invoke-id", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.SUCCEEDED, + chained_invoke_details=invoke_details, + ) + + invoke_op = InvokeOperation.from_svc_operation(svc_op) + + assert invoke_op.operation_id == "invoke-id" + assert invoke_op.operation_type is OperationType.CHAINED_INVOKE + assert invoke_op.result == json.dumps("invoke-result") + + +def test_invoke_operation_wrong_type(): + """Test InvokeOperation raises error for wrong operation type.""" + svc_op = SvcOperation( + operation_id="test-id", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + ) + + with pytest.raises( + InvalidParameterValueException, + match="Expected INVOKE operation, got OperationType.STEP", + ): + InvokeOperation.from_svc_operation(svc_op) + + +def test_operation_factories_mapping(): + """Test OPERATION_FACTORIES contains all expected mappings.""" + expected_types = { + OperationType.EXECUTION: ExecutionOperation, + OperationType.CONTEXT: ContextOperation, + OperationType.STEP: StepOperation, + OperationType.WAIT: WaitOperation, + OperationType.CHAINED_INVOKE: InvokeOperation, + OperationType.CALLBACK: CallbackOperation, + } + + assert expected_types == OPERATION_FACTORIES + + +def test_create_operation_step(): + """Test create_operation function with STEP operation.""" + svc_op = SvcOperation( + operation_id="step-id", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(result=json.dumps("test-result")), + ) + + operation = create_operation(svc_op) + + assert isinstance(operation, StepOperation) + assert operation.operation_id == "step-id" + + +def test_create_operation_unknown_type(): + """Test create_operation raises error for unknown operation type.""" + # Create a mock operation with an invalid type + svc_op = Mock() + svc_op.operation_type = "UNKNOWN_TYPE" + + with pytest.raises( + DurableFunctionsTestError, match="Unknown operation type: UNKNOWN_TYPE" + ): + create_operation(svc_op) + + +def test_durable_function_test_result_create(): + """Test DurableFunctionTestResult.create method.""" + # Create mock execution with operations + execution = Mock(spec=Execution) + + # Create mock operations - one EXECUTION (should be filtered) and one STEP + exec_op = Mock() + exec_op.operation_type = OperationType.EXECUTION + exec_op.parent_id = None + + step_op = Mock() + step_op.operation_type = OperationType.STEP + step_op.parent_id = None + step_op.operation_id = "step-id" + step_op.status = OperationStatus.SUCCEEDED + step_op.name = "test-step" + step_op.step_details = StepDetails(result=json.dumps("step-result")) + + execution.operations = [exec_op, step_op] + + # Mock execution result + execution.result = Mock() + execution.result.status = InvocationStatus.SUCCEEDED + execution.result.result = json.dumps("test-result") + execution.result.error = None + + result = DurableFunctionTestResult.create(execution) + + assert result.status is InvocationStatus.SUCCEEDED + assert result.result == json.dumps("test-result") + assert result.error is None + assert len(result.operations) == 1 # EXECUTION operation filtered out + + +def test_durable_function_test_result_get_operation_by_name(): + """Test DurableFunctionTestResult get_operation_by_name method.""" + step_op = StepOperation( + operation_id="step-id", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + name="test-step", + child_operations=[], + ) + + result = DurableFunctionTestResult( + status=InvocationStatus.SUCCEEDED, + operations=[step_op], + ) + + found_op = result.get_operation_by_name("test-step") + assert found_op == step_op + + +def test_durable_function_test_result_get_operation_by_name_not_found(): + """Test DurableFunctionTestResult get_operation_by_name raises error when not found.""" + result = DurableFunctionTestResult( + status=InvocationStatus.SUCCEEDED, + operations=[], + ) + + with pytest.raises( + DurableFunctionsTestError, match="Operation with name 'missing' not found" + ): + result.get_operation_by_name("missing") + + +def test_durable_function_test_result_get_step(): + """Test DurableFunctionTestResult get_step method.""" + step_op = StepOperation( + operation_id="step-id", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + name="test-step", + child_operations=[], + ) + + result = DurableFunctionTestResult( + status=InvocationStatus.SUCCEEDED, + operations=[step_op], + ) + + found_step = result.get_step("test-step") + assert isinstance(found_step, StepOperation) + assert found_step.name == "test-step" + + +def test_durable_function_test_result_get_wait(): + """Test DurableFunctionTestResult get_wait method.""" + wait_op = WaitOperation( + operation_id="wait-id", + operation_type=OperationType.WAIT, + status=OperationStatus.SUCCEEDED, + name="test-wait", + ) + + result = DurableFunctionTestResult( + status=InvocationStatus.SUCCEEDED, + operations=[wait_op], + ) + + found_wait = result.get_wait("test-wait") + assert isinstance(found_wait, WaitOperation) + assert found_wait.name == "test-wait" + + +def test_durable_function_test_result_get_context(): + """Test DurableFunctionTestResult get_context method.""" + ctx_op = ContextOperation( + operation_id="ctx-id", + operation_type=OperationType.CONTEXT, + status=OperationStatus.SUCCEEDED, + name="test-context", + child_operations=[], + ) + + result = DurableFunctionTestResult( + status=InvocationStatus.SUCCEEDED, + operations=[ctx_op], + ) + + found_ctx = result.get_context("test-context") + assert isinstance(found_ctx, ContextOperation) + assert found_ctx.name == "test-context" + + +def test_durable_function_test_result_get_callback(): + """Test DurableFunctionTestResult get_callback method.""" + callback_op = CallbackOperation( + operation_id="callback-id", + operation_type=OperationType.CALLBACK, + status=OperationStatus.SUCCEEDED, + name="test-callback", + child_operations=[], + ) + + result = DurableFunctionTestResult( + status=InvocationStatus.SUCCEEDED, + operations=[callback_op], + ) + + found_callback = result.get_callback("test-callback") + assert isinstance(found_callback, CallbackOperation) + assert found_callback.name == "test-callback" + + +def test_durable_function_test_result_get_invoke(): + """Test DurableFunctionTestResult get_invoke method.""" + invoke_op = InvokeOperation( + operation_id="invoke-id", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.SUCCEEDED, + name="test-invoke", + ) + + result = DurableFunctionTestResult( + status=InvocationStatus.SUCCEEDED, + operations=[invoke_op], + ) + + found_invoke = result.get_invoke("test-invoke") + assert isinstance(found_invoke, InvokeOperation) + assert found_invoke.name == "test-invoke" + + +def test_durable_function_test_result_get_execution(): + """Test DurableFunctionTestResult get_execution method.""" + exec_op = ExecutionOperation( + operation_id="exec-id", + operation_type=OperationType.EXECUTION, + status=OperationStatus.SUCCEEDED, + name="test-execution", + ) + + result = DurableFunctionTestResult( + status=InvocationStatus.SUCCEEDED, + operations=[exec_op], + ) + + found_exec = result.get_execution("test-execution") + assert isinstance(found_exec, ExecutionOperation) + assert found_exec.name == "test-execution" + + +@patch("aws_durable_execution_sdk_python_testing.runner.Scheduler") +@patch("aws_durable_execution_sdk_python_testing.runner.InMemoryExecutionStore") +@patch("aws_durable_execution_sdk_python_testing.runner.CheckpointProcessor") +@patch("aws_durable_execution_sdk_python_testing.runner.InMemoryServiceClient") +@patch("aws_durable_execution_sdk_python_testing.runner.InProcessInvoker") +@patch("aws_durable_execution_sdk_python_testing.runner.Executor") +def test_durable_function_test_runner_init( + mock_executor, mock_invoker, mock_client, mock_processor, mock_store, mock_scheduler +): + """Test DurableFunctionTestRunner initialization.""" + handler = Mock() + + DurableFunctionTestRunner(handler) + + # Verify all components are initialized + mock_scheduler.assert_called_once() + mock_scheduler.return_value.start.assert_called_once() + mock_store.assert_called_once() + mock_processor.assert_called_once() + mock_client.assert_called_once() + mock_invoker.assert_called_once_with(handler, mock_client.return_value) + mock_executor.assert_called_once() + + # Verify observer pattern setup + mock_processor.return_value.add_execution_observer.assert_called_once_with( + mock_executor.return_value + ) + + +def test_durable_function_test_runner_context_manager(): + """Test DurableFunctionTestRunner context manager.""" + handler = Mock() + + with patch.object(DurableFunctionTestRunner, "__init__", return_value=None): + with patch.object(DurableFunctionTestRunner, "close") as mock_close: + runner = DurableFunctionTestRunner(handler) + + with runner: + pass + + mock_close.assert_called_once() + + +@patch("aws_durable_execution_sdk_python_testing.runner.Scheduler") +def test_durable_function_test_runner_close(mock_scheduler): + """Test DurableFunctionTestRunner close method.""" + handler = Mock() + + # Let the constructor run normally with mocked dependencies + mock_scheduler_instance = Mock() + mock_scheduler.return_value = mock_scheduler_instance + + runner = DurableFunctionTestRunner(handler) + runner.close() + + # Verify scheduler.stop() was called + mock_scheduler_instance.stop.assert_called_once() + + +@patch("aws_durable_execution_sdk_python_testing.runner.Executor") +@patch("aws_durable_execution_sdk_python_testing.runner.InMemoryExecutionStore") +def test_durable_function_test_runner_run(mock_store_class, mock_executor_class): + """Test DurableFunctionTestRunner run method.""" + handler = Mock() + + # Mock the class instances + mock_executor = Mock() + mock_store = Mock() + mock_executor_class.return_value = mock_executor + mock_store_class.return_value = mock_store + + # Mock execution output + output = StartDurableExecutionOutput(execution_arn="test-arn") + mock_executor.start_execution.return_value = output + mock_executor.wait_until_complete.return_value = True + + # Mock execution for result creation + mock_execution = Mock(spec=Execution) + mock_execution.operations = [] + mock_execution.result = Mock() + mock_execution.result.status = InvocationStatus.SUCCEEDED + mock_execution.result.result = json.dumps("test-result") + mock_execution.result.error = None + mock_store.load.return_value = mock_execution + + runner = DurableFunctionTestRunner(handler) + result = runner.run("test-input") + + # Verify start_execution was called with correct input + mock_executor.start_execution.assert_called_once() + start_input = mock_executor.start_execution.call_args[0][0] + assert isinstance(start_input, StartDurableExecutionInput) + assert start_input.input == "test-input" + assert start_input.function_name == "test-function" + assert start_input.execution_name == "execution-name" + assert start_input.account_id == "123456789012" + + # Verify wait_until_complete was called + mock_executor.wait_until_complete.assert_called_once_with("test-arn", 900) + + # Verify store.load was called + mock_store.load.assert_called_once_with("test-arn") + + # Verify result + assert isinstance(result, DurableFunctionTestResult) + assert result.status is InvocationStatus.SUCCEEDED + + +@patch("aws_durable_execution_sdk_python_testing.runner.Executor") +@patch("aws_durable_execution_sdk_python_testing.runner.InMemoryExecutionStore") +def test_durable_function_test_runner_run_with_custom_params( + mock_store_class, mock_executor_class +): + """Test DurableFunctionTestRunner run method with custom parameters.""" + handler = Mock() + + # Mock the class instances + mock_executor = Mock() + mock_store = Mock() + mock_executor_class.return_value = mock_executor + mock_store_class.return_value = mock_store + + # Mock execution output + output = StartDurableExecutionOutput(execution_arn="test-arn") + mock_executor.start_execution.return_value = output + mock_executor.wait_until_complete.return_value = True + + # Mock execution for result creation + mock_execution = Mock(spec=Execution) + mock_execution.operations = [] + mock_execution.result = Mock() + mock_execution.result.status = InvocationStatus.SUCCEEDED + mock_execution.result.result = json.dumps("test-result") + mock_execution.result.error = None + mock_store.load.return_value = mock_execution + + runner = DurableFunctionTestRunner(handler) + result = runner.run( + input="custom-input", + timeout=1800, + function_name="custom-function", + execution_name="custom-execution", + account_id="987654321098", + ) + + # Verify start_execution was called with custom parameters + start_input = mock_executor.start_execution.call_args[0][0] + assert start_input.input == "custom-input" + assert start_input.function_name == "custom-function" + assert start_input.execution_name == "custom-execution" + assert start_input.account_id == "987654321098" + assert start_input.execution_timeout_seconds == 1800 + + # Verify wait_until_complete was called with custom timeout + mock_executor.wait_until_complete.assert_called_once_with("test-arn", 1800) + + assert result.status is InvocationStatus.SUCCEEDED + + +@patch("aws_durable_execution_sdk_python_testing.runner.Executor") +def test_durable_function_test_runner_run_timeout(mock_executor_class): + """Test DurableFunctionTestRunner run method with timeout.""" + handler = Mock() + + # Mock the class instance + mock_executor = Mock() + mock_executor_class.return_value = mock_executor + + # Mock execution output + output = StartDurableExecutionOutput(execution_arn="test-arn") + mock_executor.start_execution.return_value = output + mock_executor.wait_until_complete.return_value = False # Timeout + + runner = DurableFunctionTestRunner(handler) + + with pytest.raises(TimeoutError, match="Execution did not complete within timeout"): + runner.run("test-input") + + +def test_context_operation_wrong_type(): + """Test ContextOperation raises error for wrong operation type.""" + svc_op = SvcOperation( + operation_id="test-id", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + ) + + with pytest.raises( + InvalidParameterValueException, + match="Expected CONTEXT operation, got OperationType.STEP", + ): + ContextOperation.from_svc_operation(svc_op) + + +def test_context_operation_with_child_operations_none(): + """Test ContextOperation with None child operations.""" + svc_op = SvcOperation( + operation_id="ctx-id", + operation_type=OperationType.CONTEXT, + status=OperationStatus.SUCCEEDED, + context_details=ContextDetails(result=json.dumps("test-result")), + ) + + ctx_op = ContextOperation.from_svc_operation(svc_op, None) + + assert ctx_op.child_operations == [] + + +def test_callback_operation_with_child_operations_none(): + """Test CallbackOperation with None child operations.""" + svc_op = SvcOperation( + operation_id="callback-id", + operation_type=OperationType.CALLBACK, + status=OperationStatus.SUCCEEDED, + callback_details=CallbackDetails(callback_id="cb-123"), + ) + + callback_op = CallbackOperation.from_svc_operation(svc_op, None) + + assert callback_op.child_operations == [] + + +def test_step_operation_with_child_operations_none(): + """Test StepOperation with None child operations.""" + svc_op = SvcOperation( + operation_id="step-id", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(result=json.dumps("step-result")), + ) + + step_op = StepOperation.from_svc_operation(svc_op, None) + + assert step_op.child_operations == [] + + +def test_durable_function_test_result_create_with_parent_operations(): + """Test DurableFunctionTestResult.create with operations that have parent_id.""" + execution = Mock(spec=Execution) + + # Create operation with parent_id (should be filtered out) + child_op = Mock() + child_op.operation_type = OperationType.STEP + child_op.parent_id = "parent-id" + + # Create operation without parent_id (should be included) + root_op = Mock() + root_op.operation_type = OperationType.STEP + root_op.parent_id = None + root_op.operation_id = "root-id" + root_op.status = OperationStatus.SUCCEEDED + root_op.name = "root-step" + root_op.step_details = StepDetails(result=json.dumps("root-result")) + + execution.operations = [child_op, root_op] + execution.result = Mock() + execution.result.status = InvocationStatus.SUCCEEDED + execution.result.result = json.dumps("test-result") + execution.result.error = None + + result = DurableFunctionTestResult.create(execution) + + assert len(result.operations) == 1 # Only root operation included + + +@patch("aws_durable_execution_sdk_python_testing.runner.Scheduler") +@patch("aws_durable_execution_sdk_python_testing.runner.InMemoryExecutionStore") +@patch("aws_durable_execution_sdk_python_testing.runner.CheckpointProcessor") +@patch("aws_durable_execution_sdk_python_testing.runner.InMemoryServiceClient") +@patch("aws_durable_execution_sdk_python_testing.runner.InProcessInvoker") +@patch("aws_durable_execution_sdk_python_testing.runner.Executor") +@patch("aws_durable_execution_sdk_python_testing.runner.durable_execution") +def test_durable_context_test_runner_init( + mock_durable_execution_handler, + mock_executor, + mock_invoker, + mock_client, + mock_processor, + mock_store, + mock_scheduler, +): + """Test DurableContextTestRunner initialization.""" + handler = Mock() + decorated_handler = Mock() + mock_durable_execution_handler.return_value = decorated_handler + + DurableChildContextTestRunner(handler) # type: ignore + + # Verify all components are initialized + mock_scheduler.assert_called_once() + mock_scheduler.return_value.start.assert_called_once() + mock_store.assert_called_once() + mock_processor.assert_called_once() + mock_client.assert_called_once() + mock_invoker.assert_called_once_with(decorated_handler, mock_client.return_value) + mock_executor.assert_called_once() + + # Verify observer pattern setup + mock_processor.return_value.add_execution_observer.assert_called_once_with( + mock_executor.return_value + ) + + # Verify durable_execution was called (with internal lambda function) + mock_durable_execution_handler.assert_called_once() + + # Verify the lambda function calls our handler + durable_execution_func = mock_durable_execution_handler.call_args.args[0] + assert callable(durable_execution_func) + + # verify handler is called when durable function is invoked + durable_execution_func(Mock(), Mock()) + handler.assert_called_once() + + +@patch("aws_durable_execution_sdk_python_testing.runner.Scheduler") +@patch("aws_durable_execution_sdk_python_testing.runner.InMemoryExecutionStore") +@patch("aws_durable_execution_sdk_python_testing.runner.CheckpointProcessor") +@patch("aws_durable_execution_sdk_python_testing.runner.InMemoryServiceClient") +@patch("aws_durable_execution_sdk_python_testing.runner.InProcessInvoker") +@patch("aws_durable_execution_sdk_python_testing.runner.Executor") +@patch("aws_durable_execution_sdk_python_testing.runner.durable_execution") +def test_durable_child_context_test_runner_init_with_args( + mock_durable_execution_handler, + mock_executor, + mock_invoker, + mock_client, + mock_processor, + mock_store, + mock_scheduler, +): + """Test DurableChildContextTestRunner initialization with additional args.""" + handler = Mock() + decorated_handler = Mock() + mock_durable_execution_handler.return_value = decorated_handler + + str_input = "a random string input" + num_input = 10 + DurableChildContextTestRunner(handler, str_input, num=num_input) # type: ignore + + # Verify all components are initialized + mock_scheduler.assert_called_once() + mock_scheduler.return_value.start.assert_called_once() + mock_store.assert_called_once() + mock_processor.assert_called_once() + mock_client.assert_called_once() + mock_invoker.assert_called_once_with(decorated_handler, mock_client.return_value) + mock_executor.assert_called_once() + + # Verify observer pattern setup + mock_processor.return_value.add_execution_observer.assert_called_once_with( + mock_executor.return_value + ) + + # Verify durable_execution was called (with internal lambda function) + mock_durable_execution_handler.assert_called_once() + # Verify the lambda function calls our handler + durable_execution_func = mock_durable_execution_handler.call_args.args[0] + assert callable(durable_execution_func) + + # verify that handler is called with expected args when durable function is invoked + durable_execution_func(Mock(), Mock()) + handler.assert_called_once_with(str_input, num=num_input) + + +# Tests for DurableFunctionCloudTestRunner and from_execution_history + + +def test_durable_function_test_result_from_execution_history(): + """Test DurableFunctionTestResult.from_execution_history factory method.""" + import datetime + + from aws_durable_execution_sdk_python.execution import InvocationStatus + + from aws_durable_execution_sdk_python_testing.model import ( + Event, + EventResult, + GetDurableExecutionHistoryResponse, + GetDurableExecutionResponse, + StepSucceededDetails, + ) + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionTestResult, + ) + + execution_response = GetDurableExecutionResponse( + durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1", + durable_execution_name="test-execution", + function_arn="arn:aws:lambda:us-east-1:123456789012:function:test", + status="SUCCEEDED", + start_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC), + end_timestamp=datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC), + result="test-result", + error=None, + ) + + history_response = GetDurableExecutionHistoryResponse( + events=[ + Event( + event_type="ExecutionStarted", + event_timestamp=datetime.datetime( + 2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC + ), + operation_id="exec-1", + ), + Event( + event_type="StepStarted", + event_timestamp=datetime.datetime( + 2023, 1, 1, 0, 0, 10, tzinfo=datetime.UTC + ), + operation_id="step-1", + name="test-step", + ), + Event( + event_type="StepSucceeded", + event_timestamp=datetime.datetime( + 2023, 1, 1, 0, 0, 20, tzinfo=datetime.UTC + ), + operation_id="step-1", + step_succeeded_details=StepSucceededDetails( + result=EventResult(payload="step-result", truncated=False) + ), + ), + ] + ) + + result = DurableFunctionTestResult.from_execution_history( + execution_response, history_response + ) + + assert result.status == InvocationStatus.SUCCEEDED + assert result.result == "test-result" + assert result.error is None + assert len(result.operations) == 1 + assert result.operations[0].name == "test-step" + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_init(mock_boto3): + """Test DurableFunctionCloudTestRunner initialization.""" + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + runner = DurableFunctionCloudTestRunner( + function_name="test-function", + region="us-west-2", + poll_interval=0.5, + ) + + assert runner.function_name == "test-function" + assert runner.region == "us-west-2" + assert runner.poll_interval == 0.5 + mock_boto3.client.assert_called_once() + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_run_success(mock_boto3): + """Test DurableFunctionCloudTestRunner.run with successful execution.""" + from aws_durable_execution_sdk_python.execution import InvocationStatus + + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + mock_client.invoke.return_value = { + "StatusCode": 200, + "Payload": Mock(read=lambda: b'{"result": "success"}'), + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1", + } + + mock_client.get_durable_execution.return_value = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1", + "DurableExecutionName": "test-execution", + "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test", + "Status": "SUCCEEDED", + "StartTimestamp": "2023-01-01T00:00:00Z", + "EndTimestamp": "2023-01-01T00:01:00Z", + "Result": "test-result", + } + + mock_client.get_durable_execution_history.return_value = { + "Events": [ + { + "EventType": "ExecutionStarted", + "EventTimestamp": "2023-01-01T00:00:00Z", + "Id": "exec-1", + } + ] + } + + runner = DurableFunctionCloudTestRunner( + function_name="test-function", poll_interval=0.01 + ) + + result = runner.run(input="test-input", timeout=10) + + assert result.status == InvocationStatus.SUCCEEDED + assert result.result == "test-result" + mock_client.invoke.assert_called_once() + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_run_invoke_failure(mock_boto3): + """Test DurableFunctionCloudTestRunner.run with invoke failure.""" + from aws_durable_execution_sdk_python_testing.exceptions import ( + DurableFunctionsTestError, + ) + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + mock_client.invoke.side_effect = Exception("Invoke failed") + + runner = DurableFunctionCloudTestRunner(function_name="test-function") + + with pytest.raises( + DurableFunctionsTestError, match="Failed to invoke Lambda function" + ): + runner.run(input="test-input") + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +@patch("aws_durable_execution_sdk_python_testing.runner.time") +def test_cloud_runner_wait_for_completion_timeout(mock_time, mock_boto3): + """Test DurableFunctionCloudTestRunner._wait_for_completion with timeout.""" + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + mock_time.time.side_effect = [0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0] + + mock_client.get_durable_execution.return_value = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1", + "DurableExecutionName": "test-execution", + "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test", + "Status": "RUNNING", + "StartTimestamp": "2023-01-01T00:00:00Z", + } + + runner = DurableFunctionCloudTestRunner( + function_name="test-function", poll_interval=0.01 + ) + + with pytest.raises(TimeoutError, match="Execution did not complete within"): + runner._wait_for_completion("test-arn", timeout=2) + + +def test_durable_function_test_result_from_execution_history_with_exception(): + """Test from_execution_history handles events_to_operations exception.""" + import datetime + + from aws_durable_execution_sdk_python.execution import InvocationStatus + + from aws_durable_execution_sdk_python_testing.model import ( + Event, + GetDurableExecutionHistoryResponse, + GetDurableExecutionResponse, + ) + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionTestResult, + ) + + execution_response = GetDurableExecutionResponse( + durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1", + durable_execution_name="test-execution", + function_arn="arn:aws:lambda:us-east-1:123456789012:function:test", + status="SUCCEEDED", + start_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC), + ) + + history_response = GetDurableExecutionHistoryResponse( + events=[ + Event( + event_type="StepStarted", + event_timestamp=datetime.datetime( + 2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC + ), + operation_id=None, + ) + ] + ) + + result = DurableFunctionTestResult.from_execution_history( + execution_response, history_response + ) + + assert result.status == InvocationStatus.SUCCEEDED + assert len(result.operations) == 0 + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_wait_for_completion_failed_status(mock_boto3): + """Test DurableFunctionCloudTestRunner._wait_for_completion with FAILED status.""" + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + mock_client.get_durable_execution.return_value = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1", + "DurableExecutionName": "test-execution", + "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test", + "Status": "FAILED", + "StartTimestamp": "2023-01-01T00:00:00Z", + "EndTimestamp": "2023-01-01T00:01:00Z", + "Error": {"ErrorMessage": "execution failed"}, + } + + runner = DurableFunctionCloudTestRunner(function_name="test-function") + result = runner._wait_for_completion("test-arn", timeout=10) + + assert result.status == "FAILED" + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_run_bad_status_code(mock_boto3): + """Test DurableFunctionCloudTestRunner.run with bad HTTP status code.""" + from aws_durable_execution_sdk_python_testing.exceptions import ( + DurableFunctionsTestError, + ) + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + mock_client.invoke.return_value = { + "StatusCode": 500, + "Payload": Mock(read=lambda: b"Internal Server Error"), + } + + runner = DurableFunctionCloudTestRunner(function_name="test-function") + + with pytest.raises( + DurableFunctionsTestError, match="Lambda invocation failed with status 500" + ): + runner.run(input="test-input") + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_run_function_error(mock_boto3): + """Test DurableFunctionCloudTestRunner.run with function error.""" + from aws_durable_execution_sdk_python_testing.exceptions import ( + DurableFunctionsTestError, + ) + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + mock_client.invoke.return_value = { + "StatusCode": 200, + "FunctionError": "Unhandled", + "Payload": Mock(read=lambda: b'{"errorMessage": "Function failed"}'), + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1", + } + + mock_client.get_durable_execution.return_value = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1", + "DurableExecutionName": "test-execution", + "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test", + "Status": "FAILED", + "StartTimestamp": "2023-01-01T00:00:00Z", + "EndTimestamp": "2023-01-01T00:01:00Z", + "Error": {"ErrorMessage": "execution failed"}, + } + + mock_client.get_durable_execution_history.return_value = { + "Events": [ + { + "EventType": "ExecutionStarted", + "EventTimestamp": "2023-01-01T00:00:00Z", + "Id": "exec-1", + } + ] + } + runner = DurableFunctionCloudTestRunner(function_name="test-function") + result = runner.run(input="test-input") + assert result.status is InvocationStatus.FAILED + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_run_missing_execution_arn(mock_boto3): + """Test DurableFunctionCloudTestRunner.run with missing execution ARN.""" + from aws_durable_execution_sdk_python_testing.exceptions import ( + DurableFunctionsTestError, + ) + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + mock_client.invoke.return_value = { + "StatusCode": 200, + "Payload": Mock(read=lambda: b'{"result": "success"}'), + } + + runner = DurableFunctionCloudTestRunner(function_name="test-function") + + with pytest.raises( + DurableFunctionsTestError, match="No DurableExecutionArn in response" + ): + runner.run(input="test-input") + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_wait_for_completion_get_execution_failure(mock_boto3): + """Test DurableFunctionCloudTestRunner._wait_for_completion with API failure.""" + from aws_durable_execution_sdk_python_testing.exceptions import ( + DurableFunctionsTestError, + ) + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + mock_client.get_durable_execution.side_effect = Exception("API error") + + runner = DurableFunctionCloudTestRunner(function_name="test-function") + + with pytest.raises( + DurableFunctionsTestError, match="Failed to get execution status" + ): + runner._wait_for_completion("test-arn", timeout=10) + + +def test_durable_function_test_result_from_execution_history_filters_execution_type(): + """Test from_execution_history filters out EXECUTION type operations.""" + import datetime + + from aws_durable_execution_sdk_python.execution import InvocationStatus + + from aws_durable_execution_sdk_python_testing.model import ( + Event, + GetDurableExecutionHistoryResponse, + GetDurableExecutionResponse, + ) + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionTestResult, + ) + + execution_response = GetDurableExecutionResponse( + durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1", + durable_execution_name="test-execution", + function_arn="arn:aws:lambda:us-east-1:123456789012:function:test", + status="SUCCEEDED", + start_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC), + ) + + history_response = GetDurableExecutionHistoryResponse( + events=[ + Event( + event_type="ExecutionStarted", + event_timestamp=datetime.datetime( + 2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC + ), + operation_id="exec-1", + ), + ] + ) + + result = DurableFunctionTestResult.from_execution_history( + execution_response, history_response + ) + + assert len(result.operations) == 0 + + +def test_durable_function_test_result_from_execution_history_unknown_status(): + """Test from_execution_history with unknown status defaults to FAILED.""" + import datetime + + from aws_durable_execution_sdk_python.execution import InvocationStatus + + from aws_durable_execution_sdk_python_testing.model import ( + GetDurableExecutionHistoryResponse, + GetDurableExecutionResponse, + ) + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionTestResult, + ) + + execution_response = GetDurableExecutionResponse( + durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1", + durable_execution_name="test-execution", + function_arn="arn:aws:lambda:us-east-1:123456789012:function:test", + status="UNKNOWN_STATUS", + start_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC), + ) + + history_response = GetDurableExecutionHistoryResponse(events=[]) + + result = DurableFunctionTestResult.from_execution_history( + execution_response, history_response + ) + + assert result.status == InvocationStatus.FAILED + + +def test_durable_function_test_result_from_execution_history_with_parent_operations(): + """Test from_execution_history filters operations with parent_id.""" + import datetime + + from aws_durable_execution_sdk_python.execution import InvocationStatus + + from aws_durable_execution_sdk_python_testing.model import ( + Event, + GetDurableExecutionHistoryResponse, + GetDurableExecutionResponse, + ) + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionTestResult, + ) + + execution_response = GetDurableExecutionResponse( + durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1", + durable_execution_name="test-execution", + function_arn="arn:aws:lambda:us-east-1:123456789012:function:test", + status="SUCCEEDED", + start_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC), + ) + + history_response = GetDurableExecutionHistoryResponse( + events=[ + Event( + event_type="StepStarted", + event_timestamp=datetime.datetime( + 2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC + ), + operation_id="step-1", + name="parent-step", + ), + Event( + event_type="StepStarted", + event_timestamp=datetime.datetime( + 2023, 1, 1, 0, 0, 10, tzinfo=datetime.UTC + ), + operation_id="step-2", + name="child-step", + parent_id="step-1", + ), + ] + ) + + result = DurableFunctionTestResult.from_execution_history( + execution_response, history_response + ) + + assert len(result.operations) == 1 + assert result.operations[0].name == "parent-step" + + +def test_durable_function_test_result_from_execution_history_failed(): + """Test from_execution_history with failed execution.""" + import datetime + + from aws_durable_execution_sdk_python.execution import InvocationStatus + from aws_durable_execution_sdk_python.lambda_service import ErrorObject + + from aws_durable_execution_sdk_python_testing.model import ( + GetDurableExecutionHistoryResponse, + GetDurableExecutionResponse, + ) + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionTestResult, + ) + + execution_response = GetDurableExecutionResponse( + durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1", + durable_execution_name="test-execution", + function_arn="arn:aws:lambda:us-east-1:123456789012:function:test", + status="FAILED", + start_timestamp=datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.UTC), + end_timestamp=datetime.datetime(2023, 1, 1, 0, 1, 0, tzinfo=datetime.UTC), + error=ErrorObject( + message="execution failed", type=None, data=None, stack_trace=None + ), + ) + + history_response = GetDurableExecutionHistoryResponse(events=[]) + + result = DurableFunctionTestResult.from_execution_history( + execution_response, history_response + ) + + assert result.status == InvocationStatus.FAILED + assert result.error.message == "execution failed" + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_wait_for_completion_timed_out_status(mock_boto3): + """Test DurableFunctionCloudTestRunner._wait_for_completion with TIMED_OUT status.""" + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + mock_client.get_durable_execution.return_value = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1", + "DurableExecutionName": "test-execution", + "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test", + "Status": "TIMED_OUT", + "StartTimestamp": "2023-01-01T00:00:00Z", + "EndTimestamp": "2023-01-01T00:01:00Z", + } + + runner = DurableFunctionCloudTestRunner(function_name="test-function") + result = runner._wait_for_completion("test-arn", timeout=10) + + assert result.status == "TIMED_OUT" + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_wait_for_completion_aborted_status(mock_boto3): + """Test DurableFunctionCloudTestRunner._wait_for_completion with ABORTED status.""" + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + mock_client.get_durable_execution.return_value = { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1", + "DurableExecutionName": "test-execution", + "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test", + "Status": "ABORTED", + "StartTimestamp": "2023-01-01T00:00:00Z", + "EndTimestamp": "2023-01-01T00:01:00Z", + } + + runner = DurableFunctionCloudTestRunner(function_name="test-function") + result = runner._wait_for_completion("test-arn", timeout=10) + + assert result.status == "ABORTED" + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_run_async_success(mock_boto3): + """Test DurableFunctionCloudTestRunner.run_async with successful invocation.""" + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + mock_client.invoke.return_value = { + "StatusCode": 202, + "Payload": Mock(read=lambda: b'{"result": "success"}'), + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1", + } + + runner = DurableFunctionCloudTestRunner(function_name="test-function") + execution_arn = runner.run_async(input="test-input") + + assert ( + execution_arn + == "arn:aws:lambda:us-east-1:123456789012:function:test:execution:exec-1" + ) + mock_client.invoke.assert_called_once_with( + FunctionName="test-function", + InvocationType="Event", + Payload='"test-input"', + ) + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_run_async_with_400(mock_boto3): + """Test DurableFunctionCloudTestRunner.run_async with successful invocation.""" + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + mock_client.invoke.return_value = { + "StatusCode": 400, + "Payload": Mock(read=lambda: b'{"result": "failed"}'), + } + + runner = DurableFunctionCloudTestRunner(function_name="test-function") + + with pytest.raises( + DurableFunctionsTestError, match="Lambda invocation failed with status 400" + ): + runner.run_async(input="test-input") + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_run_async_failure(mock_boto3): + """Test DurableFunctionCloudTestRunner.run_async with invocation failure.""" + from aws_durable_execution_sdk_python_testing.exceptions import ( + DurableFunctionsTestError, + ) + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + mock_client.invoke.side_effect = Exception("Async invoke failed") + + runner = DurableFunctionCloudTestRunner(function_name="test-function") + + with pytest.raises( + DurableFunctionsTestError, match="Failed to invoke Lambda function" + ): + runner.run_async(input="test-input") + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_send_callback_success(mock_boto3): + """Test DurableFunctionCloudTestRunner.send_callback_success.""" + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + runner = DurableFunctionCloudTestRunner(function_name="test-function") + runner.send_callback_success("callback-123") + + mock_client.send_durable_execution_callback_success.assert_called_once_with( + CallbackId="callback-123", Result=None + ) + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_send_callback_failure(mock_boto3): + """Test DurableFunctionCloudTestRunner.send_callback_failure.""" + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + runner = DurableFunctionCloudTestRunner(function_name="test-function") + runner.send_callback_failure("callback-123") + + mock_client.send_durable_execution_callback_failure.assert_called_once_with( + CallbackId="callback-123", Error=None + ) + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_send_callback_heartbeat(mock_boto3): + """Test DurableFunctionCloudTestRunner.send_callback_heartbeat.""" + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + runner = DurableFunctionCloudTestRunner(function_name="test-function") + runner.send_callback_heartbeat("callback-123") + + mock_client.send_durable_execution_callback_heartbeat.assert_called_once_with( + CallbackId="callback-123" + ) + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_send_callback_error(mock_boto3): + """Test DurableFunctionCloudTestRunner callback methods with API errors.""" + from aws_durable_execution_sdk_python_testing.exceptions import ( + DurableFunctionsTestError, + ) + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + mock_client.send_durable_execution_callback_success.side_effect = Exception( + "API error" + ) + + runner = DurableFunctionCloudTestRunner(function_name="test-function") + + with pytest.raises( + DurableFunctionsTestError, match="Failed to send callback success" + ): + runner.send_callback_success("callback-123") + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_wait_for_callback_success(mock_boto3): + """Test DurableFunctionCloudTestRunner.wait_for_callback success.""" + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + mock_client.get_durable_execution_history.return_value = { + "Events": [ + { + "EventType": "CallbackStarted", + "EventTimestamp": "2023-01-01T00:00:00Z", + "Id": "callback-event-1", + "Name": "test-callback", + "CallbackStartedDetails": {"CallbackId": "callback-123"}, + } + ] + } + + runner = DurableFunctionCloudTestRunner( + function_name="test-function", poll_interval=0.01 + ) + callback_id = runner.wait_for_callback("test-arn", name="test-callback", timeout=10) + + assert callback_id == "callback-123" + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_wait_for_callback_none(mock_boto3): + """Test DurableFunctionCloudTestRunner.wait_for_callback none.""" + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + mock_client.get_durable_execution_history.return_value = { + "Events": [ + { + "EventType": "CallbackStarted", + "EventTimestamp": "2023-01-01T00:00:00Z", + "Id": "callback-event-1", + "Name": "test-callback", + "CallbackStartedDetails": {"CallbackId": "callback-123"}, + } + ] + } + + runner = DurableFunctionCloudTestRunner( + function_name="test-function", poll_interval=0.01 + ) + + with pytest.raises(TimeoutError, match="Callback did not available within"): + runner.wait_for_callback("test-arn", name="test-callback1", timeout=2) + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_wait_for_callback_success_without_name(mock_boto3): + """Test DurableFunctionCloudTestRunner.wait_for_callback success.""" + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + mock_client.get_durable_execution_history.return_value = { + "Events": [ + { + "EventType": "CallbackStarted", + "EventTimestamp": "2023-01-01T00:00:00Z", + "Id": "callback-event-1", + "Name": "test-callback", + "CallbackStartedDetails": {"CallbackId": "callback-123"}, + } + ] + } + + runner = DurableFunctionCloudTestRunner( + function_name="test-function", poll_interval=0.01 + ) + callback_id = runner.wait_for_callback("test-arn") + + assert callback_id == "callback-123" + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_wait_for_callback_all_done_without_name(mock_boto3): + """Test DurableFunctionCloudTestRunner.wait_for_callback all_done_without_name.""" + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + mock_client.get_durable_execution_history.return_value = { + "Events": [ + { + "EventType": "CallbackStarted", + "EventTimestamp": "2023-01-01T00:00:00Z", + "Id": "callback-event-1", + "Name": "test-callback", + "CallbackStartedDetails": {"CallbackId": "callback-123"}, + }, + { + "EventType": "CallbackSucceeded", + "EventTimestamp": "2023-01-01T00:05:00Z", + "Id": "callback-event-1", + "Name": "test-callback", + }, + ] + } + + runner = DurableFunctionCloudTestRunner( + function_name="test-function", poll_interval=0.01 + ) + with pytest.raises(TimeoutError, match="Callback did not available within"): + runner.wait_for_callback("test-arn", timeout=2) + + +@patch("aws_durable_execution_sdk_python_testing.runner.Executor") +def test_local_runner_wait_for_callback_all_done_without_name(mock_executor_class): + """Test DurableFunctionCloudTestRunner.wait_for_callback all_done_without_name.""" + handler = Mock() + mock_executor = Mock() + mock_executor_class.return_value = mock_executor + mock_executor.get_execution_history.return_value = ( + GetDurableExecutionHistoryResponse.from_dict( + { + "Events": [ + { + "EventType": "CallbackStarted", + "EventTimestamp": "2023-01-01T00:00:00Z", + "Id": "callback-event-1", + "Name": "test-callback", + "CallbackStartedDetails": {"CallbackId": "callback-123"}, + }, + { + "EventType": "CallbackSucceeded", + "EventTimestamp": "2023-01-01T00:05:00Z", + "Id": "callback-event-1", + "Name": "test-callback", + }, + ] + } + ) + ) + + runner = DurableFunctionTestRunner(handler) + with pytest.raises(TimeoutError, match="Callback did not available within"): + runner.wait_for_callback("test-arn", timeout=2) + + +@patch("aws_durable_execution_sdk_python_testing.runner.Executor") +def test_local_runner_wait_for_callback_with_exception(mock_executor_class): + """Test DurableFunctionCloudTestRunner.wait_for_callback with exception""" + handler = Mock() + mock_executor = Mock() + mock_executor_class.return_value = mock_executor + mock_executor.get_execution_history.side_effect = Exception("error") + + runner = DurableFunctionTestRunner(handler) + with pytest.raises( + DurableFunctionsTestError, match="Failed to fetch execution history" + ): + runner.wait_for_callback("test-arn", timeout=10) + + +@patch("aws_durable_execution_sdk_python_testing.runner.Executor") +def test_local_runner_wait_for_callback_with_resource_not_found_exception( + mock_executor_class, +): + """Test DurableFunctionCloudTestRunner.wait_for_callback with resource_not_found exception""" + handler = Mock() + mock_executor = Mock() + mock_executor_class.return_value = mock_executor + mock_executor.get_execution_history.side_effect = ResourceNotFoundException("error") + + runner = DurableFunctionTestRunner(handler) + with pytest.raises(TimeoutError, match="Callback did not available within"): + runner.wait_for_callback("test-arn", timeout=2) + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +@patch("aws_durable_execution_sdk_python_testing.runner.time") +def test_cloud_runner_wait_for_callback_timeout(mock_time, mock_boto3): + """Test DurableFunctionCloudTestRunner.wait_for_callback timeout.""" + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + mock_time.time.side_effect = [0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0] + + mock_client.get_durable_execution_history.return_value = {"Events": []} + + runner = DurableFunctionCloudTestRunner( + function_name="test-function", poll_interval=0.01 + ) + + with pytest.raises(TimeoutError, match="Callback did not available within"): + runner.wait_for_callback("test-arn", timeout=2) + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_wait_for_callback_already_completed(mock_boto3): + """Test DurableFunctionCloudTestRunner.wait_for_callback already completed.""" + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + mock_client.get_durable_execution_history.return_value = { + "Events": [ + { + "EventType": "CallbackStarted", + "EventTimestamp": "2023-01-01T00:00:00Z", + "Id": "callback-event-1", + "Name": "test-callback", + "CallbackStartedDetails": {"CallbackId": "callback-123"}, + }, + { + "EventType": "CallbackSucceeded", + "EventTimestamp": "2023-01-01T00:05:00Z", + "Id": "callback-event-1", + "Name": "test-callback", + }, + ] + } + + runner = DurableFunctionCloudTestRunner( + function_name="test-function", poll_interval=0.01 + ) + + with pytest.raises( + DurableFunctionsTestError, match="Callback test-callback has already completed" + ): + runner.wait_for_callback("test-arn", "test-callback", timeout=2) + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_wait_for_callback_client_error_retryable(mock_boto3): + """Test wait_for_callback with retryable ClientError.""" + from botocore.exceptions import ClientError # type: ignore + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + # First call raises ResourceNotFoundException, second succeeds + mock_client.get_durable_execution_history.side_effect = [ + ClientError( + error_response={"Error": {"Code": "ResourceNotFoundException"}}, + operation_name="GetDurableExecutionHistory", + ), + { + "Events": [ + { + "EventType": "CallbackStarted", + "EventTimestamp": "2023-01-01T00:00:00Z", + "Id": "callback-event-1", + "Name": "test-callback", + "CallbackStartedDetails": {"CallbackId": "callback-123"}, + } + ] + }, + ] + + runner = DurableFunctionCloudTestRunner( + function_name="test-function", poll_interval=0.01 + ) + callback_id = runner.wait_for_callback("test-arn", name="test-callback", timeout=10) + + assert callback_id == "callback-123" + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_wait_for_callback_client_error_non_retryable( + mock_boto3, +): + """Test wait_for_callback with non-retryable ClientError.""" + from botocore.exceptions import ClientError + from aws_durable_execution_sdk_python_testing.exceptions import ( + DurableFunctionsTestError, + ) + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + mock_client.get_durable_execution_history.side_effect = ClientError( + error_response={"Error": {"Code": "AccessDeniedException"}}, + operation_name="GetDurableExecutionHistory", + ) + + runner = DurableFunctionCloudTestRunner(function_name="test-function") + + with pytest.raises( + DurableFunctionsTestError, match="Failed to fetch execution history" + ): + runner.wait_for_callback("test-arn", timeout=10) + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_wait_for_callback_generic_exception(mock_boto3): + """Test wait_for_callback with generic Exception.""" + from aws_durable_execution_sdk_python_testing.exceptions import ( + DurableFunctionsTestError, + ) + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + mock_client.get_durable_execution_history.side_effect = Exception("Network error") + + runner = DurableFunctionCloudTestRunner(function_name="test-function") + + with pytest.raises( + DurableFunctionsTestError, match="Failed to fetch execution history" + ): + runner.wait_for_callback("test-arn", timeout=10) + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_wait_for_result_fetch_history_exception(mock_boto3): + """Test wait_for_result with exception in _fetch_execution_history.""" + from aws_durable_execution_sdk_python_testing.exceptions import ( + DurableFunctionsTestError, + ) + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + # Mock successful _wait_for_completion + mock_execution_response = Mock() + mock_execution_response.status = "SUCCEEDED" + + # Mock _fetch_execution_history to raise exception + runner = DurableFunctionCloudTestRunner(function_name="test-function") + runner._wait_for_completion = Mock(return_value=mock_execution_response) + runner._fetch_execution_history = Mock( + side_effect=Exception("History fetch failed") + ) + + with pytest.raises( + DurableFunctionsTestError, + match="Failed to fetch execution history: History fetch failed", + ): + runner.wait_for_result("test-arn", timeout=60) + + +@patch("aws_durable_execution_sdk_python_testing.runner.boto3") +def test_cloud_runner_wait_for_result_success(mock_boto3): + """Test wait_for_result successful execution.""" + from aws_durable_execution_sdk_python.execution import InvocationStatus + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + ) + + mock_client = Mock() + mock_boto3.client.return_value = mock_client + + # Mock successful responses + mock_execution_response = Mock() + mock_execution_response.status = "SUCCEEDED" + mock_history_response = Mock() + mock_history_response.events = [] + + runner = DurableFunctionCloudTestRunner(function_name="test-function") + runner._wait_for_completion = Mock(return_value=mock_execution_response) + runner._fetch_execution_history = Mock(return_value=mock_history_response) + + # Mock the from_execution_history method + with patch( + "aws_durable_execution_sdk_python_testing.runner.DurableFunctionTestResult.from_execution_history" + ) as mock_from_history: + mock_result = Mock() + mock_result.status = InvocationStatus.SUCCEEDED + mock_from_history.return_value = mock_result + + result = runner.wait_for_result("test-arn", timeout=60) + + assert result.status == InvocationStatus.SUCCEEDED + mock_from_history.assert_called_once_with( + mock_execution_response, mock_history_response + ) diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/runner_web_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/runner_web_test.py new file mode 100644 index 0000000..f810bf3 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/runner_web_test.py @@ -0,0 +1,1753 @@ +"""Unit tests for web runner components in runner module.""" + +from __future__ import annotations + +import logging +import os +from unittest.mock import Mock, patch + +import pytest + +from aws_durable_execution_sdk_python_testing.cli import CliApp +from aws_durable_execution_sdk_python_testing.exceptions import ( + DurableFunctionsLocalRunnerError, +) +from aws_durable_execution_sdk_python_testing.invoker import _LAMBDA_CLIENT_CONFIG +from aws_durable_execution_sdk_python_testing.runner import ( + WebRunner, + WebRunnerConfig, +) +from aws_durable_execution_sdk_python_testing.web.server import WebServiceConfig + + +def test_should_create_config_with_web_service_and_defaults(): + """Test creating WebRunnerConfig with WebServiceConfig and default Lambda settings.""" + # Arrange + web_config = WebServiceConfig( + host="localhost", + port=8080, + log_level=logging.DEBUG, + max_request_size=5 * 1024 * 1024, + ) + + # Act + config = WebRunnerConfig(web_service=web_config) + + # Assert + assert config.web_service == web_config + assert config.lambda_endpoint == "http://127.0.0.1:3001" + assert config.local_runner_endpoint == "http://0.0.0.0:5000" + assert config.local_runner_region == "us-west-2" + assert config.local_runner_mode == "local" + + +def test_should_create_config_with_custom_lambda_settings(): + """Test creating WebRunnerConfig with custom Lambda configuration.""" + # Arrange + web_config = WebServiceConfig(host="0.0.0.0", port=5000) # noqa: S104 + custom_lambda_endpoint = "http://custom-lambda:4000" + custom_runner_endpoint = "http://custom-runner:6000" + custom_region = "us-east-1" + custom_mode = "remote" + + # Act + config = WebRunnerConfig( + web_service=web_config, + lambda_endpoint=custom_lambda_endpoint, + local_runner_endpoint=custom_runner_endpoint, + local_runner_region=custom_region, + local_runner_mode=custom_mode, + ) + + # Assert + assert config.web_service == web_config + assert config.lambda_endpoint == custom_lambda_endpoint + assert config.local_runner_endpoint == custom_runner_endpoint + assert config.local_runner_region == custom_region + assert config.local_runner_mode == custom_mode + + +def test_should_access_web_service_config_fields(): + """Test accessing WebServiceConfig fields through composition.""" + # Arrange + web_config = WebServiceConfig( + host="test-host", + port=9999, + log_level=logging.WARNING, + max_request_size=1024, + ) + config = WebRunnerConfig(web_service=web_config) + + # Act & Assert + assert config.web_service.host == "test-host" + assert config.web_service.port == 9999 + assert config.web_service.log_level == logging.WARNING + assert config.web_service.max_request_size == 1024 + + +def test_should_be_immutable_frozen_dataclass(): + """Test that WebRunnerConfig is immutable (frozen=True).""" + # Arrange + web_config = WebServiceConfig() + config = WebRunnerConfig(web_service=web_config) + + # Act & Assert - attempting to modify should raise FrozenInstanceError + with pytest.raises( + AttributeError + ): # dataclass frozen raises AttributeError in Python 3.13+ + config.lambda_endpoint = "http://new-endpoint:8000" + + with pytest.raises(AttributeError): + config.web_service = WebServiceConfig(host="new-host") + + +def test_should_support_equality_comparison(): + """Test that WebRunnerConfig supports equality comparison.""" + # Arrange + web_config1 = WebServiceConfig(host="host1", port=5000) + web_config2 = WebServiceConfig(host="host1", port=5000) + web_config3 = WebServiceConfig(host="host2", port=5000) + + config1 = WebRunnerConfig( + web_service=web_config1, + lambda_endpoint="http://lambda:3001", + ) + config2 = WebRunnerConfig( + web_service=web_config2, + lambda_endpoint="http://lambda:3001", + ) + config3 = WebRunnerConfig( + web_service=web_config3, + lambda_endpoint="http://lambda:3001", + ) + + # Act & Assert + assert config1 == config2 # Same values should be equal + assert config1 != config3 # Different web_service should not be equal + assert config2 != config3 # Different web_service should not be equal + + +def test_should_support_hash_for_use_in_sets_and_dicts(): + """Test that WebRunnerConfig is hashable for use in sets and dicts.""" + # Arrange + web_config = WebServiceConfig(host="test", port=8080) + config1 = WebRunnerConfig(web_service=web_config) + config2 = WebRunnerConfig(web_service=web_config) + + # Act - should not raise exception + config_set = {config1, config2} + config_dict = {config1: "value1", config2: "value2"} + + # Assert + assert len(config_set) == 1 # Same configs should deduplicate in set + assert len(config_dict) == 1 # Same configs should overwrite in dict + + +def test_should_create_config_with_minimal_web_service(): + """Test creating config with minimal WebServiceConfig using defaults.""" + # Arrange + web_config = WebServiceConfig() # Uses all defaults + + # Act + config = WebRunnerConfig(web_service=web_config) + + # Assert + assert config.web_service.host == "localhost" + assert config.web_service.port == 5000 + assert config.web_service.log_level == logging.INFO + assert config.web_service.max_request_size == 10 * 1024 * 1024 + + +def test_should_have_proper_type_annotations(): + """Test that all fields have proper type annotations.""" + # Arrange & Act + annotations = WebRunnerConfig.__annotations__ + + # Assert + assert "web_service" in annotations + assert "lambda_endpoint" in annotations + assert "local_runner_endpoint" in annotations + assert "local_runner_region" in annotations + assert "local_runner_mode" in annotations + + # Check that the annotations are the expected string representations + assert annotations["web_service"] == "WebServiceConfig" + assert annotations["lambda_endpoint"] == "str" + assert annotations["local_runner_endpoint"] == "str" + assert annotations["local_runner_region"] == "str" + assert annotations["local_runner_mode"] == "str" + + +def test_should_create_config_with_keyword_arguments(): + """Test creating config using keyword arguments for all fields.""" + # Arrange + web_config = WebServiceConfig(host="kw-host", port=7777) + + # Act + config = WebRunnerConfig( + web_service=web_config, + lambda_endpoint="http://kw-lambda:2000", + local_runner_endpoint="http://kw-runner:3000", + local_runner_region="eu-west-1", + local_runner_mode="test", + ) + + # Assert + assert config.web_service == web_config + assert config.lambda_endpoint == "http://kw-lambda:2000" + assert config.local_runner_endpoint == "http://kw-runner:3000" + assert config.local_runner_region == "eu-west-1" + assert config.local_runner_mode == "test" + + +def test_should_represent_config_as_string(): + """Test string representation of WebRunnerConfig.""" + # Arrange + web_config = WebServiceConfig(host="repr-host", port=1234) + config = WebRunnerConfig( + web_service=web_config, + lambda_endpoint="http://repr-lambda:5000", + ) + + # Act + config_str = str(config) + + # Assert + assert "WebRunnerConfig" in config_str + assert "repr-host" in config_str + assert "1234" in config_str + assert "http://repr-lambda:5000" in config_str + + +# WebRunner class tests + + +def test_should_create_web_runner_with_config(): + """Test creating WebRunner with WebRunnerConfig.""" + # Arrange + web_config = WebServiceConfig(host="test-host", port=8080) + config = WebRunnerConfig(web_service=web_config) + + # Act + runner = WebRunner(config) + + # Assert - Test through public behavior only + assert isinstance(runner, WebRunner) + # Verify runner can be used as context manager (public API) + assert hasattr(runner, "__enter__") + assert hasattr(runner, "__exit__") + assert callable(runner.start) + assert callable(runner.stop) + assert callable(runner.serve_forever) + + +def test_should_support_context_manager_protocol(): + """Test WebRunner context manager protocol.""" + # Arrange + web_config = WebServiceConfig() + config = WebRunnerConfig(web_service=web_config) + + # Act & Assert - should not raise exception + with WebRunner(config) as runner: + assert isinstance(runner, WebRunner) + assert runner._config == config # noqa: SLF001 + + +def test_should_return_self_from_context_manager_enter(): + """Test that __enter__ returns self.""" + # Arrange + web_config = WebServiceConfig() + config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(config) + + # Act + result = runner.__enter__() + + # Assert + assert result is runner + + +def test_should_call_start_and_stop_on_context_manager(): + """Test that context manager calls start on entry and stop on exit.""" + # Arrange + web_config = WebServiceConfig() + config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(config) + + # Mock the start and stop methods to verify they're called + with ( + patch.object(runner, "start") as mock_start, + patch.object(runner, "stop") as mock_stop, + ): + # Act + with runner as context_runner: + assert context_runner is runner + mock_start.assert_called_once() + + # Assert + mock_stop.assert_called_once() + + +def test_should_handle_context_manager_exit_with_exception(): + """Test context manager exit with exception parameters.""" + # Arrange + web_config = WebServiceConfig() + config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(config) + + # Act & Assert - should not raise exception + runner.__exit__(ValueError, ValueError("test"), None) + + +def test_should_have_proper_method_signatures(): + """Test that WebRunner has all required methods with proper signatures.""" + # Arrange + web_config = WebServiceConfig() + config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(config) + + # Assert methods exist and are callable + assert callable(runner.start) + assert callable(runner.serve_forever) + assert callable(runner.stop) + assert callable(runner.__enter__) + assert callable(runner.__exit__) + + +def test_should_initialize_runner_in_stopped_state(): + """Test that WebRunner initializes in a stopped state.""" + # Arrange + web_config = WebServiceConfig() + config = WebRunnerConfig(web_service=web_config) + + # Act + runner = WebRunner(config) + + # Assert - Test through public behavior + # Should raise DurableFunctionsLocalRunnerError when trying to serve before starting + with pytest.raises(DurableFunctionsLocalRunnerError, match="Server not started"): + runner.serve_forever() + + # Should be safe to call stop multiple times (no-op when not started) + runner.stop() + runner.stop() + + +def test_should_store_config_reference(): + """Test that WebRunner can be created with config and used properly.""" + # Arrange + web_config = WebServiceConfig(host="config-test", port=9999) + config = WebRunnerConfig( + web_service=web_config, + lambda_endpoint="http://test:1234", + ) + + # Act + runner = WebRunner(config) + + # Assert - Test through public behavior + assert isinstance(runner, WebRunner) + + # Verify the runner can be started and stopped (public behavior) + with ( + patch("boto3.client") as mock_boto3_client, + patch( + "aws_durable_execution_sdk_python_testing.runner.WebServer" + ) as mock_web_server_class, + ): + mock_client = Mock() + mock_server = Mock() + mock_boto3_client.return_value = mock_client + mock_web_server_class.return_value = mock_server + + runner.start() + + # Verify server was started (public behavior - no exception on serve_forever call) + runner.serve_forever() + mock_server.serve_forever.assert_called_once() + + runner.stop() + mock_server.server_close.assert_called_once() + + +# Integration Tests - Testing Public Behavior + + +def test_should_handle_start_with_boto3_client_creation(): + """Test that start() properly handles boto3 client creation through public API.""" + # Arrange + web_config = WebServiceConfig(host="localhost", port=5000) + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Mock boto3.client to avoid actual client creation + with patch("boto3.client") as mock_boto3_client: + mock_client = Mock() + mock_boto3_client.return_value = mock_client + + # Act - Test public behavior + runner.start() + + # Assert - Verify public behavior + # Should be able to call serve_forever after start (public API) + with patch.object(runner, "serve_forever") as mock_serve: + runner.serve_forever() + mock_serve.assert_called_once() + + # Should be able to stop after start (public API) + runner.stop() + + +def test_should_handle_boto3_client_creation_with_custom_config(): + """Test that start() uses custom configuration for boto3 client through public API.""" + # Arrange + web_config = WebServiceConfig(host="localhost", port=5000) + runner_config = WebRunnerConfig( + web_service=web_config, + lambda_endpoint="http://custom-endpoint:8080", + local_runner_region="eu-west-1", + ) + runner = WebRunner(runner_config) + + with patch("boto3.client") as mock_boto3_client: + mock_client = Mock() + mock_boto3_client.return_value = mock_client + + # Act - Test public behavior + runner.start() + + # Assert - Verify boto3 client was called with correct parameters + mock_boto3_client.assert_called_once_with( + "lambda", + endpoint_url="http://custom-endpoint:8080", + region_name="eu-west-1", + config=_LAMBDA_CLIENT_CONFIG, + ) + + # Verify public behavior works + runner.stop() + + +def test_should_handle_boto3_client_creation_with_defaults(): + """Test that start() uses default configuration values through public API.""" + # Arrange + web_config = WebServiceConfig(host="localhost", port=5000) + runner_config = WebRunnerConfig(web_service=web_config) # Use defaults + runner = WebRunner(runner_config) + + with patch("boto3.client") as mock_boto3_client: + mock_client = Mock() + mock_boto3_client.return_value = mock_client + + # Act - Test public behavior + runner.start() + + # Assert - Verify boto3 client was called with default parameters + mock_boto3_client.assert_called_once_with( + "lambda", + endpoint_url="http://127.0.0.1:3001", # Default lambda_endpoint value + region_name="us-west-2", # Default value + config=_LAMBDA_CLIENT_CONFIG, + ) + + # Verify public behavior works + runner.stop() + + +def test_should_propagate_boto3_client_creation_exceptions(): + """Test that start() propagates boto3 client creation exceptions through public API.""" + # Arrange + web_config = WebServiceConfig(host="localhost", port=5000) + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Mock boto3.client to raise an exception + with patch("boto3.client") as mock_boto3_client: + mock_boto3_client.side_effect = Exception("Connection failed") + + # Act & Assert - Test public behavior + with pytest.raises(Exception, match="Connection failed"): + runner.start() + + +def test_should_create_boto3_client_during_start(): + """Test that start() creates boto3 client correctly through public API.""" + # Arrange + web_config = WebServiceConfig(host="localhost", port=5000) + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + with patch("boto3.client") as mock_boto3_client: + mock_client = Mock() + mock_boto3_client.return_value = mock_client + + # Act - Test public behavior + runner.start() + + # Assert - Verify boto3 client was created + mock_boto3_client.assert_called_once() + + # Verify public behavior works + runner.stop() + + +# Error Condition Tests + + +def test_should_raise_runtime_error_on_double_start(): + """Test that calling start() twice raises DurableFunctionsLocalRunnerError.""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Mock dependencies to allow first start + with ( + patch("boto3.client") as mock_boto3_client, + patch( + "aws_durable_execution_sdk_python_testing.runner.WebServer" + ) as mock_web_server_class, + ): + mock_client = Mock() + mock_server = Mock() + mock_boto3_client.return_value = mock_client + mock_web_server_class.return_value = mock_server + + # First start should succeed + runner.start() + + # Act & Assert - Second start should raise DurableFunctionsLocalRunnerError + with pytest.raises( + DurableFunctionsLocalRunnerError, match="Server is already running" + ): + runner.start() + + # Cleanup + runner.stop() + + +def test_should_raise_runtime_error_when_serve_before_start(): + """Test that calling serve_forever() before start() raises DurableFunctionsLocalRunnerError.""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Act & Assert - serve_forever before start should raise DurableFunctionsLocalRunnerError + with pytest.raises(DurableFunctionsLocalRunnerError, match="Server not started"): + runner.serve_forever() + + +def test_should_propagate_boto3_client_creation_failures(): + """Test that boto3 client creation failures are propagated as exceptions.""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Mock boto3.client to raise various exceptions + test_cases = [ + Exception("Connection refused"), + ConnectionError("Network error"), + ValueError("Invalid endpoint URL"), + RuntimeError("AWS credentials not found"), + ] + + for exception in test_cases: + with patch("boto3.client") as mock_boto3_client: + mock_boto3_client.side_effect = exception + + # Act & Assert - Exception should propagate + with pytest.raises(type(exception), match=str(exception)): + runner.start() + + +def test_should_handle_web_server_creation_failures(): + """Test that WebServer creation failures are propagated as exceptions.""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Mock boto3 client to succeed but WebServer to fail + with ( + patch("boto3.client") as mock_boto3_client, + patch( + "aws_durable_execution_sdk_python_testing.runner.WebServer" + ) as mock_web_server_class, + ): + mock_client = Mock() + mock_boto3_client.return_value = mock_client + mock_web_server_class.side_effect = Exception("Failed to bind to port") + + # Act & Assert - WebServer creation failure should propagate + with pytest.raises(Exception, match="Failed to bind to port"): + runner.start() + + +def test_should_handle_scheduler_creation_failures(): + """Test that Scheduler creation failures are propagated as exceptions.""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Mock boto3 client to succeed but Scheduler to fail + with ( + patch("boto3.client") as mock_boto3_client, + patch( + "aws_durable_execution_sdk_python_testing.runner.Scheduler" + ) as mock_scheduler_class, + ): + mock_client = Mock() + mock_boto3_client.return_value = mock_client + mock_scheduler_class.side_effect = Exception("Scheduler initialization failed") + + # Act & Assert - Scheduler creation failure should propagate + with pytest.raises(Exception, match="Scheduler initialization failed"): + runner.start() + + +def test_should_handle_executor_creation_failures(): + """Test that Executor creation failures are propagated as exceptions.""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Mock dependencies to succeed but Executor to fail + with ( + patch("boto3.client") as mock_boto3_client, + patch( + "aws_durable_execution_sdk_python_testing.runner.Executor" + ) as mock_executor_class, + ): + mock_client = Mock() + mock_boto3_client.return_value = mock_client + mock_executor_class.side_effect = Exception("Executor initialization failed") + + # Act & Assert - Executor creation failure should propagate + with pytest.raises(Exception, match="Executor initialization failed"): + runner.start() + + +# Dependency Creation and Wiring Tests + + +def test_should_create_all_required_dependencies_during_start(): + """Test that start() creates all required dependencies with proper wiring.""" + # Arrange + web_config = WebServiceConfig(host="test-host", port=8080) + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Mock all dependency classes + with ( + patch("boto3.client") as mock_boto3_client, + patch( + "aws_durable_execution_sdk_python_testing.runner.InMemoryExecutionStore" + ) as mock_store_class, + patch( + "aws_durable_execution_sdk_python_testing.runner.Scheduler" + ) as mock_scheduler_class, + patch( + "aws_durable_execution_sdk_python_testing.runner.LambdaInvoker" + ) as mock_invoker_class, + patch( + "aws_durable_execution_sdk_python_testing.runner.Executor" + ) as mock_executor_class, + patch( + "aws_durable_execution_sdk_python_testing.runner.WebServer" + ) as mock_web_server_class, + ): + # Setup mocks + mock_client = Mock() + mock_store = Mock() + mock_scheduler = Mock() + mock_invoker = Mock() + mock_executor = Mock() + mock_server = Mock() + + mock_boto3_client.return_value = mock_client + mock_store_class.return_value = mock_store + mock_scheduler_class.return_value = mock_scheduler + mock_invoker_class.return_value = mock_invoker + mock_executor_class.return_value = mock_executor + mock_web_server_class.return_value = mock_server + + # Act + runner.start() + + # Assert - Verify all dependencies were created + mock_store_class.assert_called_once() + mock_scheduler_class.assert_called_once() + mock_invoker_class.assert_called_once_with(mock_client) + # Verify Executor was called with the expected parameters including checkpoint_processor + assert mock_executor_class.call_count == 1 + call_args = mock_executor_class.call_args + assert call_args.kwargs["store"] == mock_store + assert call_args.kwargs["scheduler"] == mock_scheduler + assert call_args.kwargs["invoker"] == mock_invoker + assert "checkpoint_processor" in call_args.kwargs + mock_web_server_class.assert_called_once_with( + config=web_config, executor=mock_executor + ) + + # Verify scheduler was started + mock_scheduler.start.assert_called_once() + + # Cleanup + runner.stop() + + +def test_should_pass_correct_configuration_to_web_server(): + """Test that WebServer receives correct configuration from WebRunnerConfig.""" + # Arrange + web_config = WebServiceConfig( + host="custom-host", port=9999, log_level="WARNING", max_request_size=2048 + ) + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Mock dependencies + with ( + patch("boto3.client") as mock_boto3_client, + patch( + "aws_durable_execution_sdk_python_testing.runner.WebServer" + ) as mock_web_server_class, + ): + mock_client = Mock() + mock_server = Mock() + mock_boto3_client.return_value = mock_client + mock_web_server_class.return_value = mock_server + + # Act + runner.start() + + # Assert - Verify WebServer was created with correct config + mock_web_server_class.assert_called_once() + call_args = mock_web_server_class.call_args + + # Verify the web service config was passed correctly + passed_config = call_args[1]["config"] + assert passed_config == web_config + assert passed_config.host == "custom-host" + assert passed_config.port == 9999 + assert passed_config.log_level == "WARNING" + assert passed_config.max_request_size == 2048 + + # Cleanup + runner.stop() + + +def test_should_pass_correct_boto3_client_to_lambda_invoker(): + """Test that LambdaInvoker receives correct boto3 client configuration.""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig( + web_service=web_config, + lambda_endpoint="http://test-endpoint:7777", + local_runner_region="ap-southeast-2", + ) + runner = WebRunner(runner_config) + + # Mock dependencies + with ( + patch("boto3.client") as mock_boto3_client, + patch( + "aws_durable_execution_sdk_python_testing.runner.LambdaInvoker" + ) as mock_invoker_class, + ): + mock_client = Mock() + mock_invoker = Mock() + mock_boto3_client.return_value = mock_client + mock_invoker_class.return_value = mock_invoker + + # Act + runner.start() + + # Assert - Verify boto3 client was created with correct parameters + mock_boto3_client.assert_called_once_with( + "lambda", + endpoint_url="http://test-endpoint:7777", + region_name="ap-southeast-2", + config=_LAMBDA_CLIENT_CONFIG, + ) + + # Verify LambdaInvoker was created with the client + mock_invoker_class.assert_called_once_with(mock_client) + + # Cleanup + runner.stop() + + +def test_should_wire_dependencies_correctly_in_executor(): + """Test that Executor receives correctly wired dependencies.""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Mock dependencies + with ( + patch("boto3.client") as mock_boto3_client, + patch( + "aws_durable_execution_sdk_python_testing.runner.InMemoryExecutionStore" + ) as mock_store_class, + patch( + "aws_durable_execution_sdk_python_testing.runner.Scheduler" + ) as mock_scheduler_class, + patch( + "aws_durable_execution_sdk_python_testing.runner.LambdaInvoker" + ) as mock_invoker_class, + patch( + "aws_durable_execution_sdk_python_testing.runner.Executor" + ) as mock_executor_class, + patch( + "aws_durable_execution_sdk_python_testing.runner.WebServer" + ) as mock_web_server_class, + ): + mock_client = Mock() + mock_store = Mock() + mock_scheduler = Mock() + mock_invoker = Mock() + mock_executor = Mock() + mock_web_server = Mock() + + mock_boto3_client.return_value = mock_client + mock_store_class.return_value = mock_store + mock_scheduler_class.return_value = mock_scheduler + mock_invoker_class.return_value = mock_invoker + mock_executor_class.return_value = mock_executor + mock_web_server_class.return_value = mock_web_server + + # Act + runner.start() + + # Assert - Verify Executor was created with correct dependencies + assert mock_executor_class.call_count == 1 + call_args = mock_executor_class.call_args + assert call_args.kwargs["store"] == mock_store + assert call_args.kwargs["scheduler"] == mock_scheduler + assert call_args.kwargs["invoker"] == mock_invoker + assert "checkpoint_processor" in call_args.kwargs + + # Cleanup + runner.stop() + + +# WebServer Lifecycle and Configuration Tests + + +def test_should_delegate_serve_forever_to_web_server(): + """Test that serve_forever() properly delegates to WebServer.serve_forever().""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Mock dependencies + with ( + patch("boto3.client") as mock_boto3_client, + patch( + "aws_durable_execution_sdk_python_testing.runner.WebServer" + ) as mock_web_server_class, + ): + mock_client = Mock() + mock_server = Mock() + mock_boto3_client.return_value = mock_client + mock_web_server_class.return_value = mock_server + + # Start the runner + runner.start() + + # Act + runner.serve_forever() + + # Assert - Verify WebServer.serve_forever was called + mock_server.serve_forever.assert_called_once() + + # Cleanup + runner.stop() + + +def test_should_call_server_close_during_stop(): + """Test that stop() calls server_close() on WebServer.""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Mock dependencies + with ( + patch("boto3.client") as mock_boto3_client, + patch( + "aws_durable_execution_sdk_python_testing.runner.WebServer" + ) as mock_web_server_class, + patch( + "aws_durable_execution_sdk_python_testing.runner.Scheduler" + ) as mock_scheduler_class, + ): + mock_client = Mock() + mock_server = Mock() + mock_scheduler = Mock() + mock_boto3_client.return_value = mock_client + mock_web_server_class.return_value = mock_server + mock_scheduler_class.return_value = mock_scheduler + + # Start the runner + runner.start() + + # Act + runner.stop() + + # Assert - Verify cleanup methods were called + mock_server.server_close.assert_called_once() + mock_scheduler.stop.assert_called_once() + + +def test_should_handle_web_server_serve_forever_exceptions(): + """Test that exceptions from WebServer.serve_forever() are propagated.""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Mock dependencies + with ( + patch("boto3.client") as mock_boto3_client, + patch( + "aws_durable_execution_sdk_python_testing.runner.WebServer" + ) as mock_web_server_class, + ): + mock_client = Mock() + mock_server = Mock() + mock_boto3_client.return_value = mock_client + mock_web_server_class.return_value = mock_server + + # Make serve_forever raise an exception + mock_server.serve_forever.side_effect = Exception("Server error") + + # Start the runner + runner.start() + + # Act & Assert - Exception should propagate + with pytest.raises(Exception, match="Server error"): + runner.serve_forever() + + # Cleanup + runner.stop() + + +def test_should_handle_web_server_close_exceptions_gracefully(): + """Test that exceptions from server_close() are handled gracefully.""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Mock dependencies + with ( + patch("boto3.client") as mock_boto3_client, + patch( + "aws_durable_execution_sdk_python_testing.runner.WebServer" + ) as mock_web_server_class, + patch( + "aws_durable_execution_sdk_python_testing.runner.Scheduler" + ) as mock_scheduler_class, + ): + mock_client = Mock() + mock_server = Mock() + mock_scheduler = Mock() + mock_boto3_client.return_value = mock_client + mock_web_server_class.return_value = mock_server + mock_scheduler_class.return_value = mock_scheduler + + # Make server_close raise an exception + mock_server.server_close.side_effect = Exception("Close error") + + # Start the runner + runner.start() + + # Act - stop() should not raise exception despite server_close error + runner.stop() + + # Assert - Verify both cleanup methods were attempted + mock_server.server_close.assert_called_once() + mock_scheduler.stop.assert_called_once() + + +# Exception Handling Tests + + +def test_should_handle_standard_runtime_errors(): + """Test that standard RuntimeError exceptions are handled properly.""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Test RuntimeError during start + with patch("boto3.client") as mock_boto3_client: + mock_boto3_client.side_effect = RuntimeError("Runtime error during start") + + with pytest.raises(RuntimeError, match="Runtime error during start"): + runner.start() + + # Test DurableFunctionsLocalRunnerError when serve_forever called before start + with pytest.raises(DurableFunctionsLocalRunnerError, match="Server not started"): + runner.serve_forever() + + +def test_should_handle_value_errors_during_initialization(): + """Test that ValueError exceptions during initialization are propagated.""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Mock boto3 client to raise ValueError + with patch("boto3.client") as mock_boto3_client: + mock_boto3_client.side_effect = ValueError("Invalid configuration") + + # Act & Assert + with pytest.raises(ValueError, match="Invalid configuration"): + runner.start() + + +def test_should_handle_connection_errors_during_initialization(): + """Test that ConnectionError exceptions during initialization are propagated.""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Mock boto3 client to raise ConnectionError + with patch("boto3.client") as mock_boto3_client: + mock_boto3_client.side_effect = ConnectionError("Network connection failed") + + # Act & Assert + with pytest.raises(ConnectionError, match="Network connection failed"): + runner.start() + + +def test_should_handle_keyboard_interrupt_during_serve_forever(): + """Test that KeyboardInterrupt during serve_forever is propagated.""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Mock dependencies + with ( + patch("boto3.client") as mock_boto3_client, + patch( + "aws_durable_execution_sdk_python_testing.runner.WebServer" + ) as mock_web_server_class, + ): + mock_client = Mock() + mock_server = Mock() + mock_boto3_client.return_value = mock_client + mock_web_server_class.return_value = mock_server + + # Make serve_forever raise KeyboardInterrupt + mock_server.serve_forever.side_effect = KeyboardInterrupt() + + # Start the runner + runner.start() + + # Act & Assert - KeyboardInterrupt should propagate + with pytest.raises(KeyboardInterrupt): + runner.serve_forever() + + # Cleanup + runner.stop() + + +# Lifecycle Management Tests + + +def test_start_creates_dependencies_and_server(): + """Test that start() creates all dependencies and WebServer through public API.""" + # Arrange + web_config = WebServiceConfig(host="localhost", port=5000) + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Mock dependencies and WebServer + with ( + patch("boto3.client") as mock_boto3_client, + patch( + "aws_durable_execution_sdk_python_testing.runner.WebServer" + ) as mock_web_server_class, + patch( + "aws_durable_execution_sdk_python_testing.runner.Scheduler" + ) as mock_scheduler_class, + ): + mock_client = Mock() + mock_server = Mock() + mock_scheduler = Mock() + + mock_boto3_client.return_value = mock_client + mock_web_server_class.return_value = mock_server + mock_scheduler_class.return_value = mock_scheduler + + # Act + runner.start() + + # Assert - Test through public behavior + # Should be able to call serve_forever after start + runner.serve_forever() + mock_server.serve_forever.assert_called_once() + + # Should be able to stop after start + runner.stop() + mock_server.server_close.assert_called_once() + mock_scheduler.stop.assert_called_once() + + +def test_start_raises_runtime_error_if_already_started(): + """Test that start() raises DurableFunctionsLocalRunnerError if server is already running.""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Set server to simulate already started state + runner._server = Mock() # noqa: SLF001 + + # Act & Assert + with pytest.raises( + DurableFunctionsLocalRunnerError, match="Server is already running" + ): + runner.start() + + +def test_serve_forever_delegates_to_web_server(): + """Test that serve_forever() delegates to WebServer.serve_forever() through public API.""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Mock dependencies to allow start + with ( + patch("boto3.client") as mock_boto3_client, + patch( + "aws_durable_execution_sdk_python_testing.runner.WebServer" + ) as mock_web_server_class, + ): + mock_client = Mock() + mock_server = Mock() + mock_boto3_client.return_value = mock_client + mock_web_server_class.return_value = mock_server + + # Start the runner first (public API) + runner.start() + + # Act + runner.serve_forever() + + # Assert + mock_server.serve_forever.assert_called_once() + + # Cleanup + runner.stop() + + +def test_serve_forever_raises_runtime_error_if_not_started(): + """Test that serve_forever() raises DurableFunctionsLocalRunnerError if server not started.""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Ensure server is None (not started) + assert runner._server is None # noqa: SLF001 + + # Act & Assert + with pytest.raises(DurableFunctionsLocalRunnerError, match="Server not started"): + runner.serve_forever() + + +def test_stop_cleans_up_server_and_scheduler(): + """Test that stop() properly cleans up server and scheduler resources through public API.""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Mock dependencies to allow start + with ( + patch("boto3.client") as mock_boto3_client, + patch( + "aws_durable_execution_sdk_python_testing.runner.WebServer" + ) as mock_web_server_class, + patch( + "aws_durable_execution_sdk_python_testing.runner.Scheduler" + ) as mock_scheduler_class, + ): + mock_client = Mock() + mock_server = Mock() + mock_scheduler = Mock() + + mock_boto3_client.return_value = mock_client + mock_web_server_class.return_value = mock_server + mock_scheduler_class.return_value = mock_scheduler + + # Start the runner first (public API) + runner.start() + + # Act + runner.stop() + + # Assert - Verify cleanup was called + mock_server.server_close.assert_called_once() + mock_scheduler.stop.assert_called_once() + + # Verify runner is back to stopped state (public behavior) + with pytest.raises( + DurableFunctionsLocalRunnerError, match="Server not started" + ): + runner.serve_forever() + + +def test_stop_is_safe_to_call_multiple_times(): + """Test that stop() can be called multiple times safely through public API.""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Mock dependencies to allow start + with ( + patch("boto3.client") as mock_boto3_client, + patch( + "aws_durable_execution_sdk_python_testing.runner.WebServer" + ) as mock_web_server_class, + patch( + "aws_durable_execution_sdk_python_testing.runner.Scheduler" + ) as mock_scheduler_class, + ): + mock_client = Mock() + mock_server = Mock() + mock_scheduler = Mock() + + mock_boto3_client.return_value = mock_client + mock_web_server_class.return_value = mock_server + mock_scheduler_class.return_value = mock_scheduler + + # Start the runner first (public API) + runner.start() + + # Act - call stop multiple times + runner.stop() + runner.stop() + runner.stop() + + # Assert - should only be called once (first time) + mock_server.server_close.assert_called_once() + mock_scheduler.stop.assert_called_once() + + # Verify runner remains in stopped state (public behavior) + with pytest.raises( + DurableFunctionsLocalRunnerError, match="Server not started" + ): + runner.serve_forever() + + +# Integration Tests - CLI to WebRunner Flow + + +def test_should_integrate_with_cli_start_server_command(): + """Test complete integration from CLI start-server command to WebRunner.""" + # This test verifies the complete flow from CLI argument parsing + # through WebRunnerConfig creation to WebRunner execution + + # Arrange + app = CliApp() + + # Mock WebRunner to verify it receives correct configuration + with patch( + "aws_durable_execution_sdk_python_testing.cli.WebRunner" + ) as mock_web_runner_class: + # Setup mock runner instance with context manager support + mock_runner = Mock() + mock_runner.__enter__ = Mock(return_value=mock_runner) + mock_runner.__exit__ = Mock(return_value=None) + mock_web_runner_class.return_value = mock_runner + mock_runner.serve_forever.side_effect = KeyboardInterrupt() + + # Act - Run CLI command with custom arguments + exit_code = app.run( + [ + "start-server", + "--host", + "integration-host", + "--port", + "7777", + "--log-level", + "WARNING", + "--lambda-endpoint", + "http://integration-lambda:4000", + "--local-runner-endpoint", + "http://integration-runner:8000", + "--local-runner-region", + "eu-central-1", + "--local-runner-mode", + "integration", + ] + ) + + # Assert - Verify CLI handled KeyboardInterrupt correctly + assert exit_code == 130 + + # Verify WebRunner was created with correct configuration + mock_web_runner_class.assert_called_once() + config = mock_web_runner_class.call_args[0][0] + + # Verify web service configuration + assert config.web_service.host == "integration-host" + assert config.web_service.port == 7777 + assert config.web_service.log_level == "WARNING" + + # Verify Lambda service configuration + assert config.lambda_endpoint == "http://integration-lambda:4000" + assert config.local_runner_endpoint == "http://integration-runner:8000" + assert config.local_runner_region == "eu-central-1" + assert config.local_runner_mode == "integration" + + # Verify context manager protocol was used + mock_runner.__enter__.assert_called_once() + mock_runner.__exit__.assert_called_once() + mock_runner.serve_forever.assert_called_once() + + +def test_should_handle_cli_to_web_runner_startup_errors(): + """Test integration error handling from CLI to WebRunner startup failures.""" + # Arrange + app = CliApp() + + # Mock WebRunner to raise exception during creation + with patch( + "aws_durable_execution_sdk_python_testing.cli.WebRunner" + ) as mock_web_runner_class: + mock_web_runner_class.side_effect = Exception("WebRunner startup failed") + + with patch( + "aws_durable_execution_sdk_python_testing.cli.logger" + ) as mock_logger: + # Act + exit_code = app.run(["start-server"]) + + # Assert - Verify CLI handled WebRunner exception correctly + assert exit_code == 1 + mock_logger.exception.assert_called_with("Failed to start server") + + +def test_should_handle_cli_to_web_runner_context_manager_errors(): + """Test integration error handling for WebRunner context manager failures.""" + # Arrange + app = CliApp() + + # Mock WebRunner context manager to raise exception + with patch( + "aws_durable_execution_sdk_python_testing.cli.WebRunner" + ) as mock_web_runner_class: + mock_runner = Mock() + mock_runner.__enter__ = Mock( + side_effect=DurableFunctionsLocalRunnerError("Context manager failed") + ) + mock_runner.__exit__ = Mock(return_value=None) + mock_web_runner_class.return_value = mock_runner + + with patch( + "aws_durable_execution_sdk_python_testing.cli.logger" + ) as mock_logger: + # Act + exit_code = app.run(["start-server"]) + + # Assert - Verify CLI handled context manager exception correctly + assert exit_code == 1 + mock_logger.exception.assert_called_with("Failed to start server") + + +def test_should_handle_cli_to_web_runner_serve_forever_errors(): + """Test integration error handling for WebRunner serve_forever failures.""" + # Arrange + app = CliApp() + + # Mock WebRunner serve_forever to raise exception + with patch( + "aws_durable_execution_sdk_python_testing.cli.WebRunner" + ) as mock_web_runner_class: + mock_runner = Mock() + mock_runner.__enter__ = Mock(return_value=mock_runner) + mock_runner.__exit__ = Mock(return_value=None) + mock_web_runner_class.return_value = mock_runner + mock_runner.serve_forever.side_effect = Exception("Server runtime error") + + with patch( + "aws_durable_execution_sdk_python_testing.cli.logger" + ) as mock_logger: + # Act + exit_code = app.run(["start-server"]) + + # Assert - Verify CLI handled serve_forever exception correctly + assert exit_code == 1 + mock_logger.exception.assert_called_with("Failed to start server") + + +def test_should_preserve_cli_configuration_through_web_runner(): + """Test that CLI configuration is preserved through WebRunner creation.""" + # This test verifies that all CLI arguments are correctly passed through + # the WebRunnerConfig to the WebRunner and its dependencies + + # Arrange + app = CliApp() + + # Mock all WebRunner dependencies to verify configuration flow + with ( + patch( + "aws_durable_execution_sdk_python_testing.cli.WebRunner" + ) as mock_web_runner_class, + patch("boto3.client"), + patch("aws_durable_execution_sdk_python_testing.runner.WebServer"), + ): + # Setup mocks + mock_runner = Mock() + + mock_runner.__enter__ = Mock(return_value=mock_runner) + mock_runner.__exit__ = Mock(return_value=None) + mock_web_runner_class.return_value = mock_runner + mock_runner.serve_forever.return_value = None + + # No need to mock internal behavior, just verify configuration passing + + # Act - Run CLI with comprehensive configuration + exit_code = app.run( + [ + "start-server", + "--host", + "config-test-host", + "--port", + "9999", + "--log-level", + "ERROR", # ERROR level + "--lambda-endpoint", + "http://config-lambda:5000", + "--local-runner-endpoint", + "http://config-runner:9000", + "--local-runner-region", + "ap-northeast-1", + "--local-runner-mode", + "config-test", + ] + ) + + # Assert - Verify successful execution + assert exit_code == 0 + + # Verify WebRunner was created with correct configuration + mock_web_runner_class.assert_called_once() + config = mock_web_runner_class.call_args[0][0] + + # Verify web service configuration + assert config.web_service.host == "config-test-host" + assert config.web_service.port == 9999 + assert config.web_service.log_level == "ERROR" + + # Verify Lambda service configuration + assert config.lambda_endpoint == "http://config-lambda:5000" + assert config.local_runner_endpoint == "http://config-runner:9000" + assert config.local_runner_region == "ap-northeast-1" + assert config.local_runner_mode == "config-test" + + # Verify context manager protocol was used + mock_runner.__enter__.assert_called_once() + mock_runner.serve_forever.assert_called_once() + mock_runner.__exit__.assert_called_once() + + +def test_should_handle_environment_variable_integration(): + """Test integration with environment variables through CLI to WebRunner.""" + # Set environment variables + env_vars = { + "AWS_DEX_HOST": "env-host", + "AWS_DEX_PORT": "8888", + "AWS_DEX_LOG_LEVEL": "CRITICAL", # CRITICAL level + "AWS_DEX_LAMBDA_ENDPOINT": "http://env-lambda:6000", + "AWS_DEX_LOCAL_RUNNER_ENDPOINT": "http://env-runner:7000", + "AWS_DEX_LOCAL_RUNNER_REGION": "sa-east-1", + "AWS_DEX_LOCAL_RUNNER_MODE": "env-test", + } + + with patch.dict(os.environ, env_vars, clear=True): + app = CliApp() + + # Mock WebRunner to verify environment configuration + with patch( + "aws_durable_execution_sdk_python_testing.cli.WebRunner" + ) as mock_web_runner_class: + mock_runner = Mock() + mock_web_runner_class.return_value = mock_runner + mock_runner.__enter__ = Mock(return_value=mock_runner) + mock_runner.__exit__ = Mock(return_value=None) + mock_runner.serve_forever.return_value = None + + # Act - Run CLI without arguments (should use environment) + exit_code = app.run(["start-server"]) + + # Assert - Verify successful execution + assert exit_code == 0 + + # Verify WebRunner was created with environment configuration + mock_web_runner_class.assert_called_once() + config = mock_web_runner_class.call_args[0][0] + + # Verify environment variables were used + assert config.web_service.host == "env-host" + assert config.web_service.port == 8888 + assert config.web_service.log_level == "CRITICAL" + assert config.lambda_endpoint == "http://env-lambda:6000" + assert config.local_runner_endpoint == "http://env-runner:7000" + assert config.local_runner_region == "sa-east-1" + assert config.local_runner_mode == "env-test" + + +def test_should_handle_cli_argument_override_of_environment(): + """Test that CLI arguments override environment variables in integration.""" + # Set environment variables + env_vars = { + "AWS_DEX_HOST": "env-host", + "AWS_DEX_PORT": "8888", + } + + with patch.dict(os.environ, env_vars, clear=True): + app = CliApp() + + # Mock WebRunner to verify argument override + with patch( + "aws_durable_execution_sdk_python_testing.cli.WebRunner" + ) as mock_web_runner_class: + mock_runner = Mock() + mock_web_runner_class.return_value = mock_runner + mock_runner.__enter__ = Mock(return_value=mock_runner) + mock_runner.__exit__ = Mock(return_value=None) + mock_runner.serve_forever.return_value = None + + # Act - Run CLI with arguments that should override environment + exit_code = app.run( + [ + "start-server", + "--host", + "cli-override-host", + "--port", + "7777", + ] + ) + + # Assert - Verify successful execution + assert exit_code == 0 + + # Verify CLI arguments overrode environment variables + config = mock_web_runner_class.call_args[0][0] + assert config.web_service.host == "cli-override-host" # CLI override + assert config.web_service.port == 7777 # CLI override + + +def test_should_maintain_backward_compatibility_in_integration(): + """Test that integration maintains backward compatibility with existing behavior.""" + # This test ensures that the refactored CLI-to-WebRunner flow + # maintains the same external behavior as the original implementation + app = CliApp() + + # Mock WebRunner to simulate successful operation + with patch( + "aws_durable_execution_sdk_python_testing.cli.WebRunner" + ) as mock_web_runner_class: + mock_runner = Mock() + mock_web_runner_class.return_value = mock_runner + mock_runner.__enter__ = Mock(return_value=mock_runner) + mock_runner.__exit__ = Mock(return_value=None) + mock_runner.serve_forever.side_effect = KeyboardInterrupt() + + # Mock logging to verify backward compatible messages + with patch( + "aws_durable_execution_sdk_python_testing.cli.logger" + ) as mock_logger: + # Act + exit_code = app.run( + ["start-server", "--host", "compat-host", "--port", "5555"] + ) + + # Assert - Verify backward compatible behavior + assert exit_code == 130 # KeyboardInterrupt exit code + + # Verify backward compatible logging messages + mock_logger.info.assert_any_call( + "Starting Durable Functions Local Runner on %s:%s", + "compat-host", + 5555, + ) + mock_logger.info.assert_any_call("Configuration:") + mock_logger.info.assert_any_call(" Host: %s", "compat-host") + mock_logger.info.assert_any_call(" Port: %s", 5555) + mock_logger.info.assert_any_call( + "Server started successfully. Press Ctrl+C to stop." + ) + mock_logger.info.assert_any_call( + "Received shutdown signal, stopping server..." + ) + + +def test_stop_handles_unstarted_runner_gracefully(): + """Test that stop() handles unstarted runner gracefully through public API.""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Act & Assert - should not raise any exceptions when stopping unstarted runner + runner.stop() + + # Verify runner remains in stopped state (public behavior) + with pytest.raises(DurableFunctionsLocalRunnerError, match="Server not started"): + runner.serve_forever() + + +def test_complete_lifecycle_start_serve_stop(): + """Test complete lifecycle: start -> serve_forever -> stop through public API.""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Mock all dependencies + with ( + patch("boto3.client") as mock_boto3_client, + patch( + "aws_durable_execution_sdk_python_testing.runner.WebServer" + ) as mock_web_server_class, + patch( + "aws_durable_execution_sdk_python_testing.runner.Scheduler" + ) as mock_scheduler_class, + ): + mock_client = Mock() + mock_server = Mock() + mock_scheduler = Mock() + + mock_boto3_client.return_value = mock_client + mock_web_server_class.return_value = mock_server + mock_scheduler_class.return_value = mock_scheduler + + # Act - complete lifecycle through public API + runner.start() + runner.serve_forever() + runner.stop() + + # Assert - Verify all methods were called + mock_server.serve_forever.assert_called_once() + mock_server.server_close.assert_called_once() + mock_scheduler.start.assert_called_once() + mock_scheduler.stop.assert_called_once() + + # Verify runner is back to stopped state (public behavior) + with pytest.raises( + DurableFunctionsLocalRunnerError, match="Server not started" + ): + runner.serve_forever() + + +def test_context_manager_calls_start_and_stop(): + """Test that context manager properly calls start() and stop().""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Mock start and stop to track calls + with ( + patch.object(runner, "start") as mock_start, + patch.object(runner, "stop") as mock_stop, + ): + # Act + with runner as context_runner: + # Verify start was called and runner returned + mock_start.assert_called_once() + assert context_runner is runner + + # Assert stop was called on exit + mock_stop.assert_called_once() + + +def test_context_manager_calls_stop_on_exception(): + """Test that context manager calls stop() even when exception occurs.""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Mock start and stop + with ( + patch.object(runner, "start") as mock_start, + patch.object(runner, "stop") as mock_stop, + ): + # Act & Assert + with pytest.raises(ValueError, match="Test exception"): # noqa: PT012 + with runner: + mock_start.assert_called_once() + raise ValueError("Test exception") # noqa: TRY003, EM101 + + # Verify stop was still called despite exception + mock_stop.assert_called_once() + + +def test_state_transitions_prevent_invalid_operations(): + """Test that state checking prevents invalid operation sequences through public API.""" + # Arrange + web_config = WebServiceConfig() + runner_config = WebRunnerConfig(web_service=web_config) + runner = WebRunner(runner_config) + + # Test serve_forever before start + with pytest.raises(DurableFunctionsLocalRunnerError, match="Server not started"): + runner.serve_forever() + + # Mock dependencies for start + with ( + patch("boto3.client") as mock_boto3_client, + patch( + "aws_durable_execution_sdk_python_testing.runner.WebServer" + ) as mock_web_server_class, + patch( + "aws_durable_execution_sdk_python_testing.runner.Scheduler" + ) as mock_scheduler_class, + ): + mock_client = Mock() + mock_server = Mock() + mock_scheduler = Mock() + + mock_boto3_client.return_value = mock_client + mock_web_server_class.return_value = mock_server + mock_scheduler_class.return_value = mock_scheduler + + # Start server + runner.start() + + # Test double start + with pytest.raises( + DurableFunctionsLocalRunnerError, match="Server is already running" + ): + runner.start() + + # Verify serve_forever works after start + runner.serve_forever() + mock_server.serve_forever.assert_called_once() + + # Stop and verify serve_forever fails again + runner.stop() + with pytest.raises( + DurableFunctionsLocalRunnerError, match="Server not started" + ): + runner.serve_forever() diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/scheduler_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/scheduler_test.py new file mode 100644 index 0000000..65db932 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/scheduler_test.py @@ -0,0 +1,729 @@ +"""Unit tests for scheduler.py""" + +import threading +import time +from concurrent.futures import Future +from unittest.mock import patch + +import pytest + +from aws_durable_execution_sdk_python_testing.scheduler import Event, Scheduler + + +def wait_for_condition(condition_func, timeout_iterations=100): + """Wait for a condition to become true with polling.""" + for _ in range(timeout_iterations): + if condition_func(): + return True + time.sleep(0.001) + return False + + +def test_scheduler_init(): + """Test Scheduler initialization.""" + scheduler = Scheduler() + assert not scheduler.is_started() + assert scheduler.event_count() == 0 + + +def test_scheduler_context_manager(): + """Test Scheduler as context manager.""" + with Scheduler() as scheduler: + assert scheduler.is_started() + assert not scheduler.is_started() + + +def test_scheduler_start_stop(): + """Test Scheduler start and stop methods.""" + scheduler = Scheduler() + + scheduler.start() + assert scheduler.is_started() + + # Test start when already running + scheduler.start() + assert scheduler.is_started() + + scheduler.stop() + assert not scheduler.is_started() + + # Test stop when not running + scheduler.stop() + assert not scheduler.is_started() + + +def test_scheduler_is_started(): + """Test Scheduler is_started method.""" + scheduler = Scheduler() + + # Initially not started + assert not scheduler.is_started() + + # After start + scheduler.start() + assert scheduler.is_started() + + # After stop + scheduler.stop() + assert not scheduler.is_started() + + +def test_scheduler_event_count(): + """Test Scheduler event_count method.""" + scheduler = Scheduler() + scheduler.start() + + # Initially no events + assert scheduler.event_count() == 0 + + # Create events + event1 = scheduler.create_event() + assert scheduler.event_count() == 1 + + scheduler.create_event() + assert scheduler.event_count() == 2 + + # Remove event + event1.remove() + wait_for_condition(lambda: scheduler.event_count() == 1) + assert scheduler.event_count() == 1 + + scheduler.stop() + + +def test_scheduler_task_count(): + """Test Scheduler task_count method.""" + scheduler = Scheduler() + + # When not started, task count is 0 + assert scheduler.task_count() == 0 + + scheduler.start() + + # Create tasks with longer delay to ensure they're counted + future1 = scheduler.call_later(lambda: None, delay=0.5) + # Give a moment for the task to be created + time.sleep(0.01) + assert scheduler.task_count() >= 1 + + future2 = scheduler.call_later(lambda: None, delay=0.5) + time.sleep(0.01) + assert scheduler.task_count() >= 2 + + # Cancel tasks to clean up + future1.cancel() + future2.cancel() + + # Wait for tasks to complete or be cancelled + wait_for_condition(lambda: scheduler.task_count() == 0, timeout_iterations=200) + + scheduler.stop() + + +def test_scheduler_call_later_sync_function(): + """Test call_later with sync function.""" + scheduler = Scheduler() + scheduler.start() + + result = [] + + def sync_func(): + result.append("executed") + + future = scheduler.call_later(sync_func, delay=0.01) + wait_for_condition(lambda: future.done()) + + assert isinstance(future, Future) + assert result == ["executed"] + assert future.done() + + scheduler.stop() + + +def test_scheduler_call_later_async_function(): + """Test call_later with async function.""" + scheduler = Scheduler() + scheduler.start() + + result = [] + + async def async_func(): + result.append("async_executed") + + future = scheduler.call_later(async_func, delay=0.01) + wait_for_condition(lambda: future.done()) + + assert isinstance(future, Future) + assert result == ["async_executed"] + assert future.done() + + scheduler.stop() + + +def test_scheduler_call_later_multiple_count(): + """Test call_later with multiple executions.""" + scheduler = Scheduler() + scheduler.start() + + result = [] + + def func(): + result.append("count") + + # Note: Current implementation only executes once due to early return + future = scheduler.call_later(func, delay=0.01, count=3) + wait_for_condition(lambda: future.done()) + + # Current implementation only executes once + assert len(result) == 1 + assert future.done() + + scheduler.stop() + + +def test_scheduler_call_later_infinite_count(): + """Test call_later with infinite count.""" + scheduler = Scheduler() + scheduler.start() + + result = [] + + def func(): + result.append("infinite") + + # Note: Current implementation only executes once due to early return + future = scheduler.call_later(func, delay=0.01, count=None) + wait_for_condition(lambda: future.done()) + + # Current implementation only executes once + assert len(result) == 1 + assert future.done() + + scheduler.stop() + + +def test_scheduler_call_later_function_exception(): + """Test call_later with function that raises exception.""" + scheduler = Scheduler() + scheduler.start() + + def failing_func() -> None: + msg: str = "test error" + + raise ValueError(msg) + + with patch( + "aws_durable_execution_sdk_python_testing.scheduler.logger" + ) as mock_logger: + future = scheduler.call_later(failing_func, delay=0.01) + wait_for_condition(lambda: future.done()) + + assert future.done() + mock_logger.exception.assert_called() + + scheduler.stop() + + +def test_scheduler_create_event(): + """Test create_event method.""" + scheduler = Scheduler() + scheduler.start() + + event = scheduler.create_event() + + assert isinstance(event, Event) + assert scheduler.event_count() == 1 + + scheduler.stop() + + +def test_task_cancel(): + """Test Future cancel method.""" + scheduler = Scheduler() + scheduler.start() + + def func(): + pass + + future = scheduler.call_later(func, delay=0.1, count=None) + future.cancel() + + # Wait briefly for cancellation to take effect + wait_for_condition(lambda: future.cancelled()) + + assert future.cancelled() + + scheduler.stop() + + +def test_task_is_done(): + """Test Future done property.""" + scheduler = Scheduler() + scheduler.start() + + def quick_func(): + pass + + future = scheduler.call_later(quick_func, delay=0.01) + assert not future.done() + + wait_for_condition(lambda: future.done()) + assert future.done() + + # Small delay to ensure coroutine cleanup completes + time.sleep(0.01) + scheduler.stop() + + +def test_task_result(): + """Test Future result method.""" + scheduler = Scheduler() + scheduler.start() + + def func(): + return None + + future = scheduler.call_later(func, delay=0.01) + wait_for_condition(lambda: future.done()) + + result = future.result() + assert result is None + + scheduler.stop() + + +def test_task_cancel_method(): + """Test Future cancel method.""" + scheduler = Scheduler() + scheduler.start() + + # Create a future and cancel it immediately + future = scheduler.call_later(lambda: None, delay=0.01) + future.cancel() + + # The cancel method should work without hanging + # We don't test the result here to avoid timing issues + + scheduler.stop() + + +def test_task_result_completed(): + """Test Future result method when completed.""" + scheduler = Scheduler() + scheduler.start() + + def func(): + return "test_result" + + future = scheduler.call_later(func, delay=0.01) + wait_for_condition(lambda: future.done()) + assert future.done() + + # Small delay to ensure coroutine cleanup completes + time.sleep(0.01) + scheduler.stop() + + +def test_event_set_and_wait_timeout(): + """Test Event set and wait with timeout.""" + scheduler = Scheduler() + scheduler.start() + + event = scheduler.create_event() + + # Test wait with timeout (should timeout) + result = event.wait(timeout=0.01, clear_on_set=False) + assert result is False + + # Set the event + event.set() + + # Wait should now succeed + result = event.wait(timeout=0.1, clear_on_set=True) + assert result is True + + scheduler.stop() + + +def test_event_wait_set_by_thread(): + """Test Event wait when set by another thread.""" + scheduler = Scheduler() + scheduler.start() + + event = scheduler.create_event() + result_container = [] + start_event = threading.Event() + + def set_event(): + start_event.wait() # Wait for signal to start + event.set() + + def wait_for_event(): + result = event.wait(timeout=1.0) + result_container.append(result) + + set_thread = threading.Thread(target=set_event) + wait_thread = threading.Thread(target=wait_for_event) + + set_thread.start() + wait_thread.start() + start_event.set() # Signal to start setting event + + set_thread.join() + wait_thread.join() + + assert result_container[0] is True + + scheduler.stop() + + +def test_event_wait_clear_on_set_false(): + """Test Event wait with clear_on_set=False.""" + scheduler = Scheduler() + scheduler.start() + + event = scheduler.create_event() + event.set() + + result = event.wait(clear_on_set=False) + assert result is True + assert scheduler.event_count() == 1 + + scheduler.stop() + + +def test_event_remove(): + """Test Event remove method.""" + scheduler = Scheduler() + scheduler.start() + + event = scheduler.create_event() + assert scheduler.event_count() == 1 + + event.remove() + wait_for_condition(lambda: scheduler.event_count() == 0) + + assert scheduler.event_count() == 0 + + scheduler.stop() + + +def test_event_wait_removed_event(): + """Test Event wait on removed event.""" + scheduler = Scheduler() + scheduler.start() + + event = scheduler.create_event() + event.remove() + wait_for_condition(lambda: scheduler.event_count() == 0) + + result = event.wait(timeout=0.01) + assert result is False + + scheduler.stop() + + +def test_event_set_removed_event(): + """Test Event set on removed event.""" + scheduler = Scheduler() + scheduler.start() + + event = scheduler.create_event() + event.remove() + wait_for_condition(lambda: scheduler.event_count() == 0) + + # Should not crash + event.set() + + scheduler.stop() + + +def test_scheduler_cleanup_on_stop(): + """Test scheduler cleanup when stopped.""" + scheduler = Scheduler() + scheduler.start() + + # Create a future and event + scheduler.call_later(lambda: None, delay=0.1, count=1) + scheduler.create_event() + + # Stop scheduler immediately + scheduler.stop() + + # Events should be cleared (this is what we can reliably test) + assert scheduler.event_count() == 0 + # Future state may vary due to timing, but scheduler should be stopped + assert not scheduler.is_started() + + +def test_scheduler_multiple_events(): + """Test scheduler with multiple events.""" + scheduler = Scheduler() + scheduler.start() + + event1 = scheduler.create_event() + event2 = scheduler.create_event() + + assert scheduler.event_count() == 2 + + event1.set() + result1 = event1.wait(timeout=0.01) + assert result1 is True + + result2 = event2.wait(timeout=0.01) + assert result2 is False + + scheduler.stop() + + +def test_task_properties_after_scheduler_stop(): + """Test Future properties after scheduler is stopped.""" + scheduler = Scheduler() + scheduler.start() + + def func(): + pass + + future = scheduler.call_later(func, delay=0.01) + wait_for_condition(lambda: future.done()) + + scheduler.stop() + + assert future.done() + assert not future.cancelled() + + +def test_event_timeout_handling(): + """Test Event timeout handling.""" + scheduler = Scheduler() + scheduler.start() + + event = scheduler.create_event() + + start_time = time.time() + result = event.wait(timeout=0.05) + end_time = time.time() + + assert result is False + assert 0.04 <= (end_time - start_time) <= 0.1 + + scheduler.stop() + + +def test_scheduler_call_later_zero_delay(): + """Test call_later with zero delay.""" + scheduler = Scheduler() + scheduler.start() + + result = [] + + def func(): + result.append("zero_delay") + + future = scheduler.call_later(func, delay=0) + wait_for_condition(lambda: future.done()) + + assert result == ["zero_delay"] + assert future.done() + + scheduler.stop() + + +def test_scheduler_call_later_default_parameters(): + """Test call_later with default parameters.""" + scheduler = Scheduler() + scheduler.start() + + result = [] + + def func(): + result.append("default") + + future = scheduler.call_later(func) + wait_for_condition(lambda: future.done()) + + assert result == ["default"] + assert future.done() + + scheduler.stop() + + +def test_task_result_with_exception(): + """Test Future result method when function raises exception.""" + scheduler = Scheduler() + scheduler.start() + + def failing_func() -> None: + msg: str = "test exception" + + raise ValueError(msg) + + # Test that user function exceptions are propagated through the Future + with patch( + "aws_durable_execution_sdk_python_testing.scheduler.logger" + ) as mock_logger: + future = scheduler.call_later(failing_func, delay=0.01) + wait_for_condition(lambda: future.done()) + + # Future should be done and exception should be logged + assert future.done() + mock_logger.exception.assert_called() + + # Exception should be propagated through Future.result() + with pytest.raises(ValueError, match="test exception"): + future.result() + + scheduler.stop() + + +def test_get_task_result_exception_handling(): + """Test Future result exception handling.""" + scheduler = Scheduler() + scheduler.start() + + def func(): + pass + + future = scheduler.call_later(func, delay=0.01) + wait_for_condition(lambda: future.done()) + + # Future result should work normally + result = future.result() + assert result is None + + scheduler.stop() + + +def test_call_later_with_sync_function(): + """Test call_later correctly identifies and runs sync functions.""" + scheduler = Scheduler() + scheduler.start() + + result = [] + + def sync_function(): + result.append("sync_executed") + + future = scheduler.call_later(sync_function, delay=0.01) + wait_for_condition(lambda: future.done()) + + assert result == ["sync_executed"] + assert future.done() + + scheduler.stop() + + +def test_call_later_with_async_function(): + """Test call_later correctly identifies and runs async functions.""" + scheduler = Scheduler() + scheduler.start() + + result = [] + + async def async_function(): + result.append("async_executed") + + future = scheduler.call_later(async_function, delay=0.01) + wait_for_condition(lambda: future.done()) + + assert result == ["async_executed"] + assert future.done() + + scheduler.stop() + + +def test_event_set_exception(): + """Test Event set_exception method.""" + scheduler = Scheduler() + scheduler.start() + + event = scheduler.create_event() + test_exception = ValueError("test exception") + + event.set_exception(test_exception) + + with pytest.raises(ValueError, match="test exception"): + event.wait() + + scheduler.stop() + + +def test_call_later_with_completion_event_exception(): + """Test call_later with completion_event when function raises exception.""" + scheduler = Scheduler() + scheduler.start() + + completion_event = scheduler.create_event() + + def failing_func() -> None: + msg: str = "completion event test" + + raise RuntimeError(msg) + + scheduler.call_later(failing_func, delay=0.01, completion_event=completion_event) + + # Wait for the completion event to be set with exception + with pytest.raises(RuntimeError, match="completion event test"): + completion_event.wait(timeout=1.0) + + scheduler.stop() + + +def test_call_later_multiple_iterations(): + """Test call_later with multiple count iterations.""" + scheduler = Scheduler() + scheduler.start() + + result = [] + + def func(): + result.append("iteration") + # Return early to test the loop behavior + if len(result) >= 2: + return "done" + return + + # Use a very small delay and count=3 to test the loop + future = scheduler.call_later(func, delay=0.001, count=3) + wait_for_condition(lambda: future.done(), timeout_iterations=500) + + # Should execute at least once + assert len(result) >= 1 + assert future.done() + + scheduler.stop() + + +def test_wait_for_event_timeout_exception(): + """Test _wait_for_event with timeout exception handling.""" + scheduler = Scheduler() + scheduler.start() + + event = scheduler.create_event() + + # Test timeout behavior + result = event.wait(timeout=0.001) + assert result is False + + scheduler.stop() + + +def test_call_later_loop_exit_condition(): + """Test call_later loop exit condition with count=0.""" + scheduler = Scheduler() + scheduler.start() + + result = [] + + def func(): + result.append("should_not_execute") + + # Test with count=0 to hit the loop exit condition + future = scheduler.call_later(func, delay=0.01, count=0) + wait_for_condition(lambda: future.done()) + + # Should not execute the function at all + assert len(result) == 0 + assert future.done() + + scheduler.stop() diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/stores/__init__.py b/packages/aws-durable-execution-sdk-python-testing/tests/stores/__init__.py new file mode 100644 index 0000000..dbc7145 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/stores/__init__.py @@ -0,0 +1 @@ +"""Tests for store implementations.""" diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/stores/concurrent_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/stores/concurrent_test.py new file mode 100644 index 0000000..8703d4f --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/stores/concurrent_test.py @@ -0,0 +1,278 @@ +"""Concurrent access tests for execution stores.""" + +import tempfile +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +import pytest + +from aws_durable_execution_sdk_python_testing.execution import Execution +from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput +from aws_durable_execution_sdk_python_testing.stores.filesystem import ( + FileSystemExecutionStore, +) +from aws_durable_execution_sdk_python_testing.stores.memory import ( + InMemoryExecutionStore, +) +from aws_durable_execution_sdk_python_testing.stores.sqlite import SQLiteExecutionStore + + +def test_concurrent_save_load(): + """Test concurrent save and load operations.""" + store = InMemoryExecutionStore() + results = [] + results_lock = threading.Lock() + + def save_execution(i: int): + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name=f"test-{i}", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id=f"inv-{i}", + input=f'{{"test": {i}}}', + ) + execution = Execution.new(input_data) + execution.durable_execution_arn = f"arn-{i}" + store.save(execution) + with results_lock: + results.append(f"saved-{i}") + + def load_execution(i: int): + try: + execution = store.load(f"arn-{i}") + with results_lock: + results.append(f"loaded-{execution.start_input.execution_name}") + except KeyError: + with results_lock: + results.append(f"not-found-{i}") + + with ThreadPoolExecutor(max_workers=10) as executor: + # Submit save operations first + futures = [executor.submit(save_execution, i) for i in range(5)] + # Wait for saves to complete + for future in as_completed(futures): + future.result() + + # Then submit load operations + futures = [] + for i in range(5): + futures.append(executor.submit(load_execution, i)) + # Wait for loads to complete + for future in as_completed(futures): + future.result() + + assert len(results) == 10 + + +def test_concurrent_update_list(): + """Test concurrent update and list operations.""" + store = InMemoryExecutionStore() + results = [] + results_lock = threading.Lock() + + # Pre-populate store + for i in range(3): + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name=f"test-{i}", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id=f"inv-{i}", + input=f'{{"test": {i}}}', + ) + execution = Execution.new(input_data) + execution.durable_execution_arn = f"arn-{i}" + store.save(execution) + + def update_execution(i: int): + execution = store.load(f"arn-{i}") + execution.is_complete = True + store.update(execution) + with results_lock: + results.append(f"updated-{i}") + + def list_executions(): + executions = store.list_all() + with results_lock: + results.append(f"listed-{len(executions)}") + + with ThreadPoolExecutor(max_workers=6) as executor: + # Submit update operations + futures = [executor.submit(update_execution, i) for i in range(3)] + # Submit list operations + futures.extend([executor.submit(list_executions) for _ in range(3)]) + + # Wait for all operations to complete + for future in as_completed(futures): + future.result() + + assert len(results) == 6 + final_list = store.list_all() + assert len(final_list) == 3 + + +@pytest.fixture +def temp_storage_dir(): + """Create a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + +@pytest.fixture +def temp_db_path(): + """Create a temporary database file for testing.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as temp_file: + temp_path = Path(temp_file.name) + yield temp_path + if temp_path.exists(): + temp_path.unlink() + + +def test_concurrent_filesystem_save_load(temp_storage_dir): + """Test concurrent save and load operations with filesystem store.""" + store = FileSystemExecutionStore.create(temp_storage_dir) + results = [] + results_lock = threading.Lock() + + def save_execution(i: int): + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name=f"test-{i}", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id=f"inv-{i}", + input=f'{{"test": {i}}}', + ) + execution = Execution.new(input_data) + execution.durable_execution_arn = f"arn-{i}" + execution.start() + store.save(execution) + with results_lock: + results.append(f"saved-{i}") + + def load_execution(i: int): + try: + execution = store.load(f"arn-{i}") + with results_lock: + results.append(f"loaded-{execution.start_input.execution_name}") + except KeyError: + with results_lock: + results.append(f"not-found-{i}") + + with ThreadPoolExecutor(max_workers=8) as executor: + # Submit save operations first + futures = [executor.submit(save_execution, i) for i in range(4)] + for future in as_completed(futures): + future.result() + + # Then submit load operations + futures = [executor.submit(load_execution, i) for i in range(4)] + for future in as_completed(futures): + future.result() + + assert len(results) == 8 + + +def test_concurrent_sqlite_save_load(temp_db_path): + """Test concurrent save and load operations with SQLite store.""" + store = SQLiteExecutionStore.create_and_initialize(temp_db_path) + results = [] + results_lock = threading.Lock() + + def save_execution(i: int): + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name=f"test-{i}", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id=f"inv-{i}", + input=f'{{"test": {i}}}', + ) + execution = Execution.new(input_data) + execution.durable_execution_arn = f"arn-{i}" + execution.start() + store.save(execution) + with results_lock: + results.append(f"saved-{i}") + + def load_execution(i: int): + try: + execution = store.load(f"arn-{i}") + with results_lock: + results.append(f"loaded-{execution.start_input.execution_name}") + except KeyError: + with results_lock: + results.append(f"not-found-{i}") + + with ThreadPoolExecutor(max_workers=8) as executor: + # Submit save operations first + futures = [executor.submit(save_execution, i) for i in range(4)] + for future in as_completed(futures): + future.result() + + # Then submit load operations + futures = [executor.submit(load_execution, i) for i in range(4)] + for future in as_completed(futures): + future.result() + + assert len(results) == 8 + + +def test_concurrent_query_operations(): + """Test concurrent query operations on memory store.""" + store = InMemoryExecutionStore() + results = [] + results_lock = threading.Lock() + + # Pre-populate store with test data + for i in range(10): + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name=f"function-{i % 3}", # 3 different functions + function_qualifier="$LATEST", + execution_name=f"exec-{i}", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id=f"inv-{i}", + ) + execution = Execution.new(input_data) + execution.start() + # Complete some executions + if i % 4 == 0: + execution.complete_success("success") + store.save(execution) + + def query_store(query_type: str): + if query_type == "function": + executions, next_marker = store.query(function_name="function-1") + elif query_type == "status": + executions, next_marker = store.query(status_filter="SUCCEEDED") + elif query_type == "pagination": + executions, next_marker = store.query(limit=3, offset=2) + else: + executions, next_marker = store.query() + + with results_lock: + results.append(f"{query_type}-{len(executions)}") + + with ThreadPoolExecutor(max_workers=4) as executor: + futures = [ + executor.submit(query_store, "function"), + executor.submit(query_store, "status"), + executor.submit(query_store, "pagination"), + executor.submit(query_store, "all"), + ] + for future in as_completed(futures): + future.result() + + assert len(results) == 4 diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/stores/filesystem_store_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/stores/filesystem_store_test.py new file mode 100644 index 0000000..01da777 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/stores/filesystem_store_test.py @@ -0,0 +1,420 @@ +"""Tests for FileSystemExecutionStore.""" + +import tempfile +from pathlib import Path + +import pytest + +from aws_durable_execution_sdk_python_testing.exceptions import ( + ResourceNotFoundException, +) +from aws_durable_execution_sdk_python_testing.execution import Execution +from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput +from aws_durable_execution_sdk_python_testing.stores.filesystem import ( + FileSystemExecutionStore, +) + +from datetime import datetime, timezone + + +@pytest.fixture +def temp_storage_dir(): + """Create a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + +@pytest.fixture +def store(temp_storage_dir): + """Create a FileSystemExecutionStore with temporary storage.""" + return FileSystemExecutionStore.create(temp_storage_dir) + + +@pytest.fixture +def sample_execution(): + """Create a sample execution for testing.""" + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + return Execution.new(input_data) + + +def test_filesystem_execution_store_save_and_load(store, sample_execution): + """Test saving and loading an execution.""" + store.save(sample_execution) + loaded_execution = store.load(sample_execution.durable_execution_arn) + + assert ( + loaded_execution.durable_execution_arn == sample_execution.durable_execution_arn + ) + assert ( + loaded_execution.start_input.function_name + == sample_execution.start_input.function_name + ) + assert ( + loaded_execution.start_input.execution_name + == sample_execution.start_input.execution_name + ) + assert loaded_execution.token_sequence == sample_execution.token_sequence + assert loaded_execution.is_complete == sample_execution.is_complete + + +def test_filesystem_execution_store_load_nonexistent(store): + """Test loading a nonexistent execution raises ResourceNotFoundException.""" + with pytest.raises( + ResourceNotFoundException, match="Execution nonexistent-arn not found" + ): + store.load("nonexistent-arn") + + +def test_filesystem_execution_store_update(store, sample_execution): + """Test updating an execution.""" + store.save(sample_execution) + + sample_execution.is_complete = True + for _ in range(5): + sample_execution.get_new_checkpoint_token() + store.update(sample_execution) + + loaded_execution = store.load(sample_execution.durable_execution_arn) + assert loaded_execution.is_complete is True + assert loaded_execution.token_sequence == 5 + + +def test_filesystem_execution_store_update_overwrites(store, temp_storage_dir): + """Test that update overwrites existing execution.""" + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ) + execution1 = Execution.new(input_data) + execution2 = Execution.new(input_data) + execution2.durable_execution_arn = execution1.durable_execution_arn + for _ in range(10): + execution2.get_new_checkpoint_token() + + store.save(execution1) + store.update(execution2) + + loaded_execution = store.load(execution1.durable_execution_arn) + assert loaded_execution.token_sequence == 10 + + +def test_filesystem_execution_store_multiple_executions(store): + """Test storing multiple executions.""" + input_data1 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function-1", + function_qualifier="$LATEST", + execution_name="test-execution-1", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ) + input_data2 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function-2", + function_qualifier="$LATEST", + execution_name="test-execution-2", + execution_timeout_seconds=600, + execution_retention_period_days=14, + ) + + execution1 = Execution.new(input_data1) + execution2 = Execution.new(input_data2) + + store.save(execution1) + store.save(execution2) + + loaded_execution1 = store.load(execution1.durable_execution_arn) + loaded_execution2 = store.load(execution2.durable_execution_arn) + + assert loaded_execution1.durable_execution_arn == execution1.durable_execution_arn + assert loaded_execution2.durable_execution_arn == execution2.durable_execution_arn + assert loaded_execution1.start_input.function_name == "test-function-1" + assert loaded_execution2.start_input.function_name == "test-function-2" + + +def test_filesystem_execution_store_list_all_empty(store): + """Test list_all method with empty store.""" + result = store.list_all() + assert result == [] + + +def test_filesystem_execution_store_list_all_with_executions(store): + """Test list_all method with multiple executions.""" + # Create test executions + input_data1 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function-1", + function_qualifier="$LATEST", + execution_name="test-execution-1", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ) + input_data2 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function-2", + function_qualifier="$LATEST", + execution_name="test-execution-2", + execution_timeout_seconds=600, + execution_retention_period_days=14, + ) + input_data3 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function-3", + function_qualifier="$LATEST", + execution_name="test-execution-3", + execution_timeout_seconds=900, + execution_retention_period_days=21, + ) + + execution1 = Execution.new(input_data1) + execution2 = Execution.new(input_data2) + execution3 = Execution.new(input_data3) + + # Save executions + store.save(execution1) + store.save(execution2) + store.save(execution3) + + # Test list_all + result = store.list_all() + + assert len(result) == 3 + arns = {execution.durable_execution_arn for execution in result} + assert execution1.durable_execution_arn in arns + assert execution2.durable_execution_arn in arns + assert execution3.durable_execution_arn in arns + + +def test_filesystem_execution_store_file_path_generation( + store, sample_execution, temp_storage_dir +): + """Test that file paths are generated correctly with safe filenames.""" + arn_with_colons = "arn:aws:lambda:us-east-1:123456789012:durable-execution:test" + expected_filename = ( + "arn_aws_lambda_us-east-1_123456789012_durable-execution_test.json" + ) + + # Test by saving and checking the file exists with expected name + sample_execution.durable_execution_arn = arn_with_colons + store.save(sample_execution) + + expected_file = temp_storage_dir / expected_filename + assert expected_file.exists() + + +def test_filesystem_execution_store_corrupted_file_handling(store, temp_storage_dir): + """Test that corrupted files are skipped during list_all.""" + # Create a valid execution + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ) + execution = Execution.new(input_data) + store.save(execution) + + # Create a corrupted file + corrupted_file = temp_storage_dir / "corrupted.json" + with open(corrupted_file, "w") as f: + f.write("invalid json content") + + # list_all should skip the corrupted file and return only valid executions + result = store.list_all() + assert len(result) == 1 + assert result[0].durable_execution_arn == execution.durable_execution_arn + + +def test_filesystem_execution_store_custom_storage_dir(): + """Test creating store with custom storage directory.""" + with tempfile.TemporaryDirectory() as temp_dir: + custom_dir = Path(temp_dir) / "custom_storage" + FileSystemExecutionStore.create(custom_dir) + + # Directory should be created + assert custom_dir.exists() + assert custom_dir.is_dir() + + +def test_filesystem_execution_store_init_no_side_effects(): + """Test that __init__ doesn't create directories (no side effects).""" + with tempfile.TemporaryDirectory() as temp_dir: + nonexistent_dir = Path(temp_dir) / "nonexistent" + + # __init__ should not create the directory + FileSystemExecutionStore(nonexistent_dir) + assert not nonexistent_dir.exists() + + +def test_filesystem_execution_store_thread_safety_basic(store, sample_execution): + """Basic test that operations work without locking (atomic file operations).""" + # Test that basic operations work - atomic file operations provide thread safety + store.save(sample_execution) + loaded = store.load(sample_execution.durable_execution_arn) + assert loaded.durable_execution_arn == sample_execution.durable_execution_arn + + +def test_filesystem_execution_store_query_empty(store): + """Test query method with empty store.""" + executions, next_marker = store.query() + + assert executions == [] + assert next_marker is None + + +def test_filesystem_execution_store_query_by_function_name(store): + """Test query filtering by function name.""" + # Create executions with different function names + input1 = StartDurableExecutionInput( + account_id="123456789012", + function_name="function-a", + function_qualifier="$LATEST", + execution_name="exec-1", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-1", + ) + input2 = StartDurableExecutionInput( + account_id="123456789012", + function_name="function-b", + function_qualifier="$LATEST", + execution_name="exec-2", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-2", + ) + + exec1 = Execution.new(input1) + exec1.start() + exec2 = Execution.new(input2) + exec2.start() + store.save(exec1) + store.save(exec2) + + # Query for function-a only + executions, next_marker = store.query(function_name="function-a") + + assert len(executions) == 1 + assert executions[0].durable_execution_arn == exec1.durable_execution_arn + assert next_marker is None + + +def test_filesystem_execution_store_query_by_status(store): + """Test query filtering by status.""" + # Create running execution + input1 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="running-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-1", + ) + exec1 = Execution.new(input1) + exec1.start() + + # Create completed execution + input2 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="completed-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-2", + ) + exec2 = Execution.new(input2) + exec2.start() + exec2.complete_success("success result") + + store.save(exec1) + store.save(exec2) + + # Query for running executions + executions, next_marker = store.query(status_filter="RUNNING") + + assert len(executions) == 1 + assert executions[0].durable_execution_arn == exec1.durable_execution_arn + + # Query for succeeded executions + executions, next_marker = store.query(status_filter="SUCCEEDED") + + assert len(executions) == 1 + assert executions[0].durable_execution_arn == exec2.durable_execution_arn + + +def test_filesystem_execution_store_query_pagination(store): + """Test query pagination.""" + # Create multiple executions + executions = [] + for i in range(5): + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name=f"exec-{i}", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id=f"invocation-{i}", + ) + exec_obj = Execution.new(input_data) + exec_obj.start() + executions.append(exec_obj) + store.save(exec_obj) + + # Test first page + executions, next_marker = store.query(limit=2, offset=0) + + assert len(executions) == 2 + assert next_marker is not None + + # Test last page + executions, next_marker = store.query(limit=2, offset=4) + + assert len(executions) == 1 + assert next_marker is None + + +def test_filesystem_execution_store_query_corrupted_file_handling( + store, temp_storage_dir +): + """Test that corrupted files are skipped during query.""" + # Create a valid execution + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution.new(input_data) + execution.start() + store.save(execution) + + # Create a corrupted file + corrupted_file = temp_storage_dir / "corrupted.json" + with open(corrupted_file, "w") as f: + f.write("invalid json content") + + # Query should skip the corrupted file and return only valid executions + executions, next_marker = store.query() + + assert len(executions) == 1 + assert executions[0].durable_execution_arn == execution.durable_execution_arn diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/stores/memory_store_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/stores/memory_store_test.py new file mode 100644 index 0000000..b4d5b3e --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/stores/memory_store_test.py @@ -0,0 +1,494 @@ +"""Tests for InMemoryExecutionStore.""" + +from datetime import UTC +from unittest.mock import Mock + +import pytest + +from aws_durable_execution_sdk_python_testing.execution import Execution +from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput +from aws_durable_execution_sdk_python_testing.stores.memory import ( + InMemoryExecutionStore, +) + + +def test_in_memory_execution_store_save_and_load(): + """Test saving and loading an execution.""" + store = InMemoryExecutionStore() + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution.new(input_data) + + store.save(execution) + loaded_execution = store.load(execution.durable_execution_arn) + + assert loaded_execution is execution + + +def test_in_memory_execution_store_load_nonexistent(): + """Test loading a nonexistent execution raises KeyError.""" + store = InMemoryExecutionStore() + + with pytest.raises(KeyError): + store.load("nonexistent-arn") + + +def test_in_memory_execution_store_update(): + """Test updating an execution.""" + store = InMemoryExecutionStore() + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ) + execution = Execution.new(input_data) + store.save(execution) + + execution.is_complete = True + store.update(execution) + + loaded_execution = store.load(execution.durable_execution_arn) + assert loaded_execution.is_complete is True + + +def test_in_memory_execution_store_update_overwrites(): + """Test that update overwrites existing execution.""" + store = InMemoryExecutionStore() + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ) + execution1 = Execution.new(input_data) + execution2 = Execution.new(input_data) + execution2.durable_execution_arn = execution1.durable_execution_arn + + store.save(execution1) + store.update(execution2) + + loaded_execution = store.load(execution1.durable_execution_arn) + assert loaded_execution is execution2 + + +def test_in_memory_execution_store_multiple_executions(): + """Test storing multiple executions.""" + store = InMemoryExecutionStore() + input_data1 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function-1", + function_qualifier="$LATEST", + execution_name="test-execution-1", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ) + input_data2 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function-2", + function_qualifier="$LATEST", + execution_name="test-execution-2", + execution_timeout_seconds=600, + execution_retention_period_days=14, + ) + + execution1 = Execution.new(input_data1) + execution2 = Execution.new(input_data2) + + store.save(execution1) + store.save(execution2) + + loaded_execution1 = store.load(execution1.durable_execution_arn) + loaded_execution2 = store.load(execution2.durable_execution_arn) + + assert loaded_execution1 is execution1 + assert loaded_execution2 is execution2 + + +def test_in_memory_execution_store_list_all_empty(): + """Test list_all method with empty store.""" + store = InMemoryExecutionStore() + + result = store.list_all() + + assert result == [] + + +def test_in_memory_execution_store_list_all_with_executions(): + """Test list_all method with multiple executions.""" + store = InMemoryExecutionStore() + + # Create test executions + execution1 = Mock() + execution1.durable_execution_arn = "arn1" + execution2 = Mock() + execution2.durable_execution_arn = "arn2" + execution3 = Mock() + execution3.durable_execution_arn = "arn3" + + # Save executions + store.save(execution1) + store.save(execution2) + store.save(execution3) + + # Test list_all + result = store.list_all() + + assert len(result) == 3 + assert execution1 in result + assert execution2 in result + assert execution3 in result + + +def test_in_memory_execution_store_query_empty(): + """Test query method with empty store.""" + store = InMemoryExecutionStore() + + executions, next_marker = store.query() + + assert executions == [] + assert next_marker is None + + +def test_in_memory_execution_store_query_by_function_name(): + """Test query filtering by function name.""" + store = InMemoryExecutionStore() + + # Create executions with different function names + input1 = StartDurableExecutionInput( + account_id="123456789012", + function_name="function-a", + function_qualifier="$LATEST", + execution_name="exec-1", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-1", + ) + input2 = StartDurableExecutionInput( + account_id="123456789012", + function_name="function-b", + function_qualifier="$LATEST", + execution_name="exec-2", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-2", + ) + + exec1 = Execution.new(input1) + exec1.start() + exec2 = Execution.new(input2) + exec2.start() + store.save(exec1) + store.save(exec2) + + # Query for function-a only + executions, next_marker = store.query(function_name="function-a") + + assert len(executions) == 1 + assert executions[0] is exec1 + assert next_marker is None + + +def test_in_memory_execution_store_query_by_execution_name(): + """Test query filtering by execution name.""" + store = InMemoryExecutionStore() + + input1 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="exec-alpha", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-1", + ) + input2 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="exec-beta", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-2", + ) + + exec1 = Execution.new(input1) + exec1.start() + exec2 = Execution.new(input2) + exec2.start() + store.save(exec1) + store.save(exec2) + + executions, next_marker = store.query(execution_name="exec-beta") + + assert len(executions) == 1 + assert executions[0] is exec2 + + +def test_in_memory_execution_store_query_by_status(): + """Test query filtering by status.""" + store = InMemoryExecutionStore() + + # Create running execution + input1 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="running-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-1", + ) + exec1 = Execution.new(input1) + exec1.start() + + # Create completed execution + input2 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="completed-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-2", + ) + exec2 = Execution.new(input2) + exec2.start() + exec2.complete_success("success result") + + store.save(exec1) + store.save(exec2) + + # Query for running executions + executions, next_marker = store.query(status_filter="RUNNING") + + assert len(executions) == 1 + assert executions[0] is exec1 + + # Query for succeeded executions + executions, next_marker = store.query(status_filter="SUCCEEDED") + + assert len(executions) == 1 + assert executions[0] is exec2 + + +def test_in_memory_execution_store_query_pagination(): + """Test query pagination.""" + store = InMemoryExecutionStore() + + # Create multiple executions + executions = [] + for i in range(5): + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name=f"exec-{i}", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id=f"invocation-{i}", + ) + exec_obj = Execution.new(input_data) + exec_obj.start() + executions.append(exec_obj) + store.save(exec_obj) + + # Test first page + executions, next_marker = store.query(limit=2, offset=0) + + assert len(executions) == 2 + assert next_marker is not None + + # Test second page + executions, next_marker = store.query(limit=2, offset=2) + + assert len(executions) == 2 + assert next_marker is not None + + # Test last page + executions, next_marker = store.query(limit=2, offset=4) + + assert len(executions) == 1 + assert next_marker is None + + +def test_in_memory_execution_store_query_sorting(): + """Test query sorting by timestamp.""" + store = InMemoryExecutionStore() + + # Create executions - they will be sorted by creation order + executions = [] + for i in range(3): + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name=f"exec-{i}", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id=f"invocation-{i}", + ) + exec_obj = Execution.new(input_data) + exec_obj.start() + executions.append(exec_obj) + store.save(exec_obj) + + # Test ascending order (default) + executions, next_marker = store.query(reverse_order=False) + + assert len(executions) == 3 + + # Test descending order + executions, next_marker = store.query(reverse_order=True) + + assert len(executions) == 3 + + +def test_in_memory_execution_store_query_combined_filters(): + """Test query with multiple filters combined.""" + store = InMemoryExecutionStore() + + # Create various executions + inputs = [ + StartDurableExecutionInput( + account_id="123456789012", + function_name="function-a", + function_qualifier="$LATEST", + execution_name="target-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-1", + ), + StartDurableExecutionInput( + account_id="123456789012", + function_name="function-b", + function_qualifier="$LATEST", + execution_name="target-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-2", + ), + StartDurableExecutionInput( + account_id="123456789012", + function_name="function-a", + function_qualifier="$LATEST", + execution_name="other-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-3", + ), + ] + + executions = [] + for input_data in inputs: + exec_obj = Execution.new(input_data) + exec_obj.start() + executions.append(exec_obj) + store.save(exec_obj) + + # Query with both function_name and execution_name filters + filtered_executions, next_marker = store.query( + function_name="function-a", execution_name="target-exec" + ) + + assert len(filtered_executions) == 1 + assert filtered_executions[0] is executions[0] + + +def test_time_filtering_logic(): + """Test time filtering logic in process_query method.""" + from datetime import datetime + from unittest.mock import Mock + + store = InMemoryExecutionStore() + + # Create mock executions with different timestamps + exec1 = Mock() + exec1.start_input.function_name = "test-function" + exec1.start_input.execution_name = "exec1" + exec1.status = "RUNNING" + + exec2 = Mock() + exec2.start_input.function_name = "test-function" + exec2.start_input.execution_name = "exec2" + exec2.status = "RUNNING" + + exec3 = Mock() + exec3.start_input.function_name = "test-function" + exec3.start_input.execution_name = "exec3" + exec3.status = "RUNNING" + + # Use real datetime objects for timestamps + op1 = Mock() + op1.start_timestamp = datetime(2023, 1, 1, 12, 0, 0, tzinfo=UTC) + + op2 = Mock() + op2.start_timestamp = datetime(2023, 1, 2, 12, 0, 0, tzinfo=UTC) + + op3 = Mock() + op3.start_timestamp = datetime(2023, 1, 3, 12, 0, 0) # noqa: DTZ001 + + exec1.get_operation_execution_started.return_value = op1 + exec2.get_operation_execution_started.return_value = op2 + exec3.get_operation_execution_started.return_value = op3 + + executions = [exec1, exec2, exec3] + + # Test time_after filtering + filtered, _ = store.process_query( + executions, + started_after="1672617600.0", # 2023-01-01 24:00:00 UTC (between exec1 and exec2) + ) + assert len(filtered) == 2 + assert exec2 in filtered + assert exec3 in filtered + assert exec1 not in filtered + + # Test time_before filtering + filtered, _ = store.process_query( + executions, + started_before="1672617600.0", # 2023-01-01 24:00:00 UTC + ) + assert len(filtered) == 1 + assert exec1 in filtered + assert exec2 not in filtered + assert exec3 not in filtered + + # Test both time_after and time_before + filtered, _ = store.process_query( + executions, + started_after="1672617600.0", # 2023-01-02 00:00:00 UTC (between exec1 and exec2) + started_before="1672704000.0", # 2023-01-03 00:00:00 UTC (between exec2 and exec3) + ) + assert len(filtered) == 1 + assert exec2 in filtered + + # Test exception handling - exec with AttributeError + exec_error = Mock() + exec_error.start_input.function_name = "test-function" + exec_error.start_input.execution_name = "exec_error" + exec_error.status = "RUNNING" + exec_error.get_operation_execution_started.side_effect = AttributeError( + "No operation" + ) + + executions_with_error = [exec1, exec_error, exec2] + filtered, _ = store.process_query( + executions_with_error, + started_after="1672617600.0", # After exec1, before exec2 + ) + # exec_error should be filtered out due to exception, only exec2 should remain + assert len(filtered) == 1 + assert exec2 in filtered + assert exec_error not in filtered diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/stores/sqlite_store_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/stores/sqlite_store_test.py new file mode 100644 index 0000000..7c7feb4 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/stores/sqlite_store_test.py @@ -0,0 +1,860 @@ +"""Tests for SQLiteExecutionStore.""" + +import tempfile +import time +from datetime import datetime, UTC +from pathlib import Path + +import pytest + +from aws_durable_execution_sdk_python_testing.exceptions import ( + ResourceNotFoundException, + InvalidParameterValueException, +) +from aws_durable_execution_sdk_python_testing.execution import ( + ExecutionStatus, + Execution, +) +from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput +from aws_durable_execution_sdk_python_testing.stores.sqlite import SQLiteExecutionStore + + +@pytest.fixture +def temp_db_path(): + """Create a temporary database file for testing.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as temp_file: + temp_path = Path(temp_file.name) + yield temp_path + # Cleanup + if temp_path.exists(): + temp_path.unlink() + + +@pytest.fixture +def store(temp_db_path): + """Create a SQLiteExecutionStore with temporary database.""" + return SQLiteExecutionStore.create_and_initialize(temp_db_path) + + +@pytest.fixture +def sample_execution(): + """Create a sample execution for testing.""" + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + return Execution.new(input_data) + + +def test_sqlite_execution_store_save_and_load(store, sample_execution): + """Test saving and loading an execution.""" + sample_execution.start() + store.save(sample_execution) + loaded_execution = store.load(sample_execution.durable_execution_arn) + + assert ( + loaded_execution.durable_execution_arn == sample_execution.durable_execution_arn + ) + assert ( + loaded_execution.start_input.function_name + == sample_execution.start_input.function_name + ) + assert ( + loaded_execution.start_input.execution_name + == sample_execution.start_input.execution_name + ) + assert loaded_execution.token_sequence == sample_execution.token_sequence + assert loaded_execution.is_complete == sample_execution.is_complete + + +def test_sqlite_execution_store_load_nonexistent(store): + """Test loading a nonexistent execution raises KeyError.""" + with pytest.raises( + ResourceNotFoundException, match="Execution nonexistent-arn not found" + ): + store.load("nonexistent-arn") + + +def test_sqlite_execution_store_update(store, sample_execution): + """Test updating an execution.""" + sample_execution.start() + store.save(sample_execution) + + sample_execution.is_complete = True + sample_execution.close_status = ExecutionStatus.SUCCEEDED + for _ in range(5): + sample_execution.get_new_checkpoint_token() + store.update(sample_execution) + + loaded_execution = store.load(sample_execution.durable_execution_arn) + assert loaded_execution.is_complete is True + assert loaded_execution.token_sequence == 5 + + +def test_sqlite_execution_store_update_overwrites(store): + """Test that update overwrites existing execution.""" + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution1 = Execution.new(input_data) + execution1.start() + execution2 = Execution.new(input_data) + execution2.start() + execution2.durable_execution_arn = execution1.durable_execution_arn + for _ in range(10): + execution2.get_new_checkpoint_token() + + store.save(execution1) + store.update(execution2) + + loaded_execution = store.load(execution1.durable_execution_arn) + assert loaded_execution.token_sequence == 10 + + +def test_sqlite_execution_store_multiple_executions(store): + """Test storing multiple executions.""" + input_data1 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function-1", + function_qualifier="$LATEST", + execution_name="test-execution-1", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id-1", + ) + input_data2 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function-2", + function_qualifier="$LATEST", + execution_name="test-execution-2", + execution_timeout_seconds=600, + execution_retention_period_days=14, + invocation_id="test-invocation-id-2", + ) + + execution1 = Execution.new(input_data1) + execution1.start() + execution2 = Execution.new(input_data2) + execution2.start() + + store.save(execution1) + store.save(execution2) + + loaded_execution1 = store.load(execution1.durable_execution_arn) + loaded_execution2 = store.load(execution2.durable_execution_arn) + + assert loaded_execution1.durable_execution_arn == execution1.durable_execution_arn + assert loaded_execution2.durable_execution_arn == execution2.durable_execution_arn + assert loaded_execution1.start_input.function_name == "test-function-1" + assert loaded_execution2.start_input.function_name == "test-function-2" + + +def test_sqlite_execution_store_list_all_empty(store): + """Test list_all method with empty store.""" + result = store.list_all() + assert result == [] + + +def test_sqlite_execution_store_list_all_with_executions(store): + """Test list_all method with multiple executions.""" + # Create test executions + executions = [] + for i in range(3): + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name=f"test-function-{i}", + function_qualifier="$LATEST", + execution_name=f"test-execution-{i}", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id=f"test-invocation-id-{i}", + ) + execution = Execution.new(input_data) + execution.start() + executions.append(execution) + store.save(execution) + + # Test list_all + result = store.list_all() + + assert len(result) == 3 + arns = {execution.durable_execution_arn for execution in result} + for execution in executions: + assert execution.durable_execution_arn in arns + + +def test_sqlite_execution_store_query_empty(store): + """Test query method with empty store.""" + executions, next_marker = store.query() + + assert executions == [] + assert next_marker is None + + +def test_sqlite_execution_store_query_by_function_name(store): + """Test query filtering by function name.""" + # Create executions with different function names + input1 = StartDurableExecutionInput( + account_id="123456789012", + function_name="function-a", + function_qualifier="$LATEST", + execution_name="exec-1", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-1", + ) + input2 = StartDurableExecutionInput( + account_id="123456789012", + function_name="function-b", + function_qualifier="$LATEST", + execution_name="exec-2", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-2", + ) + + exec1 = Execution.new(input1) + exec1.start() + exec2 = Execution.new(input2) + exec2.start() + store.save(exec1) + store.save(exec2) + + # Query for function-a only + executions, next_marker = store.query(function_name="function-a") + + assert len(executions) == 1 + assert executions[0].durable_execution_arn == exec1.durable_execution_arn + assert next_marker is None + + +def test_sqlite_execution_store_query_by_execution_name(store): + """Test query filtering by execution name.""" + input1 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="exec-alpha", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-1", + ) + input2 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="exec-beta", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-2", + ) + + exec1 = Execution.new(input1) + exec1.start() + exec2 = Execution.new(input2) + exec2.start() + store.save(exec1) + store.save(exec2) + + executions, next_marker = store.query(execution_name="exec-beta") + + assert len(executions) == 1 + assert executions[0].durable_execution_arn == exec2.durable_execution_arn + + +def test_sqlite_execution_store_query_by_status(store): + """Test query filtering by status.""" + # Create running execution + input1 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="running-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-1", + ) + exec1 = Execution.new(input1) + exec1.start() + + # Create completed execution + input2 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="completed-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-2", + ) + exec2 = Execution.new(input2) + exec2.start() + exec2.complete_success("success result") + + store.save(exec1) + store.save(exec2) + + # Query for running executions + executions, next_marker = store.query(status_filter="RUNNING") + + assert len(executions) == 1 + assert executions[0].durable_execution_arn == exec1.durable_execution_arn + + # Query for succeeded executions + executions, next_marker = store.query(status_filter="SUCCEEDED") + + assert len(executions) == 1 + assert executions[0].durable_execution_arn == exec2.durable_execution_arn + + +def test_sqlite_execution_store_query_pagination(store): + """Test query pagination.""" + # Create multiple executions + executions = [] + for i in range(5): + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name=f"exec-{i}", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id=f"invocation-{i}", + ) + exec_obj = Execution.new(input_data) + exec_obj.start() + executions.append(exec_obj) + store.save(exec_obj) + + # Test first page + executions, next_marker = store.query(limit=2, offset=0) + + assert len(executions) == 2 + assert next_marker is not None + + # Test second page + executions, next_marker = store.query(limit=2, offset=2) + + assert len(executions) == 2 + assert next_marker is not None + + # Test last page + executions, next_marker = store.query(limit=2, offset=4) + + assert len(executions) == 1 + assert next_marker is None + + +def test_sqlite_execution_store_query_sorting(store): + """Test query sorting by timestamp.""" + # Create executions - they will be sorted by creation order + executions = [] + for i in range(3): + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name=f"exec-{i}", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id=f"invocation-{i}", + ) + exec_obj = Execution.new(input_data) + exec_obj.start() + executions.append(exec_obj) + store.save(exec_obj) + + # Test ascending order (default) + executions, next_marker = store.query(reverse_order=False) + + assert len(executions) == 3 + + # Test descending order + executions, next_marker = store.query(reverse_order=True) + + assert len(executions) == 3 + + +def test_sqlite_execution_store_query_combined_filters(store): + """Test query with multiple filters combined.""" + # Create various executions + inputs = [ + StartDurableExecutionInput( + account_id="123456789012", + function_name="function-a", + function_qualifier="$LATEST", + execution_name="target-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-1", + ), + StartDurableExecutionInput( + account_id="123456789012", + function_name="function-b", + function_qualifier="$LATEST", + execution_name="target-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-2", + ), + StartDurableExecutionInput( + account_id="123456789012", + function_name="function-a", + function_qualifier="$LATEST", + execution_name="other-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-3", + ), + ] + + executions = [] + for input_data in inputs: + exec_obj = Execution.new(input_data) + exec_obj.start() + executions.append(exec_obj) + store.save(exec_obj) + + # Query with both function_name and execution_name filters + filtered_executions, next_marker = store.query( + function_name="function-a", execution_name="target-exec" + ) + + assert len(filtered_executions) == 1 + assert ( + filtered_executions[0].durable_execution_arn + == executions[0].durable_execution_arn + ) + + +def test_sqlite_execution_store_database_initialization(temp_db_path): + """Test that database is properly initialized with schema.""" + store = SQLiteExecutionStore.create_and_initialize(temp_db_path) + + # Verify database file exists + assert temp_db_path.exists() + + # Verify we can perform basic operations + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution.new(input_data) + execution.start() + + store.save(execution) + loaded = store.load(execution.durable_execution_arn) + assert loaded.durable_execution_arn == execution.durable_execution_arn + + +def test_sqlite_execution_store_custom_db_path(): + """Test creating store with custom database path.""" + with tempfile.TemporaryDirectory() as temp_dir: + custom_path = Path(temp_dir) / "custom" / "executions.db" + store = SQLiteExecutionStore.create_and_initialize(custom_path) + + # Directory should be created + assert custom_path.parent.exists() + assert custom_path.exists() + + # Verify functionality + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution.new(input_data) + execution.start() + + store.save(execution) + loaded = store.load(execution.durable_execution_arn) + assert loaded.durable_execution_arn == execution.durable_execution_arn + + +def test_sqlite_execution_store_failed_execution_status(store): + """Test that failed executions are properly stored and queried.""" + from aws_durable_execution_sdk_python.lambda_service import ErrorObject + + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="failed-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution.new(input_data) + execution.start() + + # Complete with failure + error = ErrorObject( + type="TestError", message="Test failure", data=None, stack_trace=None + ) + execution.complete_fail(error) + + store.save(execution) + + # Query for failed executions + executions, next_marker = store.query(status_filter="FAILED") + + assert len(executions) == 1 + assert executions[0].durable_execution_arn == execution.durable_execution_arn + assert executions[0].is_complete is True + + +def test_sqlite_execution_store_error_handling(temp_db_path): + """Test error handling for database operations.""" + store = SQLiteExecutionStore.create_and_initialize(temp_db_path) + + # Test with corrupted database by removing the file after creation + temp_db_path.unlink() + + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution.new(input_data) + execution.start() + + # Should raise RuntimeError for database operations + with pytest.raises(RuntimeError, match="Failed to save execution"): + store.save(execution) + + +def test_sqlite_execution_store_invalid_execution_data(store): + """Test handling of invalid execution data.""" + # Create execution and start it + execution = Execution.new( + StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + ) + execution.start() + + # Corrupt the execution object to trigger serialization error + execution.start_input = None + + with pytest.raises(ValueError, match="Invalid execution data"): + store.save(execution) + + +def test_sqlite_execution_store_sql_injection_protection(store): + """Test SQL injection protection in query parameters.""" + # Create test execution + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution.new(input_data) + execution.start() + store.save(execution) + + # Try SQL injection attempts - should be safely parameterized + malicious_inputs = [ + "'; DROP TABLE executions; --", + "test' OR '1'='1", + "test'; DELETE FROM executions; --", + "test' UNION SELECT * FROM executions --", + ] + + for malicious_input in malicious_inputs: + # These should return empty results, not cause SQL errors + executions, _ = store.query(function_name=malicious_input) + assert executions == [] + + executions, _ = store.query(execution_name=malicious_input) + assert executions == [] + + executions, _ = store.query(status_filter=malicious_input) + assert executions == [] + + +def test_sqlite_execution_store_time_filtering(store): + """Test time-based filtering with edge cases.""" + + # Create executions at different times + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + + execution1 = Execution.new(input_data) + execution1.start() + store.save(execution1) + + # Small delay to ensure different timestamps + time.sleep(0.01) + + execution2 = Execution.new(input_data) + execution2.start() + store.save(execution2) + + # Get timestamps as ISO strings + start_time_iso = ( + execution1.get_operation_execution_started().start_timestamp.isoformat() + ) + mid_time = ( + execution1.get_operation_execution_started().start_timestamp.timestamp() + 0.005 + ) + mid_time_iso = datetime.fromtimestamp(mid_time, tz=UTC).isoformat() + end_time_iso = datetime.fromtimestamp( + execution2.get_operation_execution_started().start_timestamp.timestamp() + 1, + tz=UTC, + ).isoformat() + + # Test started_after filter + executions, _ = store.query(started_after=mid_time_iso) + assert len(executions) == 1 + + # Test started_before filter + executions, _ = store.query(started_before=mid_time_iso) + assert len(executions) == 1 + + # Test both filters + executions, _ = store.query( + started_after=start_time_iso, started_before=end_time_iso + ) + assert len(executions) == 2 + + +def test_sqlite_execution_store_corrupted_data_handling(store, temp_db_path): + """Test handling of corrupted JSON data in database.""" + import sqlite3 + + # Insert corrupted JSON data directly + with sqlite3.connect(temp_db_path) as conn: + conn.execute( + """ + INSERT INTO executions + (durable_execution_arn, function_name, execution_name, status, start_timestamp, end_timestamp, data) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + "corrupted-arn", + "test-function", + "test-execution", + "RUNNING", + 1234567890.0, + None, + "invalid json data {{{", + ), + ) + + # Loading corrupted data should raise ValueError + with pytest.raises(ValueError, match="Corrupted execution data"): + store.load("corrupted-arn") + + # Query should skip corrupted records and continue + executions, _ = store.query() + # Should not include the corrupted record + assert all(exec.durable_execution_arn != "corrupted-arn" for exec in executions) + + +def test_sqlite_execution_store_get_execution_metadata(store): + """Test get_execution_metadata method.""" + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution.new(input_data) + execution.start() + store.save(execution) + + # Test existing execution + metadata = store.get_execution_metadata(execution.durable_execution_arn) + assert metadata is not None + assert metadata["durable_execution_arn"] == execution.durable_execution_arn + assert metadata["function_name"] == "test-function" + assert metadata["execution_name"] == "test-execution" + assert metadata["status"] == "RUNNING" + assert metadata["start_timestamp"] is not None + + # Test nonexistent execution + metadata = store.get_execution_metadata("nonexistent-arn") + assert metadata is None + + +def test_sqlite_execution_store_database_init_error(): + """Test database initialization error handling.""" + # Try to create database in non-existent directory without permission + invalid_path = Path("/invalid/path/that/does/not/exist/test.db") + + with pytest.raises(RuntimeError, match="Failed to initialize database"): + store = SQLiteExecutionStore(invalid_path) + store._init_db() + + +def test_sqlite_execution_store_query_invalid_parameters(store): + """Test query with invalid parameters.""" + # Test with invalid time parameters + with pytest.raises( + InvalidParameterValueException, match="Invalid query parameters" + ): + store.query(started_after="invalid_timestamp") + + with pytest.raises( + InvalidParameterValueException, match="Invalid query parameters" + ): + store.query(started_before="not_a_number") + + +def test_sqlite_execution_store_query_no_limit_no_offset(store): + """Test query without limit and offset parameters.""" + # Create test execution + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution.new(input_data) + execution.start() + store.save(execution) + + # Query without limit should use different code path + executions, next_marker = store.query() + assert len(executions) == 1 + assert next_marker is None + + +def test_sqlite_execution_store_query_with_end_timestamp(store): + """Test execution with end timestamp.""" + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution.new(input_data) + execution.start() + execution.complete_success("result") # This should set end_timestamp + store.save(execution) + + loaded = store.load(execution.durable_execution_arn) + assert loaded.is_complete is True + + +def test_sqlite_execution_store_metadata_error_handling(temp_db_path): + """Test metadata retrieval error handling.""" + store = SQLiteExecutionStore.create_and_initialize(temp_db_path) + + # Remove database file to trigger error + temp_db_path.unlink() + + with pytest.raises(RuntimeError, match="Failed to get metadata"): + store.get_execution_metadata("test-arn") + + +def test_sqlite_execution_store_load_error_handling(temp_db_path): + """Test load error handling.""" + store = SQLiteExecutionStore.create_and_initialize(temp_db_path) + + # Remove database file to trigger error + temp_db_path.unlink() + + with pytest.raises(RuntimeError, match="Failed to load execution"): + store.load("test-arn") + + +def test_sqlite_execution_store_query_with_corrupted_data_warning( + store, temp_db_path, capsys +): + """Test that corrupted data in query results prints warning and continues.""" + import sqlite3 + + # Create a valid execution first + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution.new(input_data) + execution.start() + store.save(execution) + + # Insert corrupted JSON data directly + with sqlite3.connect(temp_db_path) as conn: + conn.execute( + """ + INSERT INTO executions + (durable_execution_arn, function_name, execution_name, status, start_timestamp, end_timestamp, data) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + "corrupted-arn-2", + "test-function", + "test-execution", + "RUNNING", + 1234567890.0, + None, + "invalid json data {{{", + ), + ) + + # Query should skip corrupted records and print warning + executions, _ = store.query() + + # Should get the valid execution, skip the corrupted one + assert len(executions) == 1 + assert executions[0].durable_execution_arn == execution.durable_execution_arn + + # Check that warning was printed + captured = capsys.readouterr() + assert "Warning: Skipping corrupted execution corrupted-arn-2" in captured.out diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/token_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/token_test.py new file mode 100644 index 0000000..714d8c9 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/token_test.py @@ -0,0 +1,132 @@ +"""Unit tests for token models.""" + +import base64 +import json + +import pytest + +from aws_durable_execution_sdk_python_testing.token import ( + CallbackToken, + CheckpointToken, +) + + +def test_checkpoint_token_init(): + """Test CheckpointToken initialization.""" + token = CheckpointToken("arn:aws:states:us-east-1:123456789012:execution:test", 42) + + assert token.execution_arn == "arn:aws:states:us-east-1:123456789012:execution:test" + assert token.token_sequence == 42 + + +def test_checkpoint_token_to_str(): + """Test CheckpointToken serialization to string.""" + token = CheckpointToken("arn:aws:states:us-east-1:123456789012:execution:test", 42) + + result = token.to_str() + + # Decode and verify the structure + decoded = base64.b64decode(result).decode() + data = json.loads(decoded) + assert data["arn"] == "arn:aws:states:us-east-1:123456789012:execution:test" + assert data["seq"] == 42 + + +def test_checkpoint_token_from_str(): + """Test CheckpointToken deserialization from string.""" + data = {"arn": "arn:aws:states:us-east-1:123456789012:execution:test", "seq": 42} + json_str = json.dumps(data, separators=(",", ":")) + token_str = base64.b64encode(json_str.encode()).decode() + + token = CheckpointToken.from_str(token_str) + + assert token.execution_arn == "arn:aws:states:us-east-1:123456789012:execution:test" + assert token.token_sequence == 42 + + +def test_checkpoint_token_round_trip(): + """Test CheckpointToken serialization and deserialization round trip.""" + original = CheckpointToken( + "arn:aws:states:us-east-1:123456789012:execution:test", 123 + ) + + token_str = original.to_str() + restored = CheckpointToken.from_str(token_str) + + assert restored == original + + +def test_checkpoint_token_frozen_dataclass(): + """Test that CheckpointToken is immutable.""" + token = CheckpointToken("arn:aws:states:us-east-1:123456789012:execution:test", 42) + + with pytest.raises(AttributeError): + token.execution_arn = "new-arn" + + with pytest.raises(AttributeError): + token.token_sequence = 999 + + +def test_callback_token_init(): + """Test CallbackToken initialization.""" + token = CallbackToken( + "arn:aws:states:us-east-1:123456789012:execution:test", "op-123" + ) + + assert token.execution_arn == "arn:aws:states:us-east-1:123456789012:execution:test" + assert token.operation_id == "op-123" + + +def test_callback_token_to_str(): + """Test CallbackToken serialization to string.""" + token = CallbackToken( + "arn:aws:states:us-east-1:123456789012:execution:test", "op-123" + ) + + result = token.to_str() + + # Decode and verify the structure + decoded = base64.b64decode(result).decode() + data = json.loads(decoded) + assert data["arn"] == "arn:aws:states:us-east-1:123456789012:execution:test" + assert data["op"] == "op-123" + + +def test_callback_token_from_str(): + """Test CallbackToken deserialization from string.""" + data = { + "arn": "arn:aws:states:us-east-1:123456789012:execution:test", + "op": "op-123", + } + json_str = json.dumps(data, separators=(",", ":")) + token_str = base64.b64encode(json_str.encode()).decode() + + token = CallbackToken.from_str(token_str) + + assert token.execution_arn == "arn:aws:states:us-east-1:123456789012:execution:test" + assert token.operation_id == "op-123" + + +def test_callback_token_round_trip(): + """Test CallbackToken serialization and deserialization round trip.""" + original = CallbackToken( + "arn:aws:states:us-east-1:123456789012:execution:test", "callback-op" + ) + + token_str = original.to_str() + restored = CallbackToken.from_str(token_str) + + assert restored == original + + +def test_callback_token_frozen_dataclass(): + """Test that CallbackToken is immutable.""" + token = CallbackToken( + "arn:aws:states:us-east-1:123456789012:execution:test", "op-123" + ) + + with pytest.raises(AttributeError): + token.execution_arn = "new-arn" + + with pytest.raises(AttributeError): + token.operation_id = "new-op" diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/web/__init__.py b/packages/aws-durable-execution-sdk-python-testing/tests/web/__init__.py new file mode 100644 index 0000000..5a4e39e --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/web/__init__.py @@ -0,0 +1 @@ +"""Tests for web server module.""" diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/web/e2e/__init__.py b/packages/aws-durable-execution-sdk-python-testing/tests/web/e2e/__init__.py new file mode 100644 index 0000000..cbd1352 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/web/e2e/__init__.py @@ -0,0 +1 @@ +"""End-to-end integration tests for web components.""" diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/web/e2e/routes_arn_encoding_int_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/web/e2e/routes_arn_encoding_int_test.py new file mode 100644 index 0000000..4b9c2a5 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/web/e2e/routes_arn_encoding_int_test.py @@ -0,0 +1,238 @@ +"""Integration test: WebServer route layer URL-decodes DurableExecutionArn. + +Drives a real ``boto3`` Lambda client against a live ``WebServer`` and asserts +that ``DurableExecutionArn`` values containing characters that boto +percent-encodes in URI labels (e.g. ``/`` -> ``%2F``) round-trip correctly so +the store lookup hits. +""" + +from __future__ import annotations + +import threading +import time +from typing import Any + +import boto3 # type: ignore +import pytest +from botocore.config import Config # type: ignore +from botocore.exceptions import ClientError # type: ignore + +from aws_durable_execution_sdk_python_testing.checkpoint.processor import ( + CheckpointProcessor, +) +from aws_durable_execution_sdk_python_testing.execution import Execution +from aws_durable_execution_sdk_python_testing.executor import Executor +from aws_durable_execution_sdk_python_testing.model import ( + StartDurableExecutionInput, +) +from aws_durable_execution_sdk_python_testing.scheduler import Scheduler +from aws_durable_execution_sdk_python_testing.stores.memory import ( + InMemoryExecutionStore, +) +from aws_durable_execution_sdk_python_testing.web.server import ( + WebServer, + WebServiceConfig, +) + + +class _NoOpInvoker: + """Satisfies the Invoker protocol without invoking anything. + + The route-layer regression doesn't depend on actually executing the + function; the executor just needs *some* invoker to construct it. + """ + + def create_invocation_input(self, execution: Any) -> Any: # noqa: ARG002 + return None + + def invoke(self, *args: Any, **kwargs: Any) -> Any: # noqa: ARG002 + return None + + def update_endpoint(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 + return None + + +def _assert_no_percent_encoding_in_error(exc: ClientError, arn: str) -> None: + """Fail the test if a ResourceNotFoundException carries a %2F-form ARN. + + Other errors (e.g. invalid checkpoint token, wrong state) are fine; this + test is narrowly about whether the route layer decoded the path segment. + """ + msg = str(exc) + assert "%2F" not in msg, ( + f"WebServer route layer did not URL-decode DurableExecutionArn. " + f"Original ARN: {arn!r}. Error: {msg}" + ) + + +@pytest.fixture +def server_with_slash_arn(): + """Yield ``(boto_client, arn, executor, store)`` for a live WebServer. + + The yielded ARN contains a literal ``/`` matching the v1.2.0+ format + produced by ``Execution.new()``. The Execution is pre-started and saved + so read paths have something to find. + """ + store = InMemoryExecutionStore() + scheduler = Scheduler() + checkpoint_processor = CheckpointProcessor(store=store, scheduler=scheduler) + executor = Executor( + store=store, + scheduler=scheduler, + invoker=_NoOpInvoker(), + checkpoint_processor=checkpoint_processor, + ) + checkpoint_processor.add_execution_observer(executor) + scheduler.start() + + # Hand-build a started Execution whose ARN contains '/' so we control + # the format under test without going through executor.start_execution + # (which schedules a real invoke + timeout). + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-fn", + function_qualifier="$LATEST", + execution_name="test-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="inv-12345", + input='"hi"', + ) + execution = Execution.new(start_input) + execution.start() + store.save(execution) + arn = execution.durable_execution_arn + assert "/" in arn, "regression precondition: ARN must contain literal '/'" + + config = WebServiceConfig(host="127.0.0.1", port=0) + server = WebServer(config, executor) + port = server.server_address[1] + server_thread = threading.Thread(target=server.serve_forever, daemon=True) + server_thread.start() + # Give the listener a beat to come up before the boto client connects. + time.sleep(0.05) + + client = boto3.client( + "lambda", + endpoint_url=f"http://127.0.0.1:{port}", + region_name="us-east-1", + aws_access_key_id="x", # noqa: S106 - test stub + aws_secret_access_key="y", # noqa: S106 - test stub + config=Config(parameter_validation=False, retries={"max_attempts": 0}), + ) + + try: + yield client, arn, executor, store + finally: + server.shutdown() + server.server_close() + scheduler.stop() + + +def test_get_durable_execution_decodes_slash_in_arn(server_with_slash_arn): + """GetDurableExecution: %2F must be decoded so the store lookup hits.""" + client, arn, _executor, _store = server_with_slash_arn + + response = client.get_durable_execution(DurableExecutionArn=arn) + + assert response["DurableExecutionArn"] == arn + + +def test_get_durable_execution_state_decodes_slash_in_arn(server_with_slash_arn): + """GetDurableExecutionState: %2F must be decoded so the store lookup hits.""" + client, arn, _executor, _store = server_with_slash_arn + + response = client.get_durable_execution_state( + DurableExecutionArn=arn, + CheckpointToken="ignored-by-route-layer", # noqa: S106 - test stub + ) + + # Response shape varies; the only assertion this test cares about is + # that we got past route resolution. + assert response is not None + + +def test_get_durable_execution_history_decodes_slash_in_arn(server_with_slash_arn): + """GetDurableExecutionHistory: %2F must be decoded so the store lookup hits.""" + client, arn, _executor, _store = server_with_slash_arn + + response = client.get_durable_execution_history(DurableExecutionArn=arn) + + assert response is not None + + +def test_checkpoint_durable_execution_decodes_slash_in_arn(server_with_slash_arn): + """CheckpointDurableExecution: %2F must be decoded so the store lookup hits. + + A checkpoint with no operation updates may still trip secondary + validation; we only assert the failure (if any) is not the + %2F-in-message 404 that indicates the route layer dropped the ball. + """ + client, arn, _executor, store = server_with_slash_arn + execution = store.load(arn) + token = execution.get_new_checkpoint_token() + + try: + client.checkpoint_durable_execution( + DurableExecutionArn=arn, + CheckpointToken=token, + Updates=[], + ) + except ClientError as exc: + _assert_no_percent_encoding_in_error(exc, arn) + + +def test_stop_durable_execution_decodes_slash_in_arn(server_with_slash_arn): + """StopDurableExecution: %2F must be decoded so the store lookup hits.""" + client, arn, _executor, _store = server_with_slash_arn + + try: + client.stop_durable_execution(DurableExecutionArn=arn) + except ClientError as exc: + _assert_no_percent_encoding_in_error(exc, arn) + + +def test_list_durable_executions_by_function_decodes_colon_in_name( + server_with_slash_arn, +): + """ListDurableExecutionsByFunction: %3A/%24 in FunctionName must be decoded. + + boto percent-encodes ``:`` and ``$`` in the non-greedy ``{FunctionName}`` + URI label, so a realistic value like ``MyFunction:$LATEST`` arrives as + ``MyFunction%3A%24LATEST``. The route layer must decode the segment so + the store's exact-match filter on ``function_name`` returns the expected + execution. + + Pre-fix behavior: handler filters on the encoded string, response has + no executions. Post-fix: handler filters on the decoded string, response + returns the seeded execution. + """ + client, _arn, _executor, store = server_with_slash_arn + + # Seed an execution whose function_name contains characters boto encodes. + realistic_function_name = "MyFunction:$LATEST" + seed = StartDurableExecutionInput( + account_id="123456789012", + function_name=realistic_function_name, + function_qualifier="$LATEST", + execution_name="encoded-fn-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="inv-encoded-fn", + input='"hi"', + ) + seeded = Execution.new(seed) + seeded.start() + store.save(seeded) + + response = client.list_durable_executions_by_function( + FunctionName=realistic_function_name, + ) + + arns = [e["DurableExecutionArn"] for e in response.get("DurableExecutions", [])] + assert seeded.durable_execution_arn in arns, ( + f"WebServer route layer did not URL-decode FunctionName. " + f"Seeded function_name {realistic_function_name!r} produced arn " + f"{seeded.durable_execution_arn!r}, but list response contained " + f"{arns!r}." + ) diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/web/e2e/server_int_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/web/e2e/server_int_test.py new file mode 100644 index 0000000..0186606 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/web/e2e/server_int_test.py @@ -0,0 +1,101 @@ +"""Integration tests for web server routing and handler integration.""" + +from __future__ import annotations + +from unittest.mock import Mock + +from aws_durable_execution_sdk_python_testing.web.models import ( + HTTPRequest, + HTTPResponse, +) +from aws_durable_execution_sdk_python_testing.web.routes import ( + HealthRoute, + StartExecutionRoute, +) +from aws_durable_execution_sdk_python_testing.web.server import ( + WebServer, + WebServiceConfig, +) + + +def test_web_server_router_integration(): + """Test that router can find routes and handlers can handle them.""" + executor = Mock() + config = WebServiceConfig(port=0) # Use port 0 to get any available port + + server = WebServer(config, executor) + + try: + # Test router can find a route + route = server.router.find_route("/health", "GET") + assert isinstance(route, HealthRoute) + + # Test handler exists for the route + handler = server.endpoint_handlers.get(type(route)) + assert handler is not None + + # Test handler can handle the route + request = HTTPRequest( + method="GET", path=route, headers={}, query_params={}, body={} + ) + + response = handler.handle(route, request) + assert isinstance(response, HTTPResponse) + assert response.status_code == 200 + assert response.body == {"status": "healthy"} + finally: + server.server_close() + + +def test_web_server_start_execution_route_integration(): + """Test that start execution route is properly integrated.""" + executor = Mock() + config = WebServiceConfig(port=0) # Use port 0 to get any available port + + server = WebServer(config, executor) + + try: + # Test router can find start execution route + route = server.router.find_route("/start-durable-execution", "POST") + assert isinstance(route, StartExecutionRoute) + + # Test handler exists for the route + handler = server.endpoint_handlers.get(type(route)) + assert handler is not None + + # Test handler returns 400 for invalid input (now implemented) + request = HTTPRequest( + method="POST", + path=route, + headers={}, + query_params={}, + body={"test": "data"}, # Invalid input - missing required fields + ) + + response = handler.handle(route, request) + assert isinstance(response, HTTPResponse) + assert response.status_code == 400 # Bad request for invalid input + finally: + server.server_close() + + +def test_web_server_context_manager_with_integration(): + """Test that WebServer context manager works with integrated components.""" + executor = Mock() + config = WebServiceConfig(port=0) # Use port 0 to get any available port + + with WebServer(config, executor) as server: + # Verify server is properly initialized + assert server.router is not None + assert server.endpoint_handlers is not None + + # Test a simple route resolution + route = server.router.find_route("/health", "GET") + handler = server.endpoint_handlers[type(route)] + + request = HTTPRequest( + method="GET", path=route, headers={}, query_params={}, body={} + ) + + response = handler.handle(route, request) + assert response.status_code == 200 diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/web/handlers_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/web/handlers_test.py new file mode 100644 index 0000000..3cb84a8 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/web/handlers_test.py @@ -0,0 +1,2635 @@ +"""Tests for HTTP endpoint handlers.""" + +from __future__ import annotations + +import base64 +import json +from typing import TYPE_CHECKING, Any +from unittest.mock import Mock + +import pytest +from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + Operation, + OperationStatus, + OperationType, +) + +from aws_durable_execution_sdk_python_testing.exceptions import ( + AwsApiException, + IllegalArgumentException, + IllegalStateException, + InvalidParameterValueException, + ResourceNotFoundException, +) + + +if TYPE_CHECKING: + from aws_durable_execution_sdk_python_testing.executor import Executor +from aws_durable_execution_sdk_python_testing.model import ( + CheckpointDurableExecutionResponse, + Event, + ExecutionStartedDetails, + GetDurableExecutionHistoryResponse, + GetDurableExecutionResponse, + GetDurableExecutionStateResponse, + ListDurableExecutionsByFunctionResponse, + ListDurableExecutionsResponse, + SendDurableExecutionCallbackFailureRequest, + SendDurableExecutionCallbackFailureResponse, + SendDurableExecutionCallbackHeartbeatRequest, + SendDurableExecutionCallbackHeartbeatResponse, + SendDurableExecutionCallbackSuccessRequest, + SendDurableExecutionCallbackSuccessResponse, + StartDurableExecutionInput, + StartDurableExecutionOutput, + StopDurableExecutionResponse, +) +from aws_durable_execution_sdk_python_testing.model import ( + Execution as ExecutionSummary, +) +from aws_durable_execution_sdk_python_testing.web import handlers +from aws_durable_execution_sdk_python_testing.web.handlers import ( + CheckpointDurableExecutionHandler, + EndpointHandler, + GetDurableExecutionHandler, + GetDurableExecutionHistoryHandler, + GetDurableExecutionStateHandler, + HealthHandler, + ListDurableExecutionsByFunctionHandler, + ListDurableExecutionsHandler, + MetricsHandler, + SendDurableExecutionCallbackFailureHandler, + SendDurableExecutionCallbackHeartbeatHandler, + SendDurableExecutionCallbackSuccessHandler, + StartExecutionHandler, + StopDurableExecutionHandler, +) +from aws_durable_execution_sdk_python_testing.web.models import ( + HTTPRequest, + HTTPResponse, +) +from aws_durable_execution_sdk_python_testing.web.routes import ( + CallbackFailureRoute, + CallbackHeartbeatRoute, + CallbackSuccessRoute, + GetDurableExecutionRoute, + ListDurableExecutionsRoute, + Route, + Router, + StartExecutionRoute, +) + + +class MockableEndpointHandler(EndpointHandler): + """Test-specific handler that exposes private methods for testing.""" + + def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: + """Handle request - test implementation.""" + return self._success_response({"test": "data"}) + + # Public methods that expose private functionality for testing + def parse_json_body(self, request: HTTPRequest) -> dict[str, Any]: + """Public wrapper for _parse_json_body.""" + return self._parse_json_body(request) + + def json_response( + self, + status_code: int, + data: dict[str, Any], + additional_headers: dict[str, str] | None = None, + ) -> HTTPResponse: + """Public wrapper for _json_response.""" + return self._json_response(status_code, data, additional_headers) + + def success_response( + self, data: dict[str, Any], additional_headers: dict[str, str] | None = None + ) -> HTTPResponse: + """Public wrapper for _success_response.""" + return self._success_response(data, additional_headers) + + def created_response( + self, data: dict[str, Any], additional_headers: dict[str, str] | None = None + ) -> HTTPResponse: + """Public wrapper for _created_response.""" + return self._created_response(data, additional_headers) + + def no_content_response( + self, additional_headers: dict[str, str] | None = None + ) -> HTTPResponse: + """Public wrapper for _no_content_response.""" + return self._no_content_response(additional_headers) + + def parse_query_param(self, request: HTTPRequest, param_name: str) -> str | None: + """Public wrapper for _parse_query_param.""" + return self._parse_query_param(request, param_name) + + def parse_query_param_list( + self, request: HTTPRequest, param_name: str + ) -> list[str]: + """Public wrapper for _parse_query_param_list.""" + return self._parse_query_param_list(request, param_name) + + def validate_required_fields( + self, data: dict[str, Any], required_fields: list[str] + ) -> None: + """Public wrapper for _validate_required_fields.""" + return self._validate_required_fields(data, required_fields) + + +def test_endpoint_handler_initialization(): + """Test EndpointHandler initialization.""" + executor = Mock() + handler = MockableEndpointHandler(executor) + assert handler.executor == executor + + +def test_endpoint_handler_parse_json_body_valid(): + """Test parse_json_body with valid JSON.""" + executor = Mock() + handler = MockableEndpointHandler(executor) + + request = HTTPRequest( + method="POST", + path=Route.from_string("/test"), + headers={"Content-Type": "application/json"}, + query_params={}, + body={"key": "value"}, + ) + + result = handler.parse_json_body(request) + assert result == {"key": "value"} + + +def test_endpoint_handler_parse_json_body_empty(): + """Test parse_json_body with empty body.""" + executor = Mock() + handler = MockableEndpointHandler(executor) + + request = HTTPRequest( + method="POST", + path=Route.from_string("/test"), + headers={"Content-Type": "application/json"}, + query_params={}, + body={}, + ) + + with pytest.raises( + InvalidParameterValueException, match="Request body is required" + ): + handler.parse_json_body(request) + + +def test_endpoint_handler_parse_json_body_invalid(): + """Test parse_json_body with invalid JSON - now this test is not applicable since body is already a dict.""" + executor = Mock() + handler = MockableEndpointHandler(executor) + + # Since body is now a dict, this test case doesn't apply anymore + # The validation happens during HTTPRequest.from_bytes() deserialization + request = HTTPRequest( + method="POST", + path=Route.from_string("/test"), + headers={"Content-Type": "application/json"}, + query_params={}, + body={"valid": "json"}, # Body is always valid dict now + ) + + # This should work fine now since body is already parsed + result = handler.parse_json_body(request) + assert result == {"valid": "json"} + + +def test_endpoint_handler_json_response(): + """Test json_response method.""" + executor = Mock() + handler = MockableEndpointHandler(executor) + + response = handler.json_response(200, {"test": "data"}) + assert response.status_code == 200 + assert response.headers["Content-Type"] == "application/json" + assert response.body == {"test": "data"} + + # Verify serialization to bytes works + body_bytes = response.body_to_bytes() + assert b'"test":"data"' in body_bytes + + +def test_endpoint_handler_success_response(): + """Test success_response method.""" + executor = Mock() + handler = MockableEndpointHandler(executor) + + response = handler.success_response({"test": "data"}) + assert response.status_code == 200 + assert response.headers["Content-Type"] == "application/json" + + +def test_endpoint_handler_created_response(): + """Test created_response method.""" + executor = Mock() + handler = MockableEndpointHandler(executor) + + response = handler.created_response({"test": "data"}) + assert response.status_code == 201 + assert response.headers["Content-Type"] == "application/json" + + +def test_endpoint_handler_no_content_response(): + """Test no_content_response method.""" + executor = Mock() + handler = MockableEndpointHandler(executor) + + response = handler.no_content_response() + assert response.status_code == 204 + assert response.body == {} + + +def test_endpoint_handler_error_response(): + """Test error response creation using HTTPResponse.create_error_from_exception.""" + # Test that we can create error responses using the new method + exception = InvalidParameterValueException("Bad request") + + response = HTTPResponse.create_error_from_exception(exception) + assert response.status_code == 400 + assert response.headers["Content-Type"] == "application/json" + + # The new format doesn't wrap in an "error" object + # InvalidParameterValueException uses lowercase "message" per Smithy definition + expected_body = { + "Type": "InvalidParameterValueException", + "message": "Bad request", + } + assert response.body == expected_body + + # Verify serialization to bytes works + body_bytes = response.body_to_bytes() + assert b'"message":"Bad request"' in body_bytes + assert b'"Type":"InvalidParameterValueException"' in body_bytes + + +def test_endpoint_handler_parse_query_param(): + """Test parse_query_param method.""" + executor = Mock() + handler = MockableEndpointHandler(executor) + + request = HTTPRequest( + method="GET", + path=Route.from_string("/test"), + headers={}, + query_params={"param1": ["value1"], "param2": ["value2a", "value2b"]}, + body={}, + ) + + assert handler.parse_query_param(request, "param1") == "value1" + assert handler.parse_query_param(request, "param2") == "value2a" # First value + assert handler.parse_query_param(request, "nonexistent") is None + + +def test_endpoint_handler_parse_query_param_list(): + """Test parse_query_param_list method.""" + executor = Mock() + handler = MockableEndpointHandler(executor) + + request = HTTPRequest( + method="GET", + path=Route.from_string("/test"), + headers={}, + query_params={"param1": ["value1"], "param2": ["value2a", "value2b"]}, + body={}, + ) + + assert handler.parse_query_param_list(request, "param1") == ["value1"] + assert handler.parse_query_param_list(request, "param2") == ["value2a", "value2b"] + assert handler.parse_query_param_list(request, "nonexistent") == [] + + +def test_endpoint_handler_validate_required_fields_valid(): + """Test validate_required_fields with valid data.""" + executor = Mock() + handler = MockableEndpointHandler(executor) + + data = {"field1": "value1", "field2": "value2", "field3": "value3"} + required_fields = ["field1", "field2"] + + # Should not raise an exception + handler.validate_required_fields(data, required_fields) + + +def test_endpoint_handler_validate_required_fields_missing(): + """Test validate_required_fields with missing fields.""" + executor = Mock() + handler = MockableEndpointHandler(executor) + + data = {"field1": "value1"} + required_fields = ["field1", "field2", "field3"] + + with pytest.raises( + InvalidParameterValueException, match="Missing required fields: field2, field3" + ): + handler.validate_required_fields(data, required_fields) + + +def test_start_execution_handler_success(): + """Test StartExecutionHandler with successful execution start.""" + executor = Mock() + handler = StartExecutionHandler(executor) + + # Mock successful executor response + mock_output = StartDurableExecutionOutput(execution_arn="test-execution-arn") + executor.start_execution.return_value = mock_output + + # Create request with valid input data + request_data = { + "AccountId": "123456789012", + "FunctionName": "test-function", + "FunctionQualifier": "$LATEST", + "ExecutionName": "test-execution", + "ExecutionTimeoutSeconds": 300, + "ExecutionRetentionPeriodDays": 7, + "Input": '{"test": "data"}', + } + + request = HTTPRequest( + method="POST", + path=StartExecutionRoute.from_string("/start-durable-execution"), + headers={"Content-Type": "application/json"}, + query_params={}, + body=request_data, + ) + + route = StartExecutionRoute.from_string("/start-durable-execution") + response = handler.handle(route, request) + + # Verify response + assert response.status_code == 201 + assert response.headers["Content-Type"] == "application/json" + assert response.body == {"ExecutionArn": "test-execution-arn"} + + # Verify executor was called with correct input + executor.start_execution.assert_called_once() + call_args = executor.start_execution.call_args[0][0] + assert isinstance(call_args, StartDurableExecutionInput) + assert call_args.account_id == "123456789012" + assert call_args.function_name == "test-function" + assert call_args.execution_name == "test-execution" + + +def test_start_execution_handler_empty_body(): + """Test StartExecutionHandler with empty request body.""" + executor = Mock() + handler = StartExecutionHandler(executor) + + request = HTTPRequest( + method="POST", + path=StartExecutionRoute.from_string("/start-durable-execution"), + headers={"Content-Type": "application/json"}, + query_params={}, + body={}, + ) + + route = StartExecutionRoute.from_string("/start-durable-execution") + response = handler.handle(route, request) + + # Should return 400 Bad Request for empty body with AWS-compliant format + assert response.status_code == 400 + assert response.body["Type"] == "InvalidParameterValueException" + assert "Request body is required" in response.body["message"] + + +def test_start_execution_handler_missing_required_fields(): + """Test StartExecutionHandler with missing required fields.""" + executor = Mock() + handler = StartExecutionHandler(executor) + + # Request missing required fields + request_data = { + "AccountId": "123456789012", + "FunctionName": "test-function", + # Missing other required fields + } + + request = HTTPRequest( + method="POST", + path=StartExecutionRoute.from_string("/start-durable-execution"), + headers={"Content-Type": "application/json"}, + query_params={}, + body=request_data, + ) + + route = StartExecutionRoute.from_string("/start-durable-execution") + response = handler.handle(route, request) + + # Should return 400 Bad Request for missing fields with AWS-compliant format + assert response.status_code == 400 + assert response.body["Type"] == "InvalidParameterValueException" + assert "FunctionQualifier" in response.body["message"] + + +def test_start_execution_handler_invalid_parameter_error(): + """Test StartExecutionHandler with IllegalArgumentException from executor.""" + + executor = Mock() + handler = StartExecutionHandler(executor) + + # Mock executor to raise IllegalArgumentException + executor.start_execution.side_effect = IllegalArgumentException( + "Invalid timeout value" + ) + + request_data = { + "AccountId": "123456789012", + "FunctionName": "test-function", + "FunctionQualifier": "$LATEST", + "ExecutionName": "test-execution", + "ExecutionTimeoutSeconds": -1, # Invalid value + "ExecutionRetentionPeriodDays": 7, + } + + request = HTTPRequest( + method="POST", + path=StartExecutionRoute.from_string("/start-durable-execution"), + headers={"Content-Type": "application/json"}, + query_params={}, + body=request_data, + ) + + route = StartExecutionRoute.from_string("/start-durable-execution") + response = handler.handle(route, request) + + # Should return 400 Bad Request with AWS-compliant format + assert response.status_code == 400 + assert response.body["Type"] == "InvalidParameterValueException" + assert response.body["message"] == "Invalid timeout value" + + +def test_start_execution_handler_execution_already_exists(): + """Test StartExecutionHandler with execution already exists error.""" + + executor = Mock() + handler = StartExecutionHandler(executor) + + # Mock executor to raise IllegalStateException (execution already exists) + executor.start_execution.side_effect = IllegalStateException( + "Execution with name 'test-execution' already exists" + ) + + request_data = { + "AccountId": "123456789012", + "FunctionName": "test-function", + "FunctionQualifier": "$LATEST", + "ExecutionName": "test-execution", + "ExecutionTimeoutSeconds": 300, + "ExecutionRetentionPeriodDays": 7, + } + + request = HTTPRequest( + method="POST", + path=StartExecutionRoute.from_string("/start-durable-execution"), + headers={"Content-Type": "application/json"}, + query_params={}, + body=request_data, + ) + + route = StartExecutionRoute.from_string("/start-durable-execution") + response = handler.handle(route, request) + + # Should return 409 Conflict with AWS-compliant format (ExecutionAlreadyStartedException has no Type field) + assert response.status_code == 409 + assert "already exists" in response.body["message"] + assert ( + response.body["DurableExecutionArn"] + == "arn:aws:lambda:us-east-1:123456789012:function:test" + ) + assert ( + "Type" not in response.body + ) # ExecutionAlreadyStartedException doesn't have Type field + + +def test_start_execution_handler_unexpected_error(): + """Test StartExecutionHandler with unexpected error from executor.""" + executor = Mock() + handler = StartExecutionHandler(executor) + + # Mock executor to raise unexpected error + executor.start_execution.side_effect = RuntimeError("Unexpected database error") + + request_data = { + "AccountId": "123456789012", + "FunctionName": "test-function", + "FunctionQualifier": "$LATEST", + "ExecutionName": "test-execution", + "ExecutionTimeoutSeconds": 300, + "ExecutionRetentionPeriodDays": 7, + } + + request = HTTPRequest( + method="POST", + path=StartExecutionRoute.from_string("/start-durable-execution"), + headers={"Content-Type": "application/json"}, + query_params={}, + body=request_data, + ) + + route = StartExecutionRoute.from_string("/start-durable-execution") + response = handler.handle(route, request) + + # Should return 500 Internal Server Error with AWS-compliant format + assert response.status_code == 500 + assert response.body["Type"] == "ServiceException" + assert response.body["Message"] == "Unexpected database error" + + +def test_start_execution_handler_with_optional_fields(): + """Test StartExecutionHandler with optional fields included.""" + + executor = Mock() + handler = StartExecutionHandler(executor) + + # Mock successful executor response + mock_output = StartDurableExecutionOutput(execution_arn="test-execution-arn") + executor.start_execution.return_value = mock_output + + # Create request with optional fields + request_data = { + "AccountId": "123456789012", + "FunctionName": "test-function", + "FunctionQualifier": "$LATEST", + "ExecutionName": "test-execution", + "ExecutionTimeoutSeconds": 300, + "ExecutionRetentionPeriodDays": 7, + "InvocationId": "test-invocation-id", + "TraceFields": {"traceId": "test-trace"}, + "TenantId": "test-tenant", + "Input": '{"test": "data"}', + } + + request = HTTPRequest( + method="POST", + path=StartExecutionRoute.from_string("/start-durable-execution"), + headers={"Content-Type": "application/json"}, + query_params={}, + body=request_data, + ) + + route = StartExecutionRoute.from_string("/start-durable-execution") + response = handler.handle(route, request) + + # Verify response + assert response.status_code == 201 + assert response.body == {"ExecutionArn": "test-execution-arn"} + + # Verify executor was called with correct input including optional fields + executor.start_execution.assert_called_once() + call_args = executor.start_execution.call_args[0][0] + assert isinstance(call_args, StartDurableExecutionInput) + assert call_args.invocation_id == "test-invocation-id" + assert call_args.trace_fields == {"traceId": "test-trace"} + assert call_args.tenant_id == "test-tenant" + assert call_args.input == '{"test": "data"}' + + +def test_get_durable_execution_handler_success(): + """Test GetDurableExecutionHandler with successful execution retrieval.""" + + executor = Mock() + handler = GetDurableExecutionHandler(executor) + + # Mock the executor response + mock_response = GetDurableExecutionResponse( + durable_execution_arn="test-arn", + durable_execution_name="test-execution", + function_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function", + status="SUCCEEDED", + start_timestamp="2023-01-01T00:00:00Z", + input_payload="test-input", + result="test-result", + error=None, + end_timestamp="2023-01-01T00:01:00Z", + version="1.0", + ) + executor.get_execution_details.return_value = mock_response + + # Create strongly-typed route + base_route = Route.from_string("/2025-12-01/durable-executions/test-arn") + typed_route = GetDurableExecutionRoute.from_route(base_route) + + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={}, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify response + assert response.status_code == 200 + expected_body = { + "DurableExecutionArn": "test-arn", + "DurableExecutionName": "test-execution", + "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function", + "Status": "SUCCEEDED", + "StartTimestamp": "2023-01-01T00:00:00Z", + "InputPayload": "test-input", + "Result": "test-result", + "EndTimestamp": "2023-01-01T00:01:00Z", + "Version": "1.0", + } + assert response.body == expected_body + + # Verify executor was called with correct ARN + executor.get_execution_details.assert_called_once_with("test-arn") + + +def test_get_durable_execution_handler_resource_not_found(): + """Test GetDurableExecutionHandler with ResourceNotFoundException.""" + + executor = Mock() + handler = GetDurableExecutionHandler(executor) + + # Mock executor to raise ResourceNotFoundException + executor.get_execution_details.side_effect = ResourceNotFoundException( + "Execution not-found-arn not found" + ) + + # Create strongly-typed route + base_route = Route.from_string("/2025-12-01/durable-executions/not-found-arn") + typed_route = GetDurableExecutionRoute.from_route(base_route) + + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={}, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify error response with AWS-compliant format + assert response.status_code == 404 + assert response.body["Type"] == "ResourceNotFoundException" + assert response.body["Message"] == "Execution not-found-arn not found" + + # Verify executor was called + executor.get_execution_details.assert_called_once_with("not-found-arn") + + +def test_get_durable_execution_handler_invalid_parameter(): + """Test GetDurableExecutionHandler with IllegalArgumentException.""" + + executor = Mock() + handler = GetDurableExecutionHandler(executor) + + # Mock executor to raise IllegalArgumentException + executor.get_execution_details.side_effect = IllegalArgumentException( + "Invalid execution ARN format" + ) + + # Create strongly-typed route + base_route = Route.from_string("/2025-12-01/durable-executions/invalid-arn") + typed_route = GetDurableExecutionRoute.from_route(base_route) + + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={}, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify error response with AWS-compliant format + assert response.status_code == 400 + assert response.body["Type"] == "InvalidParameterValueException" + assert response.body["message"] == "Invalid execution ARN format" + + # Verify executor was called + executor.get_execution_details.assert_called_once_with("invalid-arn") + + +def test_get_durable_execution_handler_unexpected_error(): + """Test GetDurableExecutionHandler with unexpected error.""" + + executor = Mock() + handler = GetDurableExecutionHandler(executor) + + # Mock executor to raise unexpected error + executor.get_execution_details.side_effect = RuntimeError("Unexpected error") + + # Create strongly-typed route + base_route = Route.from_string("/2025-12-01/durable-executions/test-arn") + typed_route = GetDurableExecutionRoute.from_route(base_route) + + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={}, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify error response with AWS-compliant format + assert response.status_code == 500 + assert response.body["Type"] == "ServiceException" + assert response.body["Message"] == "Unexpected error" + + # Verify executor was called + executor.get_execution_details.assert_called_once_with("test-arn") + + +def test_checkpoint_durable_execution_handler_success(): + """Test CheckpointDurableExecutionHandler with successful checkpoint processing.""" + + executor = Mock() + handler = CheckpointDurableExecutionHandler(executor) + + # Mock the executor response + mock_response = CheckpointDurableExecutionResponse( + checkpoint_token="new-token-123", # noqa: S106 + new_execution_state=None, + ) + executor.checkpoint_execution.return_value = mock_response + + # Create request with proper checkpoint data + request_body = { + "CheckpointToken": "current-token-123", + "Updates": [ + {"Id": "op-1", "Type": "STEP", "Action": "SUCCEED", "SubType": "Step"} + ], + "ClientToken": "client-token-123", + } + + # Create strongly-typed route using Router + router = Router() + typed_route = router.find_route( + "/2025-12-01/durable-executions/test-arn/checkpoint", "POST" + ) + + request = HTTPRequest( + method="POST", + path=typed_route, + headers={"Content-Type": "application/json"}, + query_params={}, + body=request_body, + ) + + response = handler.handle(typed_route, request) + + # Verify response + assert response.status_code == 200 + assert response.body == { + "CheckpointToken": "new-token-123", + } + + # Verify executor was called with correct parameters + executor.checkpoint_execution.assert_called_once() + call_args = executor.checkpoint_execution.call_args + assert call_args[0][0] == "test-arn" # execution_arn + assert call_args[0][1] == "current-token-123" # checkpoint_token + assert call_args[0][3] == "client-token-123" # client_token + + # Verify the updates parameter + updates = call_args[0][2] + assert len(updates) == 1 + assert updates[0].operation_id == "op-1" + + +def test_checkpoint_durable_execution_handler_invalid_request(): + """Test CheckpointDurableExecutionHandler with invalid request body.""" + + executor = Mock() + handler = CheckpointDurableExecutionHandler(executor) + + # Create strongly-typed route using Router + router = Router() + typed_route = router.find_route( + "/2025-12-01/durable-executions/test-arn/checkpoint", "POST" + ) + + request = HTTPRequest( + method="POST", + path=typed_route, + headers={}, + query_params={}, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify AWS-compliant error format + assert response.status_code == 400 + assert response.body["Type"] == "InvalidParameterValueException" + assert "Request body is required" in response.body["message"] + + +def test_checkpoint_durable_execution_handler_invalid_checkpoint_exception(): + """Test CheckpointDurableExecutionHandler with IllegalStateException mapping to ServiceException.""" + + executor = Mock() + handler = CheckpointDurableExecutionHandler(executor) + + # Mock executor to raise IllegalStateException + executor.checkpoint_execution.side_effect = IllegalStateException( + "Invalid checkpoint token" + ) + + request_body = { + "CheckpointToken": "invalid-token", + } + + # Create strongly-typed route using Router + router = Router() + typed_route = router.find_route( + "/2025-12-01/durable-executions/test-arn/checkpoint", "POST" + ) + + request = HTTPRequest( + method="POST", + path=typed_route, + headers={"Content-Type": "application/json"}, + query_params={}, + body=request_body, + ) + + response = handler.handle(typed_route, request) + + # Verify IllegalStateException maps to ServiceException in AWS-compliant format + assert response.status_code == 500 + assert response.body["Type"] == "ServiceException" + assert response.body["Message"] == "Invalid checkpoint token" + + +def test_stop_durable_execution_handler_success(): + """Test StopDurableExecutionHandler with successful execution stop.""" + + executor = Mock() + handler = StopDurableExecutionHandler(executor) + + # Mock the executor response + mock_response = StopDurableExecutionResponse(stop_timestamp="2023-01-01T00:01:00Z") + executor.stop_execution.return_value = mock_response + + # Create request with proper stop data + request_body = { + "DurableExecutionArn": "test-arn", + "Error": { + "ErrorMessage": "User requested stop", + "ErrorType": "UserStop", + }, + } + + # Create strongly-typed route using Router + router = Router() + typed_route = router.find_route( + "/2025-12-01/durable-executions/test-arn/stop", "POST" + ) + + request = HTTPRequest( + method="POST", + path=typed_route, + headers={"Content-Type": "application/json"}, + query_params={}, + body=request_body, + ) + + response = handler.handle(typed_route, request) + + # Verify response + assert response.status_code == 200 + assert response.body == {"StopTimestamp": "2023-01-01T00:01:00Z"} + + # Verify executor was called with correct parameters + executor.stop_execution.assert_called_once() + call_args = executor.stop_execution.call_args + assert call_args[0][0] == "test-arn" # execution_arn + + +def test_stop_durable_execution_handler_execution_already_stopped(): + """Test StopDurableExecutionHandler with execution already stopped returns idempotent response.""" + + executor = Mock() + handler = StopDurableExecutionHandler(executor) + + # Mock executor to return stop response with timestamp + stop_timestamp = "2023-01-01T00:01:00Z" + executor.stop_execution.return_value = StopDurableExecutionResponse( + stop_timestamp=stop_timestamp + ) + + request_body = { + "DurableExecutionArn": "test-arn", + } + + # Create strongly-typed route using Router + router = Router() + typed_route = router.find_route( + "/2025-12-01/durable-executions/test-arn/stop", "POST" + ) + + request = HTTPRequest( + method="POST", + path=typed_route, + headers={"Content-Type": "application/json"}, + query_params={}, + body=request_body, + ) + + response = handler.handle(typed_route, request) + + # Verify idempotent response with stop timestamp + assert response.status_code == 200 + assert response.body["StopTimestamp"] == stop_timestamp + + +def test_stop_durable_execution_handler_resource_not_found(): + """Test StopDurableExecutionHandler with ResourceNotFoundException.""" + + executor = Mock() + handler = StopDurableExecutionHandler(executor) + + # Mock executor to raise ResourceNotFoundException + executor.stop_execution.side_effect = ResourceNotFoundException( + "Execution not-found-arn not found" + ) + + request_body = { + "DurableExecutionArn": "not-found-arn", + } + + # Create strongly-typed route using Router + router = Router() + typed_route = router.find_route( + "/2025-12-01/durable-executions/not-found-arn/stop", "POST" + ) + + request = HTTPRequest( + method="POST", + path=typed_route, + headers={"Content-Type": "application/json"}, + query_params={}, + body=request_body, + ) + + response = handler.handle(typed_route, request) + + # Verify error response with AWS-compliant format + assert response.status_code == 404 + assert response.body["Type"] == "ResourceNotFoundException" + assert response.body["Message"] == "Execution not-found-arn not found" + + +def test_get_durable_execution_state_handler_success(): + """Test GetDurableExecutionStateHandler with successful state retrieval.""" + + executor = Mock() + handler = GetDurableExecutionStateHandler(executor) + + # Mock the executor response with operations + + mock_operations = [ + Operation( + operation_id="op-1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + name="test-step", + ) + ] + mock_response = GetDurableExecutionStateResponse( + operations=mock_operations, next_marker=None + ) + executor.get_execution_state.return_value = mock_response + + # Create strongly-typed route using Router + router = Router() + typed_route = router.find_route( + "/2025-12-01/durable-executions/test-arn/state", "GET" + ) + + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={}, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify response + assert response.status_code == 200 + assert "Operations" in response.body + assert len(response.body["Operations"]) == 1 + assert response.body["Operations"][0]["Id"] == "op-1" + assert response.body["Operations"][0]["Type"] == "STEP" + + # Verify executor was called with correct ARN + executor.get_execution_state.assert_called_once_with("test-arn") + + +def test_get_durable_execution_state_handler_resource_not_found(): + """Test GetDurableExecutionStateHandler with ResourceNotFoundException.""" + + executor = Mock() + handler = GetDurableExecutionStateHandler(executor) + + # Mock executor to raise ResourceNotFoundException + executor.get_execution_state.side_effect = ResourceNotFoundException( + "Execution not-found-arn not found" + ) + + # Create strongly-typed route using Router + router = Router() + typed_route = router.find_route( + "/2025-12-01/durable-executions/not-found-arn/state", "GET" + ) + + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={}, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify error response with AWS-compliant format + assert response.status_code == 404 + assert response.body["Type"] == "ResourceNotFoundException" + assert response.body["Message"] == "Execution not-found-arn not found" + + +def test_get_durable_execution_state_handler_invalid_parameter(): + """Test GetDurableExecutionStateHandler with IllegalArgumentException.""" + + executor = Mock() + handler = GetDurableExecutionStateHandler(executor) + + # Mock executor to raise IllegalArgumentException + executor.get_execution_state.side_effect = IllegalArgumentException( + "Invalid checkpoint token" + ) + + # Create strongly-typed route using Router + router = Router() + typed_route = router.find_route( + "/2025-12-01/durable-executions/test-arn/state", "GET" + ) + + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={}, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify error response with AWS-compliant format + assert response.status_code == 400 + assert response.body["Type"] == "InvalidParameterValueException" + assert response.body["message"] == "Invalid checkpoint token" + + +def test_get_durable_execution_history_handler_success(): + """Test GetDurableExecutionHistoryHandler with successful history retrieval.""" + + executor = Mock() + handler = GetDurableExecutionHistoryHandler(executor) + + # Mock the executor response with events + mock_events = [ + Event( + event_type="ExecutionStarted", + event_timestamp="2023-01-01T00:00:00Z", + event_id=1, + operation_id="exec-1", + execution_started_details=ExecutionStartedDetails(), + ) + ] + mock_response = GetDurableExecutionHistoryResponse( + events=mock_events, next_marker=None + ) + executor.get_execution_history.return_value = mock_response + + # Create strongly-typed route using Router + router = Router() + typed_route = router.find_route( + "/2025-12-01/durable-executions/test-arn/history", "GET" + ) + + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={"MaxItems": ["10"], "Marker": ["token-123"]}, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify response + assert response.status_code == 200 + assert "Events" in response.body + assert len(response.body["Events"]) == 1 + assert response.body["Events"][0]["EventType"] == "ExecutionStarted" + assert response.body["Events"][0]["EventId"] == 1 + + # Verify executor was called with correct parameters + executor.get_execution_history.assert_called_once_with( + "test-arn", + include_execution_data=False, + reverse_order=False, + marker="token-123", + max_items=10, + ) + + +def test_get_durable_execution_history_handler_resource_not_found(): + """Test GetDurableExecutionHistoryHandler with ResourceNotFoundException.""" + + executor = Mock() + handler = GetDurableExecutionHistoryHandler(executor) + + # Mock executor to raise ResourceNotFoundException + executor.get_execution_history.side_effect = ResourceNotFoundException( + "Execution not-found-arn not found" + ) + + # Create strongly-typed route using Router + router = Router() + typed_route = router.find_route( + "/2025-12-01/durable-executions/not-found-arn/history", "GET" + ) + + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={}, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify error response with AWS-compliant format + assert response.status_code == 404 + assert response.body["Type"] == "ResourceNotFoundException" + assert response.body["Message"] == "Execution not-found-arn not found" + + +def test_get_durable_execution_history_handler_with_query_params(): + """Test GetDurableExecutionHistoryHandler with query parameters.""" + + executor = Mock() + handler = GetDurableExecutionHistoryHandler(executor) + + # Mock the executor response + mock_response = GetDurableExecutionHistoryResponse(events=[], next_marker=None) + executor.get_execution_history.return_value = mock_response + + # Create strongly-typed route using Router + router = Router() + typed_route = router.find_route( + "/2025-12-01/durable-executions/test-arn/history", "GET" + ) + + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={"MaxItems": ["25"]}, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify response + assert response.status_code == 200 + assert response.body == {"Events": []} + + # Verify executor was called with correct parameters + executor.get_execution_history.assert_called_once_with( + "test-arn", + include_execution_data=False, + reverse_order=False, + marker=None, + max_items=25, + ) + + +def test_get_durable_execution_history_handler_with_include_execution_data(): + """Test GetDurableExecutionHistoryHandler with IncludeExecutionData parameter.""" + + executor = Mock() + handler = GetDurableExecutionHistoryHandler(executor) + + # Mock the executor response + mock_response = GetDurableExecutionHistoryResponse(events=[], next_marker=None) + executor.get_execution_history.return_value = mock_response + + # Create strongly-typed route using Router + router = Router() + typed_route = router.find_route( + "/2025-12-01/durable-executions/test-arn/history", "GET" + ) + + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={"IncludeExecutionData": ["true"], "MaxItems": ["1000"]}, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify response + assert response.status_code == 200 + assert response.body == {"Events": []} + + # Verify executor was called with include_execution_data=True + executor.get_execution_history.assert_called_once_with( + "test-arn", + include_execution_data=True, + reverse_order=False, + marker=None, + max_items=1000, + ) + + +def test_get_durable_execution_history_handler_with_include_execution_data_false(): + """Test GetDurableExecutionHistoryHandler with IncludeExecutionData=false.""" + + executor = Mock() + handler = GetDurableExecutionHistoryHandler(executor) + + # Mock the executor response + mock_response = GetDurableExecutionHistoryResponse(events=[], next_marker=None) + executor.get_execution_history.return_value = mock_response + + # Create strongly-typed route using Router + router = Router() + typed_route = router.find_route( + "/2025-12-01/durable-executions/test-arn/history", "GET" + ) + + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={"IncludeExecutionData": ["false"]}, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify response + assert response.status_code == 200 + assert response.body == {"Events": []} + + # Verify executor was called with include_execution_data=False + executor.get_execution_history.assert_called_once_with( + "test-arn", + include_execution_data=False, + reverse_order=False, + marker=None, + max_items=None, + ) + + +def test_list_durable_executions_handler_success(): + """Test ListDurableExecutionsHandler with successful execution listing.""" + executor = Mock() + handler = ListDurableExecutionsHandler(executor) + + # Mock the executor response + mock_executions = [ + ExecutionSummary( + durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:test-1", + durable_execution_name="test-execution-1", + function_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function", + status="SUCCEEDED", + start_timestamp="2023-01-01T00:00:00Z", + end_timestamp="2023-01-01T00:01:00Z", + ), + ExecutionSummary( + durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:test-2", + durable_execution_name="test-execution-2", + function_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function", + status="RUNNING", + start_timestamp="2023-01-01T00:02:00Z", + end_timestamp=None, + ), + ] + + mock_response = ListDurableExecutionsResponse( + durable_executions=mock_executions, + next_marker=None, + ) + executor.list_executions.return_value = mock_response + + # Create strongly-typed route + base_route = Route.from_string("/2025-12-01/durable-executions") + typed_route = ListDurableExecutionsRoute.from_route(base_route) + + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={}, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify response + assert response.status_code == 200 + expected_body = { + "DurableExecutions": [ + { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:test-1", + "DurableExecutionName": "test-execution-1", + "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function", + "Status": "SUCCEEDED", + "StartTimestamp": "2023-01-01T00:00:00Z", + "EndTimestamp": "2023-01-01T00:01:00Z", + }, + { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:test-2", + "DurableExecutionName": "test-execution-2", + "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function", + "Status": "RUNNING", + "StartTimestamp": "2023-01-01T00:02:00Z", + }, + ] + } + assert response.body == expected_body + + # Verify executor was called with correct parameters (all None for no filters) + executor.list_executions.assert_called_once_with( + function_name=None, + function_version=None, + execution_name=None, + status_filter=None, + started_after=None, + started_before=None, + marker=None, + max_items=None, + reverse_order=False, + ) + + +def test_list_durable_executions_handler_with_filters(): + """Test ListDurableExecutionsHandler with query parameter filters.""" + executor = Mock() + handler = ListDurableExecutionsHandler(executor) + + # Mock the executor response + mock_executions = [ + ExecutionSummary( + durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:filtered-1", + durable_execution_name="filtered-execution", + function_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function", + status="SUCCEEDED", + start_timestamp="2023-01-01T00:00:00Z", + end_timestamp="2023-01-01T00:01:00Z", + ), + ] + + mock_response = ListDurableExecutionsResponse( + durable_executions=mock_executions, + next_marker="next-page-token", + ) + executor.list_executions.return_value = mock_response + + # Create strongly-typed route + base_route = Route.from_string("/2025-12-01/durable-executions") + typed_route = ListDurableExecutionsRoute.from_route(base_route) + + # Create request with query parameters + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={ + "FunctionName": ["test-function"], + "FunctionVersion": ["$LATEST"], + "DurableExecutionName": ["filtered-execution"], + "StatusFilter": ["SUCCEEDED"], + "StartedAfter": ["2023-01-01T00:00:00Z"], + "StartedBefore": ["2023-01-01T23:59:59Z"], + "Marker": ["start-token"], + "MaxItems": ["10"], + "ReverseOrder": ["true"], + }, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify response + assert response.status_code == 200 + expected_body = { + "DurableExecutions": [ + { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:filtered-1", + "DurableExecutionName": "filtered-execution", + "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function", + "Status": "SUCCEEDED", + "StartTimestamp": "2023-01-01T00:00:00Z", + "EndTimestamp": "2023-01-01T00:01:00Z", + }, + ], + "NextMarker": "next-page-token", + } + assert response.body == expected_body + + # Verify executor was called with correct filtered parameters + executor.list_executions.assert_called_once_with( + function_name="test-function", + function_version="$LATEST", + execution_name="filtered-execution", + status_filter="SUCCEEDED", + started_after="2023-01-01T00:00:00Z", + started_before="2023-01-01T23:59:59Z", + marker="start-token", + max_items=10, + reverse_order=True, + ) + + +def test_list_durable_executions_handler_pagination(): + """Test ListDurableExecutionsHandler with pagination support.""" + executor = Mock() + handler = ListDurableExecutionsHandler(executor) + + # Mock the executor response with pagination + mock_executions = [ + ExecutionSummary( + durable_execution_arn=f"arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:page-{i}", + durable_execution_name=f"page-execution-{i}", + function_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function", + status="SUCCEEDED", + start_timestamp=f"2023-01-0{i}T00:00:00Z", + end_timestamp=f"2023-01-0{i}T00:01:00Z", + ) + for i in range(1, 4) # 3 executions + ] + + mock_response = ListDurableExecutionsResponse( + durable_executions=mock_executions, + next_marker="next-page-marker", + ) + executor.list_executions.return_value = mock_response + + # Create strongly-typed route + base_route = Route.from_string("/2025-12-01/durable-executions") + typed_route = ListDurableExecutionsRoute.from_route(base_route) + + # Create request with pagination parameters + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={ + "MaxItems": ["3"], + "Marker": ["current-page-marker"], + }, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify response includes pagination + assert response.status_code == 200 + assert len(response.body["DurableExecutions"]) == 3 + assert response.body["NextMarker"] == "next-page-marker" + + # Verify executor was called with pagination parameters + executor.list_executions.assert_called_once_with( + function_name=None, + function_version=None, + execution_name=None, + status_filter=None, + started_after=None, + started_before=None, + marker="current-page-marker", + max_items=3, + reverse_order=False, + ) + + +def test_list_durable_executions_handler_empty_results(): + """Test ListDurableExecutionsHandler with no executions found.""" + + executor = Mock() + handler = ListDurableExecutionsHandler(executor) + + # Mock empty executor response + mock_response = ListDurableExecutionsResponse( + durable_executions=[], + next_marker=None, + ) + executor.list_executions.return_value = mock_response + + # Create strongly-typed route + base_route = Route.from_string("/2025-12-01/durable-executions") + typed_route = ListDurableExecutionsRoute.from_route(base_route) + + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={}, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify response + assert response.status_code == 200 + assert response.body == {"DurableExecutions": []} + + # Verify executor was called + executor.list_executions.assert_called_once() + + +def test_list_durable_executions_handler_dataclass_serialization(): + """Test ListDurableExecutionsHandler uses from_dict/to_dict methods for serialization.""" + executor = Mock() + handler = ListDurableExecutionsHandler(executor) + + # Mock the executor response + mock_executions = [ + ExecutionSummary( + durable_execution_arn="test-arn", + durable_execution_name="test-execution", + function_arn="test-function-arn", + status="SUCCEEDED", + start_timestamp="2023-01-01T00:00:00Z", + end_timestamp="2023-01-01T00:01:00Z", + ), + ] + + mock_response = ListDurableExecutionsResponse( + durable_executions=mock_executions, + next_marker=None, + ) + executor.list_executions.return_value = mock_response + + # Create strongly-typed route + base_route = Route.from_string("/2025-12-01/durable-executions") + typed_route = ListDurableExecutionsRoute.from_route(base_route) + + # Create request with query parameters to test from_dict + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={ + "FunctionName": ["test-function"], + "MaxItems": ["5"], + }, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify response uses to_dict() serialization + assert response.status_code == 200 + assert "DurableExecutions" in response.body + assert isinstance(response.body["DurableExecutions"], list) + + # Verify the response structure matches to_dict() output + execution_data = response.body["DurableExecutions"][0] + assert execution_data["DurableExecutionArn"] == "test-arn" + assert execution_data["DurableExecutionName"] == "test-execution" + assert execution_data["Status"] == "SUCCEEDED" + + # Verify executor was called (implicitly tests from_dict was used for request parsing) + executor.list_executions.assert_called_once_with( + function_name="test-function", + function_version=None, + execution_name=None, + status_filter=None, + started_after=None, + started_before=None, + marker=None, + max_items=5, + reverse_order=False, + ) + + +def test_list_durable_executions_handler_invalid_parameter_error(): + """Test ListDurableExecutionsHandler with IllegalArgumentException from executor.""" + + executor = Mock() + handler = ListDurableExecutionsHandler(executor) + + # Mock executor to raise IllegalArgumentException + executor.list_executions.side_effect = IllegalArgumentException( + "Invalid MaxItems value" + ) + + # Create strongly-typed route + base_route = Route.from_string("/2025-12-01/durable-executions") + typed_route = ListDurableExecutionsRoute.from_route(base_route) + + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={ + "MaxItems": ["-1"], # Invalid value + }, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify error response with AWS-compliant format + assert response.status_code == 400 + assert response.body["Type"] == "InvalidParameterValueException" + assert response.body["message"] == "Invalid MaxItems value" + + +def test_list_durable_executions_handler_unexpected_error(): + """Test ListDurableExecutionsHandler with unexpected error from executor.""" + + executor = Mock() + handler = ListDurableExecutionsHandler(executor) + + # Mock executor to raise unexpected error + executor.list_executions.side_effect = RuntimeError("Database connection failed") + + # Create strongly-typed route + base_route = Route.from_string("/2025-12-01/durable-executions") + typed_route = ListDurableExecutionsRoute.from_route(base_route) + + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={}, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify error response with AWS-compliant format + assert response.status_code == 500 + assert response.body["Type"] == "ServiceException" + assert response.body["Message"] == "Database connection failed" + + +def test_list_durable_executions_handler_common_exception_handling(): + """Test ListDurableExecutionsHandler uses base class _handle_common_exceptions method.""" + + executor = Mock() + handler = ListDurableExecutionsHandler(executor) + + # Mock executor to raise ResourceNotFoundException + executor.list_executions.side_effect = ResourceNotFoundException( + "Function not found" + ) + + # Create strongly-typed route + base_route = Route.from_string("/2025-12-01/durable-executions") + typed_route = ListDurableExecutionsRoute.from_route(base_route) + + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={}, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify error response uses common exception handling with AWS-compliant format + assert response.status_code == 404 + assert response.body["Type"] == "ResourceNotFoundException" + assert response.body["Message"] == "Function not found" + + +def test_list_durable_executions_by_function_handler_success(): + """Test ListDurableExecutionsByFunctionHandler with successful execution listing.""" + + executor = Mock() + handler = ListDurableExecutionsByFunctionHandler(executor) + + # Mock the executor response + mock_executions = [ + ExecutionSummary( + durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:func-1", + durable_execution_name="function-execution-1", + function_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function", + status="SUCCEEDED", + start_timestamp="2023-01-01T00:00:00Z", + end_timestamp="2023-01-01T00:01:00Z", + ), + ExecutionSummary( + durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:func-2", + durable_execution_name="function-execution-2", + function_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function", + status="RUNNING", + start_timestamp="2023-01-01T00:02:00Z", + end_timestamp=None, + ), + ] + + mock_response = ListDurableExecutionsByFunctionResponse( + durable_executions=mock_executions, + next_marker=None, + ) + executor.list_executions_by_function.return_value = mock_response + + # Create strongly-typed route using Router + router = Router() + typed_route = router.find_route( + "/2025-12-01/functions/test-function/durable-executions", "GET" + ) + + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={}, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify response + assert response.status_code == 200 + expected_body = { + "DurableExecutions": [ + { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:func-1", + "DurableExecutionName": "function-execution-1", + "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function", + "Status": "SUCCEEDED", + "StartTimestamp": "2023-01-01T00:00:00Z", + "EndTimestamp": "2023-01-01T00:01:00Z", + }, + { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:func-2", + "DurableExecutionName": "function-execution-2", + "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function", + "Status": "RUNNING", + "StartTimestamp": "2023-01-01T00:02:00Z", + }, + ] + } + assert response.body == expected_body + + # Verify executor was called with correct function name + executor.list_executions_by_function.assert_called_once_with( + function_name="test-function", + qualifier=None, + execution_name=None, + status_filter=None, + started_after=None, + started_before=None, + marker=None, + max_items=None, + reverse_order=False, + ) + + +def test_list_durable_executions_by_function_handler_with_filters(): + """Test ListDurableExecutionsByFunctionHandler with query parameter filters.""" + + executor = Mock() + handler = ListDurableExecutionsByFunctionHandler(executor) + + # Mock the executor response + mock_executions = [ + ExecutionSummary( + durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:filtered", + durable_execution_name="filtered-execution", + function_arn="arn:aws:lambda:us-east-1:123456789012:function:test-function", + status="SUCCEEDED", + start_timestamp="2023-01-01T00:00:00Z", + end_timestamp="2023-01-01T00:01:00Z", + ), + ] + + mock_response = ListDurableExecutionsByFunctionResponse( + durable_executions=mock_executions, + next_marker="next-page-token", + ) + executor.list_executions_by_function.return_value = mock_response + + # Create strongly-typed route using Router + router = Router() + typed_route = router.find_route( + "/2025-12-01/functions/test-function/durable-executions", "GET" + ) + + # Create request with query parameters + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={ + "functionVersion": ["$LATEST"], + "executionName": ["filtered-execution"], + "statusFilter": ["SUCCEEDED"], + "startedAfter": ["2023-01-01T00:00:00Z"], + "startedBefore": ["2023-01-01T23:59:59Z"], + "marker": ["start-token"], + "maxItems": ["5"], + "reverseOrder": ["true"], + }, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify response + assert response.status_code == 200 + expected_body = { + "DurableExecutions": [ + { + "DurableExecutionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function:execution:filtered", + "DurableExecutionName": "filtered-execution", + "FunctionArn": "arn:aws:lambda:us-east-1:123456789012:function:test-function", + "Status": "SUCCEEDED", + "StartTimestamp": "2023-01-01T00:00:00Z", + "EndTimestamp": "2023-01-01T00:01:00Z", + }, + ], + "NextMarker": "next-page-token", + } + assert response.body == expected_body + + # Verify executor was called with correct filtered parameters + executor.list_executions_by_function.assert_called_once_with( + function_name="test-function", + qualifier="$LATEST", + execution_name="filtered-execution", + status_filter="SUCCEEDED", + started_after="2023-01-01T00:00:00Z", + started_before="2023-01-01T23:59:59Z", + marker="start-token", + max_items=5, + reverse_order=True, + ) + + +def test_list_durable_executions_by_function_handler_dataclass_serialization(): + """Test ListDurableExecutionsByFunctionHandler uses from_dict/to_dict methods for serialization.""" + + executor = Mock() + handler = ListDurableExecutionsByFunctionHandler(executor) + + # Mock the executor response + mock_executions = [ + ExecutionSummary( + durable_execution_arn="test-arn", + durable_execution_name="test-execution", + function_arn="test-function-arn", + status="SUCCEEDED", + start_timestamp="2023-01-01T00:00:00Z", + end_timestamp="2023-01-01T00:01:00Z", + ), + ] + + mock_response = ListDurableExecutionsByFunctionResponse( + durable_executions=mock_executions, + next_marker=None, + ) + executor.list_executions_by_function.return_value = mock_response + + # Create strongly-typed route using Router + router = Router() + typed_route = router.find_route( + "/2025-12-01/functions/test-function/durable-executions", "GET" + ) + + # Create request with query parameters to test from_dict + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={ + "functionVersion": ["$LATEST"], + "maxItems": ["10"], + }, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify response uses to_dict() serialization + assert response.status_code == 200 + assert "DurableExecutions" in response.body + assert isinstance(response.body["DurableExecutions"], list) + + # Verify the response structure matches to_dict() output + execution_data = response.body["DurableExecutions"][0] + assert execution_data["DurableExecutionArn"] == "test-arn" + assert execution_data["DurableExecutionName"] == "test-execution" + assert execution_data["Status"] == "SUCCEEDED" + + # Verify executor was called (implicitly tests from_dict was used for request parsing) + executor.list_executions_by_function.assert_called_once_with( + function_name="test-function", + qualifier="$LATEST", + execution_name=None, + status_filter=None, + started_after=None, + started_before=None, + marker=None, + max_items=10, + reverse_order=False, + ) + + +def test_list_durable_executions_by_function_handler_resource_not_found(): + """Test ListDurableExecutionsByFunctionHandler with ResourceNotFoundException.""" + + executor = Mock() + handler = ListDurableExecutionsByFunctionHandler(executor) + + # Mock executor to raise ResourceNotFoundException + executor.list_executions_by_function.side_effect = ResourceNotFoundException( + "Function not-found-function not found" + ) + + # Create strongly-typed route using Router + router = Router() + typed_route = router.find_route( + "/2025-12-01/functions/not-found-function/durable-executions", "GET" + ) + + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={}, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify error response uses common exception handling with AWS-compliant format + assert response.status_code == 404 + assert response.body["Type"] == "ResourceNotFoundException" + assert response.body["Message"] == "Function not-found-function not found" + + +def test_list_durable_executions_by_function_handler_common_exception_handling(): + """Test ListDurableExecutionsByFunctionHandler uses base class _handle_common_exceptions method.""" + + executor = Mock() + handler = ListDurableExecutionsByFunctionHandler(executor) + + # Mock executor to raise IllegalArgumentException + executor.list_executions_by_function.side_effect = IllegalArgumentException( + "Invalid function name format" + ) + + # Create strongly-typed route using Router + router = Router() + typed_route = router.find_route( + "/2025-12-01/functions/invalid-function/durable-executions", "GET" + ) + + request = HTTPRequest( + method="GET", + path=typed_route, + headers={}, + query_params={}, + body={}, + ) + + response = handler.handle(typed_route, request) + + # Verify error response uses common exception handling with AWS-compliant format + assert response.status_code == 400 + assert response.body["Type"] == "InvalidParameterValueException" + assert response.body["message"] == "Invalid function name format" + + +def test_send_durable_execution_callback_success_handler(): + """Test SendDurableExecutionCallbackSuccessHandler with valid request.""" + + executor = Mock() + executor.send_callback_success.return_value = ( + SendDurableExecutionCallbackSuccessResponse() + ) + handler = SendDurableExecutionCallbackSuccessHandler(executor) + + # Create route using Router + router = Router() + route = router.find_route( + "/2025-12-01/durable-execution-callbacks/test-callback-id/succeed", "POST" + ) + assert isinstance(route, CallbackSuccessRoute) + assert route.callback_id == "test-callback-id" + + # Result is sent as raw binary body + request_body = b"success-result" + + request = HTTPRequest( + method="POST", + path=route, + headers={}, + query_params={}, + body=request_body, + ) + + response = handler.handle(route, request) + + # Verify successful response + assert response.status_code == 200 + assert response.body == {} + + # Verify executor was called with correct parameters + executor.send_callback_success.assert_called_once_with( + callback_id="test-callback-id", result=b"success-result" + ) + + +def test_send_durable_execution_callback_success_handler_empty_body(): + """Test SendDurableExecutionCallbackSuccessHandler with empty body.""" + executor = Mock() + executor.send_callback_success.return_value = ( + SendDurableExecutionCallbackSuccessResponse() + ) + handler = SendDurableExecutionCallbackSuccessHandler(executor) + + base_route = Route.from_string( + "/2025-12-01/durable-execution-callbacks/test-id/succeed" + ) + callback_route = CallbackSuccessRoute.from_route(base_route) + + request = HTTPRequest( + method="POST", + path=callback_route, + headers={}, + query_params={}, + body=b"", + ) + + response = handler.handle(callback_route, request) + # Handler should accept empty body (Result is optional) and return 200 + assert response.status_code == 200 + assert response.body == {} + + # Verify executor was called with empty result + executor.send_callback_success.assert_called_once_with( + callback_id="test-id", result=b"" + ) + + +def test_send_durable_execution_callback_failure_handler(): + """Test SendDurableExecutionCallbackFailureHandler with valid request.""" + + executor = Mock() + executor.send_callback_failure.return_value = ( + SendDurableExecutionCallbackFailureResponse() + ) + handler = SendDurableExecutionCallbackFailureHandler(executor) + + # Create route using Router + router = Router() + route = router.find_route( + "/2025-12-01/durable-execution-callbacks/test-callback-id/fail", "POST" + ) + assert isinstance(route, CallbackFailureRoute) + assert route.callback_id == "test-callback-id" + + # Test with valid request body including error + error_data = { + "ErrorMessage": "Test error", + "ErrorType": "TestException", + "ErrorData": None, + "StackTrace": None, + } + request = HTTPRequest( + method="POST", + path=route, + headers={"Content-Type": "application/json"}, + query_params={}, + body=error_data, # Pass error data directly as body + ) + response = handler.handle(route, request) + + # Verify successful response + assert response.status_code == 200 + assert response.body == {} + + # Verify executor was called with correct parameters + executor.send_callback_failure.assert_called_once() + call_args = executor.send_callback_failure.call_args + assert call_args[1]["callback_id"] == "test-callback-id" + assert isinstance(call_args[1]["error"], ErrorObject) + assert call_args[1]["error"].message == "Test error" + + +def test_update_lambda_endpoint_handler_success(): + """Test UpdateLambdaEndpointHandler with valid request.""" + from aws_durable_execution_sdk_python_testing.invoker import LambdaInvoker + from aws_durable_execution_sdk_python_testing.web.handlers import ( + UpdateLambdaEndpointHandler, + ) + from aws_durable_execution_sdk_python_testing.web.routes import ( + UpdateLambdaEndpointRoute, + ) + + executor = Mock() + lambda_invoker = Mock(spec=LambdaInvoker) + executor._invoker = lambda_invoker # noqa: SLF001 + handler = UpdateLambdaEndpointHandler(executor) + + base_route = Route.from_string("/lambda-endpoint") + update_route = UpdateLambdaEndpointRoute.from_route(base_route) + + request = HTTPRequest( + method="PUT", + path=update_route, + headers={"Content-Type": "application/json"}, + query_params={}, + body={"EndpointUrl": "http://localhost:8080", "RegionName": "us-west-2"}, + ) + + response = handler.handle(update_route, request) + + assert response.status_code == 200 + assert response.body == {"message": "Lambda endpoint updated successfully"} + lambda_invoker.update_endpoint.assert_called_once_with( + "http://localhost:8080", "us-west-2" + ) + + +def test_update_lambda_endpoint_handler_missing_endpoint_url(): + """Test UpdateLambdaEndpointHandler with missing EndpointUrl.""" + from aws_durable_execution_sdk_python_testing.web.handlers import ( + UpdateLambdaEndpointHandler, + ) + from aws_durable_execution_sdk_python_testing.web.routes import ( + UpdateLambdaEndpointRoute, + ) + + executor = Mock() + handler = UpdateLambdaEndpointHandler(executor) + + base_route = Route.from_string("/lambda-endpoint") + update_route = UpdateLambdaEndpointRoute.from_route(base_route) + + request = HTTPRequest( + method="PUT", + path=update_route, + headers={"Content-Type": "application/json"}, + query_params={}, + body={"RegionName": "us-west-2"}, + ) + + response = handler.handle(update_route, request) + + assert response.status_code == 400 + assert response.body["Type"] == "InvalidParameterValueException" + assert response.body["message"] == "EndpointUrl is required" + + +def test_update_lambda_endpoint_handler_default_region(): + """Test UpdateLambdaEndpointHandler uses default region when not specified.""" + from aws_durable_execution_sdk_python_testing.invoker import LambdaInvoker + from aws_durable_execution_sdk_python_testing.web.handlers import ( + UpdateLambdaEndpointHandler, + ) + from aws_durable_execution_sdk_python_testing.web.routes import ( + UpdateLambdaEndpointRoute, + ) + + executor = Mock() + lambda_invoker = Mock(spec=LambdaInvoker) + executor._invoker = lambda_invoker # noqa: SLF001 + handler = UpdateLambdaEndpointHandler(executor) + + base_route = Route.from_string("/lambda-endpoint") + update_route = UpdateLambdaEndpointRoute.from_route(base_route) + + request = HTTPRequest( + method="PUT", + path=update_route, + headers={"Content-Type": "application/json"}, + query_params={}, + body={"EndpointUrl": "http://localhost:8080"}, + ) + + response = handler.handle(update_route, request) + + assert response.status_code == 200 + lambda_invoker.update_endpoint.assert_called_once_with( + "http://localhost:8080", "us-east-1" + ) + + +def test_send_durable_execution_callback_failure_handler_empty_body(): + """Test SendDurableExecutionCallbackFailureHandler with empty body.""" + executor = Mock() + handler = SendDurableExecutionCallbackFailureHandler(executor) + + base_route = Route.from_string( + "/2025-12-01/durable-execution-callbacks/test-id/fail" + ) + callback_route = CallbackFailureRoute.from_route(base_route) + + request = HTTPRequest( + method="POST", + path=callback_route, + headers={}, + query_params={}, + body={}, + ) + + response = handler.handle(callback_route, request) + # Handler should accept empty body for failure requests + assert response.status_code == 200 + + +def test_send_durable_execution_callback_heartbeat_handler(): + """Test SendDurableExecutionCallbackHeartbeatHandler with valid request.""" + + executor = Mock() + executor.send_callback_heartbeat.return_value = ( + SendDurableExecutionCallbackHeartbeatResponse() + ) + handler = SendDurableExecutionCallbackHeartbeatHandler(executor) + + # Create route using Router + router = Router() + route = router.find_route( + "/2025-12-01/durable-execution-callbacks/test-callback-id/heartbeat", "POST" + ) + assert isinstance(route, CallbackHeartbeatRoute) + assert route.callback_id == "test-callback-id" + + # Test with valid request body + request = HTTPRequest( + method="POST", + path=route, + headers={"Content-Type": "application/json"}, + query_params={}, + body={"CallbackId": "test-callback-id"}, + ) + response = handler.handle(route, request) + + # Verify successful response + assert response.status_code == 200 + assert response.body == {} + + # Verify executor was called with correct parameters + executor.send_callback_heartbeat.assert_called_once_with( + callback_id="test-callback-id" + ) + + +def test_send_durable_execution_callback_heartbeat_handler_empty_body(): + """Test SendDurableExecutionCallbackHeartbeatHandler with empty body.""" + executor = Mock() + handler = SendDurableExecutionCallbackHeartbeatHandler(executor) + + base_route = Route.from_string( + "/2025-12-01/durable-execution-callbacks/test-id/heartbeat" + ) + callback_route = CallbackHeartbeatRoute.from_route(base_route) + + request = HTTPRequest( + method="POST", + path=callback_route, + headers={}, + query_params={}, + body={}, + ) + + response = handler.handle(callback_route, request) + # Handler should accept empty body for heartbeat requests + assert response.status_code == 200 + + +def test_health_handler(): + """Test HealthHandler returns healthy status.""" + executor = Mock() + handler = HealthHandler(executor) + + request = HTTPRequest( + method="GET", + path=Route.from_string("/health"), + headers={}, + query_params={}, + body={}, + ) + + response = handler.handle(Route.from_string("/health"), request) + assert response.status_code == 200 + assert response.body == {"status": "healthy"} + + +def test_metrics_handler(): + """Test MetricsHandler returns empty metrics.""" + executor = Mock() + handler = MetricsHandler(executor) + + request = HTTPRequest( + method="GET", + path=Route.from_string("/metrics"), + headers={}, + query_params={}, + body={}, + ) + + response = handler.handle(Route.from_string("/metrics"), request) + assert response.status_code == 200 + assert response.body == {"metrics": {}} + + +def test_handler_naming_matches_smithy_operations(): + """Test that handler names match the Smithy operation names.""" + # Verify that all handlers are named after their corresponding Smithy operations + handler_names = [ + "StartExecutionHandler", # Note: This one doesn't have "Durable" prefix in Smithy + "GetDurableExecutionHandler", + "CheckpointDurableExecutionHandler", + "StopDurableExecutionHandler", + "GetDurableExecutionStateHandler", + "GetDurableExecutionHistoryHandler", + "ListDurableExecutionsHandler", + "ListDurableExecutionsByFunctionHandler", + "SendDurableExecutionCallbackSuccessHandler", + "SendDurableExecutionCallbackFailureHandler", + "SendDurableExecutionCallbackHeartbeatHandler", + "HealthHandler", + "MetricsHandler", + ] + + # Import the handlers module to check all classes exist + + for handler_name in handler_names: + assert hasattr(handlers, handler_name), f"Handler {handler_name} not found" + handler_class = getattr(handlers, handler_name) + assert issubclass(handler_class, EndpointHandler), ( + f"{handler_name} should inherit from EndpointHandler" + ) + + +def test_all_handlers_have_executor(): + """Test that all handlers store the executor reference.""" + executor = Mock() + + handlers_to_test = [ + StartExecutionHandler, + GetDurableExecutionHandler, + CheckpointDurableExecutionHandler, + StopDurableExecutionHandler, + GetDurableExecutionStateHandler, + GetDurableExecutionHistoryHandler, + ListDurableExecutionsHandler, + ListDurableExecutionsByFunctionHandler, + SendDurableExecutionCallbackSuccessHandler, + SendDurableExecutionCallbackFailureHandler, + SendDurableExecutionCallbackHeartbeatHandler, + HealthHandler, + MetricsHandler, + ] + + for handler_class in handlers_to_test: + handler = handler_class(executor) + assert handler.executor == executor, ( + f"{handler_class.__name__} should store executor reference" + ) + + +class MockExceptionHandler(EndpointHandler): + """Test handler that can trigger specific exception types for testing.""" + + def __init__( + self, executor: Executor, exception_to_raise: Exception | None = None + ) -> None: + super().__init__(executor) + self.exception_to_raise = exception_to_raise + + def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: + """Handle request by raising the configured exception.""" + if self.exception_to_raise: + if isinstance(self.exception_to_raise, AwsApiException): + return self._handle_aws_exception(self.exception_to_raise) + + return self._handle_framework_exception(self.exception_to_raise) + return self._success_response({"status": "ok"}) + + +def test_framework_exception_handling(): + """Test the framework exception handling through public API.""" + + executor = Mock() + + # Test ValueError handling - maps to InvalidParameterValueException + handler = MockExceptionHandler(executor, ValueError("Invalid input")) + response = handler.handle(Mock(), Mock()) + assert response.status_code == 400 + assert response.body["Type"] == "InvalidParameterValueException" + assert response.body["message"] == "Invalid input" + + # Test KeyError handling - maps to InvalidParameterValueException + handler = MockExceptionHandler(executor, KeyError("missing_field")) + response = handler.handle(Mock(), Mock()) + assert response.status_code == 400 + assert response.body["Type"] == "InvalidParameterValueException" + assert response.body["message"] == "'missing_field'" + + # Test unexpected exception handling - maps to ServiceException + handler = MockExceptionHandler(executor, RuntimeError("Unexpected error")) + response = handler.handle(Mock(), Mock()) + assert response.status_code == 500 + assert response.body["Type"] == "ServiceException" + assert response.body["Message"] == "Unexpected error" + + +def test_aws_exception_handling(): + """Test the AWS exception handling through public API.""" + + executor = Mock() + + # Test ResourceNotFoundException handling + handler = MockExceptionHandler( + executor, ResourceNotFoundException("Resource not found") + ) + response = handler.handle(Mock(), Mock()) + assert response.status_code == 404 + assert response.body["Type"] == "ResourceNotFoundException" + assert response.body["Message"] == "Resource not found" + + # Test IllegalArgumentException handling + handler = MockExceptionHandler( + executor, IllegalArgumentException("Invalid parameter") + ) + response = handler.handle(Mock(), Mock()) + assert response.status_code == 400 + assert response.body["Type"] == "InvalidParameterValueException" + assert response.body["message"] == "Invalid parameter" + + +def test_send_durable_execution_callback_success_handler_invalid_callback_id(): + """Test SendDurableExecutionCallbackSuccessHandler with invalid callback ID.""" + + executor = Mock() + executor.send_callback_success.side_effect = IllegalArgumentException( + "callback_id is required" + ) + handler = SendDurableExecutionCallbackSuccessHandler(executor) + + # Create route using Router + router = Router() + route = router.find_route( + "/2025-12-01/durable-execution-callbacks/test-callback-id/succeed", "POST" + ) + + request = HTTPRequest( + method="POST", + path=route, + headers={"Content-Type": "application/json"}, + query_params={}, + body={"CallbackId": "test-callback-id"}, + ) + + response = handler.handle(route, request) + + # Verify error response with AWS-compliant format + assert response.status_code == 400 + assert response.body["Type"] == "InvalidParameterValueException" + assert "callback_id is required" in response.body["message"] + + +def test_send_durable_execution_callback_success_handler_callback_state_conflict(): + """Test SendDurableExecutionCallbackSuccessHandler with callback state conflict.""" + + executor = Mock() + executor.send_callback_success.side_effect = IllegalStateException( + "Callback already completed" + ) + handler = SendDurableExecutionCallbackSuccessHandler(executor) + + # Create route using Router + router = Router() + route = router.find_route( + "/2025-12-01/durable-execution-callbacks/test-callback-id/succeed", "POST" + ) + + request = HTTPRequest( + method="POST", + path=route, + headers={"Content-Type": "application/json"}, + query_params={}, + body={"CallbackId": "test-callback-id"}, + ) + + response = handler.handle(route, request) + + # Verify error response - IllegalStateException in callback context maps to ExecutionConflictException + assert response.status_code == 409 + assert response.body["Type"] == "ExecutionConflictException" + assert response.body["message"] == "Callback already completed" + + +def test_send_durable_execution_callback_failure_handler_callback_state_conflict(): + """Test SendDurableExecutionCallbackFailureHandler with callback state conflict.""" + + executor = Mock() + executor.send_callback_failure.side_effect = IllegalStateException( + "Callback already completed" + ) + handler = SendDurableExecutionCallbackFailureHandler(executor) + + # Create route using Router + router = Router() + route = router.find_route( + "/2025-12-01/durable-execution-callbacks/test-callback-id/fail", "POST" + ) + + request = HTTPRequest( + method="POST", + path=route, + headers={"Content-Type": "application/json"}, + query_params={}, + body={"CallbackId": "test-callback-id"}, + ) + + response = handler.handle(route, request) + + # Verify error response - IllegalStateException in callback context maps to ExecutionConflictException + assert response.status_code == 409 + assert response.body["Type"] == "ExecutionConflictException" + assert response.body["message"] == "Callback already completed" + + +def test_send_durable_execution_callback_heartbeat_handler_callback_state_conflict(): + """Test SendDurableExecutionCallbackHeartbeatHandler with callback state conflict.""" + + executor = Mock() + executor.send_callback_heartbeat.side_effect = IllegalStateException( + "Callback already completed" + ) + handler = SendDurableExecutionCallbackHeartbeatHandler(executor) + + # Create route using Router + router = Router() + route = router.find_route( + "/2025-12-01/durable-execution-callbacks/test-callback-id/heartbeat", "POST" + ) + + request = HTTPRequest( + method="POST", + path=route, + headers={"Content-Type": "application/json"}, + query_params={}, + body={"CallbackId": "test-callback-id"}, + ) + + response = handler.handle(route, request) + + # Verify error response - IllegalStateException in callback context maps to ExecutionConflictException + assert response.status_code == 409 + assert response.body["Type"] == "ExecutionConflictException" + assert response.body["message"] == "Callback already completed" + + +def test_callback_handlers_use_dataclass_serialization(): + """Test that all callback handlers use dataclass from_dict/to_dict methods.""" + + # Test that all callback request dataclasses have from_dict/to_dict methods + success_request = SendDurableExecutionCallbackSuccessRequest.from_dict( + {"CallbackId": "test-id", "Result": "test-result"} + ) + assert success_request.callback_id == "test-id" + assert success_request.result == "test-result" + assert success_request.to_dict() == { + "CallbackId": "test-id", + "Result": "test-result", + } + + failure_request = SendDurableExecutionCallbackFailureRequest.from_dict( + {}, "test-id" + ) + assert failure_request.callback_id == "test-id" + assert failure_request.error is None + assert failure_request.to_dict() == {"CallbackId": "test-id"} + + heartbeat_request = SendDurableExecutionCallbackHeartbeatRequest.from_dict( + {"CallbackId": "test-id"} + ) + assert heartbeat_request.callback_id == "test-id" + assert heartbeat_request.to_dict() == {"CallbackId": "test-id"} diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/web/models_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/web/models_test.py new file mode 100644 index 0000000..8148736 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/web/models_test.py @@ -0,0 +1,796 @@ +"""Tests for HTTP request/response data models and utilities.""" + +from __future__ import annotations + +import datetime +import json +from unittest.mock import Mock, patch + +import pytest + +from aws_durable_execution_sdk_python_testing.exceptions import ( + CallbackTimeoutException, + ExecutionAlreadyStartedException, + IllegalArgumentException, + IllegalStateException, + InvalidParameterValueException, + ResourceNotFoundException, + ServiceException, + TooManyRequestsException, +) +from aws_durable_execution_sdk_python_testing.web.models import ( + HTTPRequest, + HTTPResponse, + OperationHandler, +) +from aws_durable_execution_sdk_python_testing.web.routes import Route + + +def test_http_request_creation() -> None: + """Test HTTPRequest dataclass creation.""" + path = Route.from_string("/test/path") + request = HTTPRequest( + method="GET", + path=path, + headers={"Content-Type": "application/json"}, + query_params={"param1": ["value1"], "param2": ["value2a", "value2b"]}, + body={"test": "data"}, + ) + + assert request.method == "GET" + assert request.path == path + assert request.headers == {"Content-Type": "application/json"} + assert request.query_params == { + "param1": ["value1"], + "param2": ["value2a", "value2b"], + } + assert request.body == {"test": "data"} + + +def test_http_request_immutable() -> None: + """Test that HTTPRequest is immutable.""" + path = Route.from_string("/test/path") + request = HTTPRequest(method="GET", path=path, headers={}, query_params={}, body={}) + + # Should not be able to modify fields + with pytest.raises(AttributeError): + request.method = "POST" # type: ignore + + +def test_http_response_creation() -> None: + """Test HTTPResponse dataclass creation.""" + response = HTTPResponse( + status_code=200, + headers={"Content-Type": "application/json"}, + body={"result": "success"}, + ) + + assert response.status_code == 200 + assert response.headers == {"Content-Type": "application/json"} + assert response.body == {"result": "success"} + + +def test_http_response_immutable() -> None: + """Test that HTTPResponse is immutable.""" + response = HTTPResponse(status_code=200, headers={}, body={}) + + # Should not be able to modify fields + with pytest.raises(AttributeError): + response.status_code = 404 # type: ignore + + +def test_http_response_json_basic() -> None: + """Test creating basic JSON response.""" + data = {"message": "success", "id": 123} + response = HTTPResponse.create_json(200, data) + + assert response.status_code == 200 + assert response.headers["Content-Type"] == "application/json" + + # Verify the body is stored as dict + assert response.body == data + + # Verify serialization to bytes works + body_bytes = response.body_to_bytes() + parsed_body = json.loads(body_bytes.decode("utf-8")) + assert parsed_body == data + + +def test_http_response_json_with_additional_headers() -> None: + """Test creating JSON response with additional headers.""" + data = {"result": "ok"} + additional_headers = { + "X-Custom-Header": "custom-value", + "Cache-Control": "no-cache", + } + + response = HTTPResponse.create_json(201, data, additional_headers) + + assert response.status_code == 201 + assert response.headers["Content-Type"] == "application/json" + assert response.headers["X-Custom-Header"] == "custom-value" + assert response.headers["Cache-Control"] == "no-cache" + + # Verify the body is stored as dict + assert response.body == data + + # Verify serialization to bytes works + body_bytes = response.body_to_bytes() + parsed_body = json.loads(body_bytes.decode("utf-8")) + assert parsed_body == data + + +def test_http_response_json_compact_serialization() -> None: + """Test that JSON response uses compact serialization.""" + data = {"key": "value", "nested": {"inner": "data"}} + response = HTTPResponse.create_json(200, data) + + # Verify the body is stored as dict + assert response.body == data + + # Verify serialization to bytes uses compact format + body_bytes = response.body_to_bytes() + body_str = body_bytes.decode("utf-8") + assert " " not in body_str # No spaces after separators + assert "\n" not in body_str # No newlines + + +# Removed deprecated tests for create_error method + + +def test_http_response_empty_basic() -> None: + """Test creating basic empty response.""" + response = HTTPResponse.create_empty(204) + + assert response.status_code == 204 + assert response.headers == {} + assert response.body == {} + + +def test_http_response_empty_with_headers() -> None: + """Test creating empty response with additional headers.""" + additional_headers = {"Location": "/new-resource", "X-Request-ID": "123"} + response = HTTPResponse.create_empty(201, additional_headers) + + assert response.status_code == 201 + assert response.headers == additional_headers + assert response.body == {} + + +def test_operation_handler_protocol() -> None: + """Test that OperationHandler protocol works correctly.""" + + class TestHandler: + def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: + return HTTPResponse( + status_code=200, + headers={"Content-Type": "text/plain"}, + body={"message": "handled"}, + ) + + # Should be able to use as OperationHandler + handler: OperationHandler = TestHandler() + + path = Route.from_string("/test") + request = HTTPRequest(method="GET", path=path, headers={}, query_params={}, body={}) + + response = handler.handle(path, request) + assert response.status_code == 200 + assert response.body == {"message": "handled"} + + +def test_operation_handler_protocol_type_checking() -> None: + """Test that OperationHandler protocol enforces correct signature.""" + + class InvalidHandler: + def handle(self, wrong_params: str) -> str: # Wrong signature + return "invalid" + + # This should work at runtime but would fail type checking + # We can't test static type checking in unit tests, but this documents the expected behavior + invalid_handler = InvalidHandler() + + # The protocol is structural, so this would work at runtime + # but mypy would catch the type mismatch + assert hasattr(invalid_handler, "handle") + + +def test_http_response_edge_cases() -> None: + """Test edge cases for HTTP response factory methods.""" + + # Test with empty data + response = HTTPResponse.create_json(200, {}) + assert response.body == {} + + # Test with complex nested data + complex_data = { + "list": [1, 2, 3], + "nested": {"deep": {"value": True}}, + "null": None, + "unicode": "🚀", + } + response = HTTPResponse.create_json(200, complex_data) + assert response.body == complex_data + + # Verify serialization to bytes works + body_bytes = response.body_to_bytes() + parsed = json.loads(body_bytes.decode("utf-8")) + assert parsed == complex_data + + +def test_http_request_with_empty_collections() -> None: + """Test HTTPRequest with empty collections.""" + path = Route.from_string("/empty") + request = HTTPRequest(method="GET", path=path, headers={}, query_params={}, body={}) + + assert request.headers == {} + assert request.query_params == {} + assert request.body == {} + + +def test_http_response_with_empty_collections() -> None: + """Test HTTPResponse with empty collections.""" + response = HTTPResponse(status_code=204, headers={}, body={}) + + assert response.headers == {} + assert response.body == {} + + +# Tests for HTTPRequest.from_bytes method + + +def test_http_request_from_bytes_standard_json() -> None: + """Test HTTPRequest.from_bytes with standard JSON deserialization.""" + test_data = {"key": "value", "number": 42} + body_bytes = json.dumps(test_data).encode("utf-8") + + path = Route.from_string("/test") + request = HTTPRequest.from_bytes( + body_bytes=body_bytes, + method="POST", + path=path, + headers={"Content-Type": "application/json"}, + query_params={"param": ["value"]}, + ) + + assert request.method == "POST" + assert request.path == path + assert request.headers == {"Content-Type": "application/json"} + assert request.query_params == {"param": ["value"]} + assert request.body == test_data + + +def test_http_get_request_from_bytes_ignore_body() -> None: + """Test HTTPRequest.from_bytes with standard JSON deserialization.""" + test_data = {"key": "value", "number": 42} + body_bytes = json.dumps(test_data).encode("utf-8") + + path = Route.from_string("/test") + request = HTTPRequest.from_bytes( + body_bytes=body_bytes, + method="GET", + path=path, + headers={"Content-Type": "application/json"}, + query_params={"param": ["value"]}, + ) + + assert request.method == "GET" + assert request.path == path + assert request.headers == {"Content-Type": "application/json"} + assert request.query_params == {"param": ["value"]} + assert request.body == {} + + +def test_http_request_from_bytes_minimal_params() -> None: + """Test HTTPRequest.from_bytes with minimal parameters.""" + test_data = {"message": "hello"} + body_bytes = json.dumps(test_data).encode("utf-8") + + request = HTTPRequest.from_bytes(body_bytes=body_bytes) + + assert request.method == "POST" # Default + assert request.path.raw_path == "" # Default empty route + assert request.headers == {} # Default + assert request.query_params == {} # Default + assert request.body == test_data + + +def test_http_request_from_bytes_aws_operation_fallback() -> None: + """Test HTTPRequest.from_bytes with AWS operation that falls back to JSON.""" + test_data = {"Input": "test-input", "ExecutionName": "test-execution"} + body_bytes = json.dumps(test_data).encode("utf-8") + + # Use a non-existent operation name to trigger fallback + request = HTTPRequest.from_bytes( + body_bytes=body_bytes, + operation_name="NonExistentOperation", + method="POST", + ) + + assert request.method == "POST" + assert request.body == test_data + + +def test_http_request_from_bytes_invalid_json() -> None: + """Test HTTPRequest.from_bytes with invalid JSON raises InvalidParameterValueException.""" + invalid_json = b'{"invalid": json}' + + with pytest.raises( + InvalidParameterValueException, match="JSON deserialization failed" + ): + HTTPRequest.from_bytes(body_bytes=invalid_json) + + +def test_http_request_from_bytes_invalid_utf8() -> None: + """Test HTTPRequest.from_bytes with invalid UTF-8 raises InvalidParameterValueException.""" + invalid_utf8 = b'\xff\xfe{"test": "data"}' # Invalid UTF-8 BOM + + with pytest.raises( + InvalidParameterValueException, match="JSON deserialization failed" + ): + HTTPRequest.from_bytes(body_bytes=invalid_utf8) + + +def test_http_request_from_bytes_empty_body() -> None: + """Test HTTPRequest.from_bytes with empty body.""" + empty_body = b"{}" + + request = HTTPRequest.from_bytes(body_bytes=empty_body) + + assert request.body == {} + + +def test_http_request_from_bytes_complex_json() -> None: + """Test HTTPRequest.from_bytes with complex nested JSON.""" + complex_data = { + "list": [1, 2, 3], + "nested": {"deep": {"value": True}}, + "null": None, + "unicode": "🚀", + } + body_bytes = json.dumps(complex_data).encode("utf-8") + + request = HTTPRequest.from_bytes(body_bytes=body_bytes) + + assert request.body == complex_data + + +def test_http_request_from_bytes_aws_operation_success() -> None: + """Test HTTPRequest.from_bytes with valid AWS operation (if available).""" + # This test will use AWS deserialization if available, otherwise fall back to JSON + test_data = { + "Input": "test-input", + "ExecutionName": "test-execution", + "FunctionName": "test-function", + } + body_bytes = json.dumps(test_data).encode("utf-8") + + # Try with a real AWS operation name + request = HTTPRequest.from_bytes( + body_bytes=body_bytes, + operation_name="StartDurableExecution", + method="POST", + ) + + assert request.method == "POST" + assert request.body is not None + # The exact structure may vary depending on AWS deserialization vs JSON fallback + # but we should get some valid dict data + + +def test_http_request_from_bytes_preserves_field_names() -> None: + """Test that from_bytes preserves field names from the input.""" + # Test with AWS-style PascalCase field names + aws_style_data = { + "ExecutionName": "test-execution", + "FunctionName": "my-function", + "Input": {"Key": "Value"}, + } + body_bytes = json.dumps(aws_style_data).encode("utf-8") + + request = HTTPRequest.from_bytes(body_bytes=body_bytes) + + # Field names should be preserved as-is + assert isinstance(request.body, dict) + assert "ExecutionName" in request.body + assert "FunctionName" in request.body + assert request.body["ExecutionName"] == "test-execution" + assert request.body["FunctionName"] == "my-function" + assert request.body["Input"]["Key"] == "Value" + + +# Tests for HTTPResponse.body_to_bytes method + + +def test_http_response_body_to_bytes_standard_json() -> None: + """Test HTTPResponse.body_to_bytes with standard JSON serialization.""" + test_data = {"message": "success", "id": 123} + response = HTTPResponse( + status_code=200, + headers={"Content-Type": "application/json"}, + body=test_data, + ) + + body_bytes = response.body_to_bytes() + + # Verify it's bytes + assert isinstance(body_bytes, bytes) + + # Verify content is correct + parsed_data = json.loads(body_bytes.decode("utf-8")) + assert parsed_data == test_data + + +def test_http_response_body_to_bytes_compact_format() -> None: + """Test that body_to_bytes uses compact JSON format.""" + test_data = {"key": "value", "nested": {"inner": "data"}} + response = HTTPResponse(status_code=200, headers={}, body=test_data) + + body_bytes = response.body_to_bytes() + body_str = body_bytes.decode("utf-8") + + # Should not contain extra whitespace + assert " " not in body_str # No spaces after separators + assert "\n" not in body_str # No newlines + + +def test_http_response_body_to_bytes_empty_body() -> None: + """Test body_to_bytes with empty body.""" + response = HTTPResponse(status_code=204, headers={}, body={}) + + body_bytes = response.body_to_bytes() + + assert body_bytes == b"{}" + + +def test_http_response_body_to_bytes_complex_data() -> None: + """Test body_to_bytes with complex nested data.""" + complex_data = { + "list": [1, 2, 3], + "nested": {"deep": {"value": True}}, + "null": None, + "unicode": "🚀", + } + response = HTTPResponse(status_code=200, headers={}, body=complex_data) + + body_bytes = response.body_to_bytes() + parsed_data = json.loads(body_bytes.decode("utf-8")) + + assert parsed_data == complex_data + + +# Tests for HTTPResponse.from_dict method + + +def test_http_response_from_dict_basic() -> None: + """Test HTTPResponse.from_dict with basic parameters.""" + test_data = {"message": "success", "id": 123} + + response = HTTPResponse.from_dict(test_data) + + assert response.status_code == 200 # Default + assert response.headers == {} # Default + assert response.body == test_data + + +def test_http_response_from_dict_with_status_code() -> None: + """Test HTTPResponse.from_dict with custom status code.""" + test_data = {"error": "not found"} + + response = HTTPResponse.from_dict(test_data, status_code=404) + + assert response.status_code == 404 + assert response.headers == {} + assert response.body == test_data + + +def test_http_response_from_dict_with_headers() -> None: + """Test HTTPResponse.from_dict with custom headers.""" + test_data = {"result": "ok"} + headers = {"Content-Type": "application/json", "X-Custom": "value"} + + response = HTTPResponse.from_dict(test_data, headers=headers) + + assert response.status_code == 200 + assert response.headers == headers + assert response.body == test_data + + +def test_http_response_from_dict_with_all_params() -> None: + """Test HTTPResponse.from_dict with all parameters.""" + test_data = {"data": "test"} + headers = {"Content-Type": "application/json"} + + response = HTTPResponse.from_dict(test_data, status_code=201, headers=headers) + + assert response.status_code == 201 + assert response.headers == headers + assert response.body == test_data + + +def test_http_response_from_dict_empty_data() -> None: + """Test HTTPResponse.from_dict with empty data.""" + response = HTTPResponse.from_dict({}) + + assert response.status_code == 200 + assert response.headers == {} + assert response.body == {} + + +def test_http_response_from_dict_complex_data() -> None: + """Test HTTPResponse.from_dict with complex nested data.""" + complex_data = { + "list": [1, 2, 3], + "nested": {"deep": {"value": True}}, + "null": None, + "unicode": "🚀", + } + + response = HTTPResponse.from_dict(complex_data) + + assert response.body == complex_data + + +def test_http_response_from_dict_immutable() -> None: + """Test that HTTPResponse.from_dict creates immutable response.""" + test_data = {"key": "value"} + response = HTTPResponse.from_dict(test_data) + + # Should not be able to modify fields + with pytest.raises(AttributeError): + response.status_code = 404 # type: ignore + + +def test_http_response_from_dict_integration_with_body_to_bytes() -> None: + """Test that from_dict works with body_to_bytes method.""" + test_data = {"message": "integration test", "success": True} + + response = HTTPResponse.from_dict(test_data, status_code=201) + body_bytes = response.body_to_bytes() + + # Verify round-trip serialization + parsed_data = json.loads(body_bytes.decode("utf-8")) + assert parsed_data == test_data + assert response.status_code == 201 + + +def test_http_request_from_bytes_aws_deserialization_success() -> None: + """Test HTTPRequest.from_bytes with successful AWS deserialization.""" + test_data = {"ExecutionName": "test-execution", "Input": "test-input"} + body_bytes = json.dumps(test_data).encode("utf-8") + + # Mock successful AWS deserialization + mock_deserializer = Mock() + mock_deserializer.from_bytes.return_value = test_data + + with patch( + "aws_durable_execution_sdk_python_testing.web.models.AwsRestJsonDeserializer.create", + return_value=mock_deserializer, + ): + request = HTTPRequest.from_bytes( + body_bytes=body_bytes, operation_name="StartDurableExecution" + ) + + assert request.body == test_data + mock_deserializer.from_bytes.assert_called_once_with(body_bytes) + + +def test_http_request_from_bytes_aws_deserialization_fallback_error() -> None: + """Test HTTPRequest.from_bytes when both AWS and JSON deserialization fail.""" + + invalid_bytes = b"invalid json data" + + # Mock AWS deserialization failure + mock_deserializer = Mock() + mock_deserializer.from_bytes.side_effect = InvalidParameterValueException( + "AWS failed" + ) + + with patch( + "aws_durable_execution_sdk_python_testing.web.models.AwsRestJsonDeserializer.create", + return_value=mock_deserializer, + ): + with pytest.raises( + InvalidParameterValueException, + match="Both AWS and JSON deserialization failed", + ): + HTTPRequest.from_bytes( + body_bytes=invalid_bytes, operation_name="StartDurableExecution" + ) + + +def test_http_response_body_to_bytes_serialization_error() -> None: + """Test HTTPResponse.body_to_bytes when JSON serialization fail.""" + + # Create data that can't be JSON serialized + class CustomObject: + pass + + test_data = {"custom": CustomObject()} + response = HTTPResponse(status_code=200, headers={}, body=test_data) + + with pytest.raises( + InvalidParameterValueException, + match="Failed to serialize data to JSON: Object of type CustomObject is not JSON serializable", + ): + response.body_to_bytes() + + +# Tests for HTTPResponse.create_error_from_exception method + + +def test_create_error_from_exception_invalid_parameter_value() -> None: + """Test create_error_from_exception with InvalidParameterValueException.""" + + exception = InvalidParameterValueException("Parameter 'name' is required") + response = HTTPResponse.create_error_from_exception(exception) + + assert response.status_code == 400 + assert response.headers["Content-Type"] == "application/json" + + expected_body = { + "Type": "InvalidParameterValueException", + "message": "Parameter 'name' is required", + } + assert response.body == expected_body + + +def test_create_error_from_exception_resource_not_found() -> None: + """Test create_error_from_exception with ResourceNotFoundException.""" + + exception = ResourceNotFoundException("Execution not found") + response = HTTPResponse.create_error_from_exception(exception) + + assert response.status_code == 404 + assert response.headers["Content-Type"] == "application/json" + + expected_body = { + "Type": "ResourceNotFoundException", + "Message": "Execution not found", + } + assert response.body == expected_body + + +def test_create_error_from_exception_service_exception() -> None: + """Test create_error_from_exception with ServiceException.""" + + exception = ServiceException("Internal server error") + response = HTTPResponse.create_error_from_exception(exception) + + assert response.status_code == 500 + assert response.headers["Content-Type"] == "application/json" + + expected_body = {"Type": "ServiceException", "Message": "Internal server error"} + assert response.body == expected_body + + +def test_create_error_from_exception_execution_already_started() -> None: + """Test create_error_from_exception with ExecutionAlreadyStartedException.""" + + arn = "arn:aws:lambda:us-east-1:123456789012:function:test" + exception = ExecutionAlreadyStartedException("Execution already exists", arn) + response = HTTPResponse.create_error_from_exception(exception) + + assert response.status_code == 409 + assert response.headers["Content-Type"] == "application/json" + + # ExecutionAlreadyStartedException has no Type field per Smithy definition + expected_body = {"message": "Execution already exists", "DurableExecutionArn": arn} + assert response.body == expected_body + + +def test_create_error_from_exception_callback_timeout() -> None: + """Test create_error_from_exception with CallbackTimeoutException.""" + + exception = CallbackTimeoutException("Callback timed out") + response = HTTPResponse.create_error_from_exception(exception) + + assert response.status_code == 408 + assert response.headers["Content-Type"] == "application/json" + + expected_body = { + "Type": "CallbackTimeoutException", + "message": "Callback timed out", + } + assert response.body == expected_body + + +def test_create_error_from_exception_too_many_requests() -> None: + """Test create_error_from_exception with TooManyRequestsException.""" + + exception = TooManyRequestsException("Rate limit exceeded") + response = HTTPResponse.create_error_from_exception(exception) + + assert response.status_code == 429 + assert response.headers["Content-Type"] == "application/json" + + expected_body = { + "Type": "TooManyRequestsException", + "message": "Rate limit exceeded", + } + assert response.body == expected_body + + +def test_create_error_from_exception_illegal_state() -> None: + """Test create_error_from_exception with IllegalStateException (unmapped).""" + + exception = IllegalStateException("Invalid state transition") + response = HTTPResponse.create_error_from_exception(exception) + + assert response.status_code == 500 + assert response.headers["Content-Type"] == "application/json" + + # IllegalStateException maps to ServiceException when serialized + expected_body = {"Type": "ServiceException", "Message": "Invalid state transition"} + assert response.body == expected_body + + +def test_create_error_from_exception_runtime_exception() -> None: + """Test create_error_from_exception with RuntimeException (unmapped).""" + + exception = IllegalArgumentException("Invalid argument provided") + response = HTTPResponse.create_error_from_exception(exception) + + assert response.status_code == 400 + assert response.headers["Content-Type"] == "application/json" + + # IllegalArgumentException maps to InvalidParameterValueException when serialized + expected_body = { + "Type": "InvalidParameterValueException", + "message": "Invalid argument provided", + } + assert response.body == expected_body + + +def test_create_error_from_exception_type_validation() -> None: + """Test create_error_from_exception with non-AwsApiException raises TypeError.""" + # Test with regular Exception + regular_exception = Exception("Not an AWS exception") + + with pytest.raises( + TypeError, match="Expected AwsApiException, got " + ): + HTTPResponse.create_error_from_exception(regular_exception) # type: ignore + + # Test with AWS API exception (should work fine) + + framework_exception = InvalidParameterValueException("Framework error") + + # This should NOT raise an error since InvalidParameterValueException is an AwsApiException + response = HTTPResponse.create_error_from_exception(framework_exception) + assert response.status_code == 400 + + +def test_create_error_from_exception_no_wrapper_object() -> None: + """Test that create_error_from_exception doesn't add wrapper 'error' object.""" + + exception = InvalidParameterValueException("Test message") + response = HTTPResponse.create_error_from_exception(exception) + + # Should NOT have wrapper "error" object like the old create_error method + assert "error" not in response.body + + # Should have direct AWS-compliant structure + assert "Type" in response.body + assert "message" in response.body + assert response.body["Type"] == "InvalidParameterValueException" + assert response.body["message"] == "Test message" + + +def test_create_error_from_exception_serialization_round_trip() -> None: + """Test that create_error_from_exception produces serializable responses.""" + + exception = ResourceNotFoundException("Resource not found") + response = HTTPResponse.create_error_from_exception(exception) + + # Should be able to serialize to bytes + body_bytes = response.body_to_bytes() + + # Should be valid JSON + parsed_body = json.loads(body_bytes.decode("utf-8")) + + expected_body = { + "Type": "ResourceNotFoundException", + "Message": "Resource not found", + } + assert parsed_body == expected_body diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/web/routes_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/web/routes_test.py new file mode 100644 index 0000000..cc1be9b --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/web/routes_test.py @@ -0,0 +1,1187 @@ +"""Tests for the strongly-typed route parsing system.""" + +from __future__ import annotations + +import threading +import time +from urllib.parse import quote + +import pytest + +from aws_durable_execution_sdk_python_testing.exceptions import ( + UnknownRouteError, +) +from aws_durable_execution_sdk_python_testing.web.routes import ( + CallbackFailureRoute, + CallbackHeartbeatRoute, + CallbackSuccessRoute, + CheckpointDurableExecutionRoute, + GetDurableExecutionHistoryRoute, + GetDurableExecutionRoute, + GetDurableExecutionStateRoute, + HealthRoute, + ListDurableExecutionsByFunctionRoute, + ListDurableExecutionsRoute, + MetricsRoute, + Route, + Router, + StartExecutionRoute, + StopDurableExecutionRoute, +) + + +def test_route_from_string_basic(): + """Test basic route creation from string.""" + route = Route.from_string("/test/path") + assert route.raw_path == "/test/path" + assert route.segments == ["test", "path"] + + +def test_route_from_string_with_leading_trailing_slashes(): + """Test route creation handles leading and trailing slashes.""" + route = Route.from_string("///test/path///") + assert route.raw_path == "///test/path///" + assert route.segments == ["test", "path"] + + +def test_route_from_string_empty_segments(): + """Test route creation filters out empty segments.""" + route = Route.from_string("/test//path/") + assert route.raw_path == "/test//path/" + assert route.segments == ["test", "path"] + + +def test_route_from_string_root(): + """Test route creation for root path.""" + route = Route.from_string("/") + assert route.raw_path == "/" + assert route.segments == [] + + +def test_route_matches_pattern_exact(): + """Test pattern matching with exact segments.""" + route = Route.from_string("/test/path") + assert route.matches_pattern(["test", "path"]) is True + assert route.matches_pattern(["test", "other"]) is False + + +def test_route_matches_pattern_wildcard(): + """Test pattern matching with wildcards.""" + route = Route.from_string("/test/123/path") + assert route.matches_pattern(["test", "*", "path"]) is True + assert route.matches_pattern(["test", "*", "other"]) is False + + +def test_route_matches_pattern_length_mismatch(): + """Test pattern matching fails with different lengths.""" + route = Route.from_string("/test/path") + assert route.matches_pattern(["test"]) is False + assert route.matches_pattern(["test", "path", "extra"]) is False + + +def test_start_execution_route_is_match(): + """Test StartExecutionRoute pattern matching.""" + route = Route.from_string("/start-durable-execution") + assert StartExecutionRoute.is_match(route, "POST") is True + assert StartExecutionRoute.is_match(route, "GET") is False + + route = Route.from_string("/start-execution") + assert StartExecutionRoute.is_match(route, "POST") is False + + +def test_start_execution_route_from_route(): + """Test StartExecutionRoute creation from base route.""" + base_route = Route.from_string("/start-durable-execution") + start_route = StartExecutionRoute.from_route(base_route) + + assert start_route.raw_path == "/start-durable-execution" + assert start_route.segments == ["start-durable-execution"] + + +# Removed test_start_execution_route_from_route_invalid - from_route() no longer validates +# Call is_match() first to ensure route is valid + + +def test_get_durable_execution_route_is_match(): + """Test GetDurableExecutionRoute pattern matching.""" + route = Route.from_string( + "/2025-12-01/durable-executions/arn:aws:lambda:us-east-1:123456789012:function:my-function" + ) + assert GetDurableExecutionRoute.is_match(route, "GET") is True + assert GetDurableExecutionRoute.is_match(route, "POST") is False + + route = Route.from_string("/2025-12-01/executions/some-arn") + assert GetDurableExecutionRoute.is_match(route, "GET") is False + + route = Route.from_string("/2025-12-01/durable-executions") + assert GetDurableExecutionRoute.is_match(route, "GET") is False + + +def test_get_durable_execution_route_from_route(): + """Test GetDurableExecutionRoute creation from base route.""" + arn = "arn:aws:lambda:us-east-1:123456789012:function:my-function" + base_route = Route.from_string(f"/2025-12-01/durable-executions/{arn}") + get_route = GetDurableExecutionRoute.from_route(base_route) + + assert get_route.raw_path == f"/2025-12-01/durable-executions/{arn}" + assert get_route.segments == ["2025-12-01", "durable-executions", arn] + assert get_route.arn == arn + + +# Removed test_get_durable_execution_route_from_route_invalid - from_route() no longer validates +# Call is_match() first to ensure route is valid + + +def test_checkpoint_durable_execution_route_is_match(): + """Test CheckpointDurableExecutionRoute pattern matching.""" + route = Route.from_string("/2025-12-01/durable-executions/some-arn/checkpoint") + assert CheckpointDurableExecutionRoute.is_match(route, "POST") is True + assert CheckpointDurableExecutionRoute.is_match(route, "GET") is False + + route = Route.from_string("/2025-12-01/durable-executions/some-arn/stop") + assert CheckpointDurableExecutionRoute.is_match(route, "POST") is False + + +def test_checkpoint_durable_execution_route_from_route(): + """Test CheckpointDurableExecutionRoute creation from base route.""" + arn = "test-arn" + base_route = Route.from_string(f"/2025-12-01/durable-executions/{arn}/checkpoint") + checkpoint_route = CheckpointDurableExecutionRoute.from_route(base_route) + + assert ( + checkpoint_route.raw_path == f"/2025-12-01/durable-executions/{arn}/checkpoint" + ) + assert checkpoint_route.segments == [ + "2025-12-01", + "durable-executions", + arn, + "checkpoint", + ] + assert checkpoint_route.arn == arn + + +# Removed test_checkpoint_durable_execution_route_from_route_invalid - from_route() no longer validates +# Call is_match() first to ensure route is valid + + +def test_stop_durable_execution_route_is_match(): + """Test StopDurableExecutionRoute pattern matching.""" + route = Route.from_string("/2025-12-01/durable-executions/some-arn/stop") + assert StopDurableExecutionRoute.is_match(route, "POST") is True + assert StopDurableExecutionRoute.is_match(route, "GET") is False + + route = Route.from_string("/2025-12-01/durable-executions/some-arn/checkpoint") + assert StopDurableExecutionRoute.is_match(route, "POST") is False + + +def test_stop_durable_execution_route_from_route(): + """Test StopDurableExecutionRoute creation from base route.""" + arn = "test-arn" + base_route = Route.from_string(f"/2025-12-01/durable-executions/{arn}/stop") + stop_route = StopDurableExecutionRoute.from_route(base_route) + + assert stop_route.raw_path == f"/2025-12-01/durable-executions/{arn}/stop" + assert stop_route.segments == ["2025-12-01", "durable-executions", arn, "stop"] + assert stop_route.arn == arn + + +# Removed test_stop_durable_execution_route_from_route_invalid - from_route() no longer validates +# Call is_match() first to ensure route is valid + + +def test_get_durable_execution_state_route_is_match(): + """Test GetDurableExecutionStateRoute pattern matching.""" + route = Route.from_string("/2025-12-01/durable-executions/some-arn/state") + assert GetDurableExecutionStateRoute.is_match(route, "GET") is True + assert GetDurableExecutionStateRoute.is_match(route, "POST") is False + + route = Route.from_string("/2025-12-01/durable-executions/some-arn/history") + assert GetDurableExecutionStateRoute.is_match(route, "GET") is False + + +def test_get_durable_execution_state_route_from_route(): + """Test GetDurableExecutionStateRoute creation from base route.""" + arn = "test-arn" + base_route = Route.from_string(f"/2025-12-01/durable-executions/{arn}/state") + state_route = GetDurableExecutionStateRoute.from_route(base_route) + + assert state_route.raw_path == f"/2025-12-01/durable-executions/{arn}/state" + assert state_route.segments == ["2025-12-01", "durable-executions", arn, "state"] + assert state_route.arn == arn + + +# Removed test_get_durable_execution_state_route_from_route_invalid - from_route() no longer validates +# Call is_match() first to ensure route is valid + + +def test_get_durable_execution_history_route_is_match(): + """Test GetDurableExecutionHistoryRoute pattern matching.""" + route = Route.from_string("/2025-12-01/durable-executions/some-arn/history") + assert GetDurableExecutionHistoryRoute.is_match(route, "GET") is True + assert GetDurableExecutionHistoryRoute.is_match(route, "POST") is False + + route = Route.from_string("/2025-12-01/durable-executions/some-arn/state") + assert GetDurableExecutionHistoryRoute.is_match(route, "GET") is False + + +def test_get_durable_execution_history_route_from_route(): + """Test GetDurableExecutionHistoryRoute creation from base route.""" + arn = "test-arn" + base_route = Route.from_string(f"/2025-12-01/durable-executions/{arn}/history") + history_route = GetDurableExecutionHistoryRoute.from_route(base_route) + + assert history_route.raw_path == f"/2025-12-01/durable-executions/{arn}/history" + assert history_route.segments == [ + "2025-12-01", + "durable-executions", + arn, + "history", + ] + assert history_route.arn == arn + + +# Removed test_get_durable_execution_history_route_from_route_invalid - from_route() no longer validates +# Call is_match() first to ensure route is valid + + +def test_list_durable_executions_route_is_match(): + """Test ListDurableExecutionsRoute pattern matching.""" + route = Route.from_string("/2025-12-01/durable-executions") + assert ListDurableExecutionsRoute.is_match(route, "GET") is True + assert ListDurableExecutionsRoute.is_match(route, "POST") is False + + route = Route.from_string("/2025-12-01/durable-executions/some-arn") + assert ListDurableExecutionsRoute.is_match(route, "GET") is False + + +def test_list_durable_executions_route_from_route(): + """Test ListDurableExecutionsRoute creation from base route.""" + base_route = Route.from_string("/2025-12-01/durable-executions") + list_route = ListDurableExecutionsRoute.from_route(base_route) + + assert list_route.raw_path == "/2025-12-01/durable-executions" + assert list_route.segments == ["2025-12-01", "durable-executions"] + + +# Removed test_list_durable_executions_route_from_route_invalid - from_route() no longer validates +# Call is_match() first to ensure route is valid + + +def test_list_durable_executions_by_function_route_is_match(): + """Test ListDurableExecutionsByFunctionRoute pattern matching.""" + route = Route.from_string("/2025-12-01/functions/my-function/durable-executions") + assert ListDurableExecutionsByFunctionRoute.is_match(route, "GET") is True + assert ListDurableExecutionsByFunctionRoute.is_match(route, "POST") is False + + route = Route.from_string("/2025-12-01/functions/my-function") + assert ListDurableExecutionsByFunctionRoute.is_match(route, "GET") is False + + +def test_list_durable_executions_by_function_route_from_route(): + """Test ListDurableExecutionsByFunctionRoute creation from base route.""" + function_name = "my-function" + base_route = Route.from_string( + f"/2025-12-01/functions/{function_name}/durable-executions" + ) + list_route = ListDurableExecutionsByFunctionRoute.from_route(base_route) + + assert ( + list_route.raw_path + == f"/2025-12-01/functions/{function_name}/durable-executions" + ) + assert list_route.segments == [ + "2025-12-01", + "functions", + function_name, + "durable-executions", + ] + assert list_route.function_name == function_name + + +# Removed test_list_durable_executions_by_function_route_from_route_invalid - from_route() no longer validates +# Call is_match() first to ensure route is valid + + +def test_callback_success_route_is_match(): + """Test CallbackSuccessRoute pattern matching.""" + route = Route.from_string( + "/2025-12-01/durable-execution-callbacks/callback-123/succeed" + ) + assert CallbackSuccessRoute.is_match(route, "POST") is True + assert CallbackSuccessRoute.is_match(route, "GET") is False + + route = Route.from_string( + "/2025-12-01/durable-execution-callbacks/callback-123/fail" + ) + assert CallbackSuccessRoute.is_match(route, "POST") is False + + +def test_callback_success_route_from_route(): + """Test CallbackSuccessRoute creation from base route.""" + callback_id = "callback-123" + base_route = Route.from_string( + f"/2025-12-01/durable-execution-callbacks/{callback_id}/succeed" + ) + callback_route = CallbackSuccessRoute.from_route(base_route) + + assert ( + callback_route.raw_path + == f"/2025-12-01/durable-execution-callbacks/{callback_id}/succeed" + ) + assert callback_route.segments == [ + "2025-12-01", + "durable-execution-callbacks", + callback_id, + "succeed", + ] + assert callback_route.callback_id == callback_id + + +# Removed test_callback_success_route_from_route_invalid - from_route() no longer validates +# Call is_match() first to ensure route is valid + + +def test_callback_failure_route_is_match(): + """Test CallbackFailureRoute pattern matching.""" + route = Route.from_string( + "/2025-12-01/durable-execution-callbacks/callback-123/fail" + ) + assert CallbackFailureRoute.is_match(route, "POST") is True + assert CallbackFailureRoute.is_match(route, "GET") is False + + route = Route.from_string( + "/2025-12-01/durable-execution-callbacks/callback-123/succeed" + ) + assert CallbackFailureRoute.is_match(route, "POST") is False + + +def test_callback_failure_route_from_route(): + """Test CallbackFailureRoute creation from base route.""" + callback_id = "callback-123" + base_route = Route.from_string( + f"/2025-12-01/durable-execution-callbacks/{callback_id}/fail" + ) + callback_route = CallbackFailureRoute.from_route(base_route) + + assert ( + callback_route.raw_path + == f"/2025-12-01/durable-execution-callbacks/{callback_id}/fail" + ) + assert callback_route.segments == [ + "2025-12-01", + "durable-execution-callbacks", + callback_id, + "fail", + ] + assert callback_route.callback_id == callback_id + + +# Removed test_callback_failure_route_from_route_invalid - from_route() no longer validates +# Call is_match() first to ensure route is valid + + +def test_callback_heartbeat_route_is_match(): + """Test CallbackHeartbeatRoute pattern matching.""" + route = Route.from_string( + "/2025-12-01/durable-execution-callbacks/callback-123/heartbeat" + ) + assert CallbackHeartbeatRoute.is_match(route, "POST") is True + assert CallbackHeartbeatRoute.is_match(route, "GET") is False + + route = Route.from_string( + "/2025-12-01/durable-execution-callbacks/callback-123/succeed" + ) + assert CallbackHeartbeatRoute.is_match(route, "POST") is False + + +def test_callback_heartbeat_route_from_route(): + """Test CallbackHeartbeatRoute creation from base route.""" + callback_id = "callback-123" + base_route = Route.from_string( + f"/2025-12-01/durable-execution-callbacks/{callback_id}/heartbeat" + ) + callback_route = CallbackHeartbeatRoute.from_route(base_route) + + assert ( + callback_route.raw_path + == f"/2025-12-01/durable-execution-callbacks/{callback_id}/heartbeat" + ) + assert callback_route.segments == [ + "2025-12-01", + "durable-execution-callbacks", + callback_id, + "heartbeat", + ] + assert callback_route.callback_id == callback_id + + +# Removed test_callback_heartbeat_route_from_route_invalid - from_route() no longer validates +# Call is_match() first to ensure route is valid + + +def test_health_route_is_match(): + """Test HealthRoute pattern matching.""" + route = Route.from_string("/health") + assert HealthRoute.is_match(route, "GET") is True + assert HealthRoute.is_match(route, "POST") is False + + route = Route.from_string("/metrics") + assert HealthRoute.is_match(route, "GET") is False + + +def test_health_route_from_route(): + """Test HealthRoute creation from base route.""" + base_route = Route.from_string("/health") + health_route = HealthRoute.from_route(base_route) + + assert health_route.raw_path == "/health" + assert health_route.segments == ["health"] + + +# Removed test_health_route_from_route_invalid - from_route() no longer validates +# Call is_match() first to ensure route is valid + + +def test_metrics_route_is_match(): + """Test MetricsRoute pattern matching.""" + route = Route.from_string("/metrics") + assert MetricsRoute.is_match(route, "GET") is True + assert MetricsRoute.is_match(route, "POST") is False + + route = Route.from_string("/health") + assert MetricsRoute.is_match(route, "GET") is False + + +def test_metrics_route_from_route(): + """Test MetricsRoute creation from base route.""" + base_route = Route.from_string("/metrics") + metrics_route = MetricsRoute.from_route(base_route) + + assert metrics_route.raw_path == "/metrics" + assert metrics_route.segments == ["metrics"] + + +# Removed test_metrics_route_from_route_invalid - from_route() no longer validates +# Call is_match() first to ensure route is valid + + +def test_route_immutability(): + """Test that route objects are immutable (frozen dataclasses).""" + route = StartExecutionRoute.from_route( + Route.from_string("/start-durable-execution") + ) + + # Should not be able to modify frozen dataclass + with pytest.raises(AttributeError): + route.raw_path = "/modified" # type: ignore[misc] + + with pytest.raises(AttributeError): + route.segments = ["modified"] # type: ignore[misc] + + +def test_route_with_special_characters(): + """Test route parsing with special characters in ARNs and IDs. + + URL-decoding happens once in ``Route.from_string`` so every captured + path segment (``segments[N]`` and any named field that mirrors it, + such as ``arn`` or ``callback_id``) carries the literal value the + caller passed to boto. ``raw_path`` keeps the original wire string. + """ + # ARN with %20-encoded spaces should round-trip back to a literal space. + encoded_arn = ( + "arn:aws:lambda:us-east-1:123456789012:function:my-function%20with%20spaces" + ) + decoded_arn = ( + "arn:aws:lambda:us-east-1:123456789012:function:my-function with spaces" + ) + raw_path = f"/2025-12-01/durable-executions/{encoded_arn}" + router = Router() + route = router.find_route(raw_path, "GET") + assert isinstance(route, GetDurableExecutionRoute) + assert route.arn == decoded_arn + assert route.segments[2] == decoded_arn + # raw_path is preserved as the original wire form for logging/debugging. + assert route.raw_path == raw_path + + # Test with callback ID containing special characters + callback_id = "callback-123-abc_def" + route = router.find_route( + f"/2025-12-01/durable-execution-callbacks/{callback_id}/succeed", "POST" + ) + assert isinstance(route, CallbackSuccessRoute) + assert route.callback_id == callback_id + + +def test_route_edge_cases(): + """Test route parsing edge cases.""" + router = Router() + + # Empty path + with pytest.raises(UnknownRouteError, match="Unknown path pattern"): + router.find_route("", "GET") + + # Root path + with pytest.raises(UnknownRouteError, match="Unknown path pattern"): + router.find_route("/", "GET") + + # Path with only slashes + with pytest.raises(UnknownRouteError, match="Unknown path pattern"): + router.find_route("///", "GET") + + +def test_route_case_sensitivity(): + """Test that route matching is case-sensitive.""" + router = Router() + + # Should not match due to case difference + with pytest.raises(UnknownRouteError, match="Unknown path pattern"): + router.find_route("/START-DURABLE-EXECUTION", "POST") + + with pytest.raises(UnknownRouteError, match="Unknown path pattern"): + router.find_route("/Health", "GET") + + +def test_router_find_route_method_validation(): + """Test that Router.find_route validates HTTP methods correctly.""" + router = Router() + + # Valid method combinations + route = router.find_route("/start-durable-execution", "POST") + assert isinstance(route, StartExecutionRoute) + + route = router.find_route("/2025-12-01/durable-executions/test-arn", "GET") + assert isinstance(route, GetDurableExecutionRoute) + + # Invalid method combinations + with pytest.raises( + UnknownRouteError, + match="Unknown path pattern: GET /start-durable-execution", + ): + router.find_route("/start-durable-execution", "GET") + + with pytest.raises( + UnknownRouteError, + match="Unknown path pattern: POST /2025-12-01/durable-executions/test-arn", + ): + router.find_route("/2025-12-01/durable-executions/test-arn", "POST") + + with pytest.raises(UnknownRouteError, match="Unknown path pattern: DELETE /health"): + router.find_route("/health", "DELETE") + + +def test_router_find_route_method_case_sensitivity(): + """Test that HTTP method matching is case-sensitive.""" + router = Router() + + # Should work with uppercase methods + route = router.find_route("/start-durable-execution", "POST") + assert isinstance(route, StartExecutionRoute) + + # Should not work with lowercase methods + with pytest.raises( + UnknownRouteError, + match="Unknown path pattern: post /start-durable-execution", + ): + router.find_route("/start-durable-execution", "post") + + with pytest.raises(UnknownRouteError, match="Unknown path pattern: get /health"): + router.find_route("/health", "get") + + +def test_router_initialization_default(): + """Test Router initialization with default route types.""" + router = Router() + + # Should work with default route types + route = router.find_route("/start-durable-execution", "POST") + assert isinstance(route, StartExecutionRoute) + + route = router.find_route("/health", "GET") + assert isinstance(route, HealthRoute) + + +def test_router_initialization_custom_route_types(): + """Test Router initialization with custom route types.""" + # Create router with only health and metrics routes + custom_route_types = [HealthRoute, MetricsRoute] + router = Router(route_types=custom_route_types) + + # Should work with included route types + route = router.find_route("/health", "GET") + assert isinstance(route, HealthRoute) + + route = router.find_route("/metrics", "GET") + assert isinstance(route, MetricsRoute) + + # Should not work with excluded route types + with pytest.raises( + UnknownRouteError, + match="Unknown path pattern: POST /start-durable-execution", + ): + router.find_route("/start-durable-execution", "POST") + + +def test_router_initialization_empty_route_types(): + """Test Router initialization with empty route types list.""" + router = Router(route_types=[]) + + # Should not match any routes + with pytest.raises( + UnknownRouteError, + match="Unknown path pattern: GET /health", + ): + router.find_route("/health", "GET") + + +def test_router_find_route_basic(): + """Test Router.find_route with basic routes.""" + router = Router() + + # Test various route types + route = router.find_route("/start-durable-execution", "POST") + assert isinstance(route, StartExecutionRoute) + + arn = "test-arn" + route = router.find_route(f"/2025-12-01/durable-executions/{arn}", "GET") + assert isinstance(route, GetDurableExecutionRoute) + assert route.arn == arn + + route = router.find_route("/health", "GET") + assert isinstance(route, HealthRoute) + + +def test_router_find_route_with_parameters(): + """Test Router.find_route extracts route parameters correctly.""" + router = Router() + + # Test ARN extraction + arn = "arn:aws:lambda:us-east-1:123456789012:function:my-function" + route = router.find_route( + f"/2025-12-01/durable-executions/{arn}/checkpoint", "POST" + ) + assert isinstance(route, CheckpointDurableExecutionRoute) + assert route.arn == arn + + # Test function name extraction + function_name = "my-test-function" + route = router.find_route( + f"/2025-12-01/functions/{function_name}/durable-executions", "GET" + ) + assert isinstance(route, ListDurableExecutionsByFunctionRoute) + assert route.function_name == function_name + + # Test callback ID extraction + callback_id = "callback-123-abc" + route = router.find_route( + f"/2025-12-01/durable-execution-callbacks/{callback_id}/succeed", "POST" + ) + assert isinstance(route, CallbackSuccessRoute) + assert route.callback_id == callback_id + + +def test_router_find_route_unknown_route(): + """Test Router.find_route with unknown route patterns.""" + router = Router() + + with pytest.raises( + UnknownRouteError, + match="Unknown path pattern: GET /unknown/path", + ): + router.find_route("/unknown/path", "GET") + + with pytest.raises( + UnknownRouteError, + match="Unknown path pattern: DELETE /health", + ): + router.find_route("/health", "DELETE") + + +def test_unknown_route_error_attributes(): + """Test UnknownRouteError provides structured access to method and path.""" + router = Router() + with pytest.raises(UnknownRouteError) as exc_info: + router.find_route("/unknown/path", "POST") + + e = exc_info.value + assert e.method == "POST" + assert e.path == "/unknown/path" + assert str(e) == "Unknown path pattern: POST /unknown/path" + + +def test_router_find_route_priority_order(): + """Test Router.find_route respects priority order for overlapping patterns.""" + router = Router() + + # Test that more specific patterns are matched before general ones + # This tests the order in DEFAULT_ROUTE_TYPES registry + + # Should match GetDurableExecutionRoute, not ListDurableExecutionsRoute + route = router.find_route("/2025-12-01/durable-executions/some-arn", "GET") + assert isinstance(route, GetDurableExecutionRoute) + assert route.arn == "some-arn" + + # Should match ListDurableExecutionsRoute + route = router.find_route("/2025-12-01/durable-executions", "GET") + assert isinstance(route, ListDurableExecutionsRoute) + + +def test_router_multiple_instances(): + """Test that multiple Router instances work independently.""" + # Create two routers with different route types + health_router = Router(route_types=[HealthRoute]) + full_router = Router() # Uses default route types + + # Health router should only handle health routes + route = health_router.find_route("/health", "GET") + assert isinstance(route, HealthRoute) + + with pytest.raises(UnknownRouteError): + health_router.find_route("/metrics", "GET") + + # Full router should handle all routes + route = full_router.find_route("/health", "GET") + assert isinstance(route, HealthRoute) + + route = full_router.find_route("/metrics", "GET") + assert isinstance(route, MetricsRoute) + + route = full_router.find_route("/start-durable-execution", "POST") + assert isinstance(route, StartExecutionRoute) + + +def test_router_find_route_start_execution(): + """Test Router.find_route with start execution route.""" + router = Router() + route = router.find_route("/start-durable-execution", "POST") + assert isinstance(route, StartExecutionRoute) + assert route.raw_path == "/start-durable-execution" + + +def test_router_find_route_get_durable_execution(): + """Test Router.find_route with get durable execution route.""" + router = Router() + arn = "test-arn" + route = router.find_route(f"/2025-12-01/durable-executions/{arn}", "GET") + assert isinstance(route, GetDurableExecutionRoute) + assert route.arn == arn + + +def test_router_find_route_checkpoint_durable_execution(): + """Test Router.find_route with checkpoint durable execution route.""" + router = Router() + arn = "test-arn" + route = router.find_route( + f"/2025-12-01/durable-executions/{arn}/checkpoint", "POST" + ) + assert isinstance(route, CheckpointDurableExecutionRoute) + assert route.arn == arn + + +def test_router_find_route_stop_durable_execution(): + """Test Router.find_route with stop durable execution route.""" + router = Router() + arn = "test-arn" + route = router.find_route(f"/2025-12-01/durable-executions/{arn}/stop", "POST") + assert isinstance(route, StopDurableExecutionRoute) + assert route.arn == arn + + +def test_router_find_route_get_durable_execution_state(): + """Test Router.find_route with get durable execution state route.""" + router = Router() + arn = "test-arn" + route = router.find_route(f"/2025-12-01/durable-executions/{arn}/state", "GET") + assert isinstance(route, GetDurableExecutionStateRoute) + assert route.arn == arn + + +def test_router_find_route_get_durable_execution_history(): + """Test Router.find_route with get durable execution history route.""" + router = Router() + arn = "test-arn" + route = router.find_route(f"/2025-12-01/durable-executions/{arn}/history", "GET") + assert isinstance(route, GetDurableExecutionHistoryRoute) + assert route.arn == arn + + +def test_router_find_route_list_durable_executions(): + """Test Router.find_route with list durable executions route.""" + router = Router() + route = router.find_route("/2025-12-01/durable-executions", "GET") + assert isinstance(route, ListDurableExecutionsRoute) + + +def test_router_find_route_list_durable_executions_by_function(): + """Test Router.find_route with list durable executions by function route.""" + router = Router() + function_name = "my-function" + route = router.find_route( + f"/2025-12-01/functions/{function_name}/durable-executions", "GET" + ) + assert isinstance(route, ListDurableExecutionsByFunctionRoute) + assert route.function_name == function_name + + +def test_router_find_route_callback_success(): + """Test Router.find_route with callback success route.""" + router = Router() + callback_id = "callback-123" + route = router.find_route( + f"/2025-12-01/durable-execution-callbacks/{callback_id}/succeed", "POST" + ) + assert isinstance(route, CallbackSuccessRoute) + assert route.callback_id == callback_id + + +def test_router_find_route_callback_failure(): + """Test Router.find_route with callback failure route.""" + router = Router() + callback_id = "callback-123" + route = router.find_route( + f"/2025-12-01/durable-execution-callbacks/{callback_id}/fail", "POST" + ) + assert isinstance(route, CallbackFailureRoute) + assert route.callback_id == callback_id + + +def test_router_find_route_callback_heartbeat(): + """Test Router.find_route with callback heartbeat route.""" + router = Router() + callback_id = "callback-123" + route = router.find_route( + f"/2025-12-01/durable-execution-callbacks/{callback_id}/heartbeat", "POST" + ) + assert isinstance(route, CallbackHeartbeatRoute) + assert route.callback_id == callback_id + + +def test_router_find_route_health(): + """Test Router.find_route with health route.""" + router = Router() + route = router.find_route("/health", "GET") + assert isinstance(route, HealthRoute) + + +def test_router_find_route_metrics(): + """Test Router.find_route with metrics route.""" + router = Router() + route = router.find_route("/metrics", "GET") + assert isinstance(route, MetricsRoute) + + +def test_router_find_route_unknown(): + """Test Router.find_route with unknown route pattern.""" + router = Router() + with pytest.raises( + UnknownRouteError, + match="Unknown path pattern: GET /unknown/path", + ): + router.find_route("/unknown/path", "GET") + + +def test_router_constructor_with_all_default_route_types(): + """Test Router constructor includes all expected default route types.""" + router = Router() + + # Test that all route types are included by trying to match each one + test_cases = [ + ("/start-durable-execution", "POST", StartExecutionRoute), + ("/2025-12-01/durable-executions/test-arn", "GET", GetDurableExecutionRoute), + ( + "/2025-12-01/durable-executions/test-arn/checkpoint", + "POST", + CheckpointDurableExecutionRoute, + ), + ( + "/2025-12-01/durable-executions/test-arn/stop", + "POST", + StopDurableExecutionRoute, + ), + ( + "/2025-12-01/durable-executions/test-arn/state", + "GET", + GetDurableExecutionStateRoute, + ), + ( + "/2025-12-01/durable-executions/test-arn/history", + "GET", + GetDurableExecutionHistoryRoute, + ), + ("/2025-12-01/durable-executions", "GET", ListDurableExecutionsRoute), + ( + "/2025-12-01/functions/test-func/durable-executions", + "GET", + ListDurableExecutionsByFunctionRoute, + ), + ( + "/2025-12-01/durable-execution-callbacks/test-id/succeed", + "POST", + CallbackSuccessRoute, + ), + ( + "/2025-12-01/durable-execution-callbacks/test-id/fail", + "POST", + CallbackFailureRoute, + ), + ( + "/2025-12-01/durable-execution-callbacks/test-id/heartbeat", + "POST", + CallbackHeartbeatRoute, + ), + ("/health", "GET", HealthRoute), + ("/metrics", "GET", MetricsRoute), + ] + + for path, method, expected_type in test_cases: + route = router.find_route(path, method) + assert isinstance(route, expected_type), ( + f"Expected {expected_type.__name__} for {method} {path}" + ) + + +def test_router_constructor_with_subset_of_route_types(): + """Test Router constructor with a subset of route types.""" + # Create router with only callback routes + callback_route_types = [ + CallbackSuccessRoute, + CallbackFailureRoute, + CallbackHeartbeatRoute, + ] + router = Router(route_types=callback_route_types) + + # Should work with callback routes + route = router.find_route( + "/2025-12-01/durable-execution-callbacks/test-id/succeed", "POST" + ) + assert isinstance(route, CallbackSuccessRoute) + + route = router.find_route( + "/2025-12-01/durable-execution-callbacks/test-id/fail", "POST" + ) + assert isinstance(route, CallbackFailureRoute) + + route = router.find_route( + "/2025-12-01/durable-execution-callbacks/test-id/heartbeat", "POST" + ) + assert isinstance(route, CallbackHeartbeatRoute) + + # Should not work with other route types + with pytest.raises(UnknownRouteError): + router.find_route("/health", "GET") + + with pytest.raises(UnknownRouteError): + router.find_route("/start-durable-execution", "POST") + + +def test_router_constructor_with_single_route_type(): + """Test Router constructor with a single route type.""" + router = Router(route_types=[HealthRoute]) + + # Should work with the single route type + route = router.find_route("/health", "GET") + assert isinstance(route, HealthRoute) + + # Should not work with any other route types + with pytest.raises(UnknownRouteError): + router.find_route("/metrics", "GET") + + with pytest.raises(UnknownRouteError): + router.find_route("/start-durable-execution", "POST") + + +def test_router_constructor_with_duplicate_route_types(): + """Test Router constructor handles duplicate route types gracefully.""" + # Include HealthRoute twice + duplicate_route_types = [HealthRoute, MetricsRoute, HealthRoute] + router = Router(route_types=duplicate_route_types) + + # Should still work correctly (first match wins) + route = router.find_route("/health", "GET") + assert isinstance(route, HealthRoute) + + route = router.find_route("/metrics", "GET") + assert isinstance(route, MetricsRoute) + + +def test_router_find_route_error_handling_comprehensive(): + """Test Router.find_route error handling with various invalid inputs.""" + router = Router() + + # Test various invalid path/method combinations + invalid_cases = [ + ("", "GET", "Unknown path pattern: GET "), + ("/", "GET", "Unknown path pattern: GET /"), + ("///", "GET", "Unknown path pattern: GET ///"), + ("/unknown", "GET", "Unknown path pattern: GET /unknown"), + ( + "/start-durable-execution", + "GET", + "Unknown path pattern: GET /start-durable-execution", + ), + ("/health", "POST", "Unknown path pattern: POST /health"), + ("/metrics", "DELETE", "Unknown path pattern: DELETE /metrics"), + ( + "/2025-12-01/durable-executions/test-arn", + "POST", + "Unknown path pattern: POST /2025-12-01/durable-executions/test-arn", + ), + ( + "/2025-12-01/durable-executions/test-arn/checkpoint", + "GET", + "Unknown path pattern: GET /2025-12-01/durable-executions/test-arn/checkpoint", + ), + ] + + for path, method, expected_message in invalid_cases: + with pytest.raises(UnknownRouteError, match=expected_message): + router.find_route(path, method) + + +def test_router_find_route_with_complex_parameters(): + """Test Router.find_route with complex parameter extraction.""" + router = Router() + + # Test with complex ARN + complex_arn = "arn:aws:lambda:us-west-2:123456789012:function:my-complex-function-name_with_underscores-and-dashes" + route = router.find_route(f"/2025-12-01/durable-executions/{complex_arn}", "GET") + assert isinstance(route, GetDurableExecutionRoute) + assert route.arn == complex_arn + + # Test with complex function name + complex_function_name = "my_complex-function.name123" + route = router.find_route( + f"/2025-12-01/functions/{complex_function_name}/durable-executions", "GET" + ) + assert isinstance(route, ListDurableExecutionsByFunctionRoute) + assert route.function_name == complex_function_name + + # Test with complex callback ID + complex_callback_id = "callback-123_abc-def.456" + route = router.find_route( + f"/2025-12-01/durable-execution-callbacks/{complex_callback_id}/succeed", "POST" + ) + assert isinstance(route, CallbackSuccessRoute) + assert route.callback_id == complex_callback_id + + +def test_router_find_route_order_dependency(): + """Test that Router.find_route respects route type ordering for disambiguation.""" + router = Router() + + # These paths could potentially match multiple patterns if ordering is wrong + # The more specific patterns should match first + + # Should match GetDurableExecutionRoute, not ListDurableExecutionsRoute + route = router.find_route("/2025-12-01/durable-executions/specific-arn", "GET") + assert isinstance(route, GetDurableExecutionRoute) + assert route.arn == "specific-arn" + + # Should match ListDurableExecutionsRoute + route = router.find_route("/2025-12-01/durable-executions", "GET") + assert isinstance(route, ListDurableExecutionsRoute) + + # Should match CheckpointDurableExecutionRoute, not GetDurableExecutionRoute + route = router.find_route( + "/2025-12-01/durable-executions/test-arn/checkpoint", "POST" + ) + assert isinstance(route, CheckpointDurableExecutionRoute) + assert route.arn == "test-arn" + + +def test_router_thread_safety(): + """Test that Router instances are thread-safe for concurrent access.""" + + router = Router() + results = [] + errors = [] + + def worker(worker_id: int): + try: + for i in range(10): + # Test different route types to ensure no interference + route = router.find_route( + f"/2025-12-01/durable-executions/arn-{worker_id}-{i}", "GET" + ) + assert isinstance(route, GetDurableExecutionRoute) + assert route.arn == f"arn-{worker_id}-{i}" + + route = router.find_route("/health", "GET") + assert isinstance(route, HealthRoute) + + time.sleep(0.001) # Small delay to increase chance of race conditions + + results.append(f"Worker {worker_id} completed successfully") + except (UnknownRouteError, AssertionError) as e: + errors.append(f"Worker {worker_id} failed: {e}") + + # Create multiple threads + threads = [] + for i in range(5): + thread = threading.Thread(target=worker, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Check results + assert len(errors) == 0, f"Thread safety test failed with errors: {errors}" + assert len(results) == 5, f"Expected 5 successful workers, got {len(results)}" + + +def test_callback_routes_url_decoding(): + """Test that callback routes properly URL-decode callback IDs.""" + # Test callback ID with special characters that need URL encoding + callback_id = "eyJhcm4iOiJhcm4iLCJvcCI6ImVhNjZjMDZjMWUxYzA1ZmEifQ==" + encoded_callback_id = quote(callback_id, safe="") + + # Test CallbackSuccessRoute + base_route = Route.from_string( + f"/2025-12-01/durable-execution-callbacks/{encoded_callback_id}/succeed" + ) + success_route = CallbackSuccessRoute.from_route(base_route) + assert success_route.callback_id == callback_id # Should be decoded + + # Test CallbackFailureRoute + base_route = Route.from_string( + f"/2025-12-01/durable-execution-callbacks/{encoded_callback_id}/fail" + ) + failure_route = CallbackFailureRoute.from_route(base_route) + assert failure_route.callback_id == callback_id # Should be decoded + + # Test CallbackHeartbeatRoute + base_route = Route.from_string( + f"/2025-12-01/durable-execution-callbacks/{encoded_callback_id}/heartbeat" + ) + heartbeat_route = CallbackHeartbeatRoute.from_route(base_route) + assert heartbeat_route.callback_id == callback_id # Should be decoded + + +def test_router_callback_routes_url_decoding(): + """Test Router properly handles URL-encoded callback IDs.""" + router = Router() + callback_id = "eyJhcm4iOiJhcm4iLCJvcCI6ImVhNjZjMDZjMWUxYzA1ZmEifQ==" + encoded_callback_id = quote(callback_id, safe="") + + # Test success route + route = router.find_route( + f"/2025-12-01/durable-execution-callbacks/{encoded_callback_id}/succeed", "POST" + ) + assert isinstance(route, CallbackSuccessRoute) + assert route.callback_id == callback_id # Should be decoded + + # Test failure route + route = router.find_route( + f"/2025-12-01/durable-execution-callbacks/{encoded_callback_id}/fail", "POST" + ) + assert isinstance(route, CallbackFailureRoute) + assert route.callback_id == callback_id # Should be decoded + + # Test heartbeat route + route = router.find_route( + f"/2025-12-01/durable-execution-callbacks/{encoded_callback_id}/heartbeat", + "POST", + ) + assert isinstance(route, CallbackHeartbeatRoute) + assert route.callback_id == callback_id # Should be decoded diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/web/serialization_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/web/serialization_test.py new file mode 100644 index 0000000..7c981f3 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/web/serialization_test.py @@ -0,0 +1,576 @@ +"""Tests for serialization interfaces and AWS boto integration.""" + +from __future__ import annotations + +from unittest.mock import Mock, patch + +import pytest +import json +from datetime import datetime, timezone + +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) +from aws_durable_execution_sdk_python_testing.web.serialization import ( + JSONSerializer, + AwsRestJsonDeserializer, + AwsRestJsonSerializer, +) + + +def test_aws_rest_json_serializer_should_initialize_and_serialize_data(): + """Test that serializer initializes and can serialize data through public API.""" + # Arrange + operation_name = "StartDurableExecution" + mock_serializer = Mock() + mock_operation_model = Mock() + mock_serializer.serialize_to_request.return_value = {"body": '{"test": "data"}'} + + # Act + serializer = AwsRestJsonSerializer( + operation_name, mock_serializer, mock_operation_model + ) + result = serializer.to_bytes({"test": "data"}) + + # Assert - Test public behavior only + assert isinstance(result, bytes) + assert result == b'{"test": "data"}' + mock_serializer.serialize_to_request.assert_called_once_with( + {"test": "data"}, mock_operation_model + ) + + +@patch("aws_durable_execution_sdk_python_testing.web.serialization.create_serializer") +@patch("aws_durable_execution_sdk_python_testing.web.serialization.ServiceModel") +@patch( + "aws_durable_execution_sdk_python_testing.web.serialization.botocore.loaders.Loader" +) +@patch("aws_durable_execution_sdk_python_testing.web.serialization.os.path.dirname") +def test_aws_rest_json_serializer_should_create_serializer_with_boto_components( + mock_dirname, + mock_loader_class, + mock_service_model_class, + mock_create_serializer, +): + """Test that create method sets up boto components correctly.""" + # Arrange + operation_name = "StartDurableExecution" + mock_package_path = "/path/to/package" + mock_dirname.return_value = mock_package_path + + mock_loader = Mock() + mock_loader_class.return_value = mock_loader + mock_raw_model = {"operations": {}} + mock_loader.load_service_model.return_value = mock_raw_model + + mock_service_model = Mock() + mock_service_model_class.return_value = mock_service_model + mock_operation_model = Mock() + mock_service_model.operation_model.return_value = mock_operation_model + + mock_serializer = Mock() + mock_create_serializer.return_value = mock_serializer + + # Act + result = AwsRestJsonSerializer.create(operation_name) + + # Assert - Test public behavior only + assert isinstance(result, AwsRestJsonSerializer) + + # Test that the created serializer can actually serialize data + mock_serializer.serialize_to_request.return_value = {"body": '{"test": "value"}'} + serialized_data = result.to_bytes({"test": "value"}) + assert isinstance(serialized_data, bytes) + assert serialized_data == b'{"test": "value"}' + + # Verify boto setup calls + mock_loader.load_service_model.assert_called_once_with("lambda", "service-2") + mock_service_model_class.assert_called_once_with(mock_raw_model) + mock_create_serializer.assert_called_once_with("rest-json", include_validation=True) + mock_service_model.operation_model.assert_called_once_with(operation_name) + + +@patch("aws_durable_execution_sdk_python_testing.web.serialization.create_serializer") +def test_aws_rest_json_serializer_should_raise_serialization_error_when_create_fails( + mock_create_serializer, +): + """Test that create method raises InvalidParameterValueException when boto setup fails.""" + # Arrange + operation_name = "StartDurableExecution" + mock_create_serializer.side_effect = Exception("Boto error") + + # Act & Assert + with pytest.raises(InvalidParameterValueException) as exc_info: + AwsRestJsonSerializer.create(operation_name) + + assert "Failed to create serializer for StartDurableExecution" in str( + exc_info.value + ) + + +def test_aws_rest_json_serializer_should_serialize_data_to_bytes(): + """Test that to_bytes method serializes data using boto serializer.""" + # Arrange + mock_serializer = Mock() + mock_operation_model = Mock() + serializer = AwsRestJsonSerializer("test", mock_serializer, mock_operation_model) + + test_data = {"key": "value"} + serialized_response = {"body": '{"key": "value"}'} + mock_serializer.serialize_to_request.return_value = serialized_response + + # Act + result = serializer.to_bytes(test_data) + + # Assert + assert result == b'{"key": "value"}' + mock_serializer.serialize_to_request.assert_called_once_with( + test_data, mock_operation_model + ) + + +def test_aws_rest_json_serializer_should_handle_bytes_body_in_serialization(): + """Test that to_bytes method handles bytes body from boto serializer.""" + # Arrange + mock_serializer = Mock() + mock_operation_model = Mock() + serializer = AwsRestJsonSerializer("test", mock_serializer, mock_operation_model) + + test_data = {"key": "value"} + serialized_response = {"body": b'{"key": "value"}'} + mock_serializer.serialize_to_request.return_value = serialized_response + + # Act + result = serializer.to_bytes(test_data) + + # Assert + assert result == b'{"key": "value"}' + + +def test_aws_rest_json_serializer_should_handle_empty_body_in_serialization(): + """Test that to_bytes method handles empty body from boto serializer.""" + # Arrange + mock_serializer = Mock() + mock_operation_model = Mock() + serializer = AwsRestJsonSerializer("test", mock_serializer, mock_operation_model) + + test_data = {"key": "value"} + serialized_response = {} + mock_serializer.serialize_to_request.return_value = serialized_response + + # Act + result = serializer.to_bytes(test_data) + + # Assert + assert result == b"" + + +def test_aws_rest_json_serializer_should_raise_error_when_serializer_not_initialized(): + """Test that to_bytes raises error when serializer is not initialized.""" + # Arrange + serializer = AwsRestJsonSerializer("test", None, None) + test_data = {"key": "value"} + + # Act & Assert + with pytest.raises(InvalidParameterValueException) as exc_info: + serializer.to_bytes(test_data) + + assert "Serializer not initialized for test" in str(exc_info.value) + + +def test_aws_rest_json_serializer_should_raise_error_when_serialization_fails(): + """Test that to_bytes raises InvalidParameterValueException when boto serialization fails.""" + # Arrange + mock_serializer = Mock() + mock_operation_model = Mock() + serializer = AwsRestJsonSerializer("test", mock_serializer, mock_operation_model) + + test_data = {"key": "value"} + mock_serializer.serialize_to_request.side_effect = Exception("Serialization failed") + + # Act & Assert + with pytest.raises(InvalidParameterValueException) as exc_info: + serializer.to_bytes(test_data) + + assert "Failed to serialize data for test" in str(exc_info.value) + + +def test_aws_rest_json_deserializer_should_initialize_and_deserialize_data(): + """Test that deserializer initializes and can deserialize data through public API.""" + # Arrange + operation_name = "StartDurableExecution" + mock_parser = Mock() + mock_operation_model = Mock() + mock_output_shape = Mock() + mock_operation_model.output_shape = mock_output_shape + mock_parser.parse.return_value = {"test": "data"} + + # Act + deserializer = AwsRestJsonDeserializer( + operation_name, mock_parser, mock_operation_model + ) + result = deserializer.from_bytes(b'{"test": "data"}') + + # Assert - Test public behavior only + assert isinstance(result, dict) + assert result == {"test": "data"} + expected_response_dict = { + "body": b'{"test": "data"}', + "headers": {"content-type": "application/json"}, + "status_code": 200, + } + mock_parser.parse.assert_called_once_with(expected_response_dict, mock_output_shape) + + +@patch("aws_durable_execution_sdk_python_testing.web.serialization.create_parser") +@patch("aws_durable_execution_sdk_python_testing.web.serialization.ServiceModel") +@patch( + "aws_durable_execution_sdk_python_testing.web.serialization.botocore.loaders.Loader" +) +@patch("aws_durable_execution_sdk_python_testing.web.serialization.os.path.dirname") +def test_aws_rest_json_deserializer_should_create_deserializer_with_boto_components( + mock_dirname, + mock_loader_class, + mock_service_model_class, + mock_create_parser, +): + """Test that create method sets up boto components correctly.""" + # Arrange + operation_name = "StartDurableExecution" + mock_package_path = "/path/to/package" + mock_dirname.return_value = mock_package_path + + mock_loader = Mock() + mock_loader_class.return_value = mock_loader + mock_raw_model = {"operations": {}} + mock_loader.load_service_model.return_value = mock_raw_model + + mock_service_model = Mock() + mock_service_model_class.return_value = mock_service_model + mock_operation_model = Mock() + mock_service_model.operation_model.return_value = mock_operation_model + + mock_parser = Mock() + mock_create_parser.return_value = mock_parser + + # Act + result = AwsRestJsonDeserializer.create(operation_name) + + # Assert - Test public behavior only + assert isinstance(result, AwsRestJsonDeserializer) + + # Test that the created deserializer can actually deserialize data + mock_output_shape = Mock() + mock_operation_model.output_shape = mock_output_shape + mock_parser.parse.return_value = {"test": "value"} + deserialized_data = result.from_bytes(b'{"test": "value"}') + assert isinstance(deserialized_data, dict) + assert deserialized_data == {"test": "value"} + + # Verify boto setup calls + mock_loader.load_service_model.assert_called_once_with("lambda", "service-2") + mock_service_model_class.assert_called_once_with(mock_raw_model) + mock_create_parser.assert_called_once_with("rest-json") + mock_service_model.operation_model.assert_called_once_with(operation_name) + + +@patch("aws_durable_execution_sdk_python_testing.web.serialization.create_parser") +def test_aws_rest_json_deserializer_should_raise_serialization_error_when_create_fails( + mock_create_parser, +): + """Test that create method raises InvalidParameterValueException when boto setup fails.""" + # Arrange + operation_name = "StartDurableExecution" + mock_create_parser.side_effect = Exception("Boto error") + + # Act & Assert + with pytest.raises(InvalidParameterValueException) as exc_info: + AwsRestJsonDeserializer.create(operation_name) + + assert "Failed to create deserializer for StartDurableExecution" in str( + exc_info.value + ) + + +def test_aws_rest_json_deserializer_should_deserialize_bytes_with_output_shape(): + """Test that from_bytes method deserializes data using boto parser with output shape.""" + # Arrange + mock_parser = Mock() + mock_operation_model = Mock() + mock_output_shape = Mock() + mock_operation_model.output_shape = mock_output_shape + deserializer = AwsRestJsonDeserializer("test", mock_parser, mock_operation_model) + + test_bytes = b'{"key": "value"}' + parsed_data = {"key": "value"} + mock_parser.parse.return_value = parsed_data + + # Act + result = deserializer.from_bytes(test_bytes) + + # Assert + assert result == parsed_data + expected_response_dict = { + "body": test_bytes, + "headers": {"content-type": "application/json"}, + "status_code": 200, + } + mock_parser.parse.assert_called_once_with(expected_response_dict, mock_output_shape) + + +def test_aws_rest_json_deserializer_should_deserialize_bytes_without_output_shape(): + """Test that from_bytes method falls back to JSON parsing when no output shape.""" + # Arrange + mock_parser = Mock() + mock_operation_model = Mock() + mock_operation_model.output_shape = None + deserializer = AwsRestJsonDeserializer("test", mock_parser, mock_operation_model) + + test_bytes = b'{"key": "value"}' + + # Act + result = deserializer.from_bytes(test_bytes) + + # Assert + assert result == {"key": "value"} + mock_parser.parse.assert_not_called() + + +def test_aws_rest_json_deserializer_should_raise_error_when_parser_not_initialized(): + """Test that from_bytes raises error when parser is not initialized.""" + # Arrange + deserializer = AwsRestJsonDeserializer("test", None, None) + test_bytes = b'{"key": "value"}' + + # Act & Assert + with pytest.raises(InvalidParameterValueException) as exc_info: + deserializer.from_bytes(test_bytes) + + assert "Parser not initialized for test" in str(exc_info.value) + + +def test_aws_rest_json_deserializer_should_raise_error_when_deserialization_fails(): + """Test that from_bytes raises InvalidParameterValueException when boto parsing fails.""" + # Arrange + mock_parser = Mock() + mock_operation_model = Mock() + mock_output_shape = Mock() + mock_operation_model.output_shape = mock_output_shape + deserializer = AwsRestJsonDeserializer("test", mock_parser, mock_operation_model) + + test_bytes = b'{"key": "value"}' + mock_parser.parse.side_effect = Exception("Parsing failed") + + # Act & Assert + with pytest.raises(InvalidParameterValueException) as exc_info: + deserializer.from_bytes(test_bytes) + + assert "Failed to deserialize data for test" in str(exc_info.value) + + +def test_aws_rest_json_deserializer_should_raise_error_when_json_parsing_fails(): + """Test that from_bytes raises InvalidParameterValueException when JSON parsing fails.""" + # Arrange + mock_parser = Mock() + mock_operation_model = Mock() + mock_operation_model.output_shape = None + deserializer = AwsRestJsonDeserializer("test", mock_parser, mock_operation_model) + + test_bytes = b"invalid json" + + # Act & Assert + with pytest.raises(InvalidParameterValueException) as exc_info: + deserializer.from_bytes(test_bytes) + + assert "Failed to deserialize data for test" in str(exc_info.value) + + +def test_serialize_simple_dict(): + """Test serialization of simple dictionary.""" + serializer = JSONSerializer() + data = {"key": "value", "number": 42} + result = serializer.to_bytes(data) + + expected = b'{"key":"value","number":42}' + assert result == expected + assert isinstance(result, bytes) + assert json.loads(result.decode("utf-8")) == data + + +def test_serialize_datetime(): + """Test serialization of datetime objects.""" + serializer = JSONSerializer() + now = datetime(2025, 11, 5, 16, 30, 9, 895000, tzinfo=timezone.utc) + data = {"timestamp": now} + + result = serializer.to_bytes(data) + expected = b'{"timestamp":1762360209.895}' + + assert result == expected + assert isinstance(result, bytes) + + deserialized = json.loads(result.decode("utf-8")) + assert deserialized["timestamp"] == now.timestamp() + + +def test_serialize_nested_datetime(): + """Test serialization of nested structures with datetime.""" + serializer = JSONSerializer() + now = datetime(2025, 11, 5, 16, 30, 9, tzinfo=timezone.utc) + data = { + "event": "user_login", + "timestamp": now, + "metadata": {"created_at": now, "updated_at": now}, + } + + result = serializer.to_bytes(data) + expected = ( + b'{"event":"user_login",' + b'"timestamp":1762360209.0,' + b'"metadata":{"created_at":1762360209.0,' + b'"updated_at":1762360209.0}}' + ) + + assert result == expected + + deserialized = json.loads(result.decode("utf-8")) + assert deserialized["timestamp"] == now.timestamp() + assert deserialized["metadata"]["created_at"] == now.timestamp() + + +def test_serialize_list_with_datetime(): + """Test serialization of list containing datetime.""" + serializer = JSONSerializer() + now = datetime(2025, 11, 5, 16, 30, 9, tzinfo=timezone.utc) + data = { + "events": [{"time": now, "action": "login"}, {"time": now, "action": "logout"}] + } + + result = serializer.to_bytes(data) + expected = ( + b'{"events":[' + b'{"time":1762360209.0,"action":"login"},' + b'{"time":1762360209.0,"action":"logout"}' + b"]}" + ) + + assert result == expected + + deserialized = json.loads(result.decode("utf-8")) + assert deserialized["events"][0]["time"] == now.timestamp() + assert deserialized["events"][1]["time"] == now.timestamp() + + +def test_serialize_mixed_types(): + """Test serialization of mixed data types.""" + serializer = JSONSerializer() + now = datetime(2025, 11, 5, 16, 30, 9, tzinfo=timezone.utc) + data = { + "string": "test", + "number": 42, + "float": 3.14, + "boolean": True, + "null": None, + "list": [1, 2, 3], + "datetime": now, + } + + result = serializer.to_bytes(data) + expected = ( + b'{"string":"test",' + b'"number":42,' + b'"float":3.14,' + b'"boolean":true,' + b'"null":null,' + b'"list":[1,2,3],' + b'"datetime":1762360209.0}' + ) + + assert result == expected + + deserialized = json.loads(result.decode("utf-8")) + assert deserialized["string"] == "test" + assert deserialized["number"] == 42 + assert deserialized["float"] == 3.14 + assert deserialized["boolean"] is True + assert deserialized["null"] is None + assert deserialized["list"] == [1, 2, 3] + assert deserialized["datetime"] == now.timestamp() + + +def test_serialize_returns_bytes(): + """Test that serialization returns bytes.""" + serializer = JSONSerializer() + data = {"test": "value"} + result = serializer.to_bytes(data) + expected = b'{"test":"value"}' + + assert result == expected + assert isinstance(result, bytes) + + +def test_serialize_non_serializable_object_raises_exception(): + """Test that non-serializable objects raise InvalidParameterValueException.""" + serializer = JSONSerializer() + + class CustomObject: + pass + + data = {"custom": CustomObject()} + + with pytest.raises(InvalidParameterValueException) as exc_info: + serializer.to_bytes(data) + + assert ( + "Failed to serialize data to JSON: Object of type CustomObject is not JSON serializable" + in str(exc_info.value) + ) + + +def test_serialize_circular_reference_raises_exception(): + """Test that circular references raise InvalidParameterValueException.""" + serializer = JSONSerializer() + data = {"key": "value"} + data["self"] = data # Create circular reference + + with pytest.raises(InvalidParameterValueException) as exc_info: + serializer.to_bytes(data) + + assert "Failed to serialize data to JSON" in str(exc_info.value) + + +def test_serialize_datetime_with_microseconds(): + """Test serialization of datetime with microseconds.""" + serializer = JSONSerializer() + now = datetime(2025, 11, 5, 16, 30, 9, 123456, tzinfo=timezone.utc) + data = {"timestamp": now} + + result = serializer.to_bytes(data) + expected = b'{"timestamp":1762360209.123456}' + + assert result == expected + + +def test_serialize_datetime_without_microseconds(): + """Test serialization of datetime without microseconds.""" + serializer = JSONSerializer() + now = datetime(2025, 11, 5, 16, 30, 9, tzinfo=timezone.utc) + data = {"timestamp": now} + + result = serializer.to_bytes(data) + expected = b'{"timestamp":1762360209.0}' + + assert result == expected + + +def test_serialize_multiple_datetimes(): + """Test multiple datetime objects.""" + serializer = JSONSerializer() + dt1 = datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc) + dt2 = datetime(2025, 12, 31, 23, 59, 59, tzinfo=timezone.utc) + + data = {"start": dt1, "end": dt2} + result = serializer.to_bytes(data) + expected = b'{"start":1735689600.0,"end":1767225599.0}' + + assert result == expected diff --git a/packages/aws-durable-execution-sdk-python-testing/tests/web/server_test.py b/packages/aws-durable-execution-sdk-python-testing/tests/web/server_test.py new file mode 100644 index 0000000..8a15990 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-testing/tests/web/server_test.py @@ -0,0 +1,252 @@ +"""Tests for web server implementation.""" + +from __future__ import annotations + +import logging +import threading +import time +from unittest.mock import Mock, patch + +import pytest + +from aws_durable_execution_sdk_python_testing.exceptions import ( + IllegalStateException, + InvalidParameterValueException, + ResourceNotFoundException, + SerializationError, + UnknownRouteError, +) +from aws_durable_execution_sdk_python_testing.web.models import HTTPResponse +from aws_durable_execution_sdk_python_testing.web.routes import ( + GetDurableExecutionRoute, + HealthRoute, + Router, + StartExecutionRoute, +) +from aws_durable_execution_sdk_python_testing.web.server import ( + RequestHandler, + WebServer, + WebServiceConfig, +) + + +def test_web_service_config_default_values(): + """Test that default configuration values are correct.""" + + config = WebServiceConfig() + + assert config.host == "localhost" + assert config.port == 5000 + assert config.log_level == logging.INFO + assert config.max_request_size == 10 * 1024 * 1024 + + +def test_web_service_config_custom_values(): + """Test that custom configuration values are set correctly.""" + + config = WebServiceConfig( + host="127.0.0.1", + port=9000, + log_level=logging.DEBUG, + max_request_size=5 * 1024 * 1024, + ) + + assert config.host == "127.0.0.1" + assert config.port == 9000 + assert config.log_level == logging.DEBUG + assert config.max_request_size == 5 * 1024 * 1024 + + +def test_web_service_config_frozen_dataclass(): + """Test that WebServiceConfig is immutable.""" + config = WebServiceConfig() + + with pytest.raises(AttributeError): + config.port = 9000 + + +def test_web_server_initialization(): + """Test that WebServer initializes correctly.""" + config = WebServiceConfig(port=0) # Use port 0 for testing + executor = Mock() + + with WebServer(config, executor) as server: + assert server.config == config + assert server.executor == executor + + +def test_web_server_context_manager(): + """Test that WebServer works as a context manager.""" + config = WebServiceConfig(port=0) + executor = Mock() + + # Test context manager entry and exit + with WebServer(config, executor) as server: + assert isinstance(server, WebServer) + assert server.config == config + assert server.executor == executor + + # Server should be cleaned up after context exit + + +def test_web_server_background_usage(): + """Test that server can be used in background thread for testing.""" + config = WebServiceConfig(port=0) + executor = Mock() + + with WebServer(config, executor) as server: + # Start server in background thread + server_thread = threading.Thread(target=server.serve_forever, daemon=True) + server_thread.start() + + # Give it a moment to start + time.sleep(0.1) + assert server_thread.is_alive() + + # Stop the server + server.shutdown() + + # Give it a moment to shutdown + time.sleep(0.1) + server_thread.join(timeout=1) + assert not server_thread.is_alive() + + +def test_web_server_has_executor_reference(): + """Test that WebServer stores executor reference correctly.""" + config = WebServiceConfig(port=0) + executor = Mock() + + with WebServer(config, executor) as server: + # Verify server has executor reference + assert server.executor == executor + + # Verify RequestHandler class is set correctly + + assert server.RequestHandlerClass == RequestHandler + + +def test_web_server_has_router_and_handlers(): + """Test that WebServer creates router and handlers correctly.""" + + executor = Mock() + config = WebServiceConfig(port=0) # Use port 0 to get any available port + + with WebServer(config, executor) as server: + # Verify router is created + assert server.router is not None + assert isinstance(server.router, Router) + + # Verify handlers are created + assert server.endpoint_handlers is not None + assert len(server.endpoint_handlers) > 0 + + # Verify specific handlers exist + assert StartExecutionRoute in server.endpoint_handlers + assert HealthRoute in server.endpoint_handlers + + # Verify handlers have executor reference + start_handler = server.endpoint_handlers[StartExecutionRoute] + assert start_handler.executor is executor + + +def test_web_server_all_routes_have_handlers(): + """Test that all routes in the router have corresponding handlers.""" + executor = Mock() + config = WebServiceConfig(port=0) # Use port 0 to get any available port + + with WebServer(config, executor) as server: + # Test that router can find routes for all handler types + handler_route_types = set(server.endpoint_handlers.keys()) + + # Test a sample of routes to verify router functionality + test_routes = [ + ("/start-durable-execution", "POST", StartExecutionRoute), + ("/health", "GET", HealthRoute), + ( + "/2025-12-01/durable-executions/test-arn", + "GET", + GetDurableExecutionRoute, + ), + ] + + for path, method, expected_route_type in test_routes: + # Verify router can find the route (tests public API) + found_route = server.router.find_route(path, method) + assert isinstance(found_route, expected_route_type) + + # Verify handler exists for this route type + assert expected_route_type in handler_route_types + + +def test_request_handler_exception_mapping(): + """Test that RequestHandler has proper exception handling capabilities.""" + + # Verify that all the required exception types are available for import + assert SerializationError is not None + assert InvalidParameterValueException is not None + assert ResourceNotFoundException is not None + assert IllegalStateException is not None + assert UnknownRouteError is not None + + +def test_http_response_create_error_from_exception(): + """Test HTTPResponse.create_error_from_exception method directly.""" + test_exception = InvalidParameterValueException("Test error message") + response = HTTPResponse.create_error_from_exception(test_exception) + + assert response.status_code == 400 + assert response.headers["Content-Type"] == "application/json" + + # AWS-compliant format without wrapper + expected_body = { + "Type": "InvalidParameterValueException", + "message": "Test error message", + } + assert response.body == expected_body + + +def test_request_handler_error_response_through_public_api(): + """Test error response handling through public do_POST method.""" + import io + from unittest.mock import MagicMock + + # Create a mock request handler with minimal setup + mock_server = MagicMock() + mock_server.executor = Mock() + mock_server.router = Mock() + mock_server.endpoint_handlers = {} + + # Mock the router to raise an exception + mock_server.router.find_route.side_effect = InvalidParameterValueException( + "Test error message" + ) + + # Create handler instance + with patch.object(RequestHandler, "__init__", return_value=None): + handler = RequestHandler.__new__(RequestHandler) + handler.executor = mock_server.executor + handler.router = mock_server.router + handler.endpoint_handlers = mock_server.endpoint_handlers + handler.path = "/test-path" + handler.headers = {"Content-Length": "0"} + handler.rfile = io.BytesIO(b"") + + # Mock the response sending + with patch.object(handler, "_send_response") as mock_send_response: + # Call the public method that should trigger error handling + handler.do_POST() + + # Verify _send_response was called with correct error response + mock_send_response.assert_called_once() + response = mock_send_response.call_args[0][0] + + assert response.status_code == 400 + assert response.headers["Content-Type"] == "application/json" + + # AWS-compliant format without wrapper + expected_body = { + "Type": "InvalidParameterValueException", + "message": "Test error message", + } + assert response.body == expected_body diff --git a/pyproject.toml b/pyproject.toml index b10df01..c744ed5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dependencies = [ "pytest", "pytest-cov", "opentelemetry-sdk>=1.20.0", - "aws_durable_execution_sdk_python_testing" + "aws-durable-execution-sdk-python-testing", ] [tool.hatch.envs.test.scripts] @@ -27,6 +27,7 @@ addopts = "-v --strict-markers --import-mode=importlib" testpaths = [ "packages/aws-durable-execution-sdk-python/tests", "packages/aws-durable-execution-sdk-python-otel/tests", + "packages/aws-durable-execution-sdk-python-testing/tests", "packages/aws-durable-execution-sdk-python-examples/test", ] markers = [ @@ -77,11 +78,29 @@ test = "pytest packages/aws-durable-execution-sdk-python-otel/tests {args}" cov = "pytest --cov-report=term-missing --cov-config=packages/aws-durable-execution-sdk-python-otel/pyproject.toml --cov=packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel packages/aws-durable-execution-sdk-python-otel/tests {args}" typecheck = "mypy --install-types --non-interactive packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel packages/aws-durable-execution-sdk-python-otel/tests" +[tool.hatch.envs.dev-testing] +workspace.members = [ + "packages/aws-durable-execution-sdk-python", + "packages/aws-durable-execution-sdk-python-testing", +] +dependencies = [ + "pytest", + "pytest-cov", + "coverage[toml]", + "mypy>=1.0.0", +] + +[tool.hatch.envs.dev-testing.scripts] +test = "pytest packages/aws-durable-execution-sdk-python-testing/tests {args}" +cov = "pytest --cov-report=term-missing --cov-config=packages/aws-durable-execution-sdk-python-testing/pyproject.toml --cov=packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing --cov-fail-under=95 packages/aws-durable-execution-sdk-python-testing/tests {args}" +typecheck = "mypy --install-types --non-interactive packages/aws-durable-execution-sdk-python-testing/src/aws_durable_execution_sdk_python_testing packages/aws-durable-execution-sdk-python-testing/tests" + [tool.hatch.envs.dev-examples] workspace.members = [ "packages/aws-durable-execution-sdk-python", "packages/aws-durable-execution-sdk-python-examples", "packages/aws-durable-execution-sdk-python-otel", + "packages/aws-durable-execution-sdk-python-testing", ] dependencies = [ "pytest", @@ -149,7 +168,11 @@ preview = true select = ["TID252"] # Enforce absolute imports (ban relative imports) [tool.ruff.lint.isort] -known-first-party = ["aws_durable_execution_sdk_python", "aws_durable_execution_sdk_python_otel"] +known-first-party = [ + "aws_durable_execution_sdk_python", + "aws_durable_execution_sdk_python_otel", + "aws_durable_execution_sdk_python_testing", +] force-single-line = false lines-after-imports = 2