diff --git a/packages/aws-durable-execution-sdk-python-examples/cli.py b/packages/aws-durable-execution-sdk-python-examples/cli.py index 4cf75cf9..ff2e8394 100755 --- a/packages/aws-durable-execution-sdk-python-examples/cli.py +++ b/packages/aws-durable-execution-sdk-python-examples/cli.py @@ -5,6 +5,7 @@ import logging import os import shutil +import subprocess import sys import time import zipfile @@ -35,6 +36,7 @@ def build_examples(): build_dir = Path(__file__).parent / "build" src_dir = Path(__file__).parent / "src" + packages_dir = Path(__file__).parent.parent logger.info("Building examples...") @@ -57,15 +59,29 @@ def build_examples(): logger.exception("Failed to copy testing library") return False - # Copy SDK source from the main SDK package - testing_src = ( - Path(__file__).parent.parent - / "aws-durable-execution-sdk-python" - / "src" - / "aws_durable_execution_sdk_python" - ) - logger.info("Copying SDK from %s", testing_src) - shutil.copytree(testing_src, build_dir / "aws_durable_execution_sdk_python") + # Install local packages so their runtime dependencies are included in + # the Lambda deployment package. + runtime_packages = [ + packages_dir / "aws-durable-execution-sdk-python", + packages_dir / "aws-durable-execution-sdk-python-otel", + ] + try: + subprocess.run( + [ + sys.executable, + "-m", + "pip", + "install", + "--upgrade", + "--target", + str(build_dir), + *[str(package) for package in runtime_packages], + ], + check=True, + ) + except subprocess.CalledProcessError: + logger.exception("Failed to install runtime dependencies") + return False # Copy example functions logger.info("Copying examples from %s", src_dir) diff --git a/packages/aws-durable-execution-sdk-python-examples/examples-catalog.json b/packages/aws-durable-execution-sdk-python-examples/examples-catalog.json index fe8ccfd7..d921bce8 100644 --- a/packages/aws-durable-execution-sdk-python-examples/examples-catalog.json +++ b/packages/aws-durable-execution-sdk-python-examples/examples-catalog.json @@ -613,6 +613,17 @@ "ExecutionTimeout": 300 }, "path": "./src/plugin/execution_with_plugin.py" + }, + { + "name": "Otel Plugin", + "description": "Test Otel plugin", + "handler": "execution_with_otel.handler", + "integration": true, + "durableConfig": { + "RetentionPeriodInDays": 7, + "ExecutionTimeout": 300 + }, + "path": "./src/plugin/execution_with_otel.py" } ] } diff --git a/packages/aws-durable-execution-sdk-python-examples/src/plugin/execution_with_otel.py b/packages/aws-durable-execution-sdk-python-examples/src/plugin/execution_with_otel.py new file mode 100644 index 00000000..6bceefff --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-examples/src/plugin/execution_with_otel.py @@ -0,0 +1,53 @@ +"""Demonstrates handler execution without any durable operations.""" + +from typing import Any + +from opentelemetry import trace + +from aws_durable_execution_sdk_python import StepContext +from aws_durable_execution_sdk_python.config import Duration, StepConfig, StepSemantics +from aws_durable_execution_sdk_python.context import ( + DurableContext, + durable_step, + durable_with_child_context, +) +from aws_durable_execution_sdk_python.execution import durable_execution + +from aws_durable_execution_sdk_python_otel import DurableExecutionOtelPlugin + + +# use default provider +tracer_provider = trace.get_tracer_provider() +otel = DurableExecutionOtelPlugin(tracer_provider) + + +@durable_step +def add_numbers(_step_context: StepContext, a: int, b: int) -> int: + return a + b + + +@durable_with_child_context +def add_numbers_in_child(child_context: DurableContext, a: int, b: int): + result: int = child_context.step( + add_numbers(a, b), + name=f"step-{b}", + ) + child_context.wait( + Duration.from_seconds(1), + name=f"wait-{b}", + ) + return result + + +@durable_execution(plugins=[otel]) +def handler(_event: Any, context: DurableContext) -> int: + result = 0 + for i in range(3): + result += context.run_in_child_context( + add_numbers_in_child(6, i), + name=f"context-{i}", + ) + return context.step( + add_numbers(result, 2), + name="final-step", + ) diff --git a/packages/aws-durable-execution-sdk-python-examples/template.yaml b/packages/aws-durable-execution-sdk-python-examples/template.yaml index bf91637f..d56b1af8 100644 --- a/packages/aws-durable-execution-sdk-python-examples/template.yaml +++ b/packages/aws-durable-execution-sdk-python-examples/template.yaml @@ -995,6 +995,24 @@ "ExecutionTimeout": 300 } } + }, + "ExecutionWithOtel": { + "Type": "AWS::Serverless::Function", + "Properties": { + "CodeUri": "build/", + "Handler": "execution_with_otel.handler", + "Description": "Test Otel plugin", + "Role": { + "Fn::GetAtt": [ + "DurableFunctionRole", + "Arn" + ] + }, + "DurableConfig": { + "RetentionPeriodInDays": 7, + "ExecutionTimeout": 300 + } + } } } } \ No newline at end of file diff --git a/packages/aws-durable-execution-sdk-python-examples/test/plugin/test_otel_plugin.py b/packages/aws-durable-execution-sdk-python-examples/test/plugin/test_otel_plugin.py new file mode 100644 index 00000000..8599f599 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-examples/test/plugin/test_otel_plugin.py @@ -0,0 +1,24 @@ +"""Tests for step example.""" + +import pytest +from aws_durable_execution_sdk_python.execution import InvocationStatus + +from src.plugin import execution_with_otel +from test.conftest import deserialize_operation_payload + + +@pytest.mark.example +@pytest.mark.durable_execution( + handler=execution_with_otel.handler, + lambda_function_name="Otel Plugin", +) +def test_plugin(durable_runner): + """Test basic step example.""" + with durable_runner: + result = durable_runner.run(input="{}", timeout=10) + + assert result.status is InvocationStatus.SUCCEEDED + assert deserialize_operation_payload(result.result) == 23 + + step_result = result.get_step("final-step") + assert deserialize_operation_payload(step_result.result) == 23 diff --git a/packages/aws-durable-execution-sdk-python-otel/pyproject.toml b/packages/aws-durable-execution-sdk-python-otel/pyproject.toml index bd9a1231..72d64176 100644 --- a/packages/aws-durable-execution-sdk-python-otel/pyproject.toml +++ b/packages/aws-durable-execution-sdk-python-otel/pyproject.toml @@ -25,6 +25,8 @@ dependencies = [ "aws-durable-execution-sdk-python>=1.5.0", "opentelemetry-api>=1.20.0", "opentelemetry-sdk>=1.20.0", + "opentelemetry-exporter-otlp", + "opentelemetry-propagator-aws-xray", ] [project.urls] diff --git a/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/__init__.py b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/__init__.py index 63b1b9cc..7ba31caa 100644 --- a/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/__init__.py +++ b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/__init__.py @@ -1,8 +1,24 @@ """OpenTelemetry instrumentation for AWS Lambda Durable Executions Python SDK.""" from aws_durable_execution_sdk_python_otel.__about__ import __version__ +from aws_durable_execution_sdk_python_otel.context_extractors import ( + ContextExtractor, + w3c_client_context_extractor, + xray_context_extractor, +) +from aws_durable_execution_sdk_python_otel.deterministic_id_generator import ( + DeterministicIdGenerator, +) +from aws_durable_execution_sdk_python_otel.plugin import ( + DurableExecutionOtelPlugin, +) __all__ = [ "__version__", + "ContextExtractor", + "DeterministicIdGenerator", + "DurableExecutionOtelPlugin", + "w3c_client_context_extractor", + "xray_context_extractor", ] diff --git a/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/context_extractors.py b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/context_extractors.py new file mode 100644 index 00000000..79029fe5 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/context_extractors.py @@ -0,0 +1,40 @@ +"""Context extractors for propagating trace context into durable executions.""" + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Callable + +from opentelemetry import context as otel_context, propagate + + +if TYPE_CHECKING: + from opentelemetry.context import Context + + from aws_durable_execution_sdk_python.plugin import InvocationStartInfo + +ContextExtractor = Callable[["InvocationStartInfo"], "Context"] + + +def xray_context_extractor(info: "InvocationStartInfo") -> "Context": + """Read the X-Ray trace header from the _X_AMZN_TRACE_ID environment variable. + + The durable execution backend propagates the same Root trace ID to every + invocation, so all invocations share one traceId. + """ + trace_header = os.environ.get("_X_AMZN_TRACE_ID") + if not trace_header: + return otel_context.get_current() + return propagate.extract( + carrier={"X-Amzn-Trace-Id": trace_header}, + context=otel_context.get_current(), + ) + + +def w3c_client_context_extractor(info: "InvocationStartInfo") -> "Context": + """Read W3C traceparent from context.clientContext.custom.traceparent. + + Requires the backend clientContext propagation to be enabled. + This extractor is a placeholder for when backend propagation is supported. + """ + return otel_context.get_current() diff --git a/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/deterministic_id_generator.py b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/deterministic_id_generator.py new file mode 100644 index 00000000..14753bec --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/deterministic_id_generator.py @@ -0,0 +1,109 @@ +"""Deterministic ID generator for OpenTelemetry spans in durable executions.""" + +from __future__ import annotations + +import hashlib +import os +import re +from datetime import datetime, UTC + +from opentelemetry.sdk.trace import IdGenerator, RandomIdGenerator + +HASH_LENGTH = 16 +HASHED_ID_PATTERN = re.compile(r"^[0-9a-f]{16}$") + + +def _parse_xray_root_trace_id(trace_header: str | None) -> str | None: + """Parse the Root trace ID from an X-Ray trace header string. + + The header format is: + Root=1-<8 hex>-<24 hex>;Parent=<16 hex>;Sampled=0|1 + + Returns the root value (e.g. "1-5759e988-bd862e3fe1be46a994272793") + or None if the header is missing or malformed. + """ + if not trace_header: + return None + match = re.search(r"Root=(1-[0-9a-fA-F]{8}-[0-9a-fA-F]{24})", trace_header) + return match.group(1) if match else None + + +def _xray_trace_id_to_otel(xray_trace_id: str) -> int: + """Convert an X-Ray trace ID to the W3C/OpenTelemetry 32-char hex format. + + X-Ray format: "1-<8hex>-<24hex>" (36 chars with prefix and dashes) + OTel format: "<8hex><24hex>" (32 lowercase hex chars) + """ + otel_id = xray_trace_id.replace("1-", "", 1).replace("-", "").lower() + return int(otel_id, 16) + + +def _to_otel_trace_id(execution_arn: str, start_timestamp: datetime | None) -> int: + """Build an OTel-compatible trace ID (128 bits) + + First attempts to read the trace ID from the _X_AMZN_TRACE_ID environment + variable that Lambda populates on each invocation. This ties the durable + execution spans to the same trace that X-Ray is already tracking. + + Falls back to generating a deterministic trace ID from the execution ARN + and timestamp when the environment variable is not set (e.g. in tests or + non-Lambda environments). + """ + env_trace_id = _parse_xray_root_trace_id(os.environ.get("_X_AMZN_TRACE_ID")) + if env_trace_id: + return _xray_trace_id_to_otel(env_trace_id) + + # Fallback: deterministic ID from execution ARN + timestamp + time_part = format(int((start_timestamp or datetime.now(UTC)).timestamp()), "08x") + hash_part = hashlib.blake2b(execution_arn.encode()).hexdigest()[:24] # noqa: S324 + return int(f"{time_part}{hash_part}", 16) + + +def operation_id_to_span_id(operation_id: str) -> int: + """Derive a deterministic span ID (64 bits) from an operation ID.""" + hashed_operation_id = hashlib.blake2b(operation_id.encode()).hexdigest()[:16] + return int(hashed_operation_id, 16) + + +class DeterministicIdGenerator(IdGenerator): + """An ID generator that produces deterministic span IDs when a pending + operation ID is set, and random IDs otherwise. + + Trace IDs are deterministic when an execution ARN is set, ensuring all + invocations of the same durable execution share a single trace. + + Trace IDs embed a real timestamp so they satisfy the X-Ray format + requirement (first 8 hex chars = Unix epoch seconds). + """ + + def __init__(self) -> None: + self._next_span_id: int | None = None + self._execution_trace_id: int | None = None + self._random_id_generator = RandomIdGenerator() + + def set_next_span_id(self, span_id: int | None) -> None: + """Set the operation ID to use for the next span's ID. + + After one span is created, it resets to random. + """ + self._next_span_id = span_id + + def set_trace_id( + self, execution_arn: str, start_timestamp: datetime | None + ) -> None: + """Compute and cache the deterministic trace ID for this execution. + + Args: + execution_arn: The durable execution ARN (used for the hash portion). + start_timestamp: start time of invocation + """ + self._execution_trace_id = _to_otel_trace_id(execution_arn, start_timestamp) + + def generate_trace_id(self) -> int: + """Generate a 128-bit trace ID.""" + return self._execution_trace_id or self._random_id_generator.generate_trace_id() + + def generate_span_id(self) -> int: + """Generate a 64-bit span ID.""" + span_id, self._next_span_id = self._next_span_id, None + return span_id or self._random_id_generator.generate_span_id() diff --git a/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/plugin.py b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/plugin.py new file mode 100644 index 00000000..a2dda254 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/plugin.py @@ -0,0 +1,436 @@ +"""OpenTelemetry instrumentation plugin for AWS Durable Execution SDK.""" + +from __future__ import annotations + +import datetime +import logging +import threading +from typing import TYPE_CHECKING, Any + +from opentelemetry import trace, context +from opentelemetry.context import Context +from opentelemetry.sdk.trace.sampling import TraceIdRatioBased +from opentelemetry.trace import ( + Tracer, + StatusCode, + SpanContext, + Span, + TracerProvider, + Link, + TraceFlags, +) + +from aws_durable_execution_sdk_python.lambda_service import OperationType +from aws_durable_execution_sdk_python.plugin import ( + DurableInstrumentationPlugin, + InvocationEndInfo, + InvocationStartInfo, + OperationEndInfo, + OperationStartInfo, + UserFunctionStartInfo, + UserFunctionEndInfo, + UserFunctionOutcome, +) +from aws_durable_execution_sdk_python_otel.context_extractors import ( + ContextExtractor, + xray_context_extractor, +) +from aws_durable_execution_sdk_python_otel.deterministic_id_generator import ( + DeterministicIdGenerator, + operation_id_to_span_id, +) + +if TYPE_CHECKING: + pass + + +logger = logging.getLogger(__name__) + + +def _to_otel_timestamp(dt: datetime.datetime | None) -> int | None: + """Convert a datetime to OTel timestamp (nanoseconds since epoch), or None.""" + if dt is None: + dt = datetime.datetime.now(datetime.UTC) + return int(dt.timestamp() * 1_000_000_000) + + +class DurableExecutionOtelPlugin(DurableInstrumentationPlugin): + """OpenTelemetry instrumentation plugin for durable executions. + + The plugin creates spans for Lambda invocations, durable operations, and + user-function attempts. Trace IDs are derived from the durable execution ARN + and execution start time so each replay or resumed invocation contributes to + the same trace. + + Operation IDs are converted into deterministic span IDs. The first observed + span for an operation uses that deterministic ID; later continuation spans + use newly generated span IDs and link back to the deterministic span ID so + trace viewers can relate retries and replay-created terminal spans to the + original logical operation. + + Args: + trace_provider: OpenTelemetry tracer provider used to create spans. + context_extractor: Optional extractor for upstream context. Defaults to + AWS X-Ray header extraction. + sampling_rate: Ratio used by ``TraceIdRatioBased`` sampling. + instrument_name: Instrumentation scope name registered with the tracer. + """ + + DEFAULT_INSTRUMENT_NAME = "aws-durable-execution-sdk-python" + + def __init__( + self, + trace_provider: TracerProvider, + context_extractor: ContextExtractor | None = None, + sampling_rate: float = 1.0, + instrument_name: str = DEFAULT_INSTRUMENT_NAME, + ) -> None: + """Initialize the plugin with an OpenTelemetry tracer provider. + + The provided tracer provider is configured with this plugin's + deterministic ID generator and sampling strategy so spans for a durable + execution share stable trace and logical operation identifiers. + """ + self._context_extractor: ContextExtractor = ( + context_extractor or xray_context_extractor + ) + + self._id_generator: DeterministicIdGenerator = DeterministicIdGenerator() + + self._provider = trace_provider + self._provider.id_generator = self._id_generator + self._provider.sampler = TraceIdRatioBased(sampling_rate) + self._tracer: Tracer = self._provider.get_tracer(instrument_name) + + # per invocation status: + self._execution_arn = "" + self._extracted_context: Context | None = None + # Maps operation ID (None for root) to the active span. + self._operation_spans: dict[str | None, Span] = {} + self._operation_spans_lock = threading.RLock() + + def _set_span(self, operation_id: str | None, span: Span) -> None: + """Register the active span for an operation ID.""" + with self._operation_spans_lock: + self._operation_spans[operation_id] = span + + def _delete_span(self, operation_id: str | None) -> None: + """Remove the active span for an operation ID if one is stored.""" + with self._operation_spans_lock: + self._operation_spans.pop(operation_id, None) + + def _get_span(self, operation_id: str | None) -> Span | None: + """Return the active span for an operation ID, if present.""" + with self._operation_spans_lock: + return self._operation_spans.get(operation_id) + + # ------------------------------------------------------------------ + # Context resolution + # ------------------------------------------------------------------ + def _resolve_parent_span(self, parent_id: str | None = None) -> Span: + """Resolve the active parent span for a durable operation. + + ``parent_id`` is ``None`` for root-level durable operations beneath the + invocation span. For child operations, the parent operation must already + have an active span in the current invocation. + + Raises: + ValueError: If the requested parent span is not active. + """ + + # Check if we already have a context for this parent + existing_span = self._get_span(parent_id) + if existing_span is not None: + return existing_span + + raise ValueError("No parent span found") + + def _start_span( + self, + operation_id: str | None, + name: str, + attributes: dict[str, str], + start_time: datetime.datetime | None = None, + parent_span: Span | None = None, + existed: bool = False, + ) -> Span: + """Start and store a span for an invocation or durable operation. + + Args: + operation_id: Durable operation ID. ``None`` is used for the root + invocation span. + name: Span display name. + attributes: Span attributes. + start_time: Optional durable start timestamp. + parent_span: Active parent span. When omitted, the extracted + upstream context is used as the parent. + existed: Whether the logical operation already had a previous span. + Continuation spans link back to the deterministic span ID for + the operation while using a fresh generated span ID. + + Returns: + The started OpenTelemetry span. + """ + logger.info( + "starting a span: operation_id=%s, name=%s, parent_span=%s", + operation_id, + name, + parent_span, + ) + with self._operation_spans_lock: + if existed: + if not operation_id: + raise ValueError("operation id is required") + span_id = operation_id_to_span_id(operation_id) + links = [ + Link( + context=SpanContext( + trace_id=self._id_generator.generate_trace_id(), + span_id=span_id, + is_remote=False, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + ) + ) + ] + self._id_generator.set_next_span_id(None) + else: + links = [] + + self._id_generator.set_next_span_id( + operation_id_to_span_id(operation_id) if operation_id else None + ) + if parent_span is None: + # root span + parent_context = self._extracted_context + else: + parent_context = trace.set_span_in_context( + parent_span, self._extracted_context + ) + span = self._tracer.start_span( + name=name, + attributes=attributes, + start_time=_to_otel_timestamp(start_time), + context=parent_context, + links=links, + ) + self._operation_spans[operation_id] = span + + logger.info("started a span: %s", span) + return span + + def _end_span( + self, operation_id: str | None, end_timestamp: datetime.datetime | None = None + ): + """End and unregister the active span for an operation ID. + + Args: + operation_id: Durable operation ID, or ``None`` for the invocation + span. + end_timestamp: Optional durable end timestamp to use as the span end + time. When omitted, OpenTelemetry uses the current time. + """ + logger.info("ending a span for operation: %s", operation_id) + with self._operation_spans_lock: + span = self._operation_spans.pop(operation_id, None) + if span: + # the span is not going to be populated if it has the same end_time and start_time + end_time = _to_otel_timestamp(end_timestamp) if end_timestamp else None + span.end(end_time=end_time) + logger.info("ended otel span: %s", span) + + # ------------------------------------------------------------------ + # Plugin lifecycle callbacks + # ------------------------------------------------------------------ + def on_invocation_start(self, info: InvocationStartInfo) -> None: + """Called at the start of each invocation. Creates the invocation span.""" + logger.info("Invocation started: %s", info) + self._execution_arn = info.execution_arn or "" + self._extracted_context = self._context_extractor(info) + self._id_generator.set_trace_id(self._execution_arn, info.start_time) + + self._start_span( + operation_id=None, + name=f"invocation", + attributes=self._extract_attributes(info), + ) + + def on_invocation_end(self, info: InvocationEndInfo) -> None: + """Called at the end of each invocation. Ends the invocation span and flushes.""" + logger.info(f"Invocation ended: {info}") + end_time = info.end_time + # end all pending spans + with self._operation_spans_lock: + operation_ids = list(self._operation_spans.keys()) + for operation_id in operation_ids: + if operation_id: + self._end_span(operation_id, end_time) + + # end the invocation span + self._end_span(None, end_time) + + # Clear all per-invocation state to prevent leaks across warm Lambda reuses + self._execution_arn = "" + self._extracted_context = None + with self._operation_spans_lock: + self._operation_spans = {} + + # Flush before Lambda freeze + if hasattr(self._provider, "force_flush"): + self._provider.force_flush() + + def on_operation_start(self, info: OperationStartInfo) -> None: + """Called when an operation begins. Creates a span for the operation.""" + logger.info(f"Operation started: {info}") + if info.operation_type in [OperationType.CONTEXT, OperationType.STEP]: + # Context and Step operations are tracked using on_user_function_start + return + parent_span = self._resolve_parent_span(info.parent_id) + attributes = self._extract_attributes(info) + + self._start_span( + operation_id=info.operation_id, + name=info.name or info.operation_id, + attributes=attributes, + start_time=info.start_time, + parent_span=parent_span, + ) + + def on_operation_end(self, info: OperationEndInfo) -> None: + """Called when an operation reaches a terminal durable status. + + Non-user-function operations are started by ``on_operation_start``. If + an operation end is observed without a matching in-memory span, this + invocation is completing an operation that began earlier, so a short + continuation span is created and linked to the deterministic logical + operation span before being ended. + """ + logger.info(f"Operation ended: {info}") + if info.operation_type in [OperationType.CONTEXT, OperationType.STEP]: + # Context and Step operations are tracked using on_user_function_end + return + span = self._get_span(info.operation_id) + if not span: + # the span was not started in the current invocation, so we need to + # create a new one that links to the previous one + parent_span = self._resolve_parent_span(info.parent_id) + attributes = self._extract_attributes(info) + span = self._start_span( + operation_id=info.operation_id, + name=info.name or info.operation_id, + attributes=attributes, + start_time=datetime.datetime.now(datetime.UTC), + parent_span=parent_span, + existed=True, + ) + + if info.error: + span.set_status(StatusCode.ERROR, info.error.message or "") + span.record_exception( + Exception(info.error.message or info.error.type or "Unknown error") + ) + else: + span.set_status(StatusCode.OK) + + end_timestamp = info.end_time + if end_timestamp is not None and end_timestamp == info.start_time: + end_timestamp += datetime.timedelta(microseconds=1) + self._end_span(info.operation_id, end_timestamp) + + def on_user_function_start(self, info: UserFunctionStartInfo) -> None: + """Called when a context or step operation starts user code. + + This callback runs inside the thread that executes user code so the + started span can be attached to the OpenTelemetry context for any + instrumentation used by that code. Attempts after the first are emitted + as continuation spans linked to the logical operation span. + + Args: + info: Information about the operation attempt. + """ + logger.info("User function started: %s", info) + # Context and Step operations are tracked using on_user_function_start + if info.operation_type not in [OperationType.CONTEXT, OperationType.STEP]: + raise RuntimeError( + "on_user_function_start should only be called for CONTEXT and STEP operations" + ) + parent_span = self._resolve_parent_span(info.parent_id) + attributes = self._extract_attributes(info) + span = self._start_span( + operation_id=info.operation_id, + name=info.name or info.operation_id, + attributes=attributes, + start_time=info.start_time, + parent_span=parent_span, + existed=info.attempt != 1, + ) + context.attach(trace.set_span_in_context(span, self._extracted_context)) + + def on_user_function_end(self, info: UserFunctionEndInfo) -> None: + """Called when a context or step operation finishes user code. + + This callback records the final attempt status, captures exceptions for + failed attempts, and ends the span that was attached in + ``on_user_function_start``. + + Args: + info: Information about the operation attempt. + """ + logger.info("User function ended: %s", info) + if info.operation_type not in [OperationType.CONTEXT, OperationType.STEP]: + raise RuntimeError( + "on_user_function_end should only be called for CONTEXT and STEP operations" + ) + # key = f"{info.operation_id}-{int(info.start_time.timestamp())}" + span = self._get_span(info.operation_id) + if not span: + raise RuntimeError( + "on_user_function_end called without matching on_user_function_start" + ) + + span.set_attributes(self._extract_attributes(info)) + if info.outcome is UserFunctionOutcome.FAILED: + span.set_status(StatusCode.ERROR, info.error.message if info.error else "") + span.record_exception( + Exception( + (info.error.message or info.error.type) + if info.error + else "Unknown error" + ) + ) + elif info.outcome is UserFunctionOutcome.SUCCEEDED: + span.set_status(StatusCode.OK) + else: + # PENDING + span.set_status(StatusCode.UNSET, "PENDING") + + end_timestamp = info.end_time + if end_timestamp is not None and end_timestamp == info.start_time: + end_timestamp += datetime.timedelta(microseconds=1) + self._end_span(info.operation_id, end_timestamp) + # We don't call context.detach because the next operation will override it anyway + + def _extract_attributes(self, info: Any) -> dict[str, str]: + """Extract durable execution fields as OpenTelemetry span attributes. + + Args: + info: Invocation, operation, or user-function callback payload. + + Returns: + A dictionary of durable execution attributes suitable for a span. + """ + attributes: dict[str, str] = { + "durable.execution.arn": self._execution_arn, + } + + if hasattr(info, "operation_id") and info.operation_id is not None: + attributes["durable.operation.id"] = info.operation_id + if hasattr(info, "operation_type") and info.operation_type is not None: + attributes["durable.operation.type"] = info.operation_type.value + if hasattr(info, "name") and info.name is not None: + attributes["durable.operation.name"] = info.name + if hasattr(info, "attempt") and info.attempt is not None: + attributes["durable.attempt.number"] = info.attempt + if hasattr(info, "outcome") and info.outcome is not None: + attributes["durable.attempt.outcome"] = info.outcome.value + + return attributes diff --git a/packages/aws-durable-execution-sdk-python-otel/tests/test_context_extractors.py b/packages/aws-durable-execution-sdk-python-otel/tests/test_context_extractors.py new file mode 100644 index 00000000..d150af92 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-otel/tests/test_context_extractors.py @@ -0,0 +1,66 @@ +"""Tests for trace context extraction helpers.""" + +from __future__ import annotations + +from opentelemetry.context import Context + +from aws_durable_execution_sdk_python_otel import context_extractors + + +def test_xray_context_extractor_returns_current_context_without_trace_header( + monkeypatch, +): + """Verify absent X-Ray trace headers leave the active context unchanged.""" + current_context = Context({"durable": "current"}) + monkeypatch.delenv("_X_AMZN_TRACE_ID", raising=False) + monkeypatch.setattr( + context_extractors.otel_context, + "get_current", + lambda: current_context, + ) + + assert context_extractors.xray_context_extractor(object()) is current_context + + +def test_xray_context_extractor_extracts_trace_header_from_environment( + monkeypatch, +): + """Verify X-Ray trace headers are passed through OpenTelemetry propagation.""" + trace_header = ( + "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + ) + current_context = Context({"durable": "current"}) + extracted_context = Context({"durable": "extracted"}) + extract_calls = [] + monkeypatch.setenv("_X_AMZN_TRACE_ID", trace_header) + monkeypatch.setattr( + context_extractors.otel_context, + "get_current", + lambda: current_context, + ) + + def extract(*, carrier, context): + extract_calls.append({"carrier": carrier, "context": context}) + return extracted_context + + monkeypatch.setattr(context_extractors.propagate, "extract", extract) + + assert context_extractors.xray_context_extractor(object()) is extracted_context + assert extract_calls == [ + { + "carrier": {"X-Amzn-Trace-Id": trace_header}, + "context": current_context, + } + ] + + +def test_w3c_client_context_extractor_returns_current_context(monkeypatch): + """Verify the placeholder W3C extractor leaves the active context unchanged.""" + current_context = Context({"durable": "current"}) + monkeypatch.setattr( + context_extractors.otel_context, + "get_current", + lambda: current_context, + ) + + assert context_extractors.w3c_client_context_extractor(object()) is current_context diff --git a/packages/aws-durable-execution-sdk-python-otel/tests/test_deterministic_id_generator.py b/packages/aws-durable-execution-sdk-python-otel/tests/test_deterministic_id_generator.py new file mode 100644 index 00000000..3f4e53f7 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-otel/tests/test_deterministic_id_generator.py @@ -0,0 +1,134 @@ +"""Tests for deterministic OpenTelemetry ID generation.""" + +from __future__ import annotations + +from datetime import UTC, datetime + +from aws_durable_execution_sdk_python_otel.deterministic_id_generator import ( + HASHED_ID_PATTERN, + DeterministicIdGenerator, + _parse_xray_root_trace_id, + _to_otel_trace_id, + _xray_trace_id_to_otel, + operation_id_to_span_id, +) + + +def test_parse_xray_root_trace_id_returns_root_from_header(): + """Verify X-Ray Root trace ID parsing ignores other header fields.""" + trace_header = ( + "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1" + ) + + assert ( + _parse_xray_root_trace_id(trace_header) == "1-5759e988-bd862e3fe1be46a994272793" + ) + + +def test_parse_xray_root_trace_id_returns_none_for_missing_or_malformed_header(): + """Verify absent or malformed X-Ray headers are ignored.""" + assert _parse_xray_root_trace_id(None) is None + assert _parse_xray_root_trace_id("") is None + assert _parse_xray_root_trace_id("Parent=53995c3f42cd8ad8;Sampled=1") is None + assert ( + _parse_xray_root_trace_id( + "Root=1-5759e988-not-enough-hex;Parent=53995c3f42cd8ad8" + ) + is None + ) + + +def test_xray_trace_id_to_otel_removes_xray_prefix_and_normalizes_case(): + """Verify X-Ray trace IDs are converted into OTel-compatible integers.""" + trace_id = "1-5759E988-BD862E3FE1BE46A994272793" + + assert _xray_trace_id_to_otel(trace_id) == int( + "5759e988bd862e3fe1be46a994272793", 16 + ) + + +def test_to_otel_trace_id_uses_xray_root_header_when_available(monkeypatch): + """Verify Lambda's X-Ray trace header takes precedence over fallback IDs.""" + monkeypatch.setenv( + "_X_AMZN_TRACE_ID", + "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1", + ) + start_timestamp = datetime(2024, 1, 2, 3, 4, 5, tzinfo=UTC) + + assert _to_otel_trace_id("different-execution-arn", start_timestamp) == int( + "5759e988bd862e3fe1be46a994272793", 16 + ) + + +def test_to_otel_trace_id_falls_back_to_timestamp_and_execution_arn(monkeypatch): + """Verify fallback trace IDs are deterministic for the same execution.""" + monkeypatch.delenv("_X_AMZN_TRACE_ID", raising=False) + execution_arn = "arn:aws:lambda:us-west-2:123456789012:function:workflow:$LATEST" + start_timestamp = datetime(2024, 1, 2, 3, 4, 5, tzinfo=UTC) + + assert _to_otel_trace_id(execution_arn, start_timestamp) == int( + "65937d253aa8c3f7ffe36c50d65b1a6d", 16 + ) + + +def test_operation_id_to_span_id_returns_deterministic_64_bit_id(): + """Verify operation IDs map to stable 64-bit span IDs.""" + assert operation_id_to_span_id("my-operation") == int("ab1f94a6d3c668f3", 16) + + +def test_deterministic_id_generator_returns_cached_trace_id(monkeypatch): + """Verify trace IDs are cached after being set for an execution.""" + monkeypatch.delenv("_X_AMZN_TRACE_ID", raising=False) + generator = DeterministicIdGenerator() + + generator.set_trace_id( + "arn:aws:lambda:us-west-2:123456789012:function:workflow:$LATEST", + datetime(2024, 1, 2, 3, 4, 5, tzinfo=UTC), + ) + + assert generator.generate_trace_id() == int("65937d253aa8c3f7ffe36c50d65b1a6d", 16) + + +def test_deterministic_id_generator_falls_back_to_random_trace_id(monkeypatch): + """Verify trace IDs are random until an execution trace ID is set.""" + expected_trace_id = int("1" * 32, 16) + generator = DeterministicIdGenerator() + monkeypatch.setattr( + generator._random_id_generator, + "generate_trace_id", + lambda: expected_trace_id, + ) + + assert generator.generate_trace_id() == expected_trace_id + + +def test_deterministic_id_generator_uses_next_span_id_once(monkeypatch): + """Verify a configured span ID only applies to the next generated span.""" + deterministic_span_id = int("2" * 16, 16) + random_span_id = int("3" * 16, 16) + generator = DeterministicIdGenerator() + monkeypatch.setattr( + generator._random_id_generator, + "generate_span_id", + lambda: random_span_id, + ) + + generator.set_next_span_id(deterministic_span_id) + + assert generator.generate_span_id() == deterministic_span_id + assert generator.generate_span_id() == random_span_id + + +def test_deterministic_id_generator_accepts_cleared_next_span_id(monkeypatch): + """Verify clearing the next span ID preserves random span generation.""" + expected_span_id = int("4" * 16, 16) + generator = DeterministicIdGenerator() + monkeypatch.setattr( + generator._random_id_generator, + "generate_span_id", + lambda: expected_span_id, + ) + + generator.set_next_span_id(None) + + assert generator.generate_span_id() == expected_span_id diff --git a/packages/aws-durable-execution-sdk-python-otel/tests/test_plugin.py b/packages/aws-durable-execution-sdk-python-otel/tests/test_plugin.py new file mode 100644 index 00000000..5fb8a430 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-otel/tests/test_plugin.py @@ -0,0 +1,225 @@ +"""Tests for the OpenTelemetry durable execution plugin.""" + +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor +from datetime import UTC, datetime + +from opentelemetry.context import Context +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + +from aws_durable_execution_sdk_python.lambda_service import ( + InvocationStatus, + OperationStatus, + OperationType, +) +from aws_durable_execution_sdk_python.plugin import ( + InvocationEndInfo, + InvocationStartInfo, + OperationEndInfo, + OperationStartInfo, + UserFunctionEndInfo, + UserFunctionOutcome, + UserFunctionStartInfo, +) +from aws_durable_execution_sdk_python_otel.deterministic_id_generator import ( + operation_id_to_span_id, +) +from aws_durable_execution_sdk_python_otel.plugin import DurableExecutionOtelPlugin + + +START_TIME = datetime(2024, 1, 2, 3, 4, 5, tzinfo=UTC) +END_TIME = datetime(2024, 1, 2, 3, 4, 6, tzinfo=UTC) +EXECUTION_ARN = "arn:aws:lambda:us-west-2:123456789012:function:workflow:$LATEST" + + +def _create_plugin() -> tuple[DurableExecutionOtelPlugin, InMemorySpanExporter]: + """Create a plugin wired to an in-memory span exporter.""" + exporter = InMemorySpanExporter() + trace_provider = TracerProvider() + trace_provider.add_span_processor(SimpleSpanProcessor(exporter)) + plugin = DurableExecutionOtelPlugin( + trace_provider=trace_provider, + context_extractor=lambda _: Context(), + ) + return plugin, exporter + + +def _invocation_start_info() -> InvocationStartInfo: + """Create standard invocation start info for tests.""" + return InvocationStartInfo( + request_id="request-1", + execution_arn=EXECUTION_ARN, + start_time=START_TIME, + is_first_invocation=True, + ) + + +def _invocation_end_info() -> InvocationEndInfo: + """Create standard invocation end info for tests.""" + return InvocationEndInfo( + request_id="request-1", + execution_arn=EXECUTION_ARN, + start_time=START_TIME, + is_first_invocation=True, + status=InvocationStatus.SUCCEEDED, + end_time=END_TIME, + error=None, + ) + + +def test_invocation_start_and_end_emit_invocation_span(): + """Verify invocation lifecycle callbacks create and finish the root span.""" + plugin, exporter = _create_plugin() + + plugin.on_invocation_start(_invocation_start_info()) + assert plugin._get_span(None) is not None + + plugin.on_invocation_end(_invocation_end_info()) + + spans = exporter.get_finished_spans() + assert [span.name for span in spans] == ["invocation"] + assert spans[0].attributes["durable.execution.arn"] == EXECUTION_ARN + assert plugin._get_span(None) is None + + +def test_operation_callbacks_emit_child_span_with_deterministic_span_id(): + """Verify non-user-function operations are traced beneath the invocation.""" + plugin, exporter = _create_plugin() + plugin.on_invocation_start(_invocation_start_info()) + operation_id = "wait-1" + + plugin.on_operation_start( + OperationStartInfo( + operation_id=operation_id, + operation_type=OperationType.WAIT, + sub_type=None, + name="wait-for-signal", + parent_id=None, + start_time=START_TIME, + ) + ) + plugin.on_operation_end( + OperationEndInfo( + operation_id=operation_id, + operation_type=OperationType.WAIT, + sub_type=None, + name="wait-for-signal", + parent_id=None, + start_time=START_TIME, + status=OperationStatus.SUCCEEDED, + end_time=END_TIME, + error=None, + ) + ) + plugin.on_invocation_end(_invocation_end_info()) + + spans_by_name = {span.name: span for span in exporter.get_finished_spans()} + wait_span = spans_by_name["wait-for-signal"] + invocation_span = spans_by_name["invocation"] + assert wait_span.context.span_id == operation_id_to_span_id(operation_id) + assert wait_span.parent.span_id == invocation_span.context.span_id + assert wait_span.attributes["durable.operation.id"] == operation_id + assert wait_span.attributes["durable.operation.type"] == OperationType.WAIT.value + + +def test_operation_end_without_start_emits_continuation_span_with_link(): + """Verify completed existing operations link to their logical operation span.""" + plugin, exporter = _create_plugin() + plugin.on_invocation_start(_invocation_start_info()) + operation_id = "wait-existing" + random_span_id = int("1234567890abcdef", 16) + plugin._id_generator._random_id_generator.generate_span_id = lambda: random_span_id + + plugin.on_operation_end( + OperationEndInfo( + operation_id=operation_id, + operation_type=OperationType.WAIT, + sub_type=None, + name="existing-wait", + parent_id=None, + start_time=START_TIME, + status=OperationStatus.SUCCEEDED, + end_time=END_TIME, + error=None, + ) + ) + + span = exporter.get_finished_spans()[0] + assert span.name == "existing-wait" + assert span.context.span_id == random_span_id + assert span.links[0].context.span_id == operation_id_to_span_id(operation_id) + + +def test_user_function_callbacks_emit_attempt_span_attributes(): + """Verify user-function end refreshes all extractable span attributes.""" + plugin, exporter = _create_plugin() + plugin.on_invocation_start(_invocation_start_info()) + operation_id = "step-1" + + start_info = UserFunctionStartInfo( + operation_id=operation_id, + operation_type=OperationType.STEP, + sub_type=None, + name="fetch-user", + parent_id=None, + start_time=START_TIME, + is_replay_children=False, + attempt=1, + ) + plugin.on_user_function_start(start_info) + active_span = plugin._get_span(operation_id) + assert active_span is not None + active_span.set_attributes( + { + "durable.operation.name": "stale-name", + "durable.attempt.number": 99, + } + ) + plugin.on_user_function_end( + UserFunctionEndInfo( + operation_id=operation_id, + operation_type=OperationType.STEP, + sub_type=None, + name="fetch-user", + parent_id=None, + start_time=START_TIME, + is_replay_children=False, + attempt=1, + outcome=UserFunctionOutcome.SUCCEEDED, + end_time=END_TIME, + error=None, + ) + ) + + span = exporter.get_finished_spans()[0] + assert span.name == "fetch-user" + assert span.context.span_id == operation_id_to_span_id(operation_id) + assert span.attributes["durable.execution.arn"] == EXECUTION_ARN + assert span.attributes["durable.operation.id"] == operation_id + assert span.attributes["durable.operation.type"] == OperationType.STEP.value + assert span.attributes["durable.operation.name"] == "fetch-user" + assert span.attributes["durable.attempt.number"] == 1 + assert ( + span.attributes["durable.attempt.outcome"] + == UserFunctionOutcome.SUCCEEDED.value + ) + + +def test_span_registry_helpers_can_be_called_from_multiple_threads(): + """Verify active span registry helpers are safe under concurrent access.""" + plugin, _ = _create_plugin() + + def update_span(index: int) -> None: + operation_id = f"operation-{index}" + plugin._set_span(operation_id, object()) # type: ignore[arg-type] + assert plugin._get_span(operation_id) is not None + plugin._delete_span(operation_id) + + with ThreadPoolExecutor(max_workers=8) as executor: + list(executor.map(update_span, range(100))) + + with plugin._operation_spans_lock: + assert plugin._operation_spans == {} diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/plugin.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/plugin.py index 1a7ecd27..0deff94b 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/plugin.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/plugin.py @@ -318,6 +318,7 @@ def on_operation_action(self, update: OperationUpdate): Args: update: the operation update that is checkpointed """ + # todo: this could be called more than once for step when it's retried if update.action is OperationAction.START: # we handle only START action here because on_operation_update may not be able to see a STARTED update # when START is checkpointed in batch with terminal status updates. @@ -330,7 +331,7 @@ def on_operation_action(self, update: OperationUpdate): parent_id=update.parent_id, start_time=datetime.datetime.now(datetime.UTC), ), - sync=False, + sync=True, ) def on_operation_update(self, operation: Operation | None): @@ -357,7 +358,7 @@ def on_operation_update(self, operation: Operation | None): status=operation.status, error=self._extract_error(operation), ), - sync=False, + sync=True, ) @staticmethod