From f8d4575ffa1a9e1117aa62e06157012f761af544 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Tue, 9 Jun 2026 14:27:49 -0500 Subject: [PATCH 01/15] feat(durabletask): host MAF workflows on a standalone Durable Task worker Add a host-agnostic workflow execution engine to agent-framework-durabletask so a MAF Workflow can run as a durable orchestration outside Azure Functions: - WorkflowOrchestrationContext protocol + DurableTaskWorkflowContext adapter, the superstep orchestrator, serialization helpers, capturing runner context, and the shared non-agent activity body (including the yield-output classifier so intermediate executors are not surfaced as final outputs). - DurableAIAgentWorker.configure_workflow auto-registers agent executors as entities, non-agent executors as activities, and the workflow orchestrator. - plan_workflow_registration centralizes the 'what to register' decision so it can be shared across hosts. - run_agent_coroutine runs all agent coroutines on one persistent event loop, fixing a cross-loop hang when shared chat clients/credentials bind their asyncio primitives to a dead loop. - DurableWorkflowClient (start/await workflow + HITL discover/respond); DurableAIAgentClient stays agent-only. --- .../agent_framework_durabletask/__init__.py | 18 + .../_async_bridge.py | 73 ++ .../agent_framework_durabletask/_worker.py | 119 ++- .../_workflow_activity.py | 179 ++++ .../_workflow_client.py | 176 ++++ .../_workflow_context.py | 143 ++++ .../_workflow_dt_context.py | 91 ++ .../_workflow_orchestrator.py | 786 ++++++++++++++++++ .../_workflow_registration.py | 69 ++ .../_workflow_runner_context.py | 147 ++++ .../_workflow_serialization.py | 183 ++++ .../tests/test_workflow_activity.py | 123 +++ 12 files changed, 2099 insertions(+), 8 deletions(-) create mode 100644 python/packages/durabletask/agent_framework_durabletask/_async_bridge.py create mode 100644 python/packages/durabletask/agent_framework_durabletask/_workflow_activity.py create mode 100644 python/packages/durabletask/agent_framework_durabletask/_workflow_client.py create mode 100644 python/packages/durabletask/agent_framework_durabletask/_workflow_context.py create mode 100644 python/packages/durabletask/agent_framework_durabletask/_workflow_dt_context.py create mode 100644 python/packages/durabletask/agent_framework_durabletask/_workflow_orchestrator.py create mode 100644 python/packages/durabletask/agent_framework_durabletask/_workflow_registration.py create mode 100644 python/packages/durabletask/agent_framework_durabletask/_workflow_runner_context.py create mode 100644 python/packages/durabletask/agent_framework_durabletask/_workflow_serialization.py create mode 100644 python/packages/durabletask/tests/test_workflow_activity.py diff --git a/python/packages/durabletask/agent_framework_durabletask/__init__.py b/python/packages/durabletask/agent_framework_durabletask/__init__.py index a518b5ad235..59872a2216b 100644 --- a/python/packages/durabletask/agent_framework_durabletask/__init__.py +++ b/python/packages/durabletask/agent_framework_durabletask/__init__.py @@ -4,6 +4,7 @@ import importlib.metadata +from ._async_bridge import run_agent_coroutine from ._callbacks import AgentCallbackContext, AgentResponseCallbackProtocol from ._client import DurableAIAgentClient from ._constants import ( @@ -50,6 +51,13 @@ from ._response_utils import ensure_response_format, load_agent_response from ._shim import DurableAIAgent from ._worker import DurableAIAgentWorker +from ._workflow_activity import execute_workflow_activity +from ._workflow_client import DurableWorkflowClient +from ._workflow_context import WorkflowOrchestrationContext +from ._workflow_dt_context import DurableTaskWorkflowContext +from ._workflow_orchestrator import WORKFLOW_ORCHESTRATOR_NAME, run_workflow_orchestrator +from ._workflow_registration import WorkflowRegistrationPlan, plan_workflow_registration +from ._workflow_runner_context import CapturingRunnerContext try: __version__ = importlib.metadata.version(__name__) @@ -67,12 +75,14 @@ "THREAD_ID_HEADER", "WAIT_FOR_RESPONSE_FIELD", "WAIT_FOR_RESPONSE_HEADER", + "WORKFLOW_ORCHESTRATOR_NAME", "AgentCallbackContext", "AgentEntity", "AgentEntityStateProviderMixin", "AgentResponseCallbackProtocol", "AgentSessionId", "ApiResponseFields", + "CapturingRunnerContext", "ContentTypes", "DurableAIAgent", "DurableAIAgentClient", @@ -101,8 +111,16 @@ "DurableAgentStateUsage", "DurableAgentStateUsageContent", "DurableStateFields", + "DurableTaskWorkflowContext", + "DurableWorkflowClient", "RunRequest", + "WorkflowOrchestrationContext", + "WorkflowRegistrationPlan", "__version__", "ensure_response_format", + "execute_workflow_activity", "load_agent_response", + "plan_workflow_registration", + "run_agent_coroutine", + "run_workflow_orchestrator", ] diff --git a/python/packages/durabletask/agent_framework_durabletask/_async_bridge.py b/python/packages/durabletask/agent_framework_durabletask/_async_bridge.py new file mode 100644 index 00000000000..72b1140d02c --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_async_bridge.py @@ -0,0 +1,73 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Persistent background event loop for running agent coroutines. + +Durable entity (and agent) handlers are invoked synchronously by the host on +arbitrary worker threads. Agent clients and their async credentials create +asyncio primitives (locks, connection pools, futures) that are bound to the +event loop on which they are *first* used. Running a later invocation on a +*different* event loop causes those primitives to await futures attached to a +now-idle loop, which results in a silent, permanent hang. + +This module provides a single, process-wide persistent event loop running on a +dedicated daemon thread. All agent coroutines are submitted to this loop via +``run_coroutine_threadsafe`` so shared async resources remain valid across +invocations regardless of which worker thread the host happens to use. +""" + +from __future__ import annotations + +import asyncio +import threading +from collections.abc import Coroutine +from typing import Any, TypeVar + +_T = TypeVar("_T") + +_loop: asyncio.AbstractEventLoop | None = None +_thread: threading.Thread | None = None +_lock = threading.Lock() + + +def _ensure_loop() -> asyncio.AbstractEventLoop: + """Return the shared persistent event loop, starting it on first use.""" + global _loop, _thread + + if _loop is not None and not _loop.is_closed(): + return _loop + + with _lock: + if _loop is not None and not _loop.is_closed(): + return _loop + + loop = asyncio.new_event_loop() + + def _run() -> None: + asyncio.set_event_loop(loop) + loop.run_forever() + + thread = threading.Thread(target=_run, name="dafx-agent-loop", daemon=True) + thread.start() + + _loop = loop + _thread = thread + return loop + + +def run_agent_coroutine(coro: Coroutine[Any, Any, _T]) -> _T: + """Run a coroutine on the shared persistent event loop and return its result. + + The calling (worker) thread blocks until the coroutine completes. Because + every agent coroutine runs on the same loop, async resources created by + shared agent clients/credentials (locks, connection pools) remain bound to a + live loop across all invocations, preventing cross-loop hangs. + + Args: + coro: The coroutine to execute. + + Returns: + The coroutine's result. + """ + loop = _ensure_loop() + future = asyncio.run_coroutine_threadsafe(coro, loop) + return future.result() diff --git a/python/packages/durabletask/agent_framework_durabletask/_worker.py b/python/packages/durabletask/agent_framework_durabletask/_worker.py index 728ae17629a..74a5c8401f4 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_worker.py +++ b/python/packages/durabletask/agent_framework_durabletask/_worker.py @@ -3,29 +3,47 @@ """Worker wrapper for Durable Task Agent Framework. This module provides the DurableAIAgentWorker class that wraps a durabletask worker -and enables registration of agents as durable entities. +and enables registration of agents as durable entities, and optionally workflows +as durable orchestrations with automatically generated activity functions. """ from __future__ import annotations -import asyncio +import json import logging from typing import Any -from agent_framework import SupportsAgentRun +from agent_framework import SupportsAgentRun, Workflow +from durabletask.task import ActivityContext, OrchestrationContext from durabletask.worker import TaskHubGrpcWorker +from ._async_bridge import run_agent_coroutine from ._callbacks import AgentResponseCallbackProtocol from ._entities import AgentEntity, DurableTaskEntityStateProvider +from ._workflow_activity import execute_workflow_activity +from ._workflow_dt_context import DurableTaskWorkflowContext +from ._workflow_orchestrator import WORKFLOW_ORCHESTRATOR_NAME, run_workflow_orchestrator +from ._workflow_registration import plan_workflow_registration logger = logging.getLogger("agent_framework.durabletask") class DurableAIAgentWorker: - """Wrapper for durabletask worker that enables agent registration. + """Wrapper for a durabletask worker that hosts agents and workflows. - This class wraps an existing TaskHubGrpcWorker instance and provides - a convenient interface for registering agents as durable entities. + This class wraps an existing TaskHubGrpcWorker instance and is the single + host-side registration surface for a worker process. It supports two + complementary kinds of work: + + - **Agents** via :meth:`add_agent`, which registers each agent as a durable entity. + - **Workflows** via :meth:`configure_workflow`, which registers a MAF + ``Workflow`` (its agent executors as entities, its non-agent executors as + activities, and the workflow orchestrator). + + A single worker process commonly hosts both, so registration is intentionally + aggregated on one object rather than split per kind. (On the *client* side the + surfaces are split into :class:`DurableAIAgentClient` and ``DurableWorkflowClient``, + because a caller invokes one or the other.) Example: ```python @@ -40,7 +58,7 @@ class DurableAIAgentWorker: # Wrap it with the agent worker agent_worker = DurableAIAgentWorker(worker) - # Register agents + # Register agents (or call configure_workflow(workflow) to host a workflow) client = OpenAIChatCompletionClient() my_agent = Agent(client=client, name="assistant") agent_worker.add_agent(my_agent) @@ -64,6 +82,7 @@ def __init__( self._worker = worker self._callback = callback self._registered_agents: dict[str, SupportsAgentRun] = {} + self._workflow: Workflow | None = None logger.debug("[DurableAIAgentWorker] Initialized with worker type: %s", type(worker).__name__) def add_agent( @@ -140,6 +159,87 @@ def registered_agent_names(self) -> list[str]: """ return list(self._registered_agents.keys()) + # ----------------------------------------------------------------- + # Workflow support + # ----------------------------------------------------------------- + + def configure_workflow( + self, + workflow: Workflow, + callback: AgentResponseCallbackProtocol | None = None, + ) -> None: + """Register a :class:`Workflow` for automatic orchestration. + + This extracts agents from the workflow and registers them as durable + entities, registers non-agent executors as activities, and creates an + orchestrator function that drives the workflow graph. + + Args: + workflow: The MAF :class:`Workflow` to register. + callback: Optional callback for agent response notifications. + """ + self._workflow = workflow + + # The "what to register" decision (agent -> entity, non-agent -> activity) + # is shared with the Azure Functions host via plan_workflow_registration. + plan = plan_workflow_registration(workflow) + + # Register agent executors as durable entities. + for agent in plan.agents: + if agent.name and agent.name not in self._registered_agents: + self.add_agent(agent, callback=callback) + + # Register non-agent executors as durable activities. + for executor in plan.activity_executors: + self._register_executor_activity(executor) + + # Register the workflow orchestrator. + self._register_workflow_orchestrator() + + logger.info( + "[DurableAIAgentWorker] Workflow configured with %d executors (%d agents, %d activities)", + len(workflow.executors), + len(plan.agents), + len(plan.activity_executors), + ) + + def _register_executor_activity(self, executor: Any) -> None: + """Register a non-agent executor as a durabletask activity.""" + captured_executor = executor + captured_workflow = self._workflow + activity_name = f"dafx-{executor.id}" + + def executor_activity(ctx: ActivityContext, input_data: str) -> str: + return execute_workflow_activity(captured_executor, input_data, captured_workflow) + + # Give the function the expected name for registration + executor_activity.__name__ = activity_name + executor_activity.__qualname__ = activity_name + + self._worker.add_activity(executor_activity) # type: ignore[arg-type] + logger.debug("[DurableAIAgentWorker] Registered activity: %s", activity_name) + + def _register_workflow_orchestrator(self) -> None: + """Register the workflow orchestrator function with the worker.""" + captured_workflow = self._workflow + + def workflow_orchestrator(context: OrchestrationContext, input_data: Any) -> Any: # type: ignore[type-arg] + if captured_workflow is None: + raise RuntimeError("Workflow not configured") + + initial_message = json.dumps(input_data) if isinstance(input_data, (dict, list)) else str(input_data) + shared_state: dict[str, Any] = {} + + dt_ctx = DurableTaskWorkflowContext(context) + outputs = yield from run_workflow_orchestrator(dt_ctx, captured_workflow, initial_message, shared_state) + return outputs # noqa: B901 + + workflow_orchestrator.__name__ = WORKFLOW_ORCHESTRATOR_NAME + workflow_orchestrator.__qualname__ = WORKFLOW_ORCHESTRATOR_NAME + + self._worker.add_orchestrator(workflow_orchestrator) # type: ignore[arg-type] + logger.debug("[DurableAIAgentWorker] Registered workflow orchestrator") + def __create_agent_entity( self, agent: SupportsAgentRun, @@ -187,7 +287,10 @@ def run(self, request: Any) -> Any: AgentResponse as dict """ logger.debug("[ConfiguredAgentEntity.run] Executing agent: %s", agent_name) - response = asyncio.run(self._agent_entity.run(request)) + # Run on the shared persistent loop so async resources created by + # shared agent clients/credentials stay bound to a live loop across + # successive entity invocations (avoids cross-loop hangs). + response = run_agent_coroutine(self._agent_entity.run(request)) return response.to_dict() def reset(self) -> None: diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflow_activity.py b/python/packages/durabletask/agent_framework_durabletask/_workflow_activity.py new file mode 100644 index 00000000000..15d6182c064 --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_workflow_activity.py @@ -0,0 +1,179 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Host-agnostic execution of non-agent workflow executors as durable activities. + +When a MAF :class:`Workflow` runs as a durable orchestration, each non-agent +executor is dispatched as a durable *activity*. The activity body is identical +regardless of host (Azure Functions or a standalone durabletask worker): it +deserializes the activity input, runs the executor (or a human-in-the-loop +response handler), diffs the shared state, and serializes the executor's +outputs, sent messages, shared-state changes, and any pending HITL requests back +to the orchestrator. + +This module provides that shared body as :func:`execute_workflow_activity` so +both host adapters call one implementation instead of duplicating it. +""" + +from __future__ import annotations + +import asyncio +import json +from copy import deepcopy +from typing import Any, cast + +from agent_framework import Executor, Workflow, WorkflowEvent +from agent_framework._workflows._runner_context import YieldOutputEventType +from agent_framework._workflows._state import State + +from ._workflow_orchestrator import ( + SOURCE_HITL_RESPONSE, + SOURCE_ORCHESTRATOR, + execute_hitl_response_handler, +) +from ._workflow_runner_context import CapturingRunnerContext +from ._workflow_serialization import deserialize_value, serialize_value + + +def execute_workflow_activity(executor: Executor, input_json: str, workflow: Workflow | None = None) -> str: + """Execute a single non-agent workflow executor and return its serialized result. + + This is the host-agnostic activity body shared by the Azure Functions and + standalone durabletask workflow hosts. + + Args: + executor: The non-agent executor instance to run. + input_json: JSON-encoded activity input with keys ``message``, + ``shared_state_snapshot``, and ``source_executor_ids``. + workflow: The owning workflow, used to classify the executor's + ``yield_output`` payloads as final ``output`` vs ``intermediate``. + When omitted, all yielded outputs are treated as final outputs. + + Returns: + A JSON string with keys ``sent_messages``, ``outputs``, + ``shared_state_updates``, ``shared_state_deletes``, and + ``pending_request_info_events``. + + Raises: + ValueError: If the input does not decode to a JSON object, or a HITL + message payload is not a JSON object. + """ + data_obj = json.loads(input_json) + if not isinstance(data_obj, dict): + raise ValueError("Activity input must decode to a JSON object") + data = cast(dict[str, Any], data_obj) + + message_data = data.get("message") + shared_state_snapshot = data.get("shared_state_snapshot", {}) + source_executor_ids = cast(list[str], data.get("source_executor_ids", [SOURCE_ORCHESTRATOR])) + + # Reconstruct the message - deserialize_value restores the original typed + # objects from the encoded data (with type markers). + message = deserialize_value(message_data) + + # A HITL response is identified by a source id starting with the HITL prefix. + is_hitl_response = any(s.startswith(SOURCE_HITL_RESPONSE) for s in source_executor_ids) + + def classify_yielded_output(executor_id: str) -> YieldOutputEventType | None: + # Mirror the core runner's classification so intermediate executors' + # yields are not surfaced as final workflow outputs. + if workflow is None: + return "output" + if workflow.is_terminal_executor(executor_id): + return "output" + if workflow.is_intermediate_executor(executor_id): + return "intermediate" + return None + + async def _run() -> dict[str, Any]: + runner_context = CapturingRunnerContext() + runner_context.set_yield_output_classifier(classify_yielded_output) + shared_state = State() + + # Deserialize shared state values to reconstruct dataclasses / Pydantic models. + deserialized_state: dict[str, Any] = { + str(k): deserialize_value(v) for k, v in shared_state_snapshot.items() + } + # Snapshot the deserialized (in-memory) state for diffing. State.export_state() + # returns the in-memory committed objects, so the snapshot must hold objects + # too (deepcopy) - comparing against a serialized snapshot would mark every + # key as changed. + original_snapshot = deepcopy(deserialized_state) + shared_state.import_state(deserialized_state) + + if is_hitl_response: + if not isinstance(message_data, dict): + raise ValueError("HITL message payload must be a JSON object") + await execute_hitl_response_handler( + executor=executor, + hitl_message=cast(dict[str, Any], message_data), + shared_state=shared_state, + runner_context=runner_context, + ) + else: + await executor.execute( + message=message, + source_executor_ids=source_executor_ids, + state=shared_state, + runner_context=runner_context, + ) + + # Commit pending state changes and compute the diff vs the original snapshot. + shared_state.commit() + current_state = shared_state.export_state() + original_keys: set[str] = set(original_snapshot.keys()) + current_keys: set[str] = set(current_state.keys()) + + # Deleted = was in original, not in current. + deletes: set[str] = original_keys - current_keys + + # Updates = keys that are new or whose value changed. + updates: dict[str, Any] = {} + for key in current_keys: + if key not in original_keys or current_state[key] != original_snapshot.get(key): + updates[key] = current_state[key] + + sent_messages = await runner_context.drain_messages() + events = await runner_context.drain_events() + + # Extract outputs from WorkflowEvent instances with type='output'. + outputs: list[Any] = [] + for event in events: + if isinstance(event, WorkflowEvent) and event.type == "output": + outputs.append(serialize_value(event.data)) + + # Serialize pending HITL request info events for the orchestrator. + pending_request_info_events = await runner_context.get_pending_request_info_events() + serialized_pending_requests: list[dict[str, Any]] = [] + for _request_id, event in pending_request_info_events.items(): + serialized_pending_requests.append({ + "request_id": event.request_id, + "source_executor_id": event.source_executor_id, + "data": serialize_value(event.data), + "request_type": f"{type(event.data).__module__}:{type(event.data).__name__}", + "response_type": f"{event.response_type.__module__}:{event.response_type.__name__}" + if event.response_type + else None, + }) + + # Serialize sent messages for JSON compatibility. + serialized_sent_messages: list[dict[str, Any]] = [] + for _source_id, msg_list in sent_messages.items(): + for msg in msg_list: + serialized_sent_messages.append({ + "message": serialize_value(msg.data), + "target_id": msg.target_id, + "source_id": msg.source_id, + }) + + serialized_updates = {k: serialize_value(v) for k, v in updates.items()} + + return { + "sent_messages": serialized_sent_messages, + "outputs": outputs, + "shared_state_updates": serialized_updates, + "shared_state_deletes": list(deletes), + "pending_request_info_events": serialized_pending_requests, + } + + result = asyncio.run(_run()) + return json.dumps(result) diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflow_client.py b/python/packages/durabletask/agent_framework_durabletask/_workflow_client.py new file mode 100644 index 00000000000..92e6a8b225d --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_workflow_client.py @@ -0,0 +1,176 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Workflow client wrapper for Durable Task Agent Framework. + +This module provides :class:`DurableWorkflowClient` for external clients to start, +await, and drive (including human-in-the-loop) workflows registered on a worker via +``DurableAIAgentWorker.configure_workflow``. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +from durabletask.client import TaskHubGrpcClient + +from ._workflow_orchestrator import WORKFLOW_ORCHESTRATOR_NAME +from ._workflow_serialization import strip_pickle_markers + +logger = logging.getLogger("agent_framework.durabletask") + + +class DurableWorkflowClient: + """Client wrapper for starting and driving durable workflows externally. + + This class wraps a durabletask ``TaskHubGrpcClient`` and provides a convenient + interface for the workflow registered by ``DurableAIAgentWorker.configure_workflow``: + starting it, awaiting its output, and responding to human-in-the-loop (HITL) pauses. + + For interacting with individual durable *agents*, use + :class:`~agent_framework_durabletask.DurableAIAgentClient` instead. Both wrap the + same underlying ``TaskHubGrpcClient``, so an application that needs both can + construct both over one client. + + Example: + ```python + from durabletask.azuremanaged.client import DurableTaskSchedulerClient + from agent_framework.azure import DurableWorkflowClient + + # Create the underlying client + client = DurableTaskSchedulerClient(host_address="localhost:8080", taskhub="default") + + # Wrap it with the workflow client + workflow_client = DurableWorkflowClient(client) + + # Start a workflow and wait for its output + instance_id = workflow_client.start_workflow(input="some input") + output = workflow_client.await_workflow_output(instance_id) + print(output) + ``` + """ + + def __init__(self, client: TaskHubGrpcClient): + """Initialize the workflow client wrapper. + + Args: + client: The durabletask client instance to wrap. + """ + self._client = client + logger.debug("[DurableWorkflowClient] Initialized with client type: %s", type(client).__name__) + + def start_workflow(self, input: Any = None, *, instance_id: str | None = None) -> str: + """Start the workflow orchestration registered by ``configure_workflow``. + + This schedules the orchestrator that ``DurableAIAgentWorker.configure_workflow`` + auto-registers, so callers do not need to know its internal name. + + Args: + input: The initial message/payload for the workflow. + instance_id: Optional explicit orchestration instance ID. If omitted, one + is generated. + + Returns: + The orchestration instance ID, for use with ``await_workflow_output``. + """ + new_instance_id = self._client.schedule_new_orchestration( + WORKFLOW_ORCHESTRATOR_NAME, + input=input, + instance_id=instance_id, + ) + logger.debug("[DurableWorkflowClient] Started workflow instance: %s", new_instance_id) + return new_instance_id + + def await_workflow_output(self, instance_id: str, *, timeout_seconds: int = 300) -> Any: + """Wait for a workflow orchestration to complete and return its output. + + Args: + instance_id: The instance ID returned by ``start_workflow``. + timeout_seconds: Maximum time, in seconds, to wait for completion. + + Returns: + The deserialized workflow output (typically a list of yielded outputs), + or ``None`` if the workflow produced no output. + + Raises: + TimeoutError: If the workflow does not complete within ``timeout_seconds``. + RuntimeError: If the workflow completes with a non-successful status. + """ + metadata = self._client.wait_for_orchestration_completion(instance_id, timeout=timeout_seconds) + if metadata is None: + raise TimeoutError(f"Workflow '{instance_id}' did not complete within {timeout_seconds}s") + + status = metadata.runtime_status.name + if status != "COMPLETED": + raise RuntimeError(f"Workflow '{instance_id}' ended with status {status}: {metadata.serialized_output}") + + if metadata.serialized_output is None: + return None + return json.loads(metadata.serialized_output) + + def get_pending_hitl_requests(self, instance_id: str) -> list[dict[str, Any]]: + """Return the workflow's pending human-in-the-loop (HITL) requests, if any. + + While a workflow is paused awaiting human input, the orchestrator records the + open requests in its custom status. This method reads and normalizes that + status so callers do not need to know its internal schema. + + Args: + instance_id: The workflow instance ID returned by ``start_workflow``. + + Returns: + A list of pending requests. Each entry contains ``request_id``, + ``source_executor_id``, ``data``, ``request_type``, and ``response_type``. + Empty if the workflow is not currently waiting for human input. + """ + state = self._client.get_orchestration_state(instance_id) + if state is None or not state.serialized_custom_status: + return [] + + try: + custom_status = json.loads(state.serialized_custom_status) + except (json.JSONDecodeError, TypeError): + return [] + + if not isinstance(custom_status, dict): + return [] + + pending = custom_status.get("pending_requests") + if not isinstance(pending, dict): + return [] + + requests: list[dict[str, Any]] = [] + for request_id, req_data in pending.items(): + if not isinstance(req_data, dict): + continue + requests.append({ + "request_id": req_data.get("request_id", request_id), + "source_executor_id": req_data.get("source_executor_id"), + "data": req_data.get("data"), + "request_type": req_data.get("request_type"), + "response_type": req_data.get("response_type"), + }) + return requests + + def send_hitl_response(self, instance_id: str, request_id: str, response: Any) -> None: + """Send a response to a pending HITL request, resuming the workflow. + + The orchestrator correlates the response by using ``request_id`` as the + external-event name, so callers do not need to know that convention. + + Args: + instance_id: The workflow instance ID. + request_id: The pending request's ID (from ``get_pending_hitl_requests``). + response: The response payload (e.g. a dict matching the expected + response type the executor's ``@response_handler`` expects). + + Note: + The payload is sanitized with ``strip_pickle_markers`` before delivery to + neutralize pickle-marker injection, since the worker deserializes it. + """ + safe_response = strip_pickle_markers(response) + self._client.raise_orchestration_event(instance_id, event_name=request_id, data=safe_response) + logger.debug( + "[DurableWorkflowClient] Sent HITL response for request %s on instance %s", request_id, instance_id + ) diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflow_context.py b/python/packages/durabletask/agent_framework_durabletask/_workflow_context.py new file mode 100644 index 00000000000..3d31cf90e27 --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_workflow_context.py @@ -0,0 +1,143 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Protocol definition for workflow orchestration contexts. + +This module defines the ``WorkflowOrchestrationContext`` protocol that abstracts +the differences between Azure Functions' ``DurableOrchestrationContext`` and the +standalone ``durabletask.task.OrchestrationContext``. The shared workflow +orchestrator (:func:`run_workflow_orchestrator`) programs against this protocol +so that the same orchestration logic works on any host. + +Each host provides a thin adapter that maps its native context to this protocol: + +- ``DurableTaskWorkflowContext`` (this package) — wraps ``OrchestrationContext`` +- ``AzureFunctionsWorkflowContext`` (azurefunctions package) — wraps + ``DurableOrchestrationContext`` +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any, Protocol, runtime_checkable + + +@runtime_checkable +class WorkflowOrchestrationContext(Protocol): + """Host-agnostic interface for workflow orchestration primitives. + + All methods that return yieldable tasks return ``Any`` because the concrete + task types differ between hosting SDKs (``TaskBase`` for Azure Functions, + ``Task[T]`` for durabletask). The generator-based orchestrator simply + yields these opaque objects back to the hosting framework. + """ + + @property + def instance_id(self) -> str: + """The unique ID of the current orchestration instance.""" + ... + + @property + def current_utc_datetime(self) -> datetime: + """The current replay-safe UTC datetime.""" + ... + + def prepare_agent_task(self, executor_id: str, message: str, orchestration_instance_id: str) -> Any: + """Create a yieldable task that runs an agent executor. + + Args: + executor_id: Agent name / executor ID. + message: The text message to send to the agent. + orchestration_instance_id: Instance ID used as the entity session key. + + Returns: + A yieldable task whose result is an ``AgentResponse``. + """ + ... + + def prepare_activity_task(self, activity_name: str, input_json: str) -> Any: + """Create a yieldable task that runs an activity executor. + + Args: + activity_name: The registered activity function name. + input_json: JSON-serialized activity input. + + Returns: + A yieldable task whose result is a JSON string. + """ + ... + + def task_all(self, tasks: list[Any]) -> Any: + """Create a yieldable composite task that completes when *all* tasks complete. + + Args: + tasks: List of yieldable tasks. + + Returns: + A yieldable task whose result is a list of individual results. + """ + ... + + def task_any(self, tasks: list[Any]) -> Any: + """Create a yieldable composite task that completes when *any* task completes. + + Args: + tasks: List of yieldable tasks. + + Returns: + A yieldable task whose result is the winning task. + """ + ... + + def wait_for_external_event(self, name: str) -> Any: + """Create a yieldable task that waits for a named external event. + + Args: + name: Event name to wait for. + + Returns: + A yieldable task whose result is the event payload. + """ + ... + + def create_timer(self, fire_at: datetime) -> Any: + """Create a yieldable timer task. + + Args: + fire_at: UTC datetime when the timer should fire. + + Returns: + A yieldable timer task. + """ + ... + + def set_custom_status(self, status: Any) -> None: + """Set the orchestration's custom status (visible to external clients). + + Args: + status: JSON-serializable status object. + """ + ... + + def new_uuid(self) -> str: + """Generate a replay-safe UUID.""" + ... + + def cancel_task(self, task: Any) -> None: + """Best-effort cancellation of a pending task. + + Args: + task: The task to cancel. If the underlying SDK does not support + cancellation this is a no-op. + """ + ... + + def get_task_result(self, task: Any) -> Any: + """Extract the result from a completed task. + + Args: + task: A completed task object. + + Returns: + The result value. + """ + ... diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflow_dt_context.py b/python/packages/durabletask/agent_framework_durabletask/_workflow_dt_context.py new file mode 100644 index 00000000000..f8695a71450 --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_workflow_dt_context.py @@ -0,0 +1,91 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""DurableTask SDK adapter for WorkflowOrchestrationContext. + +Wraps ``durabletask.task.OrchestrationContext`` to satisfy the +:class:`WorkflowOrchestrationContext` protocol. +""" + +from __future__ import annotations + +import logging +from datetime import datetime +from typing import Any + +from durabletask.task import OrchestrationContext, Task, when_all, when_any + +from ._executors import OrchestrationAgentExecutor +from ._models import AgentSessionId, DurableAgentSession +from ._shim import DurableAIAgent +from ._workflow_context import WorkflowOrchestrationContext + +logger = logging.getLogger(__name__) + + +class DurableTaskWorkflowContext: + """Adapter that maps ``OrchestrationContext`` to :class:`WorkflowOrchestrationContext`.""" + + def __init__(self, context: OrchestrationContext) -> None: + self._context = context + self._executor = OrchestrationAgentExecutor(context) + + # -- Properties ----------------------------------------------------------- + + @property + def instance_id(self) -> str: + return self._context.instance_id + + @property + def current_utc_datetime(self) -> datetime: + return self._context.current_utc_datetime + + # -- Agent / Activity dispatch -------------------------------------------- + + def prepare_agent_task(self, executor_id: str, message: str, orchestration_instance_id: str) -> Any: + session_id = AgentSessionId(name=executor_id, key=orchestration_instance_id) + session = DurableAgentSession(durable_session_id=session_id) + agent = DurableAIAgent(self._executor, executor_id) + return agent.run(message, session=session) + + def prepare_activity_task(self, activity_name: str, input_json: str) -> Any: + return self._context.call_activity(activity_name, input=input_json) + + # -- Composite tasks ------------------------------------------------------ + + def task_all(self, tasks: list[Any]) -> Any: + return when_all(tasks) + + def task_any(self, tasks: list[Any]) -> Any: + return when_any(tasks) + + # -- External events / timers --------------------------------------------- + + def wait_for_external_event(self, name: str) -> Any: + return self._context.wait_for_external_event(name) + + def create_timer(self, fire_at: datetime) -> Any: + return self._context.create_timer(fire_at) + + # -- Status / utility ----------------------------------------------------- + + def set_custom_status(self, status: Any) -> None: + self._context.set_custom_status(status) + + def new_uuid(self) -> str: + return self._context.new_uuid() + + def cancel_task(self, task: Any) -> None: + # durabletask Task doesn't expose cancel(); this is a best-effort no-op. + cancel_fn = getattr(task, "cancel", None) + if callable(cancel_fn): + cancel_fn() + + def get_task_result(self, task: Any) -> Any: + if isinstance(task, Task): + return task.get_result() + return getattr(task, "result", None) + + +# Ensure the adapter satisfies the protocol. Validated statically by the type +# checker (and at every ``run_workflow_orchestrator`` call site) with no runtime cost. +_protocol_check: type[WorkflowOrchestrationContext] = DurableTaskWorkflowContext diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflow_orchestrator.py b/python/packages/durabletask/agent_framework_durabletask/_workflow_orchestrator.py new file mode 100644 index 00000000000..2d7ecbe4a18 --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_workflow_orchestrator.py @@ -0,0 +1,786 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Host-agnostic workflow orchestration engine. + +This module provides the shared workflow orchestration logic that executes MAF +Workflows as durable task orchestrations. It programs against the +:class:`WorkflowOrchestrationContext` protocol so that the same code runs on +both Azure Functions and standalone durabletask hosts. + +Key components: + +* :func:`run_workflow_orchestrator` — main generator-based orchestrator +* Routing helpers (edge groups, fan-in, HITL) +* Result processing helpers + +All host-specific task creation (agent dispatch, activity dispatch, task_all / +task_any) is delegated to the ``WorkflowOrchestrationContext`` adapter. +""" + +from __future__ import annotations + +import json +import logging +from collections import defaultdict +from collections.abc import Generator +from dataclasses import dataclass +from datetime import timedelta +from enum import Enum +from typing import Any + +from agent_framework import ( + AgentExecutor, + AgentExecutorRequest, + AgentExecutorResponse, + AgentResponse, + Message, + Workflow, +) +from agent_framework._workflows._edge import ( + Edge, + EdgeGroup, + FanInEdgeGroup, + FanOutEdgeGroup, + SingleEdgeGroup, + SwitchCaseEdgeGroup, +) +from agent_framework._workflows._state import State + +from ._workflow_context import WorkflowOrchestrationContext +from ._workflow_serialization import ( + deserialize_value, + reconstruct_to_type, + resolve_type, + serialize_value, + strip_pickle_markers, +) + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Source Marker Constants +# ============================================================================ + +SOURCE_WORKFLOW_START = "__workflow_start__" +SOURCE_ORCHESTRATOR = "__orchestrator__" +SOURCE_HITL_RESPONSE = "__hitl_response__" + +# Name of the auto-generated orchestrator registered by +# ``DurableAIAgentWorker.configure_workflow`` (and the Azure Functions host). +# Standalone clients start a configured workflow by scheduling an orchestration +# with this name, e.g. +# ``client.schedule_new_orchestration(WORKFLOW_ORCHESTRATOR_NAME, input=...)``. +WORKFLOW_ORCHESTRATOR_NAME = "workflow_orchestrator" + + +# ============================================================================ +# Task Types and Data Structures +# ============================================================================ + + +class TaskType(Enum): + """Type of executor task.""" + + AGENT = "agent" + ACTIVITY = "activity" + + +@dataclass +class TaskMetadata: + """Metadata for a pending task.""" + + executor_id: str + message: Any + source_executor_id: str + task_type: TaskType + remaining_messages: list[tuple[str, Any, str]] | None = None + + +@dataclass +class ExecutorResult: + """Result from executing an agent or activity.""" + + executor_id: str + output_message: AgentExecutorResponse | None + activity_result: dict[str, Any] | None + task_type: TaskType + + +@dataclass +class PendingHITLRequest: + """Tracks a pending Human-in-the-Loop request.""" + + request_id: str + source_executor_id: str + request_data: Any + request_type: str | None + response_type: str | None + + +DEFAULT_HITL_TIMEOUT_HOURS = 72.0 + + +# ============================================================================ +# Routing Functions +# ============================================================================ + + +def _evaluate_edge_condition_sync(edge: Edge, message: Any) -> bool: + """Evaluate an edge's condition synchronously. + + Durable orchestrators use generators, not async/await, so we cannot call + async methods like ``edge.should_route()``. + """ + condition = edge._condition # pyright: ignore[reportPrivateUsage] + if condition is None: + return True + result = condition(message) + if hasattr(result, "__await__"): + import warnings + + warnings.warn( + f"Edge condition for {edge.source_id}->{edge.target_id} is async, " + "which is not supported in durable orchestrators. " + "The edge will be traversed unconditionally.", + RuntimeWarning, + stacklevel=2, + ) + return True + return bool(result) + + +def route_message_through_edge_groups( + edge_groups: list[EdgeGroup], + source_id: str, + message: Any, +) -> list[str]: + """Route a message through edge groups to find target executor IDs.""" + targets: list[str] = [] + + for group in edge_groups: + if source_id not in group.source_executor_ids: + continue + + if isinstance(group, (SwitchCaseEdgeGroup, FanOutEdgeGroup)): + if group.selection_func is not None: + selected = group.selection_func(message, group.target_executor_ids) + targets.extend(selected) + else: + targets.extend(group.target_executor_ids) + + elif isinstance(group, SingleEdgeGroup): + edge = group.edges[0] + if _evaluate_edge_condition_sync(edge, message): + targets.append(edge.target_id) + + elif isinstance(group, FanInEdgeGroup): + pass # Handled separately in the orchestrator loop + + else: + for edge in group.edges: + if edge.source_id == source_id and _evaluate_edge_condition_sync(edge, message): + targets.append(edge.target_id) + + return targets + + +def build_agent_executor_response( + executor_id: str, + response_text: str | None, + structured_response: dict[str, Any] | None, + previous_message: Any, +) -> AgentExecutorResponse: + """Build an AgentExecutorResponse from entity response data.""" + final_text: str = response_text or "" + if structured_response: + final_text = json.dumps(structured_response) + + assistant_message = Message(role="assistant", contents=[final_text]) + agent_response = AgentResponse(messages=[assistant_message]) + + full_conversation: list[Message] = [] + if isinstance(previous_message, AgentExecutorResponse) and previous_message.full_conversation: + full_conversation.extend(previous_message.full_conversation) + elif isinstance(previous_message, str): + full_conversation.append(Message(role="user", contents=[previous_message])) + full_conversation.append(assistant_message) + + return AgentExecutorResponse( + executor_id=executor_id, + agent_response=agent_response, + full_conversation=full_conversation, + ) + + +# ============================================================================ +# Task Preparation Helpers +# ============================================================================ + + +def _prepare_agent_task( + ctx: WorkflowOrchestrationContext, + executor_id: str, + message: Any, +) -> Any: + """Prepare an agent task for execution via the context adapter.""" + message_content = _extract_message_content(message) + return ctx.prepare_agent_task(executor_id, message_content, ctx.instance_id) + + +def _prepare_activity_task( + ctx: WorkflowOrchestrationContext, + executor_id: str, + message: Any, + source_executor_id: str, + shared_state_snapshot: dict[str, Any] | None, +) -> Any: + """Prepare an activity task for execution via the context adapter.""" + activity_input = { + "executor_id": executor_id, + "message": serialize_value(message), + "shared_state_snapshot": shared_state_snapshot, + "source_executor_ids": [source_executor_id], + } + activity_input_json = json.dumps(activity_input) + activity_name = f"dafx-{executor_id}" + return ctx.prepare_activity_task(activity_name, activity_input_json) + + +# ============================================================================ +# Result Processing Helpers +# ============================================================================ + + +def _process_agent_response( + agent_response: AgentResponse, + executor_id: str, + message: Any, +) -> ExecutorResult: + """Process an agent response into an ExecutorResult.""" + response_text = agent_response.text if agent_response else None + structured_response: dict[str, Any] | None = None + + if agent_response and agent_response.value is not None: + model_dump = getattr(agent_response.value, "model_dump", None) + if callable(model_dump): + dumped = model_dump() + if isinstance(dumped, dict): + structured_response = dumped # type: ignore[assignment] + elif isinstance(agent_response.value, dict): + structured_response = agent_response.value # type: ignore[assignment] + + output_message = build_agent_executor_response( + executor_id=executor_id, + response_text=response_text, + structured_response=structured_response, + previous_message=message, + ) + + return ExecutorResult( + executor_id=executor_id, + output_message=output_message, + activity_result=None, + task_type=TaskType.AGENT, + ) + + +def _process_activity_result( + result_json: str | None, + executor_id: str, + shared_state: dict[str, Any] | None, + workflow_outputs: list[Any], +) -> ExecutorResult: + """Process an activity result and apply shared state updates.""" + result = json.loads(result_json) if result_json else None + + if shared_state is not None and result: + if result.get("shared_state_updates"): + updates = result["shared_state_updates"] + logger.debug("[workflow] Applying SharedState updates from %s: %s", executor_id, updates) + shared_state.update(updates) + if result.get("shared_state_deletes"): + deletes = result["shared_state_deletes"] + logger.debug("[workflow] Applying SharedState deletes from %s: %s", executor_id, deletes) + for key in deletes: + shared_state.pop(key, None) + + if result and result.get("outputs"): + workflow_outputs.extend(result["outputs"]) + + return ExecutorResult( + executor_id=executor_id, + output_message=None, + activity_result=result, + task_type=TaskType.ACTIVITY, + ) + + +# ============================================================================ +# Routing Helpers +# ============================================================================ + + +def _route_result_messages( + result: ExecutorResult, + workflow: Workflow, + next_pending_messages: dict[str, list[tuple[Any, str]]], + fan_in_pending: dict[str, dict[str, list[tuple[Any, str]]]], +) -> None: + """Route messages from an executor result to their targets.""" + executor_id = result.executor_id + messages_to_route: list[tuple[Any, str | None]] = [] + + if result.output_message: + messages_to_route.append((result.output_message, None)) + + if result.activity_result and result.activity_result.get("sent_messages"): + for msg_data in result.activity_result["sent_messages"]: + sent_msg = msg_data.get("message") + target_id = msg_data.get("target_id") + if sent_msg: + sent_msg = deserialize_value(sent_msg) + messages_to_route.append((sent_msg, target_id)) + + for msg_to_route, explicit_target in messages_to_route: + logger.debug("Routing output from %s", executor_id) + + if explicit_target: + if explicit_target not in next_pending_messages: + next_pending_messages[explicit_target] = [] + next_pending_messages[explicit_target].append((msg_to_route, executor_id)) + logger.debug("Routed message from %s to explicit target %s", executor_id, explicit_target) + continue + + for group in workflow.edge_groups: + if isinstance(group, FanInEdgeGroup) and executor_id in group.source_executor_ids: + fan_in_pending[group.id][executor_id].append((msg_to_route, executor_id)) + logger.debug("Accumulated message for FanIn group %s from %s", group.id, executor_id) + + targets = route_message_through_edge_groups(workflow.edge_groups, executor_id, msg_to_route) + + for target_id in targets: + logger.debug("Routing to %s", target_id) + if target_id not in next_pending_messages: + next_pending_messages[target_id] = [] + next_pending_messages[target_id].append((msg_to_route, executor_id)) + + +def _check_fan_in_ready( + workflow: Workflow, + fan_in_pending: dict[str, dict[str, list[tuple[Any, str]]]], + next_pending_messages: dict[str, list[tuple[Any, str]]], +) -> None: + """Check if any FanInEdgeGroups are ready and deliver their messages.""" + for group in workflow.edge_groups: + if not isinstance(group, FanInEdgeGroup): + continue + + pending_sources = fan_in_pending.get(group.id, {}) + + if not all(src in pending_sources and pending_sources[src] for src in group.source_executor_ids): + continue + + aggregated: list[Any] = [] + aggregated_sources: list[str] = [] + for src in group.source_executor_ids: + for msg, msg_source in pending_sources[src]: + aggregated.append(msg) + aggregated_sources.append(msg_source) + + target_id = group.target_executor_ids[0] + logger.debug("FanIn group %s ready, delivering %d messages to %s", group.id, len(aggregated), target_id) + + if target_id not in next_pending_messages: + next_pending_messages[target_id] = [] + + first_source = aggregated_sources[0] if aggregated_sources else "__fan_in__" + next_pending_messages[target_id].append((aggregated, first_source)) + + fan_in_pending[group.id] = defaultdict(list) + + +# ============================================================================ +# HITL Helpers +# ============================================================================ + + +def _collect_hitl_requests( + result: ExecutorResult, + pending_hitl_requests: dict[str, PendingHITLRequest], +) -> None: + """Collect pending HITL requests from an activity result.""" + if result.activity_result and result.activity_result.get("pending_request_info_events"): + for req_data in result.activity_result["pending_request_info_events"]: + request_id = req_data.get("request_id") + if request_id: + pending_hitl_requests[request_id] = PendingHITLRequest( + request_id=request_id, + source_executor_id=req_data.get("source_executor_id", result.executor_id), + request_data=req_data.get("data"), + request_type=req_data.get("request_type"), + response_type=req_data.get("response_type"), + ) + logger.debug( + "Collected HITL request %s from executor %s", + request_id, + result.executor_id, + ) + + +def _route_hitl_response( + hitl_request: PendingHITLRequest, + raw_response: Any, + pending_messages: dict[str, list[tuple[Any, str]]], +) -> None: + """Route a HITL response back to the source executor's @response_handler.""" + response_message = { + "request_id": hitl_request.request_id, + "original_request": hitl_request.request_data, + "response": raw_response, + "response_type": hitl_request.response_type, + } + + target_id = hitl_request.source_executor_id + if target_id not in pending_messages: + pending_messages[target_id] = [] + + source_id = f"{SOURCE_HITL_RESPONSE}_{hitl_request.request_id}" + pending_messages[target_id].append((response_message, source_id)) + + logger.debug( + "Routed HITL response for request %s to executor %s", + hitl_request.request_id, + target_id, + ) + + +# ============================================================================ +# Message Content Extraction +# ============================================================================ + + +def _extract_message_content(message: Any) -> str: + """Extract text content from various message types.""" + message_content = "" + if isinstance(message, AgentExecutorResponse) and message.agent_response: + if message.agent_response.text: + message_content = message.agent_response.text + elif message.agent_response.messages: + message_content = message.agent_response.messages[-1].text or "" + elif isinstance(message, AgentExecutorRequest) and message.messages: + message_content = message.messages[-1].text or "" + elif isinstance(message, dict): + key_names = list(message.keys()) # type: ignore[union-attr] + logger.warning("Unexpected dict message in _extract_message_content. Keys: %s", key_names) # type: ignore + elif isinstance(message, str): + message_content = message + return message_content + + +# ============================================================================ +# HITL Response Handler Execution +# ============================================================================ + + +async def execute_hitl_response_handler( + executor: Any, + hitl_message: dict[str, Any], + shared_state: State, + runner_context: Any, +) -> None: + """Execute a HITL response handler on an executor. + + Args: + executor: The executor instance that has a @response_handler. + hitl_message: The HITL response message dict. + shared_state: The shared state for the workflow context. + runner_context: The runner context for capturing outputs. + """ + from agent_framework._workflows._workflow_context import WorkflowContext + + original_request_data = hitl_message.get("original_request") + response_data = hitl_message.get("response") + response_type_str = hitl_message.get("response_type") + + original_request = deserialize_value(original_request_data) + response = _deserialize_hitl_response(response_data, response_type_str) + + handler = executor._find_response_handler(original_request, response) # pyright: ignore[reportPrivateUsage] + + if handler is None: + logger.warning( + "No response handler found for HITL response in executor %s. Request type: %s, Response type: %s", + executor.id, + type(original_request).__name__, + type(response).__name__, + ) + return + + ctx = WorkflowContext( + executor=executor, + source_executor_ids=[SOURCE_HITL_RESPONSE], + runner_context=runner_context, + state=shared_state, + ) + + logger.debug( + "Invoking response handler for HITL request in executor %s", + executor.id, + ) + await handler(response, ctx) + + +def _deserialize_hitl_response(response_data: Any, response_type_str: str | None) -> Any: + """Deserialize a HITL response to its expected type.""" + logger.debug( + "Deserializing HITL response. response_type_str=%s, response_data type=%s", + response_type_str, + type(response_data).__name__, + ) + + if response_data is None: + return None + + response_data = strip_pickle_markers(response_data) + if response_data is None: + return None + + if not isinstance(response_data, dict): + logger.debug("Response data is not a dict, returning as-is: %s", type(response_data).__name__) + return response_data + + if response_type_str: + response_type = resolve_type(response_type_str) + if response_type: + logger.debug("Found response type %s, attempting reconstruction", response_type) + result = reconstruct_to_type(response_data, response_type) + logger.debug("Reconstructed response type: %s", type(result).__name__) + return result + logger.warning("Could not resolve response type: %s", response_type_str) + + logger.debug("No type hint; returning sanitized data as-is") + return response_data # type: ignore[reportUnknownVariableType] + + +# ============================================================================ +# Task Preparation (All Tasks) +# ============================================================================ + + +def _prepare_all_tasks( + ctx: WorkflowOrchestrationContext, + workflow: Workflow, + pending_messages: dict[str, list[tuple[Any, str]]], + shared_state: dict[str, Any] | None, +) -> tuple[list[Any], list[TaskMetadata], list[tuple[str, Any, str]]]: + """Prepare all pending tasks for parallel execution. + + Groups agent messages by executor ID so that only the first message per agent + runs in the parallel batch. Additional messages to the same agent are returned + for sequential processing. + """ + all_tasks: list[Any] = [] + task_metadata_list: list[TaskMetadata] = [] + remaining_agent_messages: list[tuple[str, Any, str]] = [] + + agent_messages_by_executor: dict[str, list[tuple[str, Any, str]]] = defaultdict(list) + + for executor_id, messages_with_sources in pending_messages.items(): + executor = workflow.executors[executor_id] + is_agent = isinstance(executor, AgentExecutor) + + for message, source_executor_id in messages_with_sources: + if is_agent: + agent_messages_by_executor[executor_id].append((executor_id, message, source_executor_id)) + else: + logger.debug("Preparing activity task: %s", executor_id) + task = _prepare_activity_task(ctx, executor_id, message, source_executor_id, shared_state) + all_tasks.append(task) + task_metadata_list.append( + TaskMetadata( + executor_id=executor_id, + message=message, + source_executor_id=source_executor_id, + task_type=TaskType.ACTIVITY, + ) + ) + + for executor_id, messages_list in agent_messages_by_executor.items(): + first_msg = messages_list[0] + remaining = messages_list[1:] + + logger.debug("Preparing agent task: %s", executor_id) + task = _prepare_agent_task(ctx, first_msg[0], first_msg[1]) + all_tasks.append(task) + task_metadata_list.append( + TaskMetadata( + executor_id=first_msg[0], + message=first_msg[1], + source_executor_id=first_msg[2], + task_type=TaskType.AGENT, + ) + ) + + remaining_agent_messages.extend(remaining) + + return all_tasks, task_metadata_list, remaining_agent_messages + + +# ============================================================================ +# Main Orchestrator +# ============================================================================ + + +def run_workflow_orchestrator( + ctx: WorkflowOrchestrationContext, + workflow: Workflow, + initial_message: Any, + shared_state: dict[str, Any] | None = None, + hitl_timeout_hours: float = DEFAULT_HITL_TIMEOUT_HOURS, +) -> Generator[Any, Any, list[Any]]: + """Traverse and execute the workflow graph as a durable orchestration. + + This is a generator-based orchestrator that works with any host by + programming against the :class:`WorkflowOrchestrationContext` protocol. + + Supports: + - SingleEdgeGroup: Direct 1:1 routing with optional condition + - SwitchCaseEdgeGroup: First matching condition wins + - FanOutEdgeGroup: Broadcast to multiple targets (parallel execution) + - FanInEdgeGroup: Aggregates messages from multiple sources + - SharedState: Cross-executor state sharing (local to orchestration) + - HITL: Human-in-the-loop via request_info / @response_handler + + Args: + ctx: Host-specific orchestration context adapter. + workflow: The MAF Workflow instance to execute. + initial_message: Initial message to send to the start executor. + shared_state: Optional dict for cross-executor state sharing. + hitl_timeout_hours: Timeout in hours for HITL requests. + + Returns: + List of workflow outputs collected from executor activities. + """ + pending_messages: dict[str, list[tuple[Any, str]]] = { + workflow.start_executor_id: [(initial_message, SOURCE_WORKFLOW_START)] + } + workflow_outputs: list[Any] = [] + iteration = 0 + + fan_in_pending: dict[str, dict[str, list[tuple[Any, str]]]] = { + group.id: defaultdict(list) for group in workflow.edge_groups if isinstance(group, FanInEdgeGroup) + } + + pending_hitl_requests: dict[str, PendingHITLRequest] = {} + + while pending_messages and iteration < workflow.max_iterations: + logger.debug("Orchestrator iteration %d", iteration) + next_pending_messages: dict[str, list[tuple[Any, str]]] = {} + + # Phase 1: Prepare all tasks + all_tasks, task_metadata_list, remaining_agent_messages = _prepare_all_tasks( + ctx, workflow, pending_messages, shared_state + ) + + # Phase 2: Execute all tasks in parallel + all_results: list[ExecutorResult] = [] + if all_tasks: + logger.debug("Executing %d tasks in parallel (agents + activities)", len(all_tasks)) + raw_results = yield ctx.task_all(all_tasks) + logger.debug("All %d tasks completed", len(all_tasks)) + + for idx, raw_result in enumerate(raw_results): + metadata = task_metadata_list[idx] + if metadata.task_type == TaskType.AGENT: + result = _process_agent_response(raw_result, metadata.executor_id, metadata.message) + else: + result = _process_activity_result(raw_result, metadata.executor_id, shared_state, workflow_outputs) + all_results.append(result) + + # Phase 3: Process sequential agent messages + for executor_id, message, _source_executor_id in remaining_agent_messages: + logger.debug("Processing sequential message for agent: %s", executor_id) + task = _prepare_agent_task(ctx, executor_id, message) + agent_response: AgentResponse = yield task + logger.debug("Agent %s sequential response completed", executor_id) + + result = _process_agent_response(agent_response, executor_id, message) + all_results.append(result) + + # Phase 4: Collect HITL requests + for result in all_results: + _collect_hitl_requests(result, pending_hitl_requests) + + # Phase 5: Route results + for result in all_results: + _route_result_messages(result, workflow, next_pending_messages, fan_in_pending) + + # Phase 6: Check fan-in readiness + _check_fan_in_ready(workflow, fan_in_pending, next_pending_messages) + + pending_messages = next_pending_messages + + # Phase 7: HITL wait + if not pending_messages and pending_hitl_requests: + logger.debug("Workflow paused for HITL - %d pending requests", len(pending_hitl_requests)) + + ctx.set_custom_status({ + "state": "waiting_for_human_input", + "pending_requests": { + req_id: { + "request_id": req.request_id, + "source_executor_id": req.source_executor_id, + "data": req.request_data, + "request_type": req.request_type, + "response_type": req.response_type, + } + for req_id, req in pending_hitl_requests.items() + }, + }) + + for request_id, hitl_request in list(pending_hitl_requests.items()): + logger.debug("Waiting for HITL response for request: %s", request_id) + + approval_task = ctx.wait_for_external_event(request_id) + timeout_task = ctx.create_timer(ctx.current_utc_datetime + timedelta(hours=hitl_timeout_hours)) + + winner = yield ctx.task_any([approval_task, timeout_task]) + + if winner == approval_task: + ctx.cancel_task(timeout_task) + + raw_response = ctx.get_task_result(approval_task) + logger.debug( + "Received HITL response for request %s. Type: %s, Value: %s", + request_id, + type(raw_response).__name__, + raw_response, + ) + + if isinstance(raw_response, str): + try: + raw_response = json.loads(raw_response) + logger.debug("Parsed JSON string response to: %s", type(raw_response).__name__) + except (json.JSONDecodeError, TypeError): + logger.debug("Response is not JSON, keeping as string") + + del pending_hitl_requests[request_id] + + _route_hitl_response( + hitl_request, + raw_response, + pending_messages, + ) + else: + ctx.cancel_task(approval_task) + logger.warning("HITL request %s timed out after %s hours", request_id, hitl_timeout_hours) + raise TimeoutError( + f"Human-in-the-loop request '{request_id}' timed out after {hitl_timeout_hours} hours." + ) + + ctx.set_custom_status({"state": "running"}) + + iteration += 1 + + return workflow_outputs # noqa: B901 diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflow_registration.py b/python/packages/durabletask/agent_framework_durabletask/_workflow_registration.py new file mode 100644 index 00000000000..032ef022965 --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_workflow_registration.py @@ -0,0 +1,69 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Host-agnostic plan for registering a MAF Workflow as a durable orchestration. + +A MAF :class:`Workflow` is hosted by turning each graph node into a durable +primitive: + +- each :class:`AgentExecutor` becomes a durable **entity**, and +- each other :class:`Executor` becomes a durable **activity**, + +driven by a single workflow **orchestrator**. + +The *decision* of which executor maps to which primitive is identical on every +host (Azure Functions or a standalone durabletask worker); only the *mechanism* +for registering them differs (Functions trigger decorators vs. +``worker.add_*``). :func:`plan_workflow_registration` captures the shared +decision so each host applies one consistent plan with its own registration +mechanism — analogous to .NET's shared ``DurableWorkflowOptions`` feeding +host-specific trigger generation. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +from agent_framework import AgentExecutor, Executor, SupportsAgentRun, Workflow + +from ._workflow_orchestrator import WORKFLOW_ORCHESTRATOR_NAME + + +@dataclass +class WorkflowRegistrationPlan: + """The durable primitives a workflow registers, independent of host. + + Attributes: + agents: Agents (from agent executors) to register as durable entities. + activity_executors: Non-agent executors to register as durable activities. + orchestrator_name: The orchestrator name to register and to start runs with. + """ + + agents: list[SupportsAgentRun] + activity_executors: list[Executor] + orchestrator_name: str + + +def plan_workflow_registration(workflow: Workflow) -> WorkflowRegistrationPlan: + """Classify a workflow's executors into the durable primitives to register. + + Args: + workflow: The MAF :class:`Workflow` to host. + + Returns: + A :class:`WorkflowRegistrationPlan` describing the agents (entities), + non-agent executors (activities), and the orchestrator name. + """ + agents: list[SupportsAgentRun] = [] + activity_executors: list[Executor] = [] + + for executor in workflow.executors.values(): + if isinstance(executor, AgentExecutor): + agents.append(executor.agent) + else: + activity_executors.append(executor) + + return WorkflowRegistrationPlan( + agents=agents, + activity_executors=activity_executors, + orchestrator_name=WORKFLOW_ORCHESTRATOR_NAME, + ) diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflow_runner_context.py b/python/packages/durabletask/agent_framework_durabletask/_workflow_runner_context.py new file mode 100644 index 00000000000..1851339bb37 --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_workflow_runner_context.py @@ -0,0 +1,147 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Runner context for activity execution within durable orchestrations. + +This module provides the :class:`CapturingRunnerContext` class that captures +messages and events produced during executor execution within activities. +It is host-agnostic and works on any durable task host. +""" + +from __future__ import annotations + +import asyncio +from copy import copy +from typing import Any + +from agent_framework import ( + CheckpointStorage, + RunnerContext, + WorkflowCheckpoint, + WorkflowEvent, + WorkflowMessage, +) +from agent_framework._workflows._runner_context import YieldOutputClassifier, YieldOutputEventType +from agent_framework._workflows._state import State + + +class CapturingRunnerContext(RunnerContext): + """A RunnerContext that captures messages and events for durable activities. + + This context captures all messages and events produced during execution + without requiring durable entity storage, allowing the results to be + returned to the orchestrator. + + Checkpointing is not supported — the orchestrator manages state. + """ + + def __init__(self) -> None: + self._messages: dict[str, list[WorkflowMessage]] = {} + self._event_queue: asyncio.Queue[WorkflowEvent] = asyncio.Queue() + self._pending_request_info_events: dict[str, WorkflowEvent[Any]] = {} + self._workflow_id: str | None = None + self._streaming: bool = False + self._yield_output_classifier: YieldOutputClassifier = lambda _executor_id: "output" + + # -- Messaging ------------------------------------------------------------ + + async def send_message(self, message: WorkflowMessage) -> None: + self._messages.setdefault(message.source_id, []) + self._messages[message.source_id].append(message) + + async def drain_messages(self) -> dict[str, list[WorkflowMessage]]: + messages = copy(self._messages) + self._messages.clear() + return messages + + async def has_messages(self) -> bool: + return bool(self._messages) + + # -- Events --------------------------------------------------------------- + + async def add_event(self, event: WorkflowEvent) -> None: + await self._event_queue.put(event) + + async def drain_events(self) -> list[WorkflowEvent]: + events: list[WorkflowEvent] = [] + while True: + try: + events.append(self._event_queue.get_nowait()) + except asyncio.QueueEmpty: + break + return events + + async def has_events(self) -> bool: + return not self._event_queue.empty() + + async def next_event(self) -> WorkflowEvent: + return await self._event_queue.get() + + # -- Checkpointing (not supported) ---------------------------------------- + + def has_checkpointing(self) -> bool: + return False + + def set_runtime_checkpoint_storage(self, storage: CheckpointStorage) -> None: + pass + + def clear_runtime_checkpoint_storage(self) -> None: + pass + + async def create_checkpoint( + self, + workflow_name: str, + graph_signature_hash: str, + state: State, + previous_checkpoint_id: str | None, + iteration_count: int, + metadata: dict[str, Any] | None = None, + ) -> str: + raise NotImplementedError("Checkpointing is not supported in activity context") + + async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None: + raise NotImplementedError("Checkpointing is not supported in activity context") + + async def apply_checkpoint(self, checkpoint: WorkflowCheckpoint) -> None: + raise NotImplementedError("Checkpointing is not supported in activity context") + + # -- Workflow configuration ----------------------------------------------- + + def set_workflow_id(self, workflow_id: str) -> None: + self._workflow_id = workflow_id + + def reset_for_new_run(self) -> None: + self._messages.clear() + self._event_queue = asyncio.Queue() + self._pending_request_info_events.clear() + self._streaming = False + + def set_streaming(self, streaming: bool) -> None: + self._streaming = streaming + + def is_streaming(self) -> bool: + return self._streaming + + # -- Yield-output classification ------------------------------------------- + + def set_yield_output_classifier(self, classifier: YieldOutputClassifier) -> None: + """Set the classifier used by ``WorkflowContext.yield_output()``.""" + self._yield_output_classifier = classifier + + def classify_yielded_output(self, executor_id: str) -> YieldOutputEventType | None: + """Classify an executor's yield_output payload as output, intermediate, or hidden.""" + return self._yield_output_classifier(executor_id) + + # -- Request Info Events -------------------------------------------------- + + async def add_request_info_event(self, event: WorkflowEvent[Any]) -> None: + self._pending_request_info_events[event.request_id] = event + await self.add_event(event) + + async def send_request_info_response(self, request_id: str, response: Any) -> None: + raise NotImplementedError( + "send_request_info_response is not supported in activity context. " + "Human-in-the-loop scenarios should be handled at the orchestrator level." + ) + + async def get_pending_request_info_events(self) -> dict[str, WorkflowEvent[Any]]: + return dict(self._pending_request_info_events) diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflow_serialization.py b/python/packages/durabletask/agent_framework_durabletask/_workflow_serialization.py new file mode 100644 index 00000000000..b0d5e647875 --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_workflow_serialization.py @@ -0,0 +1,183 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Serialization utilities for workflow execution. + +This module provides thin wrappers around the core checkpoint encoding system +(encode_checkpoint_value / decode_checkpoint_value) from agent_framework._workflows. + +The core checkpoint encoding uses pickle + base64 for type-safe roundtripping of +arbitrary Python objects (dataclasses, Pydantic models, Message, etc.) while +keeping JSON-native types (str, int, float, bool, None) as-is. + +This module adds: +- serialize_value / deserialize_value: convenience aliases for encode/decode +- reconstruct_to_type: for HITL responses where external data (without type markers) + needs to be reconstructed to a known type +- resolve_type: resolves 'module:class' type keys to Python types +""" + +from __future__ import annotations + +import importlib +import logging +from contextlib import suppress +from dataclasses import is_dataclass +from typing import Any, cast + +from agent_framework._workflows._checkpoint_encoding import ( + _PICKLE_MARKER, # pyright: ignore[reportPrivateUsage] + _TYPE_MARKER, # pyright: ignore[reportPrivateUsage] + decode_checkpoint_value, + encode_checkpoint_value, +) +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + + +def resolve_type(type_key: str) -> type | None: + """Resolve a 'module:class' type key to its Python type. + + Args: + type_key: Fully qualified type reference in 'module_name:class_name' format. + + Returns: + The resolved type, or None if resolution fails. + """ + try: + module_name, class_name = type_key.split(":", 1) + module = importlib.import_module(module_name) + return getattr(module, class_name, None) + except Exception: + logger.debug("Could not resolve type %s", type_key) + return None + + +# ============================================================================ +# Pickle marker sanitization (security) +# ============================================================================ + + +def strip_pickle_markers(data: Any) -> Any: + """Recursively strip pickle/type markers from untrusted data. + + The core checkpoint encoding uses ``__pickled__`` and ``__type__`` markers to + roundtrip arbitrary Python objects via *pickle*. If an attacker crafts an + HTTP payload that contains these markers, the data would flow into + ``pickle.loads()`` and enable **arbitrary code execution**. + + This function walks the incoming data structure and replaces any ``dict`` + that contains either marker key with ``None``, neutralising the attack + vector while leaving all other data untouched. + + It **must** be called on every value that originates from an untrusted + source (e.g. ``req.get_json()``) *before* the value is passed to + ``deserialize_value`` / ``decode_checkpoint_value``. + """ + if isinstance(data, dict): + if _PICKLE_MARKER in data or _TYPE_MARKER in data: + logger.debug("Stripped pickle/type markers from untrusted input.") + return None + typed_dict = cast(dict[str, Any], data) + return {k: strip_pickle_markers(v) for k, v in typed_dict.items()} + + if isinstance(data, list): + typed_list = cast(list[Any], data) # type: ignore[redundant-cast] + return [strip_pickle_markers(item) for item in typed_list] + + return data + + +# ============================================================================ +# Serialize / Deserialize +# ============================================================================ + + +def serialize_value(value: Any) -> Any: + """Serialize a value for JSON-compatible cross-activity communication. + + Delegates to core checkpoint encoding which uses pickle + base64 for + non-JSON-native types (dataclasses, Pydantic models, Message, etc.). + + Args: + value: Any Python value (primitive, dataclass, Pydantic model, Message, etc.) + + Returns: + A JSON-serializable representation with embedded type metadata for reconstruction. + """ + return encode_checkpoint_value(value) + + +def deserialize_value(value: Any) -> Any: + """Deserialize a value previously serialized with serialize_value(). + + Delegates to core checkpoint decoding which unpickles base64-encoded values + and verifies type integrity. + + Args: + value: The serialized data (dict with pickle markers, list, or primitive) + + Returns: + Reconstructed typed object if type metadata found, otherwise original value. + """ + return decode_checkpoint_value(value) + + +# ============================================================================ +# HITL Type Reconstruction +# ============================================================================ + + +def reconstruct_to_type(value: Any, target_type: type) -> Any: + """Reconstruct a value to a known target type. + + Used for HITL responses where external data (without checkpoint type markers) + needs to be reconstructed to a specific type determined by the response_type hint. + + Tries strategies in order: + 1. Return as-is if already the correct type + 2. deserialize_value (for data with any type markers) + 3. Pydantic model_validate (for Pydantic models) + 4. Dataclass constructor (for dataclasses) + + Args: + value: The value to reconstruct (typically a dict from JSON) + target_type: The expected type to reconstruct to + + Returns: + Reconstructed value if possible, otherwise the original value + """ + if value is None: + return None + + with suppress(TypeError): + if isinstance(value, target_type): + return value + + if not isinstance(value, dict): + return value + + # Try decoding if data has pickle markers (from checkpoint encoding). + # NOTE: This function is general-purpose. Callers that handle untrusted + # data (e.g. HITL responses) MUST call strip_pickle_markers() before + # passing data here. See _deserialize_hitl_response in _workflow_orchestrator.py. + decoded = deserialize_value(value) + if not isinstance(decoded, dict): + return decoded + + # Try Pydantic model validation (for unmarked dicts, e.g., external HITL data) + if issubclass(target_type, BaseModel): + try: + return target_type.model_validate(value) + except Exception: + logger.debug("Could not validate Pydantic model %s", target_type) + return value # type: ignore[return-value] + + # Try dataclass construction (for unmarked dicts, e.g., external HITL data) + if is_dataclass(target_type) and isinstance(target_type, type): # type: ignore + try: + return target_type(**value) + except Exception: + logger.debug("Could not construct dataclass %s", target_type) + + return value # type: ignore[return-value] diff --git a/python/packages/durabletask/tests/test_workflow_activity.py b/python/packages/durabletask/tests/test_workflow_activity.py new file mode 100644 index 00000000000..d92a0a82372 --- /dev/null +++ b/python/packages/durabletask/tests/test_workflow_activity.py @@ -0,0 +1,123 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for execute_workflow_activity (shared non-agent executor activity body). + +These tests exercise the host-agnostic activity execution shared by the Azure +Functions and standalone durabletask workflow hosts. In particular they protect +the state snapshot/diff semantics: the snapshot must be a *deep* copy so that +in-place mutations to nested objects (dicts, lists) are correctly detected as +updates (regression guard for the shallow-copy bug, #4500). +""" + +import json +from typing import Any +from unittest.mock import AsyncMock, Mock + +from agent_framework_durabletask import execute_workflow_activity +from agent_framework_durabletask._workflow_orchestrator import SOURCE_ORCHESTRATOR + + +def _make_executor(executor_id: str, mutate: Any) -> Mock: + """Build a mock non-agent executor whose execute() mutates shared state.""" + executor = Mock() + executor.id = executor_id + executor.execute = AsyncMock(side_effect=mutate) + return executor + + +def _run(executor: Mock, snapshot: dict[str, Any]) -> dict[str, Any]: + """Invoke execute_workflow_activity and return the parsed result dict.""" + input_data = json.dumps({ + "message": "test", + "shared_state_snapshot": snapshot, + "source_executor_ids": [SOURCE_ORCHESTRATOR], + }) + return json.loads(execute_workflow_activity(executor, input_data)) + + +class TestExecuteWorkflowActivityStateDiff: + """State snapshot/diff behavior of the shared workflow activity body.""" + + def test_nested_dict_mutation_detected(self) -> None: + """In-place mutation of a nested dict is reported as an update.""" + + async def mutate(message: Any, source_executor_ids: Any, state: Any, runner_context: Any) -> None: + config = state.get("Local.config") + config["code"] = "SOMECODEXXX" + config["enabled"] = True + state.commit() + + executor = _make_executor("test-exec", mutate) + result = _run(executor, {"Local.config": {"code": "", "enabled": False}, "simple_key": "simple_value"}) + + updates = result["shared_state_updates"] + assert "Local.config" in updates, "nested mutation not detected — snapshot may be a shallow copy" + assert updates["Local.config"]["code"] == "SOMECODEXXX" + assert updates["Local.config"]["enabled"] is True + + def test_new_key_in_nested_dict_detected(self) -> None: + """Adding a key to a nested dict is reported as an update.""" + + async def mutate(message: Any, source_executor_ids: Any, state: Any, runner_context: Any) -> None: + state.get("Local.data")["code"] = "NEW_CODE" + state.commit() + + executor = _make_executor("test-exec", mutate) + result = _run(executor, {"Local.data": {"existing": "value"}}) + + assert result["shared_state_updates"]["Local.data"]["code"] == "NEW_CODE" + + def test_nested_list_mutation_detected(self) -> None: + """Appending to a nested list is reported as an update.""" + + async def mutate(message: Any, source_executor_ids: Any, state: Any, runner_context: Any) -> None: + state.get("Local.items").append(4) + state.commit() + + executor = _make_executor("test-exec", mutate) + result = _run(executor, {"Local.items": [1, 2, 3]}) + + assert result["shared_state_updates"]["Local.items"] == [1, 2, 3, 4] + + def test_new_top_level_key_detected(self) -> None: + """Setting a new top-level key is reported as an update.""" + + async def mutate(message: Any, source_executor_ids: Any, state: Any, runner_context: Any) -> None: + state.set("Local.code", "SOMECODEXXX") + state.commit() + + executor = _make_executor("test-exec", mutate) + result = _run(executor, {"existing": "value"}) + + assert result["shared_state_updates"]["Local.code"] == "SOMECODEXXX" + + def test_unchanged_state_produces_empty_diff(self) -> None: + """Unmodified state produces no updates.""" + + async def mutate(message: Any, source_executor_ids: Any, state: Any, runner_context: Any) -> None: + # No mutations performed. + state.commit() + + executor = _make_executor("test-exec", mutate) + result = _run(executor, {"Local.config": {"code": "existing", "enabled": True}, "simple_key": "v"}) + + assert result["shared_state_updates"] == {} + + def test_deleted_key_reported(self) -> None: + """A key removed during execution is reported as a delete.""" + + async def mutate(message: Any, source_executor_ids: Any, state: Any, runner_context: Any) -> None: + state.delete("to_remove") + state.commit() + + executor = _make_executor("test-exec", mutate) + result = _run(executor, {"to_remove": "value", "keep": "value"}) + + assert "to_remove" in result["shared_state_deletes"] + assert "keep" not in result["shared_state_deletes"] + + +if __name__ == "__main__": + import pytest + + pytest.main([__file__, "-v", "--tb=short"]) From 68449309dd3040b1062c1b5667f9bda84cd6b64f Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Tue, 9 Jun 2026 14:28:05 -0500 Subject: [PATCH 02/15] refactor(azurefunctions): delegate workflow execution to agent-framework-durabletask AgentFunctionApp now reuses the shared orchestrator, activity body, and registration planner from agent_framework_durabletask instead of maintaining its own copies; _workflow.py becomes a thin host-specific adapter (AzureFunctionsWorkflowContext). - Run agent entity coroutines on the shared persistent event loop, fixing the cross-loop hang. - Relocate state-diff unit tests to the durabletask package; update entity loop tests. --- .../agent_framework_azurefunctions/_app.py | 169 +-- .../_entities.py | 27 +- .../_workflow.py | 1012 +---------------- .../_workflow_af_context.py | 82 ++ .../packages/azurefunctions/tests/test_app.py | 283 +---- .../azurefunctions/tests/test_entities.py | 70 +- 6 files changed, 191 insertions(+), 1452 deletions(-) create mode 100644 python/packages/azurefunctions/agent_framework_azurefunctions/_workflow_af_context.py diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index c25c2461cee..bc524674a22 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -14,15 +14,13 @@ import re import uuid from collections.abc import Callable, Mapping -from copy import deepcopy from dataclasses import dataclass from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, TypeVar, cast import azure.durable_functions as df import azure.functions as func -from agent_framework import AgentExecutor, SupportsAgentRun, Workflow, WorkflowEvent -from agent_framework._workflows._runner_context import YieldOutputEventType +from agent_framework import SupportsAgentRun, Workflow from agent_framework_durabletask import ( DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS, @@ -34,25 +32,22 @@ THREAD_ID_HEADER, WAIT_FOR_RESPONSE_FIELD, WAIT_FOR_RESPONSE_HEADER, + WORKFLOW_ORCHESTRATOR_NAME, AgentResponseCallbackProtocol, AgentSessionId, ApiResponseFields, DurableAgentState, DurableAIAgent, RunRequest, + execute_workflow_activity, + plan_workflow_registration, ) -from ._context import CapturingRunnerContext from ._entities import create_agent_entity from ._errors import IncomingRequestError from ._orchestration import AgentOrchestrationContextType, AgentTask, AzureFunctionsAgentExecutor -from ._serialization import deserialize_value, serialize_value, strip_pickle_markers -from ._workflow import ( - SOURCE_HITL_RESPONSE, - SOURCE_ORCHESTRATOR, - execute_hitl_response_handler, - run_workflow_orchestrator, -) +from ._serialization import strip_pickle_markers +from ._workflow import run_workflow_orchestrator logger = logging.getLogger("agent_framework.azurefunctions") @@ -60,11 +55,6 @@ HandlerT = TypeVar("HandlerT", bound=Callable[..., Any]) -def _create_state_snapshot(state: dict[str, Any]) -> dict[str, Any]: - """Create a deep copy of the deserialized state for later diffing.""" - return deepcopy(state) - - @dataclass class AgentMetadata: """Metadata for a registered agent. @@ -234,17 +224,18 @@ def __init__( interval = DEFAULT_POLL_INTERVAL_SECONDS self.poll_interval_seconds = interval if interval > 0 else DEFAULT_POLL_INTERVAL_SECONDS - # If workflow is provided, extract agents and set up orchestration + # If workflow is provided, extract agents and set up orchestration. + # The "what to register" decision (agent -> entity, non-agent -> activity) + # is shared with the standalone durabletask host via plan_workflow_registration. if workflow: if agents is None: agents = [] logger.debug("[AgentFunctionApp] Extracting agents from workflow") - for executor in workflow.executors.values(): - if isinstance(executor, AgentExecutor): - agents.append(executor.agent) - else: - # Setup individual activity for each non-agent executor - self._setup_executor_activity(executor.id) + plan = plan_workflow_registration(workflow) + agents.extend(plan.agents) + for executor in plan.activity_executors: + # Set up a Functions activity trigger for each non-agent executor. + self._setup_executor_activity(executor.id) self._setup_workflow_orchestration() @@ -279,18 +270,9 @@ def executor_activity(inputData: str) -> str: Note: We use str type annotations instead of dict to work around Azure Functions worker type validation issues with dict[str, Any]. + The execution body is shared with the standalone durabletask host via + ``execute_workflow_activity``. """ - from agent_framework._workflows._state import State - - data_obj = json.loads(inputData) - if not isinstance(data_obj, dict): - raise ValueError("Activity inputData must decode to a JSON object") - data = cast(dict[str, Any], data_obj) - - message_data = data.get("message") - shared_state_snapshot = data.get("shared_state_snapshot", {}) - source_executor_ids = cast(list[str], data.get("source_executor_ids", [SOURCE_ORCHESTRATOR])) - if not self.workflow: raise RuntimeError("Workflow not initialized in AgentFunctionApp") @@ -298,120 +280,7 @@ def executor_activity(inputData: str) -> str: if not executor: raise ValueError(f"Unknown executor: {captured_executor_id}") - # Reconstruct message - deserialize_value restores the original typed objects - # from the encoded data (with type markers) - message = deserialize_value(message_data) - - # Check if this is a HITL response message by examining source_executor_ids - is_hitl_response = any(s.startswith(SOURCE_HITL_RESPONSE) for s in source_executor_ids) - - async def run() -> dict[str, Any]: - # Create runner context and shared state - runner_context = CapturingRunnerContext() - workflow = self.workflow - - def classify_yielded_output(executor_id: str) -> YieldOutputEventType | None: - if workflow is None: - return "output" - if workflow.is_terminal_executor(executor_id): - return "output" - if workflow.is_intermediate_executor(executor_id): - return "intermediate" - return None - - runner_context.set_yield_output_classifier(classify_yielded_output) - shared_state = State() - - # Deserialize shared state values to reconstruct dataclasses/Pydantic models - deserialized_state: dict[str, Any] = { - str(k): deserialize_value(v) for k, v in shared_state_snapshot.items() - } - original_snapshot = _create_state_snapshot(deserialized_state) - shared_state.import_state(deserialized_state) - - if is_hitl_response: - # Handle HITL response by calling the executor's @response_handler - if not isinstance(message_data, dict): - raise ValueError("HITL message payload must be a JSON object") - - await execute_hitl_response_handler( - executor=executor, - hitl_message=cast(dict[str, Any], message_data), - shared_state=shared_state, - runner_context=runner_context, - ) - else: - # Execute using the public execute() method - await executor.execute( - message=message, - source_executor_ids=source_executor_ids, - state=shared_state, - runner_context=runner_context, - ) - - # Commit pending state changes and export - shared_state.commit() - current_state = shared_state.export_state() - original_keys: set[str] = set(original_snapshot.keys()) - current_keys: set[str] = set(current_state.keys()) - - # Deleted = was in original, not in current - deletes: set[str] = original_keys - current_keys - - # Updates = keys in current that are new or have different values - updates: dict[str, Any] = {} - for key in current_keys: - if key not in original_keys or current_state[key] != original_snapshot.get(key): - updates[key] = current_state[key] - - # Drain messages and events from runner context - sent_messages = await runner_context.drain_messages() - events = await runner_context.drain_events() - - # Extract outputs from WorkflowEvent instances with type='output' - outputs: list[Any] = [] - for event in events: - if isinstance(event, WorkflowEvent) and event.type == "output": - outputs.append(serialize_value(event.data)) - - # Get pending request info events for HITL - pending_request_info_events = await runner_context.get_pending_request_info_events() - - # Serialize pending request info events for orchestrator - serialized_pending_requests: list[dict[str, Any]] = [] - for _request_id, event in pending_request_info_events.items(): - serialized_pending_requests.append({ - "request_id": event.request_id, - "source_executor_id": event.source_executor_id, - "data": serialize_value(event.data), - "request_type": f"{type(event.data).__module__}:{type(event.data).__name__}", - "response_type": f"{event.response_type.__module__}:{event.response_type.__name__}" - if event.response_type - else None, - }) - - # Serialize messages for JSON compatibility - serialized_sent_messages: list[dict[str, Any]] = [] - for _source_id, msg_list in sent_messages.items(): - for msg in msg_list: - serialized_sent_messages.append({ - "message": serialize_value(msg.data), - "target_id": msg.target_id, - "source_id": msg.source_id, - }) - - serialized_updates = {k: serialize_value(v) for k, v in updates.items()} - - return { - "sent_messages": serialized_sent_messages, - "outputs": outputs, - "shared_state_updates": serialized_updates, - "shared_state_deletes": list(deletes), - "pending_request_info_events": serialized_pending_requests, - } - - result = asyncio.run(run()) - return json.dumps(result) + return execute_workflow_activity(executor, inputData, self.workflow) # Ensure the function is registered (prevents garbage collection) _ = executor_activity @@ -448,7 +317,7 @@ async def start_workflow_orchestration( except ValueError: return self._build_error_response("Invalid JSON body") - instance_id = await client.start_new("workflow_orchestrator", client_input=req_body) + instance_id = await client.start_new(WORKFLOW_ORCHESTRATOR_NAME, client_input=req_body) base_url = self._build_base_url(req.url) status_url = f"{base_url}/api/workflow/status/{instance_id}" @@ -1092,8 +961,6 @@ async def _get_response_from_entity( thread_id: str, ) -> dict[str, Any]: """Poll the entity state until a response is available or timeout occurs.""" - import asyncio - max_retries = self.max_poll_retries interval = self.poll_interval_seconds retry_count = 0 diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py index d734950979d..9de242efaa6 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py @@ -9,7 +9,6 @@ from __future__ import annotations -import asyncio import logging from collections.abc import Callable from typing import Any, cast @@ -20,6 +19,7 @@ AgentEntity, AgentEntityStateProviderMixin, AgentResponseCallbackProtocol, + run_agent_coroutine, ) logger = logging.getLogger("agent_framework.azurefunctions") @@ -101,23 +101,16 @@ async def _entity_coroutine(context: df.DurableEntityContext) -> None: context.set_result({"error": str(exc), "status": "error"}) def entity_function(context: df.DurableEntityContext) -> None: - """Synchronous wrapper invoked by the Durable Functions runtime.""" + """Synchronous wrapper invoked by the Durable Functions runtime. + + All agent coroutines run on a single process-wide persistent event loop + (see ``run_agent_coroutine``). This keeps async resources created by + shared agent clients/credentials bound to a live loop across every + invocation, preventing cross-loop hangs when the host dispatches + successive entity operations onto different worker threads. + """ try: - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - if loop.is_running(): - temp_loop = asyncio.new_event_loop() - try: - temp_loop.run_until_complete(_entity_coroutine(context)) - finally: - temp_loop.close() - else: - loop.run_until_complete(_entity_coroutine(context)) - + run_agent_coroutine(_entity_coroutine(context)) except Exception as exc: # pragma: no cover - defensive logging logger.error("[entity_function] Unexpected error executing entity: %s", exc, exc_info=True) context.set_result({"error": str(exc), "status": "error"}) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py index 6fbdaf44f7e..7a7c687460c 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py @@ -2,603 +2,63 @@ """Workflow Execution for Durable Functions. -This module provides the workflow orchestration engine that executes MAF Workflows -using Azure Durable Functions. It reuses MAF's edge group routing logic while -adapting execution to the DF generator-based model (yield instead of await). - -Key components: -- run_workflow_orchestrator: Main orchestration function for workflow execution -- route_message_through_edge_groups: Routing helper using MAF edge group APIs -- build_agent_executor_response: Helper to construct AgentExecutorResponse - -HITL (Human-in-the-Loop) Support: -- Detects pending RequestInfoEvents from executor activities -- Uses wait_for_external_event to pause for human input -- Routes responses back to executor's @response_handler methods +This module provides the Azure Functions entry point for workflow orchestration. +The actual orchestration logic lives in the shared module +``agent_framework_durabletask._workflow_orchestrator`` and is host-agnostic. +This module re-exports the public API and provides the AF-specific +``run_workflow_orchestrator`` wrapper that creates an +:class:`AzureFunctionsWorkflowContext` before delegating. """ from __future__ import annotations -import json import logging -from collections import defaultdict from collections.abc import Generator -from dataclasses import dataclass -from datetime import timedelta -from enum import Enum from typing import Any -from agent_framework import ( - AgentExecutor, - AgentExecutorRequest, - AgentExecutorResponse, - AgentResponse, - Message, - Workflow, +from agent_framework import Workflow +from agent_framework_durabletask._workflow_orchestrator import ( + DEFAULT_HITL_TIMEOUT_HOURS, + SOURCE_HITL_RESPONSE, + SOURCE_ORCHESTRATOR, + SOURCE_WORKFLOW_START, + ExecutorResult, + PendingHITLRequest, + TaskMetadata, + TaskType, + _extract_message_content, + build_agent_executor_response, + execute_hitl_response_handler, + route_message_through_edge_groups, ) -from agent_framework._workflows._edge import ( - Edge, - EdgeGroup, - FanInEdgeGroup, - FanOutEdgeGroup, - SingleEdgeGroup, - SwitchCaseEdgeGroup, +from agent_framework_durabletask._workflow_orchestrator import ( + run_workflow_orchestrator as _run_workflow_orchestrator_shared, ) -from agent_framework._workflows._state import State -from agent_framework_durabletask import AgentSessionId, DurableAgentSession, DurableAIAgent from azure.durable_functions import DurableOrchestrationContext -from ._context import CapturingRunnerContext -from ._orchestration import AzureFunctionsAgentExecutor -from ._serialization import deserialize_value, reconstruct_to_type, resolve_type, serialize_value, strip_pickle_markers +# Re-export serialization helpers so existing AF code (e.g. _app.py) keeps working +# via ``from ._workflow import ...`` or ``from ._serialization import ...``. +from ._serialization import deserialize_value, serialize_value # noqa: F401 +from ._workflow_af_context import AzureFunctionsWorkflowContext logger = logging.getLogger(__name__) - -# ============================================================================ -# Source Marker Constants -# ============================================================================ -# These markers identify the origin of messages in the workflow orchestration. -# They are used to track message provenance and handle special cases like HITL. - -# Marker indicating the message originated from the workflow start (initial user input) -SOURCE_WORKFLOW_START = "__workflow_start__" - -# Marker indicating the message originated from the orchestrator itself -# (used as default when executor is called directly by orchestrator, not via another executor) -SOURCE_ORCHESTRATOR = "__orchestrator__" - -# Marker indicating the message is a human-in-the-loop response. -# Used as a source ID prefix. To detect HITL responses, check if any source_executor_id -# starts with this prefix. -SOURCE_HITL_RESPONSE = "__hitl_response__" - - -# ============================================================================ -# Task Types and Data Structures -# ============================================================================ - - -class TaskType(Enum): - """Type of executor task.""" - - AGENT = "agent" - ACTIVITY = "activity" - - -@dataclass -class TaskMetadata: - """Metadata for a pending task.""" - - executor_id: str - message: Any - source_executor_id: str - task_type: TaskType - remaining_messages: list[tuple[str, Any, str]] | None = None # For agents with multiple messages - - -@dataclass -class ExecutorResult: - """Result from executing an agent or activity.""" - - executor_id: str - output_message: AgentExecutorResponse | None - activity_result: dict[str, Any] | None - task_type: TaskType - - -@dataclass -class PendingHITLRequest: - """Tracks a pending Human-in-the-Loop request in the orchestrator. - - Attributes: - request_id: Unique identifier for correlation with external events - source_executor_id: The executor that called ctx.request_info() - request_data: The serialized request payload - request_type: Fully qualified type name of the request data - response_type: Fully qualified type name of expected response - """ - - request_id: str - source_executor_id: str - request_data: Any - request_type: str | None - response_type: str | None - - -# Default timeout for HITL requests (72 hours) -DEFAULT_HITL_TIMEOUT_HOURS = 72.0 - - -# ============================================================================ -# Routing Functions -# ============================================================================ - - -def _evaluate_edge_condition_sync(edge: Edge, message: Any) -> bool: - """Evaluate an edge's condition synchronously. - - This is needed because Durable Functions orchestrators use generators, - not async/await, so we cannot call async methods like edge.should_route(). - - Args: - edge: The Edge with an optional _condition callable - message: The message to evaluate against the condition - - Returns: - True if the edge should be traversed, False otherwise - """ - # Access the internal condition directly since should_route is async - condition = edge._condition # pyright: ignore[reportPrivateUsage] - if condition is None: - return True - result = condition(message) - # If the condition is async, we cannot await it in a generator context - # Log a warning and assume True (or False for safety) - if hasattr(result, "__await__"): - import warnings - - warnings.warn( - f"Edge condition for {edge.source_id}->{edge.target_id} is async, " - "which is not supported in Durable Functions orchestrators. " - "The edge will be traversed unconditionally.", - RuntimeWarning, - stacklevel=2, - ) - return True - return bool(result) - - -def route_message_through_edge_groups( - edge_groups: list[EdgeGroup], - source_id: str, - message: Any, -) -> list[str]: - """Route a message through edge groups to find target executor IDs. - - Delegates to MAF's edge group routing logic instead of manual inspection. - - Args: - edge_groups: List of EdgeGroup instances from the workflow - source_id: The ID of the source executor - message: The message to route - - Returns: - List of target executor IDs that should receive the message - """ - targets: list[str] = [] - - for group in edge_groups: - if source_id not in group.source_executor_ids: - continue - - # SwitchCaseEdgeGroup and FanOutEdgeGroup use selection_func - if isinstance(group, (SwitchCaseEdgeGroup, FanOutEdgeGroup)): - if group.selection_func is not None: - selected = group.selection_func(message, group.target_executor_ids) - targets.extend(selected) - else: - # No selection func means broadcast to all targets - targets.extend(group.target_executor_ids) - - elif isinstance(group, SingleEdgeGroup): - # SingleEdgeGroup has exactly one edge - edge = group.edges[0] - if _evaluate_edge_condition_sync(edge, message): - targets.append(edge.target_id) - - elif isinstance(group, FanInEdgeGroup): - # FanIn is handled separately in the orchestrator loop - # since it requires aggregation - pass - - else: - # Generic EdgeGroup: check each edge's condition - for edge in group.edges: - if edge.source_id == source_id and _evaluate_edge_condition_sync(edge, message): - targets.append(edge.target_id) - - return targets - - -def build_agent_executor_response( - executor_id: str, - response_text: str | None, - structured_response: dict[str, Any] | None, - previous_message: Any, -) -> AgentExecutorResponse: - """Build an AgentExecutorResponse from entity response data. - - Shared helper to construct the response object consistently. - - Args: - executor_id: The ID of the executor that produced the response - response_text: Plain text response from the agent (if any) - structured_response: Structured JSON response (if any) - previous_message: The input message that triggered this response - - Returns: - AgentExecutorResponse with reconstructed conversation - """ - final_text: str = response_text or "" - if structured_response: - final_text = json.dumps(structured_response) - - assistant_message = Message(role="assistant", contents=[final_text]) - - agent_response = AgentResponse( - messages=[assistant_message], - ) - - # Build conversation history - full_conversation: list[Message] = [] - if isinstance(previous_message, AgentExecutorResponse) and previous_message.full_conversation: - full_conversation.extend(previous_message.full_conversation) - elif isinstance(previous_message, str): - full_conversation.append(Message(role="user", contents=[previous_message])) - - full_conversation.append(assistant_message) - - return AgentExecutorResponse( - executor_id=executor_id, - agent_response=agent_response, - full_conversation=full_conversation, - ) - - -# ============================================================================ -# Task Preparation Helpers -# ============================================================================ - - -def _prepare_agent_task( - context: DurableOrchestrationContext, - executor_id: str, - message: Any, -) -> Any: - """Prepare an agent task for execution. - - Args: - context: The Durable Functions orchestration context - executor_id: The agent executor ID (agent name) - message: The input message for the agent - - Returns: - A task that can be yielded to execute the agent - """ - message_content = _extract_message_content(message) - session_id = AgentSessionId(name=executor_id, key=context.instance_id) - session = DurableAgentSession(durable_session_id=session_id) - - az_executor = AzureFunctionsAgentExecutor(context) - agent = DurableAIAgent(az_executor, executor_id) - return agent.run(message_content, session=session) - - -def _prepare_activity_task( - context: DurableOrchestrationContext, - executor_id: str, - message: Any, - source_executor_id: str, - shared_state_snapshot: dict[str, Any] | None, -) -> Any: - """Prepare an activity task for execution. - - Args: - context: The Durable Functions orchestration context - executor_id: The activity executor ID - message: The input message for the activity - source_executor_id: The ID of the executor that sent the message - shared_state_snapshot: Current shared state snapshot - - Returns: - A task that can be yielded to execute the activity - """ - activity_input = { - "executor_id": executor_id, - "message": serialize_value(message), - "shared_state_snapshot": shared_state_snapshot, - "source_executor_ids": [source_executor_id], - } - activity_input_json = json.dumps(activity_input) - # Use the prefixed activity name that matches the registered function - activity_name = f"dafx-{executor_id}" - orchestration_context: Any = context - return orchestration_context.call_activity(activity_name, activity_input_json) - - -# ============================================================================ -# Result Processing Helpers -# ============================================================================ - - -def _process_agent_response( - agent_response: AgentResponse, - executor_id: str, - message: Any, -) -> ExecutorResult: - """Process an agent response into an ExecutorResult. - - Args: - agent_response: The response from the agent - executor_id: The agent executor ID - message: The original input message - - Returns: - ExecutorResult containing the processed response - """ - response_text = agent_response.text if agent_response else None - structured_response: dict[str, Any] | None = None - - if agent_response and agent_response.value is not None: - model_dump = getattr(agent_response.value, "model_dump", None) - if callable(model_dump): - dumped = model_dump() - if isinstance(dumped, dict): - structured_response = dumped # type: ignore[assignment] - elif isinstance(agent_response.value, dict): - structured_response = agent_response.value # type: ignore[assignment] - - output_message = build_agent_executor_response( - executor_id=executor_id, - response_text=response_text, - structured_response=structured_response, - previous_message=message, - ) - - return ExecutorResult( - executor_id=executor_id, - output_message=output_message, - activity_result=None, - task_type=TaskType.AGENT, - ) - - -def _process_activity_result( - result_json: str | None, - executor_id: str, - shared_state: dict[str, Any] | None, - workflow_outputs: list[Any], -) -> ExecutorResult: - """Process an activity result and apply shared state updates. - - Args: - result_json: The JSON result from the activity - executor_id: The activity executor ID - shared_state: The shared state dict to update (mutated in place) - workflow_outputs: List to append outputs to (mutated in place) - - Returns: - ExecutorResult containing the processed result - """ - result = json.loads(result_json) if result_json else None - - # Apply shared state updates - if shared_state is not None and result: - if result.get("shared_state_updates"): - updates = result["shared_state_updates"] - logger.debug("[workflow] Applying SharedState updates from %s: %s", executor_id, updates) - shared_state.update(updates) - if result.get("shared_state_deletes"): - deletes = result["shared_state_deletes"] - logger.debug("[workflow] Applying SharedState deletes from %s: %s", executor_id, deletes) - for key in deletes: - shared_state.pop(key, None) - - # Collect outputs - if result and result.get("outputs"): - workflow_outputs.extend(result["outputs"]) - - return ExecutorResult( - executor_id=executor_id, - output_message=None, - activity_result=result, - task_type=TaskType.ACTIVITY, - ) - - -# ============================================================================ -# Routing Helpers -# ============================================================================ - - -def _route_result_messages( - result: ExecutorResult, - workflow: Workflow, - next_pending_messages: dict[str, list[tuple[Any, str]]], - fan_in_pending: dict[str, dict[str, list[tuple[Any, str]]]], -) -> None: - """Route messages from an executor result to their targets. - - Args: - result: The executor result containing messages to route - workflow: The workflow definition - next_pending_messages: Dict to accumulate next iteration's messages (mutated) - fan_in_pending: Dict tracking fan-in state (mutated) - """ - executor_id = result.executor_id - messages_to_route: list[tuple[Any, str | None]] = [] - - # Collect messages from agent response - if result.output_message: - messages_to_route.append((result.output_message, None)) - - # Collect sent_messages from activity results - if result.activity_result and result.activity_result.get("sent_messages"): - for msg_data in result.activity_result["sent_messages"]: - sent_msg = msg_data.get("message") - target_id = msg_data.get("target_id") - if sent_msg: - sent_msg = deserialize_value(sent_msg) - messages_to_route.append((sent_msg, target_id)) - - # Route each message - for msg_to_route, explicit_target in messages_to_route: - logger.debug("Routing output from %s", executor_id) - - # If explicit target specified, route directly - if explicit_target: - if explicit_target not in next_pending_messages: - next_pending_messages[explicit_target] = [] - next_pending_messages[explicit_target].append((msg_to_route, executor_id)) - logger.debug("Routed message from %s to explicit target %s", executor_id, explicit_target) - continue - - # Check for FanInEdgeGroup sources - for group in workflow.edge_groups: - if isinstance(group, FanInEdgeGroup) and executor_id in group.source_executor_ids: - fan_in_pending[group.id][executor_id].append((msg_to_route, executor_id)) - logger.debug("Accumulated message for FanIn group %s from %s", group.id, executor_id) - - # Use MAF's edge group routing for other edge types - targets = route_message_through_edge_groups(workflow.edge_groups, executor_id, msg_to_route) - - for target_id in targets: - logger.debug("Routing to %s", target_id) - if target_id not in next_pending_messages: - next_pending_messages[target_id] = [] - next_pending_messages[target_id].append((msg_to_route, executor_id)) - - -def _check_fan_in_ready( - workflow: Workflow, - fan_in_pending: dict[str, dict[str, list[tuple[Any, str]]]], - next_pending_messages: dict[str, list[tuple[Any, str]]], -) -> None: - """Check if any FanInEdgeGroups are ready and deliver their messages. - - Args: - workflow: The workflow definition - fan_in_pending: Dict tracking fan-in state (mutated - cleared when delivered) - next_pending_messages: Dict to add aggregated messages to (mutated) - """ - for group in workflow.edge_groups: - if not isinstance(group, FanInEdgeGroup): - continue - - pending_sources = fan_in_pending.get(group.id, {}) - - # Check if all sources have contributed at least one message - if not all(src in pending_sources and pending_sources[src] for src in group.source_executor_ids): - continue - - # Aggregate all messages into a single list - aggregated: list[Any] = [] - aggregated_sources: list[str] = [] - for src in group.source_executor_ids: - for msg, msg_source in pending_sources[src]: - aggregated.append(msg) - aggregated_sources.append(msg_source) - - target_id = group.target_executor_ids[0] - logger.debug("FanIn group %s ready, delivering %d messages to %s", group.id, len(aggregated), target_id) - - if target_id not in next_pending_messages: - next_pending_messages[target_id] = [] - - first_source = aggregated_sources[0] if aggregated_sources else "__fan_in__" - next_pending_messages[target_id].append((aggregated, first_source)) - - # Clear the pending sources for this group - fan_in_pending[group.id] = defaultdict(list) - - -# ============================================================================ -# HITL (Human-in-the-Loop) Helpers -# ============================================================================ - - -def _collect_hitl_requests( - result: ExecutorResult, - pending_hitl_requests: dict[str, PendingHITLRequest], -) -> None: - """Collect pending HITL requests from an activity result. - - Args: - result: The executor result that may contain pending request info events - pending_hitl_requests: Dict to accumulate pending requests (mutated) - """ - if result.activity_result and result.activity_result.get("pending_request_info_events"): - for req_data in result.activity_result["pending_request_info_events"]: - request_id = req_data.get("request_id") - if request_id: - pending_hitl_requests[request_id] = PendingHITLRequest( - request_id=request_id, - source_executor_id=req_data.get("source_executor_id", result.executor_id), - request_data=req_data.get("data"), - request_type=req_data.get("request_type"), - response_type=req_data.get("response_type"), - ) - logger.debug( - "Collected HITL request %s from executor %s", - request_id, - result.executor_id, - ) - - -def _route_hitl_response( - hitl_request: PendingHITLRequest, - raw_response: Any, - pending_messages: dict[str, list[tuple[Any, str]]], -) -> None: - """Route a HITL response back to the source executor's @response_handler. - - The response is packaged as a special HITL response message that the executor - activity can recognize and route to the appropriate @response_handler method. - - Args: - hitl_request: The original HITL request - raw_response: The raw response data from the external event - pending_messages: Dict to add the response message to (mutated) - """ - # Create a message structure that the executor can recognize - # This mimics what the InProcRunnerContext does for request_info responses - # Note: HITL origin is identified via source_executor_ids (starting with SOURCE_HITL_RESPONSE) - response_message = { - "request_id": hitl_request.request_id, - "original_request": hitl_request.request_data, - "response": raw_response, - "response_type": hitl_request.response_type, - } - - target_id = hitl_request.source_executor_id - if target_id not in pending_messages: - pending_messages[target_id] = [] - - # Use a special source ID to indicate this is a HITL response - source_id = f"{SOURCE_HITL_RESPONSE}_{hitl_request.request_id}" - pending_messages[target_id].append((response_message, source_id)) - - logger.debug( - "Routed HITL response for request %s to executor %s", - hitl_request.request_id, - target_id, - ) - - -# ============================================================================ -# Main Orchestrator -# ============================================================================ +# Re-export shared symbols for backward compatibility +__all__ = [ + "DEFAULT_HITL_TIMEOUT_HOURS", + "SOURCE_HITL_RESPONSE", + "SOURCE_ORCHESTRATOR", + "SOURCE_WORKFLOW_START", + "ExecutorResult", + "PendingHITLRequest", + "TaskMetadata", + "TaskType", + "_extract_message_content", + "build_agent_executor_response", + "execute_hitl_response_handler", + "route_message_through_edge_groups", + "run_workflow_orchestrator", +] def run_workflow_orchestrator( @@ -608,386 +68,20 @@ def run_workflow_orchestrator( shared_state: dict[str, Any] | None = None, hitl_timeout_hours: float = DEFAULT_HITL_TIMEOUT_HOURS, ) -> Generator[Any, Any, list[Any]]: - """Traverse and execute the workflow graph using Durable Functions. - - This orchestrator reuses MAF's edge group routing logic while adapting - execution to the DF generator-based model (yield instead of await). - - Supports: - - SingleEdgeGroup: Direct 1:1 routing with optional condition - - SwitchCaseEdgeGroup: First matching condition wins - - FanOutEdgeGroup: Broadcast to multiple targets - **executed in parallel** - - FanInEdgeGroup: Aggregates messages from multiple sources before delivery - - SharedState: Local shared state accessible to all executors - - HITL: Human-in-the-loop via request_info / @response_handler pattern - - Execution model: - - All pending executors (agents AND activities) run in parallel via single task_all() - - Multiple messages to the SAME agent are processed sequentially for conversation coherence - - SharedState updates are applied in order after parallel tasks complete - - HITL requests pause the orchestration until external events are received - - Args: - context: The Durable Functions orchestration context - workflow: The MAF Workflow instance to execute - initial_message: The initial message to send to the start executor - shared_state: Optional dict for cross-executor state sharing (local to orchestration) - hitl_timeout_hours: Timeout in hours for HITL requests (default: 72 hours) - - Returns: - List of workflow outputs collected from executor activities - """ - pending_messages: dict[str, list[tuple[Any, str]]] = { - workflow.start_executor_id: [(initial_message, SOURCE_WORKFLOW_START)] - } - workflow_outputs: list[Any] = [] - iteration = 0 - - # Track pending sources for FanInEdgeGroups using defaultdict for cleaner access - fan_in_pending: dict[str, dict[str, list[tuple[Any, str]]]] = { - group.id: defaultdict(list) for group in workflow.edge_groups if isinstance(group, FanInEdgeGroup) - } - - # Track pending HITL requests - pending_hitl_requests: dict[str, PendingHITLRequest] = {} - - while pending_messages and iteration < workflow.max_iterations: - logger.debug("Orchestrator iteration %d", iteration) - next_pending_messages: dict[str, list[tuple[Any, str]]] = {} - - # Phase 1: Prepare all tasks (agents and activities unified) - all_tasks, task_metadata_list, remaining_agent_messages = _prepare_all_tasks( - context, workflow, pending_messages, shared_state - ) - - # Phase 2: Execute all tasks in parallel (single task_all for true parallelism) - all_results: list[ExecutorResult] = [] - if all_tasks: - logger.debug("Executing %d tasks in parallel (agents + activities)", len(all_tasks)) - raw_results = yield context.task_all(all_tasks) - logger.debug("All %d tasks completed", len(all_tasks)) - - # Process results based on task type - for idx, raw_result in enumerate(raw_results): - metadata = task_metadata_list[idx] - if metadata.task_type == TaskType.AGENT: - result = _process_agent_response(raw_result, metadata.executor_id, metadata.message) - else: - result = _process_activity_result(raw_result, metadata.executor_id, shared_state, workflow_outputs) - all_results.append(result) - - # Phase 3: Process sequential agent messages (for same-agent conversation coherence) - for executor_id, message, _source_executor_id in remaining_agent_messages: - logger.debug("Processing sequential message for agent: %s", executor_id) - task = _prepare_agent_task(context, executor_id, message) - agent_response: AgentResponse = yield task - logger.debug("Agent %s sequential response completed", executor_id) - - result = _process_agent_response(agent_response, executor_id, message) - all_results.append(result) - - # Phase 4: Collect pending HITL requests from activity results - for result in all_results: - _collect_hitl_requests(result, pending_hitl_requests) - - # Phase 5: Route all results to next iteration - for result in all_results: - _route_result_messages(result, workflow, next_pending_messages, fan_in_pending) - - # Phase 6: Check if any FanInEdgeGroups are ready to deliver - _check_fan_in_ready(workflow, fan_in_pending, next_pending_messages) - - pending_messages = next_pending_messages + """Azure Functions wrapper around the shared workflow orchestrator. - # Phase 7: Handle HITL - if no pending work but HITL requests exist, wait for responses - if not pending_messages and pending_hitl_requests: - logger.debug("Workflow paused for HITL - %d pending requests", len(pending_hitl_requests)) - - # Update custom status to expose pending requests - context.set_custom_status({ - "state": "waiting_for_human_input", - "pending_requests": { - req_id: { - "request_id": req.request_id, - "source_executor_id": req.source_executor_id, - "data": req.request_data, - "request_type": req.request_type, - "response_type": req.response_type, - } - for req_id, req in pending_hitl_requests.items() - }, - }) - - # Wait for external events for each pending request - # Process responses one at a time to maintain ordering - for request_id, hitl_request in list(pending_hitl_requests.items()): - logger.debug("Waiting for HITL response for request: %s", request_id) - - # Create tasks for approval and timeout - approval_task = context.wait_for_external_event(request_id) - timeout_task = context.create_timer(context.current_utc_datetime + timedelta(hours=hitl_timeout_hours)) - - winner = yield context.task_any([approval_task, timeout_task]) - - if winner == approval_task: - # Cancel the timeout - timeout_task.cancel() # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] - - # Get the response - raw_response = approval_task.result - logger.debug( - "Received HITL response for request %s. Type: %s, Value: %s", - request_id, - type(raw_response).__name__, - raw_response, - ) - - # Durable Functions may return a JSON string; parse it if so - if isinstance(raw_response, str): - try: - raw_response = json.loads(raw_response) - logger.debug("Parsed JSON string response to: %s", type(raw_response).__name__) - except (json.JSONDecodeError, TypeError): - logger.debug("Response is not JSON, keeping as string") - - # Remove from pending - del pending_hitl_requests[request_id] - - # Route the response back to the source executor's @response_handler - _route_hitl_response( - hitl_request, - raw_response, - pending_messages, - ) - else: - # Timeout occurred — cancel the dangling external event listener - approval_task.cancel() # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] - logger.warning("HITL request %s timed out after %s hours", request_id, hitl_timeout_hours) - raise TimeoutError( - f"Human-in-the-loop request '{request_id}' timed out after {hitl_timeout_hours} hours." - ) - - # Clear custom status after HITL is resolved - context.set_custom_status({"state": "running"}) - - iteration += 1 - - # Durable Functions runtime extracts return value from StopIteration - return workflow_outputs # noqa: B901 - - -def _prepare_all_tasks( - context: DurableOrchestrationContext, - workflow: Workflow, - pending_messages: dict[str, list[tuple[Any, str]]], - shared_state: dict[str, Any] | None, -) -> tuple[list[Any], list[TaskMetadata], list[tuple[str, Any, str]]]: - """Prepare all pending tasks for parallel execution. - - Groups agent messages by executor ID so that only the first message per agent - runs in the parallel batch. Additional messages to the same agent are returned - for sequential processing. - - Args: - context: The Durable Functions orchestration context - workflow: The workflow definition - pending_messages: Messages pending for each executor - shared_state: Current shared state snapshot - - Returns: - Tuple of (tasks, metadata, remaining_agent_messages): - - tasks: List of tasks ready for task_all() - - metadata: TaskMetadata for each task (same order as tasks) - - remaining_agent_messages: Agent messages requiring sequential processing - """ - all_tasks: list[Any] = [] - task_metadata_list: list[TaskMetadata] = [] - remaining_agent_messages: list[tuple[str, Any, str]] = [] - - # Group agent messages by executor_id for sequential handling of same-agent messages - agent_messages_by_executor: dict[str, list[tuple[str, Any, str]]] = defaultdict(list) - - # Categorize all pending messages - for executor_id, messages_with_sources in pending_messages.items(): - executor = workflow.executors[executor_id] - is_agent = isinstance(executor, AgentExecutor) - - for message, source_executor_id in messages_with_sources: - if is_agent: - agent_messages_by_executor[executor_id].append((executor_id, message, source_executor_id)) - else: - # Activity tasks can all run in parallel - logger.debug("Preparing activity task: %s", executor_id) - task = _prepare_activity_task(context, executor_id, message, source_executor_id, shared_state) - all_tasks.append(task) - task_metadata_list.append( - TaskMetadata( - executor_id=executor_id, - message=message, - source_executor_id=source_executor_id, - task_type=TaskType.ACTIVITY, - ) - ) - - # Process agent messages: first message per agent goes to parallel batch - for executor_id, messages_list in agent_messages_by_executor.items(): - first_msg = messages_list[0] - remaining = messages_list[1:] - - logger.debug("Preparing agent task: %s", executor_id) - task = _prepare_agent_task(context, first_msg[0], first_msg[1]) - all_tasks.append(task) - task_metadata_list.append( - TaskMetadata( - executor_id=first_msg[0], - message=first_msg[1], - source_executor_id=first_msg[2], - task_type=TaskType.AGENT, - ) - ) - - # Queue remaining messages for sequential processing - remaining_agent_messages.extend(remaining) - - return all_tasks, task_metadata_list, remaining_agent_messages - - -# ============================================================================ -# Message Content Extraction -# ============================================================================ - - -def _extract_message_content(message: Any) -> str: - """Extract text content from various message types.""" - message_content = "" - if isinstance(message, AgentExecutorResponse) and message.agent_response: - if message.agent_response.text: - message_content = message.agent_response.text - elif message.agent_response.messages: - message_content = message.agent_response.messages[-1].text or "" - elif isinstance(message, AgentExecutorRequest) and message.messages: - # Extract text from the last message in the request - message_content = message.messages[-1].text or "" - elif isinstance(message, dict): - key_names = list(message.keys()) # type: ignore[union-attr] - logger.warning("Unexpected dict message in _extract_message_content. Keys: %s", key_names) # type: ignore - elif isinstance(message, str): - message_content = message - - return message_content - - -# ============================================================================ -# HITL Response Handler Execution -# ============================================================================ - - -async def execute_hitl_response_handler( - executor: Any, - hitl_message: dict[str, Any], - shared_state: State, - runner_context: CapturingRunnerContext, -) -> None: - """Execute a HITL response handler on an executor. - - This function handles the delivery of a HITL response to the executor's - @response_handler method. It: - 1. Deserializes the original request and response - 2. Finds the matching response handler based on types - 3. Creates a WorkflowContext and invokes the handler - - Args: - executor: The executor instance that has a @response_handler - hitl_message: The HITL response message containing original_request and response - shared_state: The shared state for the workflow context - runner_context: The runner context for capturing outputs - """ - from agent_framework._workflows._workflow_context import WorkflowContext - - # Extract the response data - original_request_data = hitl_message.get("original_request") - response_data = hitl_message.get("response") - response_type_str = hitl_message.get("response_type") - - # Deserialize the original request - original_request = deserialize_value(original_request_data) - - # Deserialize the response - try to match expected type - response = _deserialize_hitl_response(response_data, response_type_str) - - # Find the matching response handler - handler = executor._find_response_handler(original_request, response) # pyright: ignore[reportPrivateUsage] - - if handler is None: - logger.warning( - "No response handler found for HITL response in executor %s. Request type: %s, Response type: %s", - executor.id, - type(original_request).__name__, - type(response).__name__, - ) - return - - # Create a WorkflowContext for the handler - # Use a special source ID to indicate this is a HITL response - ctx = WorkflowContext( - executor=executor, - source_executor_ids=[SOURCE_HITL_RESPONSE], - runner_context=runner_context, - state=shared_state, - ) - - # Call the response handler - # Note: handler is already a partial with original_request bound - logger.debug( - "Invoking response handler for HITL request in executor %s", - executor.id, - ) - await handler(response, ctx) - - -def _deserialize_hitl_response(response_data: Any, response_type_str: str | None) -> Any: - """Deserialize a HITL response to its expected type. + Creates an :class:`AzureFunctionsWorkflowContext` and delegates to the + host-agnostic :func:`run_workflow_orchestrator` in the durabletask package. Args: - response_data: The raw response data (typically a dict from JSON) - response_type_str: The fully qualified type name (module:classname) + context: The Azure Functions ``DurableOrchestrationContext``. + workflow: The MAF Workflow instance to execute. + initial_message: Initial message to send to the start executor. + shared_state: Optional dict for cross-executor state sharing. + hitl_timeout_hours: Timeout in hours for HITL requests. Returns: - The deserialized response, or the original data if deserialization fails + List of workflow outputs collected from executor activities. """ - logger.debug( - "Deserializing HITL response. response_type_str=%s, response_data type=%s", - response_type_str, - type(response_data).__name__, - ) - - if response_data is None: - return None - - # Sanitize untrusted external input before deserialization. - # HITL response data originates from an HTTP POST and must not contain - # pickle/type markers that would reach pickle.loads(). - response_data = strip_pickle_markers(response_data) - if response_data is None: - return None - - # If already a primitive, return as-is - if not isinstance(response_data, dict): - logger.debug("Response data is not a dict, returning as-is: %s", type(response_data).__name__) - return response_data - - # Try to reconstruct using the type hint (Pydantic / dataclass) - if response_type_str: - response_type = resolve_type(response_type_str) - if response_type: - logger.debug("Found response type %s, attempting reconstruction", response_type) - result = reconstruct_to_type(response_data, response_type) - logger.debug("Reconstructed response type: %s", type(result).__name__) - return result - logger.warning("Could not resolve response type: %s", response_type_str) - - # No type hint available - return the sanitized dict as-is. - # We intentionally do NOT call deserialize_value() here because HITL - # response data is untrusted and must never flow into pickle.loads(). - logger.debug("No type hint; returning sanitized data as-is") - return response_data # type: ignore[reportUnknownVariableType] + af_ctx = AzureFunctionsWorkflowContext(context) + return _run_workflow_orchestrator_shared(af_ctx, workflow, initial_message, shared_state, hitl_timeout_hours) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow_af_context.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow_af_context.py new file mode 100644 index 00000000000..a035628d998 --- /dev/null +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow_af_context.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Azure Functions adapter for WorkflowOrchestrationContext. + +Wraps ``azure.durable_functions.DurableOrchestrationContext`` to satisfy the +:class:`~agent_framework_durabletask.WorkflowOrchestrationContext` protocol. +""" + +from __future__ import annotations + +import logging +from datetime import datetime +from typing import Any + +from agent_framework_durabletask import AgentSessionId, DurableAgentSession, DurableAIAgent +from azure.durable_functions import DurableOrchestrationContext + +from ._orchestration import AzureFunctionsAgentExecutor + +logger = logging.getLogger(__name__) + + +class AzureFunctionsWorkflowContext: + """Adapter that maps ``DurableOrchestrationContext`` to ``WorkflowOrchestrationContext``.""" + + def __init__(self, context: DurableOrchestrationContext) -> None: + self._context = context + + # -- Properties ----------------------------------------------------------- + + @property + def instance_id(self) -> str: + return self._context.instance_id + + @property + def current_utc_datetime(self) -> datetime: + return self._context.current_utc_datetime + + # -- Agent / Activity dispatch -------------------------------------------- + + def prepare_agent_task(self, executor_id: str, message: str, orchestration_instance_id: str) -> Any: + session_id = AgentSessionId(name=executor_id, key=orchestration_instance_id) + session = DurableAgentSession(durable_session_id=session_id) + az_executor = AzureFunctionsAgentExecutor(self._context) + agent = DurableAIAgent(az_executor, executor_id) + return agent.run(message, session=session) + + def prepare_activity_task(self, activity_name: str, input_json: str) -> Any: + orchestration_context: Any = self._context + return orchestration_context.call_activity(activity_name, input_json) + + # -- Composite tasks ------------------------------------------------------ + + def task_all(self, tasks: list[Any]) -> Any: + return self._context.task_all(tasks) + + def task_any(self, tasks: list[Any]) -> Any: + return self._context.task_any(tasks) + + # -- External events / timers --------------------------------------------- + + def wait_for_external_event(self, name: str) -> Any: + return self._context.wait_for_external_event(name) + + def create_timer(self, fire_at: datetime) -> Any: + return self._context.create_timer(fire_at) + + # -- Status / utility ----------------------------------------------------- + + def set_custom_status(self, status: Any) -> None: + self._context.set_custom_status(status) + + def new_uuid(self) -> str: + return self._context.new_uuid() + + def cancel_task(self, task: Any) -> None: + cancel_fn = getattr(task, "cancel", None) + if callable(cancel_fn): + cancel_fn() + + def get_task_result(self, task: Any) -> Any: + return getattr(task, "result", None) diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py index 61518fa44b7..62f15582266 100644 --- a/python/packages/azurefunctions/tests/test_app.py +++ b/python/packages/azurefunctions/tests/test_app.py @@ -26,7 +26,6 @@ from agent_framework_azurefunctions import AgentFunctionApp from agent_framework_azurefunctions._entities import create_agent_entity -from agent_framework_azurefunctions._workflow import SOURCE_ORCHESTRATOR FuncT = TypeVar("FuncT", bound=Callable[..., Any]) @@ -1450,285 +1449,9 @@ def test_build_status_url_handles_trailing_slash(self) -> None: assert "instance-456" in url -def _compute_state_updates(original_snapshot: dict[str, Any], current_state: dict[str, Any]) -> dict[str, Any]: - """Compute state updates by comparing current state against the original snapshot. - - This mirrors the inlined logic in ``_app.py``'s ``executor_activity.run()``. - """ - original_keys = set(original_snapshot.keys()) - current_keys = set(current_state.keys()) - updates: dict[str, Any] = {} - for key in current_keys: - if key not in original_keys or current_state[key] != original_snapshot.get(key): - updates[key] = current_state[key] - return updates - - -class TestStateSnapshotDiff: - """Test suite for state snapshot diffing in activity execution. - - The activity executor snapshots state before execution and diffs against the - post-execution state to determine which keys were updated. These tests exercise - the production snapshot helper and the state-update diffing logic to ensure that - in-place mutations to nested objects (dicts, lists) are correctly detected as changes. - """ - - def test_nested_dict_mutation_detected_in_diff(self) -> None: - """Test that mutating values inside a nested dict appears in the diff.""" - from agent_framework._workflows._state import State - - from agent_framework_azurefunctions._app import _create_state_snapshot - - deserialized_state: dict[str, Any] = { - "Local.config": {"code": "", "enabled": False}, - "simple_key": "simple_value", - } - - original_snapshot = _create_state_snapshot(deserialized_state) - - shared_state = State() - shared_state.import_state(deserialized_state) - - config = shared_state.get("Local.config") - config["code"] = "SOMECODEXXX" - config["enabled"] = True - - shared_state.commit() - current_state = shared_state.export_state() - - updates = _compute_state_updates(original_snapshot, current_state) - - assert "Local.config" in updates - assert updates["Local.config"]["code"] == "SOMECODEXXX" - assert updates["Local.config"]["enabled"] is True - - def test_new_key_in_nested_dict_detected_in_diff(self) -> None: - """Test that adding a key to a nested dict appears in the diff.""" - from agent_framework._workflows._state import State - - from agent_framework_azurefunctions._app import _create_state_snapshot - - deserialized_state: dict[str, Any] = { - "Local.data": {"existing": "value"}, - } - - original_snapshot = _create_state_snapshot(deserialized_state) - - shared_state = State() - shared_state.import_state(deserialized_state) - - data = shared_state.get("Local.data") - data["code"] = "NEW_CODE" - - shared_state.commit() - current_state = shared_state.export_state() - - updates = _compute_state_updates(original_snapshot, current_state) - - assert "Local.data" in updates - assert updates["Local.data"]["code"] == "NEW_CODE" - - def test_nested_list_mutation_detected_in_diff(self) -> None: - """Test that appending to a nested list appears in the diff.""" - from agent_framework._workflows._state import State - - from agent_framework_azurefunctions._app import _create_state_snapshot - - deserialized_state: dict[str, Any] = { - "Local.items": [1, 2, 3], - } - - original_snapshot = _create_state_snapshot(deserialized_state) - - shared_state = State() - shared_state.import_state(deserialized_state) - - items = shared_state.get("Local.items") - items.append(4) - - shared_state.commit() - current_state = shared_state.export_state() - - updates = _compute_state_updates(original_snapshot, current_state) - - assert "Local.items" in updates - assert updates["Local.items"] == [1, 2, 3, 4] - - def test_new_top_level_key_detected_in_diff(self) -> None: - """Test that setting a new top-level key appears in the diff.""" - from agent_framework._workflows._state import State - - from agent_framework_azurefunctions._app import _create_state_snapshot - - deserialized_state: dict[str, Any] = { - "existing": "value", - } - - original_snapshot = _create_state_snapshot(deserialized_state) - - shared_state = State() - shared_state.import_state(deserialized_state) - - shared_state.set("Local.code", "SOMECODEXXX") - - shared_state.commit() - current_state = shared_state.export_state() - - updates = _compute_state_updates(original_snapshot, current_state) - - assert "Local.code" in updates - assert updates["Local.code"] == "SOMECODEXXX" - - def test_unchanged_nested_state_produces_empty_diff(self) -> None: - """Test that unmodified nested state produces no updates.""" - from agent_framework._workflows._state import State - - from agent_framework_azurefunctions._app import _create_state_snapshot - - deserialized_state: dict[str, Any] = { - "Local.config": {"code": "existing", "enabled": True}, - "simple_key": "simple_value", - } - - original_snapshot = _create_state_snapshot(deserialized_state) - - shared_state = State() - shared_state.import_state(deserialized_state) - - # No mutations performed - shared_state.commit() - current_state = shared_state.export_state() - - updates = _compute_state_updates(original_snapshot, current_state) - - assert updates == {} - - def test_shallow_copy_would_miss_nested_mutations(self) -> None: - """Regression test: a shallow copy (dict()) shares nested refs, hiding mutations. - - This reproduces the original bug from #4500 where ``dict(deserialized_state)`` - was used instead of ``copy.deepcopy()``. With a shallow copy the snapshot and - the live state share nested objects, so in-place mutations appear in both and - the diff produces an empty update set. - """ - from agent_framework._workflows._state import State - - deserialized_state: dict[str, Any] = { - "Local.config": {"code": "", "enabled": False}, - } - - # Shallow copy (the OLD, buggy behaviour) - shallow_snapshot = dict(deserialized_state) - - shared_state = State() - shared_state.import_state(deserialized_state) - - config = shared_state.get("Local.config") - config["code"] = "SOMECODEXXX" - config["enabled"] = True - - shared_state.commit() - current_state = shared_state.export_state() - - # With a shallow copy the mutation leaks into the snapshot → empty diff - updates_shallow = _compute_state_updates(shallow_snapshot, current_state) - assert updates_shallow == {}, "shallow copy should miss nested mutations (demonstrating the bug)" - - def test_create_state_snapshot_isolates_nested_objects(self) -> None: - """Verify _create_state_snapshot produces a deep copy that is mutation-proof. - - This ensures the production snapshot helper is not equivalent to ``dict()`` - and will correctly isolate nested objects so that later mutations are detected. - """ - from agent_framework_azurefunctions._app import _create_state_snapshot - - original: dict[str, Any] = { - "nested_dict": {"a": 1}, - "nested_list": [1, 2, 3], - } - - snapshot = _create_state_snapshot(original) - - # Mutate the originals in place - original["nested_dict"]["a"] = 999 - original["nested_list"].append(4) - - # Snapshot must be unaffected - assert snapshot["nested_dict"]["a"] == 1 - assert snapshot["nested_list"] == [1, 2, 3] - - def test_executor_activity_detects_nested_state_mutations(self) -> None: - """Integration test: the full activity wrapper detects nested mutations. - - This exercises the actual executor_activity function registered by - _setup_executor_activity to verify the production code path uses - _create_state_snapshot (deep copy) rather than dict() (shallow copy). - If the implementation regressed to using a shallow copy such as - ``dict(deserialized_state)``, this test would fail because in-place - mutations would leak into the snapshot and produce an empty diff. - """ - mock_executor = Mock() - mock_executor.id = "test-exec" - - async def mutate_nested_state( - message: Any, - source_executor_ids: Any, - state: Any, - runner_context: Any, - ) -> None: - config = state.get("Local.config") - config["code"] = "MUTATED" - config["enabled"] = True - state.commit() - - mock_executor.execute = AsyncMock(side_effect=mutate_nested_state) - - mock_workflow = Mock() - mock_workflow.executors = {"test-exec": mock_executor} - - # Capture the activity function by making decorators pass-through - captured_activity: dict[str, Any] = {} - - def passthrough_function_name(name: str) -> Callable[[FuncT], FuncT]: - def decorator(fn: FuncT) -> FuncT: - captured_activity["fn"] = fn - return fn - - return decorator - - def passthrough_activity_trigger(input_name: str) -> Callable[[FuncT], FuncT]: - def decorator(fn: FuncT) -> FuncT: - return fn - - return decorator - - with ( - patch.object(AgentFunctionApp, "function_name", side_effect=passthrough_function_name), - patch.object(AgentFunctionApp, "activity_trigger", side_effect=passthrough_activity_trigger), - patch.object(AgentFunctionApp, "_setup_workflow_orchestration"), - ): - AgentFunctionApp(workflow=mock_workflow) - - assert "fn" in captured_activity, "activity function was not captured" - - # Call the activity with nested state that the executor will mutate - input_data = json.dumps({ - "message": "test", - "shared_state_snapshot": { - "Local.config": {"code": "", "enabled": False}, - }, - "source_executor_ids": [SOURCE_ORCHESTRATOR], - }) - - result = json.loads(captured_activity["fn"](input_data)) - - # The deep copy snapshot must detect the in-place nested mutations - assert "Local.config" in result["shared_state_updates"], ( - "nested mutation not detected — snapshot may be using shallow copy" - ) - updated_config = result["shared_state_updates"]["Local.config"] - assert updated_config["code"] == "MUTATED" - assert updated_config["enabled"] is True +# NOTE: State snapshot/diff tests were moved to durabletask once the activity +# execution body was extracted into the host-agnostic execute_workflow_activity. +# See packages/durabletask/tests/test_workflow_activity.py. if __name__ == "__main__": diff --git a/python/packages/azurefunctions/tests/test_entities.py b/python/packages/azurefunctions/tests/test_entities.py index bfd17ba176e..cc2fc75320e 100644 --- a/python/packages/azurefunctions/tests/test_entities.py +++ b/python/packages/azurefunctions/tests/test_entities.py @@ -241,10 +241,8 @@ def test_entity_function_handles_none_input(self) -> None: # Verify the result was set (likely error result) assert mock_context.set_result.called - def test_entity_function_handles_event_loop_runtime_error(self) -> None: - """Test that the entity function handles RuntimeError from get_event_loop by creating a new loop.""" - from unittest.mock import patch - + def test_entity_function_runs_on_persistent_loop(self) -> None: + """Entity coroutines run on the shared persistent loop and set a result.""" mock_agent = Mock() mock_agent.run = AsyncMock(return_value=_agent_response("Response")) @@ -253,60 +251,42 @@ def test_entity_function_handles_event_loop_runtime_error(self) -> None: mock_context = Mock() mock_context.operation_name = "run" mock_context.entity_key = "conv-loop-test" - mock_context.get_input.return_value = {"message": "Test"} + mock_context.get_input.return_value = {"message": "Test", "correlationId": "corr-loop-test"} mock_context.get_state.return_value = None - # Simulate RuntimeError when getting event loop - with ( - patch("asyncio.get_event_loop", side_effect=RuntimeError("No event loop")), - patch("asyncio.new_event_loop") as mock_new_loop, - patch("asyncio.set_event_loop") as mock_set_loop, - ): - mock_loop = Mock() - mock_loop.is_running.return_value = False - mock_loop.run_until_complete = Mock() - mock_new_loop.return_value = mock_loop + # Execute - the persistent loop should run the coroutine and set a result. + entity_function(mock_context) - # Execute - entity_function(mock_context) + assert mock_context.set_result.called + mock_agent.run.assert_awaited() - # Verify new event loop was created - mock_new_loop.assert_called_once() - mock_set_loop.assert_called_once_with(mock_loop) + def test_entity_function_runs_across_threads_without_hang(self) -> None: + """Successive entity invocations from different threads must not hang. - def test_entity_function_handles_running_event_loop(self) -> None: - """Test that the entity function handles a running event loop by creating a temporary loop.""" - from unittest.mock import patch + This reproduces the cross-loop scenario that previously deadlocked: a + shared async resource bound to one loop being awaited from a different + worker thread. The persistent loop keeps every invocation on one loop. + """ + from concurrent.futures import ThreadPoolExecutor mock_agent = Mock() mock_agent.run = AsyncMock(return_value=_agent_response("Response")) entity_function = create_agent_entity(mock_agent) - mock_context = Mock() - mock_context.operation_name = "run" - mock_context.entity_key = "conv-running-loop" - mock_context.get_input.return_value = {"message": "Test"} - mock_context.get_state.return_value = None - - # Simulate a running event loop - mock_existing_loop = Mock() - mock_existing_loop.is_running.return_value = True - - mock_temp_loop = Mock() - mock_temp_loop.run_until_complete = Mock() - mock_temp_loop.close = Mock() + def invoke(i: int) -> bool: + ctx = Mock() + ctx.operation_name = "run" + ctx.entity_key = f"conv-{i}" + ctx.get_input.return_value = {"message": f"Test {i}", "correlationId": f"corr-{i}"} + ctx.get_state.return_value = None + entity_function(ctx) + return ctx.set_result.called - with ( - patch("asyncio.get_event_loop", return_value=mock_existing_loop), - patch("asyncio.new_event_loop", return_value=mock_temp_loop), - ): - # Execute - entity_function(mock_context) + with ThreadPoolExecutor(max_workers=8) as ex: + results = list(ex.map(invoke, range(16))) - # Verify temporary loop was created and closed - mock_temp_loop.run_until_complete.assert_called_once() - mock_temp_loop.close.assert_called_once() + assert all(results) if __name__ == "__main__": From 7e167be12f3cd1fff1843b8813dc6519d8a2f8e0 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Tue, 9 Jun 2026 14:28:16 -0500 Subject: [PATCH 03/15] feat(core): expose durabletask workflow symbols via agent_framework.azure Lazily re-export WORKFLOW_ORCHESTRATOR_NAME and DurableWorkflowClient from the agent_framework.azure namespace so standalone hosts can import them without depending on internal module paths. --- python/packages/core/agent_framework/azure/__init__.py | 2 ++ python/packages/core/agent_framework/azure/__init__.pyi | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/python/packages/core/agent_framework/azure/__init__.py b/python/packages/core/agent_framework/azure/__init__.py index 7cff0150f19..a6a1de34064 100644 --- a/python/packages/core/agent_framework/azure/__init__.py +++ b/python/packages/core/agent_framework/azure/__init__.py @@ -19,6 +19,8 @@ "DurableAIAgentClient": ("agent_framework_durabletask", "agent-framework-durabletask"), "DurableAIAgentOrchestrationContext": ("agent_framework_durabletask", "agent-framework-durabletask"), "DurableAIAgentWorker": ("agent_framework_durabletask", "agent-framework-durabletask"), + "DurableWorkflowClient": ("agent_framework_durabletask", "agent-framework-durabletask"), + "WORKFLOW_ORCHESTRATOR_NAME": ("agent_framework_durabletask", "agent-framework-durabletask"), } diff --git a/python/packages/core/agent_framework/azure/__init__.pyi b/python/packages/core/agent_framework/azure/__init__.pyi index 12527b6db31..97a44180f8f 100644 --- a/python/packages/core/agent_framework/azure/__init__.pyi +++ b/python/packages/core/agent_framework/azure/__init__.pyi @@ -10,15 +10,18 @@ from agent_framework_azure_ai_search import ( from agent_framework_azure_cosmos import CosmosHistoryProvider from agent_framework_azurefunctions import AgentFunctionApp from agent_framework_durabletask import ( + WORKFLOW_ORCHESTRATOR_NAME, AgentCallbackContext, AgentResponseCallbackProtocol, DurableAIAgent, DurableAIAgentClient, DurableAIAgentOrchestrationContext, DurableAIAgentWorker, + DurableWorkflowClient, ) __all__ = [ + "WORKFLOW_ORCHESTRATOR_NAME", "AgentCallbackContext", "AgentFunctionApp", "AgentResponseCallbackProtocol", @@ -29,4 +32,5 @@ __all__ = [ "DurableAIAgentClient", "DurableAIAgentOrchestrationContext", "DurableAIAgentWorker", + "DurableWorkflowClient", ] From c7136cdc2b74d00a68d5222a6d84953e7ca7e31b Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Tue, 9 Jun 2026 14:28:29 -0500 Subject: [PATCH 04/15] docs(samples): add standalone durabletask workflow and HITL samples Add two samples under samples/04-hosting/durabletask demonstrating MAF workflows on a standalone Durable Task worker (no Azure Functions): - 08_workflow: conditional spam-detection workflow started via DurableWorkflowClient.start_workflow / await_workflow_output. - 09_workflow_hitl: content-moderation workflow that pauses with ctx.request_info and is resumed via DurableWorkflowClient.get_pending_hitl_requests / send_hitl_response. Also add the durabletask workflow integration test (test_08_dt_workflow). --- .../integration_tests/test_08_dt_workflow.py | 78 ++++ .../durabletask/08_workflow/README.md | 57 +++ .../durabletask/08_workflow/client.py | 73 ++++ .../durabletask/08_workflow/worker.py | 214 +++++++++++ .../durabletask/09_workflow_hitl/README.md | 78 ++++ .../durabletask/09_workflow_hitl/client.py | 118 ++++++ .../durabletask/09_workflow_hitl/worker.py | 352 ++++++++++++++++++ 7 files changed, 970 insertions(+) create mode 100644 python/packages/durabletask/tests/integration_tests/test_08_dt_workflow.py create mode 100644 python/samples/04-hosting/durabletask/08_workflow/README.md create mode 100644 python/samples/04-hosting/durabletask/08_workflow/client.py create mode 100644 python/samples/04-hosting/durabletask/08_workflow/worker.py create mode 100644 python/samples/04-hosting/durabletask/09_workflow_hitl/README.md create mode 100644 python/samples/04-hosting/durabletask/09_workflow_hitl/client.py create mode 100644 python/samples/04-hosting/durabletask/09_workflow_hitl/worker.py diff --git a/python/packages/durabletask/tests/integration_tests/test_08_dt_workflow.py b/python/packages/durabletask/tests/integration_tests/test_08_dt_workflow.py new file mode 100644 index 00000000000..edf271757eb --- /dev/null +++ b/python/packages/durabletask/tests/integration_tests/test_08_dt_workflow.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Integration tests for the standalone durabletask workflow sample (08_workflow). + +Exercises the standalone (non-Azure-Functions) workflow path: +- ``DurableAIAgentWorker.configure_workflow`` auto-registers the agent entities, + non-agent executor activities, and the workflow orchestrator. +- A client starts the workflow by scheduling ``WORKFLOW_ORCHESTRATOR_NAME``. +- Conditional routing sends spam to a non-agent handler and legitimate email + through a second agent and a sender executor. +""" + +import logging + +import pytest +from durabletask.client import OrchestrationStatus + +from agent_framework_durabletask import WORKFLOW_ORCHESTRATOR_NAME + +logging.basicConfig(level=logging.WARNING) + +# Module-level markers +pytestmark = [ + pytest.mark.flaky, + pytest.mark.integration, + pytest.mark.sample("08_workflow"), + pytest.mark.integration_test, + pytest.mark.requires_dts, +] + + +class TestStandaloneWorkflow: + """Standalone (non-Azure-Functions) workflow execution on a durabletask worker.""" + + @pytest.fixture(autouse=True) + def setup(self, agent_client_factory: type, orchestration_helper) -> None: + """Provide a DTS client and orchestration helper for each test.""" + self.dts_client, self.agent_client = agent_client_factory.create() + self.orch_helper = orchestration_helper + + def test_legitimate_email_drafts_response(self) -> None: + """A legitimate email routes through the email agent and is 'sent'.""" + instance_id = self.dts_client.schedule_new_orchestration( + orchestrator=WORKFLOW_ORCHESTRATOR_NAME, + input=( + "Hi team, just a reminder about our sprint planning meeting tomorrow at 10 AM. " + "Please review the agenda in Jira." + ), + ) + + metadata, output = self.orch_helper.wait_for_orchestration_with_output( + instance_id=instance_id, + timeout=180.0, + ) + + assert metadata.runtime_status == OrchestrationStatus.COMPLETED + assert output is not None + assert "Email sent" in str(output) + + def test_spam_email_handled(self) -> None: + """A spam email routes to the non-agent spam handler.""" + instance_id = self.dts_client.schedule_new_orchestration( + orchestrator=WORKFLOW_ORCHESTRATOR_NAME, + input="URGENT! You've won $1,000,000! Click here now to claim your prize! Limited time offer!", + ) + + metadata, output = self.orch_helper.wait_for_orchestration_with_output( + instance_id=instance_id, + timeout=180.0, + ) + + assert metadata.runtime_status == OrchestrationStatus.COMPLETED + assert output is not None + assert "spam" in str(output).lower() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/python/samples/04-hosting/durabletask/08_workflow/README.md b/python/samples/04-hosting/durabletask/08_workflow/README.md new file mode 100644 index 00000000000..0e032693776 --- /dev/null +++ b/python/samples/04-hosting/durabletask/08_workflow/README.md @@ -0,0 +1,57 @@ +# Workflow on a Standalone Durable Task Worker + +This sample demonstrates running an agent-framework `Workflow` as a durable +orchestration on a **standalone Durable Task worker** — no Azure Functions +required. It is the durabletask counterpart to the Azure Functions workflow +samples (`samples/04-hosting/azure_functions/10_workflow_no_shared_state`). + +## Key Concepts Demonstrated + +- Hosting a MAF `Workflow` outside Azure Functions via + `DurableAIAgentWorker.configure_workflow(workflow)`, which auto-registers: + - a durable **entity** for each agent executor, + - a durable **activity** for each non-agent executor, and + - the **workflow orchestrator** (registered as `WORKFLOW_ORCHESTRATOR_NAME`). +- Conditional routing with `add_switch_case_edge_group` (spam vs. legitimate email). +- Mixing AI agents with non-agent executors in one workflow graph. +- Starting the workflow from a client with + `DurableWorkflowClient.start_workflow(input=...)` and reading its result with + `await_workflow_output(instance_id)`. + +## Environment Setup + +See the [README.md](../README.md) in the parent directory for environment setup. + +This sample uses Azure AI Foundry credentials: + +- `FOUNDRY_PROJECT_ENDPOINT` +- `FOUNDRY_MODEL` + +It also needs a Durable Task Scheduler. For local development, start the +emulator (defaults to `http://localhost:8080`): + +```bash +docker run -d -p 8080:8080 -p 8082:8082 mcr.microsoft.com/dts/dts-emulator:latest +``` + +## Running the Sample + +Start the worker in one terminal: + +```bash +cd samples/04-hosting/durabletask/08_workflow +python worker.py +``` + +In a second terminal, run the client: + +```bash +python client.py +``` + +The client runs two cases: + +- **Legitimate email** → `SpamDetectionAgent` → `EmailAssistantAgent` → + `email_sender` → `"Email sent: ..."`. +- **Spam email** → `SpamDetectionAgent` → `spam_handler` → + `"Email marked as spam: ..."`. diff --git a/python/samples/04-hosting/durabletask/08_workflow/client.py b/python/samples/04-hosting/durabletask/08_workflow/client.py new file mode 100644 index 00000000000..118101753b6 --- /dev/null +++ b/python/samples/04-hosting/durabletask/08_workflow/client.py @@ -0,0 +1,73 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Client that starts the standalone workflow orchestration and prints the result. + +The worker (``worker.py``) must be running first. The workflow is started via +``DurableAIAgentClient.start_workflow`` - which schedules the orchestrator that +``DurableAIAgentWorker.configure_workflow`` auto-registers, so the caller never +needs to know its internal name. + +Prerequisites: +- ``worker.py`` running and connected to the same Durable Task Scheduler. +- A Durable Task Scheduler reachable at ``ENDPOINT`` (default ``http://localhost:8080``). +""" + +import asyncio +import logging +import os + +from agent_framework.azure import DurableWorkflowClient +from azure.identity import AzureCliCredential +from dotenv import load_dotenv +from durabletask.azuremanaged.client import DurableTaskSchedulerClient + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def get_client(taskhub: str | None = None, endpoint: str | None = None) -> DurableTaskSchedulerClient: + """Create a configured DurableTaskSchedulerClient.""" + taskhub_name = taskhub or os.getenv("TASKHUB", "default") + endpoint_url = endpoint or os.getenv("ENDPOINT", "http://localhost:8080") + + credential = None if endpoint_url == "http://localhost:8080" else AzureCliCredential() + + return DurableTaskSchedulerClient( + host_address=endpoint_url, + secure_channel=endpoint_url != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential, + ) + + +def run_workflow(client: DurableWorkflowClient, email_content: str) -> None: + """Start the workflow with an email and wait for the result.""" + instance_id = client.start_workflow(input=email_content) + logger.info("Started workflow instance: %s", instance_id) + + output = client.await_workflow_output(instance_id) + logger.info("Workflow output: %s", output) + + +async def main() -> None: + """Run the workflow against a legitimate email and a spam email.""" + client = DurableWorkflowClient(get_client()) + + logger.info("TEST 1: Legitimate email") + run_workflow( + client, + "Hi team, just a reminder about our sprint planning meeting tomorrow at 10 AM. " + "Please review the agenda in Jira.", + ) + + logger.info("TEST 2: Spam email") + run_workflow( + client, + "URGENT! You've won $1,000,000! Click here now to claim your prize! Limited time offer!", + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/04-hosting/durabletask/08_workflow/worker.py b/python/samples/04-hosting/durabletask/08_workflow/worker.py new file mode 100644 index 00000000000..48cd2c58b2e --- /dev/null +++ b/python/samples/04-hosting/durabletask/08_workflow/worker.py @@ -0,0 +1,214 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Worker that hosts a MAF Workflow as a durable orchestration (no Azure Functions). + +This sample shows how to run an agent-framework ``Workflow`` on a standalone +Durable Task worker using ``DurableAIAgentWorker.configure_workflow``. The worker +auto-registers: + +- a durable entity for each agent executor, +- a durable activity for each non-agent executor, and +- the workflow orchestrator (named ``WORKFLOW_ORCHESTRATOR_NAME``). + +The workflow classifies an email and conditionally routes it: spam is handled by +a non-agent executor, while legitimate email is drafted by a second agent and +"sent" by another non-agent executor. + +Prerequisites: +- Set ``FOUNDRY_PROJECT_ENDPOINT`` and ``FOUNDRY_MODEL``. +- Sign in with Azure CLI (``az login``) for ``AzureCliCredential``. +- Start a Durable Task Scheduler (e.g. the DTS emulator on ``localhost:8080``). + +Run the worker (this process), then run ``client.py`` in another process. +""" + +import asyncio +import logging +import os +from typing import Any + +from agent_framework import ( + Agent, + AgentExecutorResponse, + Case, + Default, + Executor, + Workflow, + WorkflowBuilder, + WorkflowContext, + handler, +) +from agent_framework.azure import DurableAIAgentWorker +from agent_framework.foundry import FoundryChatClient +from azure.identity import AzureCliCredential +from azure.identity.aio import AzureCliCredential as AsyncAzureCliCredential +from dotenv import load_dotenv +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker +from pydantic import BaseModel, ValidationError +from typing_extensions import Never + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +SPAM_AGENT_NAME = "SpamDetectionAgent" +EMAIL_AGENT_NAME = "EmailAssistantAgent" + +SPAM_DETECTION_INSTRUCTIONS = ( + "You are a spam detection assistant that identifies spam emails. " + "Return JSON with fields is_spam (bool) and reason (string)." +) +EMAIL_ASSISTANT_INSTRUCTIONS = ( + "You are an email assistant that drafts professional replies to legitimate emails. " + "Return JSON with a single field 'response' containing the drafted reply." +) + + +class SpamDetectionResult(BaseModel): + """Structured output from the spam detection agent.""" + + is_spam: bool + reason: str + + +class EmailResponse(BaseModel): + """Structured output from the email assistant agent.""" + + response: str + + +class SpamHandlerExecutor(Executor): + """Non-agent executor that finalizes spam emails.""" + + @handler + async def handle_spam_result( + self, agent_response: AgentExecutorResponse, ctx: WorkflowContext[Never, str] + ) -> None: + text = agent_response.agent_response.text + try: + result = SpamDetectionResult.model_validate_json(text) + reason = result.reason + except ValidationError: + reason = "Invalid JSON from agent" + await ctx.yield_output(f"Email marked as spam: {reason}") + + +class EmailSenderExecutor(Executor): + """Non-agent executor that 'sends' the drafted reply.""" + + @handler + async def handle_email_response( + self, agent_response: AgentExecutorResponse, ctx: WorkflowContext[Never, str] + ) -> None: + text = agent_response.agent_response.text + try: + email = EmailResponse.model_validate_json(text) + reply = email.response + except ValidationError: + reply = "Error generating response." + await ctx.yield_output(f"Email sent: {reply}") + + +def is_spam_detected(message: Any) -> bool: + """Routing condition: True when the spam agent flagged the email as spam.""" + if not isinstance(message, AgentExecutorResponse): + return False + try: + return SpamDetectionResult.model_validate_json(message.agent_response.text).is_spam + except Exception: + return False + + +def _create_chat_client() -> FoundryChatClient: + """Create an Azure AI Foundry chat client using AzureCliCredential.""" + return FoundryChatClient( + project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"], + model=os.environ["FOUNDRY_MODEL"], + credential=AsyncAzureCliCredential(), + ) + + +def create_workflow() -> Workflow: + """Build the conditional spam-detection workflow.""" + chat_client = _create_chat_client() + + spam_agent = Agent( + client=chat_client, + name=SPAM_AGENT_NAME, + instructions=SPAM_DETECTION_INSTRUCTIONS, + default_options={"response_format": SpamDetectionResult}, + ) + email_agent = Agent( + client=chat_client, + name=EMAIL_AGENT_NAME, + instructions=EMAIL_ASSISTANT_INSTRUCTIONS, + default_options={"response_format": EmailResponse}, + ) + + spam_handler = SpamHandlerExecutor(id="spam_handler") + email_sender = EmailSenderExecutor(id="email_sender") + + return ( + WorkflowBuilder(start_executor=spam_agent) + .add_switch_case_edge_group( + spam_agent, + [ + Case(condition=is_spam_detected, target=spam_handler), + Default(target=email_agent), + ], + ) + .add_edge(email_agent, email_sender) + .build() + ) + + +def get_worker( + taskhub: str | None = None, endpoint: str | None = None, log_handler: logging.Handler | None = None +) -> DurableTaskSchedulerWorker: + """Create a configured DurableTaskSchedulerWorker.""" + taskhub_name = taskhub or os.getenv("TASKHUB", "default") + endpoint_url = endpoint or os.getenv("ENDPOINT", "http://localhost:8080") + + credential = None if endpoint_url == "http://localhost:8080" else AzureCliCredential() + + return DurableTaskSchedulerWorker( + host_address=endpoint_url, + secure_channel=endpoint_url != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential, + log_handler=log_handler, + ) + + +def setup_worker(worker: DurableTaskSchedulerWorker) -> DurableAIAgentWorker: + """Register the workflow (agents + activities + orchestrator) on the worker.""" + agent_worker = DurableAIAgentWorker(worker) + + workflow = create_workflow() + # One call wires up: agent entities, non-agent executor activities, and the + # workflow orchestrator (registered as WORKFLOW_ORCHESTRATOR_NAME). + agent_worker.configure_workflow(workflow) + logger.info("✓ Configured workflow with %d executors", len(workflow.executors)) + + return agent_worker + + +async def main() -> None: + """Start the worker and block until interrupted.""" + worker = get_worker() + setup_worker(worker) + + logger.info("Worker is ready and listening for work items. Press Ctrl+C to stop.") + try: + worker.start() + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + logger.info("Worker shutdown initiated") + + logger.info("Worker stopped") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/04-hosting/durabletask/09_workflow_hitl/README.md b/python/samples/04-hosting/durabletask/09_workflow_hitl/README.md new file mode 100644 index 00000000000..653535bc8cd --- /dev/null +++ b/python/samples/04-hosting/durabletask/09_workflow_hitl/README.md @@ -0,0 +1,78 @@ +# Human-in-the-Loop Workflow on a Standalone Durable Task Worker + +This sample demonstrates a Human-in-the-Loop (HITL) agent-framework `Workflow` +running as a durable orchestration on a **standalone Durable Task worker** — no +Azure Functions required. It is the durabletask counterpart to the Azure +Functions sample `samples/04-hosting/azure_functions/12_workflow_hitl`. + +## Key Concepts Demonstrated + +- Pausing a workflow for human input with MAF's `ctx.request_info()` / + `@response_handler` pattern, hosted on a standalone worker via + `DurableAIAgentWorker.configure_workflow(workflow)`. +- Discovering pending HITL requests from a client with + `DurableWorkflowClient.get_pending_hitl_requests(instance_id)`. +- Resuming the workflow by sending a decision with + `DurableWorkflowClient.send_hitl_response(instance_id, request_id, response)`. +- Reading the final result with `DurableWorkflowClient.await_workflow_output(instance_id)`. + +The workflow is a content-moderation pipeline: + +``` +input_router -> ContentAnalyzerAgent -> content_analyzer_executor + -> human_review_executor (HITL pause) -> publish_executor +``` + +## How HITL Works Here + +The HITL mechanism is host-agnostic — the same shared workflow orchestrator +drives it on both Azure Functions and a standalone worker: + +1. `human_review_executor` calls `ctx.request_info(...)`, which pauses the + workflow. The orchestrator records the open request in its **custom status** + and waits for an external event named by the request's `request_id`. +2. The client reads the custom status via `get_pending_hitl_requests` and sends + a response via `send_hitl_response`, which raises that external event. +3. The orchestrator routes the response back to the executor's + `@response_handler`, and the workflow resumes. + +`send_hitl_response` sanitizes the payload (neutralizing pickle-marker +injection) before delivery, since the worker deserializes it. + +## Environment Setup + +See the [README.md](../README.md) in the parent directory for environment setup. + +This sample uses Azure AI Foundry credentials: + +- `FOUNDRY_PROJECT_ENDPOINT` +- `FOUNDRY_MODEL` + +It also needs a Durable Task Scheduler. For local development, start the +emulator (defaults to `http://localhost:8080`): + +```bash +docker run -d -p 8080:8080 -p 8082:8082 mcr.microsoft.com/dts/dts-emulator:latest +``` + +## Running the Sample + +Start the worker in one terminal: + +```bash +cd samples/04-hosting/durabletask/09_workflow_hitl +python worker.py +``` + +In a second terminal, run the client: + +```bash +python client.py +``` + +The client runs two cases: + +- **Appropriate content** → analyzed → HITL pause → client **approves** → + `"Content '...' has been APPROVED and published."` +- **Spammy content** → analyzed → HITL pause → client **rejects** → + `"Content '...' has been REJECTED."` diff --git a/python/samples/04-hosting/durabletask/09_workflow_hitl/client.py b/python/samples/04-hosting/durabletask/09_workflow_hitl/client.py new file mode 100644 index 00000000000..823da43b60b --- /dev/null +++ b/python/samples/04-hosting/durabletask/09_workflow_hitl/client.py @@ -0,0 +1,118 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Client that drives the standalone HITL workflow to completion. + +The worker (``worker.py``) must be running first. This client: + +1. Starts the workflow with ``DurableAIAgentClient.start_workflow``. +2. Polls ``get_pending_hitl_requests`` until the workflow pauses for human input. +3. Sends a decision with ``send_hitl_response`` (the request_id correlates the + response back to the paused executor). +4. Reads the final output with ``await_workflow_output``. + +It runs two cases: appropriate content (approved) and spammy content (rejected). + +Prerequisites: +- ``worker.py`` running and connected to the same Durable Task Scheduler. +- A Durable Task Scheduler reachable at ``ENDPOINT`` (default ``http://localhost:8080``). +""" + +import asyncio +import logging +import os +import time +from typing import Any + +from agent_framework.azure import DurableWorkflowClient +from azure.identity import AzureCliCredential +from dotenv import load_dotenv +from durabletask.azuremanaged.client import DurableTaskSchedulerClient + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def get_client(taskhub: str | None = None, endpoint: str | None = None) -> DurableTaskSchedulerClient: + """Create a configured DurableTaskSchedulerClient.""" + taskhub_name = taskhub or os.getenv("TASKHUB", "default") + endpoint_url = endpoint or os.getenv("ENDPOINT", "http://localhost:8080") + + credential = None if endpoint_url == "http://localhost:8080" else AzureCliCredential() + + return DurableTaskSchedulerClient( + host_address=endpoint_url, + secure_channel=endpoint_url != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential, + ) + + +def _wait_for_hitl_request( + client: DurableWorkflowClient, instance_id: str, timeout_seconds: int = 60 +) -> list[dict[str, Any]]: + """Poll until the workflow has at least one pending HITL request.""" + deadline = time.time() + timeout_seconds + while time.time() < deadline: + pending = client.get_pending_hitl_requests(instance_id) + if pending: + return pending + time.sleep(2) + raise TimeoutError(f"Timed out waiting for a HITL request on instance {instance_id}") + + +def run_case(client: DurableWorkflowClient, submission: dict[str, Any], *, approve: bool) -> None: + """Run one moderation case: start, respond to the HITL pause, print the result.""" + instance_id = client.start_workflow(input=submission) + logger.info("Started workflow instance: %s", instance_id) + + pending = _wait_for_hitl_request(client, instance_id) + request = pending[0] + logger.info("Pending HITL request %s from %s", request["request_id"], request["source_executor_id"]) + + decision = { + "approved": approve, + "reviewer_notes": "Looks good." if approve else "Violates content policy.", + } + client.send_hitl_response(instance_id, request["request_id"], decision) + logger.info("Sent decision: approved=%s", approve) + + output = client.await_workflow_output(instance_id) + logger.info("Workflow output: %s", output) + + +async def main() -> None: + """Run an approved case and a rejected case.""" + client = DurableWorkflowClient(get_client()) + + logger.info("CASE 1: Appropriate content (will approve)") + run_case( + client, + { + "content_id": "article-001", + "title": "Introduction to AI in Healthcare", + "body": ( + "Artificial intelligence is improving healthcare by enabling faster diagnosis, " + "personalized treatment plans, and better patient outcomes." + ), + "author": "Dr. Jane Smith", + }, + approve=True, + ) + + logger.info("CASE 2: Spammy content (will reject)") + run_case( + client, + { + "content_id": "article-002", + "title": "Get Rich Quick", + "body": "Click here NOW to make $10,000 overnight! GUARANTEED! Limited time offer!", + "author": "Definitely Not Spam", + }, + approve=False, + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/04-hosting/durabletask/09_workflow_hitl/worker.py b/python/samples/04-hosting/durabletask/09_workflow_hitl/worker.py new file mode 100644 index 00000000000..2192c5ae378 --- /dev/null +++ b/python/samples/04-hosting/durabletask/09_workflow_hitl/worker.py @@ -0,0 +1,352 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Worker that hosts a Human-in-the-Loop (HITL) MAF Workflow on a standalone worker. + +This sample is the durabletask counterpart to the Azure Functions +``12_workflow_hitl`` sample. It runs an agent-framework ``Workflow`` that pauses +for human approval using MAF's ``ctx.request_info`` / ``@response_handler`` +pattern, hosted on a standalone Durable Task worker (no Azure Functions). + +``DurableAIAgentWorker.configure_workflow`` auto-registers: + +- a durable entity for each agent executor, +- a durable activity for each non-agent executor, and +- the workflow orchestrator (named ``WORKFLOW_ORCHESTRATOR_NAME``). + +When the workflow calls ``ctx.request_info``, the orchestrator pauses and records +the open request in its custom status. An external client discovers the request +(``DurableAIAgentClient.get_pending_hitl_requests``) and resumes the workflow by +sending a response (``DurableAIAgentClient.send_hitl_response``). + +The workflow is a content-moderation pipeline: +``input_router`` -> ``ContentAnalyzerAgent`` -> ``content_analyzer_executor`` +-> ``human_review_executor`` (HITL pause) -> ``publish_executor``. + +Prerequisites: +- Set ``FOUNDRY_PROJECT_ENDPOINT`` and ``FOUNDRY_MODEL``. +- Sign in with Azure CLI (``az login``) for ``AzureCliCredential``. +- Start a Durable Task Scheduler (e.g. the DTS emulator on ``localhost:8080``). + +Run the worker (this process), then run ``client.py`` in another process. +""" + +import asyncio +import json +import logging +import os +from dataclasses import dataclass + +from agent_framework import ( + Agent, + AgentExecutorRequest, + AgentExecutorResponse, + Executor, + Message, + Workflow, + WorkflowBuilder, + WorkflowContext, + handler, + response_handler, +) +from agent_framework.azure import DurableAIAgentWorker +from agent_framework.foundry import FoundryChatClient +from azure.identity import AzureCliCredential +from azure.identity.aio import AzureCliCredential as AsyncAzureCliCredential +from dotenv import load_dotenv +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker +from pydantic import BaseModel, ValidationError +from typing_extensions import Never + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +CONTENT_ANALYZER_AGENT_NAME = "ContentAnalyzerAgent" + +CONTENT_ANALYZER_INSTRUCTIONS = ( + "You are a content moderation assistant that analyzes user-submitted content for policy compliance. " + "Evaluate appropriateness, assign a risk level ('low', 'medium', 'high'), list any concerns, and give a " + "brief recommendation for human reviewers. Return JSON with fields is_appropriate (bool), risk_level (str), " + "concerns (list of str), and recommendation (str)." +) + + +# ============================================================================ +# Data Models +# ============================================================================ + + +class ContentAnalysisResult(BaseModel): + """Structured output from the content analysis agent.""" + + is_appropriate: bool + risk_level: str + concerns: list[str] + recommendation: str + + +@dataclass +class ContentSubmission: + """Content submitted for moderation.""" + + content_id: str + title: str + body: str + author: str + + +@dataclass +class AnalysisWithSubmission: + """Combines the AI analysis with the original submission for downstream processing.""" + + submission: ContentSubmission + analysis: ContentAnalysisResult + + +@dataclass +class HumanApprovalRequest: + """Request sent to a human reviewer. Surfaced to clients via the orchestration status.""" + + content_id: str + title: str + body: str + author: str + ai_analysis: ContentAnalysisResult + prompt: str + + +class HumanApprovalResponse(BaseModel): + """Response the external client sends back via the HITL response endpoint/method.""" + + approved: bool + reviewer_notes: str = "" + + +@dataclass +class ModerationResult: + """Final result of the moderation workflow.""" + + content_id: str + status: str + reviewer_notes: str + + +# ============================================================================ +# Executors +# ============================================================================ + + +class InputRouterExecutor(Executor): + """Parses the incoming submission and routes it to the analysis agent.""" + + def __init__(self) -> None: + super().__init__(id="input_router") + + @handler + async def route_input(self, input_json: str, ctx: WorkflowContext[AgentExecutorRequest]) -> None: + data = json.loads(input_json) if isinstance(input_json, str) else input_json + + submission = ContentSubmission( + content_id=data.get("content_id", "unknown"), + title=data.get("title", "Untitled"), + body=data.get("body", ""), + author=data.get("author", "Anonymous"), + ) + ctx.set_state("current_submission", submission) + + message = ( + f"Please analyze the following content for policy compliance:\n\n" + f"Title: {submission.title}\n" + f"Author: {submission.author}\n" + f"Content:\n{submission.body}" + ) + await ctx.send_message( + AgentExecutorRequest(messages=[Message(role="user", contents=[message])], should_respond=True) + ) + + +class ContentAnalyzerExecutor(Executor): + """Parses the AI agent's response and forwards it with the original submission.""" + + def __init__(self) -> None: + super().__init__(id="content_analyzer_executor") + + @handler + async def handle_analysis( + self, response: AgentExecutorResponse, ctx: WorkflowContext[AnalysisWithSubmission] + ) -> None: + try: + analysis = ContentAnalysisResult.model_validate_json(response.agent_response.text) + except ValidationError: + analysis = ContentAnalysisResult( + is_appropriate=False, + risk_level="high", + concerns=["Agent execution failed or yielded invalid JSON."], + recommendation="Manual review required", + ) + + submission: ContentSubmission = ctx.get_state("current_submission") + await ctx.send_message(AnalysisWithSubmission(submission=submission, analysis=analysis)) + + +class HumanReviewExecutor(Executor): + """Requests human approval using MAF's request_info / response_handler pattern.""" + + def __init__(self) -> None: + super().__init__(id="human_review_executor") + + @handler + async def request_review(self, data: AnalysisWithSubmission, ctx: WorkflowContext) -> None: + submission = data.submission + analysis = data.analysis + + prompt = ( + f"Please review the following content for publication:\n\n" + f"Title: {submission.title}\n" + f"Author: {submission.author}\n" + f"Content: {submission.body}\n\n" + f"AI Analysis:\n" + f"- Appropriate: {analysis.is_appropriate}\n" + f"- Risk Level: {analysis.risk_level}\n" + f"- Concerns: {', '.join(analysis.concerns) if analysis.concerns else 'None'}\n" + f"- Recommendation: {analysis.recommendation}\n\n" + f"Please approve or reject this content." + ) + approval_request = HumanApprovalRequest( + content_id=submission.content_id, + title=submission.title, + body=submission.body, + author=submission.author, + ai_analysis=analysis, + prompt=prompt, + ) + + # Pause the workflow and wait for a human response. + await ctx.request_info(request_data=approval_request, response_type=HumanApprovalResponse) + + @response_handler + async def handle_approval_response( + self, + original_request: HumanApprovalRequest, + response: HumanApprovalResponse, + ctx: WorkflowContext[ModerationResult], + ) -> None: + logger.info( + "Human review received for content %s: approved=%s", + original_request.content_id, + response.approved, + ) + await ctx.send_message( + ModerationResult( + content_id=original_request.content_id, + status="approved" if response.approved else "rejected", + reviewer_notes=response.reviewer_notes, + ) + ) + + +class PublishExecutor(Executor): + """Finalizes publication or rejection of the content.""" + + def __init__(self) -> None: + super().__init__(id="publish_executor") + + @handler + async def handle_result(self, result: ModerationResult, ctx: WorkflowContext[Never, str]) -> None: + if result.status == "approved": + message = ( + f"Content '{result.content_id}' has been APPROVED and published. " + f"Reviewer notes: {result.reviewer_notes or 'None'}" + ) + else: + message = ( + f"Content '{result.content_id}' has been REJECTED. " + f"Reviewer notes: {result.reviewer_notes or 'None'}" + ) + logger.info(message) + await ctx.yield_output(message) + + +def _create_chat_client() -> FoundryChatClient: + """Create an Azure AI Foundry chat client using AzureCliCredential.""" + return FoundryChatClient( + project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"], + model=os.environ["FOUNDRY_MODEL"], + credential=AsyncAzureCliCredential(), + ) + + +def create_workflow() -> Workflow: + """Build the content-moderation workflow with a human-in-the-loop pause.""" + chat_client = _create_chat_client() + + content_analyzer_agent = Agent( + client=chat_client, + name=CONTENT_ANALYZER_AGENT_NAME, + instructions=CONTENT_ANALYZER_INSTRUCTIONS, + default_options={"response_format": ContentAnalysisResult}, + ) + + input_router = InputRouterExecutor() + content_analyzer_executor = ContentAnalyzerExecutor() + human_review_executor = HumanReviewExecutor() + publish_executor = PublishExecutor() + + return ( + WorkflowBuilder(start_executor=input_router) + .add_edge(input_router, content_analyzer_agent) + .add_edge(content_analyzer_agent, content_analyzer_executor) + .add_edge(content_analyzer_executor, human_review_executor) + .add_edge(human_review_executor, publish_executor) + .build() + ) + + +def get_worker( + taskhub: str | None = None, endpoint: str | None = None, log_handler: logging.Handler | None = None +) -> DurableTaskSchedulerWorker: + """Create a configured DurableTaskSchedulerWorker.""" + taskhub_name = taskhub or os.getenv("TASKHUB", "default") + endpoint_url = endpoint or os.getenv("ENDPOINT", "http://localhost:8080") + + credential = None if endpoint_url == "http://localhost:8080" else AzureCliCredential() + + return DurableTaskSchedulerWorker( + host_address=endpoint_url, + secure_channel=endpoint_url != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential, + log_handler=log_handler, + ) + + +def setup_worker(worker: DurableTaskSchedulerWorker) -> DurableAIAgentWorker: + """Register the workflow (agents + activities + orchestrator) on the worker.""" + agent_worker = DurableAIAgentWorker(worker) + + workflow = create_workflow() + agent_worker.configure_workflow(workflow) + logger.info("✓ Configured HITL workflow with %d executors", len(workflow.executors)) + + return agent_worker + + +async def main() -> None: + """Start the worker and block until interrupted.""" + worker = get_worker() + setup_worker(worker) + + logger.info("Worker is ready and listening for work items. Press Ctrl+C to stop.") + try: + worker.start() + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + logger.info("Worker shutdown initiated") + + logger.info("Worker stopped") + + +if __name__ == "__main__": + asyncio.run(main()) From 26dd80e61cf38fd1a7037b438e5091c02e6807ae Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Tue, 9 Jun 2026 15:18:37 -0500 Subject: [PATCH 05/15] fix: address PR review feedback - Sanitize HITL external-event responses with strip_pickle_markers in the orchestrator (defense-in-depth for callers that bypass DurableWorkflowClient). - Raise WorkflowConvergenceException when max_iterations is reached with pending messages, matching the core WorkflowRunner instead of silently returning partial output. - Route falsy 'sent' messages (use 'is not None' instead of truthiness). - Normalize None shared_state_snapshot/source_executor_ids in execute_workflow_activity. - Cast Any returns in AzureFunctionsWorkflowContext to satisfy mypy/pyright. - Fix sample docstrings to reference DurableWorkflowClient. --- .../_workflow_af_context.py | 8 ++++---- .../_workflow_activity.py | 6 ++++-- .../_workflow_orchestrator.py | 17 ++++++++++++++++- .../durabletask/08_workflow/client.py | 2 +- .../durabletask/09_workflow_hitl/client.py | 2 +- .../durabletask/09_workflow_hitl/worker.py | 4 ++-- 6 files changed, 28 insertions(+), 11 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow_af_context.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow_af_context.py index a035628d998..62c5336e3e0 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow_af_context.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow_af_context.py @@ -10,7 +10,7 @@ import logging from datetime import datetime -from typing import Any +from typing import Any, cast from agent_framework_durabletask import AgentSessionId, DurableAgentSession, DurableAIAgent from azure.durable_functions import DurableOrchestrationContext @@ -30,11 +30,11 @@ def __init__(self, context: DurableOrchestrationContext) -> None: @property def instance_id(self) -> str: - return self._context.instance_id + return cast(str, self._context.instance_id) @property def current_utc_datetime(self) -> datetime: - return self._context.current_utc_datetime + return cast(datetime, self._context.current_utc_datetime) # -- Agent / Activity dispatch -------------------------------------------- @@ -71,7 +71,7 @@ def set_custom_status(self, status: Any) -> None: self._context.set_custom_status(status) def new_uuid(self) -> str: - return self._context.new_uuid() + return cast(str, self._context.new_uuid()) def cancel_task(self, task: Any) -> None: cancel_fn = getattr(task, "cancel", None) diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflow_activity.py b/python/packages/durabletask/agent_framework_durabletask/_workflow_activity.py index 15d6182c064..ddac8c618d8 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_workflow_activity.py +++ b/python/packages/durabletask/agent_framework_durabletask/_workflow_activity.py @@ -63,8 +63,10 @@ def execute_workflow_activity(executor: Executor, input_json: str, workflow: Wor data = cast(dict[str, Any], data_obj) message_data = data.get("message") - shared_state_snapshot = data.get("shared_state_snapshot", {}) - source_executor_ids = cast(list[str], data.get("source_executor_ids", [SOURCE_ORCHESTRATOR])) + # The orchestrator may pass null for these when shared state / sources are + # omitted, so coerce None to the appropriate empty default. + shared_state_snapshot = data.get("shared_state_snapshot") or {} + source_executor_ids = cast(list[str], data.get("source_executor_ids") or [SOURCE_ORCHESTRATOR]) # Reconstruct the message - deserialize_value restores the original typed # objects from the encoded data (with type markers). diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflow_orchestrator.py b/python/packages/durabletask/agent_framework_durabletask/_workflow_orchestrator.py index 2d7ecbe4a18..432c7b744f9 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_workflow_orchestrator.py +++ b/python/packages/durabletask/agent_framework_durabletask/_workflow_orchestrator.py @@ -35,6 +35,7 @@ AgentResponse, Message, Workflow, + WorkflowConvergenceException, ) from agent_framework._workflows._edge import ( Edge, @@ -338,7 +339,9 @@ def _route_result_messages( for msg_data in result.activity_result["sent_messages"]: sent_msg = msg_data.get("message") target_id = msg_data.get("target_id") - if sent_msg: + # Use an explicit None check so legitimately falsy payloads + # (empty string, 0, False) are still routed. + if sent_msg is not None: sent_msg = deserialize_value(sent_msg) messages_to_route.append((sent_msg, target_id)) @@ -767,6 +770,11 @@ def run_workflow_orchestrator( del pending_hitl_requests[request_id] + # Sanitize against pickle-marker injection in case a caller + # bypassed DurableWorkflowClient.send_hitl_response and raised + # the external event directly (e.g. via the raw DTS client). + raw_response = strip_pickle_markers(raw_response) + _route_hitl_response( hitl_request, raw_response, @@ -783,4 +791,11 @@ def run_workflow_orchestrator( iteration += 1 + # Match the core WorkflowRunner: if the loop stopped because max_iterations + # was reached while messages are still pending, the workflow did not converge. + if pending_messages: + raise WorkflowConvergenceException( + f"Workflow did not converge after {workflow.max_iterations} iterations." + ) + return workflow_outputs # noqa: B901 diff --git a/python/samples/04-hosting/durabletask/08_workflow/client.py b/python/samples/04-hosting/durabletask/08_workflow/client.py index 118101753b6..7f8a40ff841 100644 --- a/python/samples/04-hosting/durabletask/08_workflow/client.py +++ b/python/samples/04-hosting/durabletask/08_workflow/client.py @@ -3,7 +3,7 @@ """Client that starts the standalone workflow orchestration and prints the result. The worker (``worker.py``) must be running first. The workflow is started via -``DurableAIAgentClient.start_workflow`` - which schedules the orchestrator that +``DurableWorkflowClient.start_workflow`` - which schedules the orchestrator that ``DurableAIAgentWorker.configure_workflow`` auto-registers, so the caller never needs to know its internal name. diff --git a/python/samples/04-hosting/durabletask/09_workflow_hitl/client.py b/python/samples/04-hosting/durabletask/09_workflow_hitl/client.py index 823da43b60b..aab91d2e991 100644 --- a/python/samples/04-hosting/durabletask/09_workflow_hitl/client.py +++ b/python/samples/04-hosting/durabletask/09_workflow_hitl/client.py @@ -4,7 +4,7 @@ The worker (``worker.py``) must be running first. This client: -1. Starts the workflow with ``DurableAIAgentClient.start_workflow``. +1. Starts the workflow with ``DurableWorkflowClient.start_workflow``. 2. Polls ``get_pending_hitl_requests`` until the workflow pauses for human input. 3. Sends a decision with ``send_hitl_response`` (the request_id correlates the response back to the paused executor). diff --git a/python/samples/04-hosting/durabletask/09_workflow_hitl/worker.py b/python/samples/04-hosting/durabletask/09_workflow_hitl/worker.py index 2192c5ae378..007e0224231 100644 --- a/python/samples/04-hosting/durabletask/09_workflow_hitl/worker.py +++ b/python/samples/04-hosting/durabletask/09_workflow_hitl/worker.py @@ -15,8 +15,8 @@ When the workflow calls ``ctx.request_info``, the orchestrator pauses and records the open request in its custom status. An external client discovers the request -(``DurableAIAgentClient.get_pending_hitl_requests``) and resumes the workflow by -sending a response (``DurableAIAgentClient.send_hitl_response``). +(``DurableWorkflowClient.get_pending_hitl_requests``) and resumes the workflow by +sending a response (``DurableWorkflowClient.send_hitl_response``). The workflow is a content-moderation pipeline: ``input_router`` -> ``ContentAnalyzerAgent`` -> ``content_analyzer_executor`` From e63d7daefd47a60b6b2cbbfdad873cdb88e2605f Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Tue, 9 Jun 2026 15:37:28 -0500 Subject: [PATCH 06/15] fix: resolve pyright Package Checks errors - Use typed locals instead of cast in AzureFunctionsWorkflowContext (mypy sees Any, pyright sees concrete types -> avoid reportUnnecessaryCast). - Annotate shared_state_snapshot and cast partially-typed durabletask SDK returns / HITL custom-status parsing to satisfy reportUnknownVariableType/reportUnknownMemberType. - Drop the dead deserialize/serialize re-export in _workflow.py and mark the intentional private _extract_message_content re-export. --- .../_workflow.py | 5 +---- .../_workflow_af_context.py | 13 +++++++++---- .../_workflow_activity.py | 2 +- .../_workflow_client.py | 19 +++++++++++-------- .../_workflow_dt_context.py | 17 +++++++++++------ 5 files changed, 33 insertions(+), 23 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py index 7a7c687460c..a01f35729ff 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py @@ -26,7 +26,7 @@ PendingHITLRequest, TaskMetadata, TaskType, - _extract_message_content, + _extract_message_content, # pyright: ignore[reportPrivateUsage] build_agent_executor_response, execute_hitl_response_handler, route_message_through_edge_groups, @@ -36,9 +36,6 @@ ) from azure.durable_functions import DurableOrchestrationContext -# Re-export serialization helpers so existing AF code (e.g. _app.py) keeps working -# via ``from ._workflow import ...`` or ``from ._serialization import ...``. -from ._serialization import deserialize_value, serialize_value # noqa: F401 from ._workflow_af_context import AzureFunctionsWorkflowContext logger = logging.getLogger(__name__) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow_af_context.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow_af_context.py index 62c5336e3e0..639c9332ce4 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow_af_context.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow_af_context.py @@ -10,7 +10,7 @@ import logging from datetime import datetime -from typing import Any, cast +from typing import Any from agent_framework_durabletask import AgentSessionId, DurableAgentSession, DurableAIAgent from azure.durable_functions import DurableOrchestrationContext @@ -30,11 +30,15 @@ def __init__(self, context: DurableOrchestrationContext) -> None: @property def instance_id(self) -> str: - return cast(str, self._context.instance_id) + # Typed local (not cast): mypy sees the untyped context as Any, while + # pyright sees a concrete str - the annotation satisfies both. + instance_id: str = self._context.instance_id + return instance_id @property def current_utc_datetime(self) -> datetime: - return cast(datetime, self._context.current_utc_datetime) + current: datetime = self._context.current_utc_datetime + return current # -- Agent / Activity dispatch -------------------------------------------- @@ -71,7 +75,8 @@ def set_custom_status(self, status: Any) -> None: self._context.set_custom_status(status) def new_uuid(self) -> str: - return cast(str, self._context.new_uuid()) + new_uuid: str = self._context.new_uuid() + return new_uuid def cancel_task(self, task: Any) -> None: cancel_fn = getattr(task, "cancel", None) diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflow_activity.py b/python/packages/durabletask/agent_framework_durabletask/_workflow_activity.py index ddac8c618d8..bc4147255bf 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_workflow_activity.py +++ b/python/packages/durabletask/agent_framework_durabletask/_workflow_activity.py @@ -65,7 +65,7 @@ def execute_workflow_activity(executor: Executor, input_json: str, workflow: Wor message_data = data.get("message") # The orchestrator may pass null for these when shared state / sources are # omitted, so coerce None to the appropriate empty default. - shared_state_snapshot = data.get("shared_state_snapshot") or {} + shared_state_snapshot: dict[str, Any] = data.get("shared_state_snapshot") or {} source_executor_ids = cast(list[str], data.get("source_executor_ids") or [SOURCE_ORCHESTRATOR]) # Reconstruct the message - deserialize_value restores the original typed diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflow_client.py b/python/packages/durabletask/agent_framework_durabletask/_workflow_client.py index 92e6a8b225d..c2d304fb324 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_workflow_client.py +++ b/python/packages/durabletask/agent_framework_durabletask/_workflow_client.py @@ -11,7 +11,7 @@ import json import logging -from typing import Any +from typing import Any, cast from durabletask.client import TaskHubGrpcClient @@ -135,21 +135,24 @@ def get_pending_hitl_requests(self, instance_id: str) -> list[dict[str, Any]]: if not isinstance(custom_status, dict): return [] + status_dict = cast(dict[str, Any], custom_status) - pending = custom_status.get("pending_requests") + pending = status_dict.get("pending_requests") if not isinstance(pending, dict): return [] + pending_dict = cast(dict[str, Any], pending) requests: list[dict[str, Any]] = [] - for request_id, req_data in pending.items(): + for request_id, req_data in pending_dict.items(): if not isinstance(req_data, dict): continue + req = cast(dict[str, Any], req_data) requests.append({ - "request_id": req_data.get("request_id", request_id), - "source_executor_id": req_data.get("source_executor_id"), - "data": req_data.get("data"), - "request_type": req_data.get("request_type"), - "response_type": req_data.get("response_type"), + "request_id": req.get("request_id", request_id), + "source_executor_id": req.get("source_executor_id"), + "data": req.get("data"), + "request_type": req.get("request_type"), + "response_type": req.get("response_type"), }) return requests diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflow_dt_context.py b/python/packages/durabletask/agent_framework_durabletask/_workflow_dt_context.py index f8695a71450..8d2caf0baf1 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_workflow_dt_context.py +++ b/python/packages/durabletask/agent_framework_durabletask/_workflow_dt_context.py @@ -10,9 +10,14 @@ import logging from datetime import datetime -from typing import Any +from typing import Any, cast -from durabletask.task import OrchestrationContext, Task, when_all, when_any +from durabletask.task import ( + OrchestrationContext, + Task, + when_all, + when_any, # pyright: ignore[reportUnknownVariableType] +) from ._executors import OrchestrationAgentExecutor from ._models import AgentSessionId, DurableAgentSession @@ -48,7 +53,7 @@ def prepare_agent_task(self, executor_id: str, message: str, orchestration_insta return agent.run(message, session=session) def prepare_activity_task(self, activity_name: str, input_json: str) -> Any: - return self._context.call_activity(activity_name, input=input_json) + return cast(Any, self._context.call_activity(activity_name, input=input_json)) # -- Composite tasks ------------------------------------------------------ @@ -61,10 +66,10 @@ def task_any(self, tasks: list[Any]) -> Any: # -- External events / timers --------------------------------------------- def wait_for_external_event(self, name: str) -> Any: - return self._context.wait_for_external_event(name) + return cast(Any, self._context).wait_for_external_event(name) def create_timer(self, fire_at: datetime) -> Any: - return self._context.create_timer(fire_at) + return cast(Any, self._context).create_timer(fire_at) # -- Status / utility ----------------------------------------------------- @@ -82,7 +87,7 @@ def cancel_task(self, task: Any) -> None: def get_task_result(self, task: Any) -> Any: if isinstance(task, Task): - return task.get_result() + return cast(Any, task.get_result()) return getattr(task, "result", None) From ba4c3292bfa073093a1bf280c08ecd79dbb4c962 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Wed, 10 Jun 2026 12:39:21 -0500 Subject: [PATCH 07/15] fix(durabletask): agent-executor identity and typed workflow input Register each workflow agent entity under the executor id that the orchestrator dispatches to (instead of the agent name), so AgentExecutor(agent, id=...) works when the id differs from agent.name. The azure-functions host mirrors this. Reconstruct the start executor declared input type from the workflow initial JSON payload in the shared engine (mirroring in-process delivery) instead of string-coercing it per host. Untrusted input is stripped of pickle markers before reconstruction to prevent deserialization RCE. --- .../agent_framework_azurefunctions/_app.py | 13 +++-- .../agent_framework_durabletask/_worker.py | 50 +++++++++++------- .../_workflow_orchestrator.py | 51 ++++++++++++++++++- .../_workflow_registration.py | 21 +++++--- 4 files changed, 104 insertions(+), 31 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index bc524674a22..715a57601c7 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -228,11 +228,13 @@ def __init__( # The "what to register" decision (agent -> entity, non-agent -> activity) # is shared with the standalone durabletask host via plan_workflow_registration. if workflow: - if agents is None: - agents = [] logger.debug("[AgentFunctionApp] Extracting agents from workflow") plan = plan_workflow_registration(workflow) - agents.extend(plan.agents) + for agent_executor in plan.agent_executors: + # Register each agent executor's entity under its executor id (the + # identity the orchestrator dispatches to), so AgentExecutor(agent, + # id=...) works even when the id differs from agent.name. + self._setup_agent_entity(agent_executor.agent, agent_executor.id, self.default_callback) for executor in plan.activity_executors: # Set up a Functions activity trigger for each non-agent executor. self._setup_executor_activity(executor.id) @@ -296,8 +298,9 @@ def workflow_orchestrator(context: df.DurableOrchestrationContext) -> Any: # ty input_data = context.get_input() - # Ensure input is a string for the agent - initial_message = json.dumps(input_data) if isinstance(input_data, (dict, list)) else str(input_data) + # Pass the deserialized client input straight to the shared engine, which + # reconstructs the start executor's declared type (see _coerce_initial_input). + initial_message = input_data # Create local shared state dict for cross-executor state sharing shared_state: dict[str, Any] = {} diff --git a/python/packages/durabletask/agent_framework_durabletask/_worker.py b/python/packages/durabletask/agent_framework_durabletask/_worker.py index 74a5c8401f4..d0a46422f2e 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_worker.py +++ b/python/packages/durabletask/agent_framework_durabletask/_worker.py @@ -9,7 +9,6 @@ from __future__ import annotations -import json import logging from typing import Any @@ -89,37 +88,44 @@ def add_agent( self, agent: SupportsAgentRun, callback: AgentResponseCallbackProtocol | None = None, + *, + entity_id: str | None = None, ) -> None: """Register an agent with the worker. This method creates a durable entity class for the agent and registers it with the underlying durabletask worker. The entity will be accessible - by the name "dafx-{agent_name}". + by the name "dafx-{entity_id or agent_name}". Args: agent: The agent to register (must have a name) callback: Optional callback for this specific agent (overrides worker-level callback) + entity_id: Optional identity to register the entity under instead of + ``agent.name``. Workflow hosting passes the executor's ``id`` so the + entity matches the identity the orchestrator dispatches to. Raises: ValueError: If the agent doesn't have a name or is already registered """ - agent_name = agent.name - if not agent_name: + registration_name = entity_id or agent.name + if not registration_name: raise ValueError("Agent must have a name to be registered") - if agent_name in self._registered_agents: - raise ValueError(f"Agent '{agent_name}' is already registered") + if registration_name in self._registered_agents: + raise ValueError(f"Agent '{registration_name}' is already registered") - logger.info("[DurableAIAgentWorker] Registering agent: %s as entity: dafx-%s", agent_name, agent_name) + logger.info( + "[DurableAIAgentWorker] Registering agent: %s as entity: dafx-%s", registration_name, registration_name + ) # Store the agent reference - self._registered_agents[agent_name] = agent + self._registered_agents[registration_name] = agent # Use agent-specific callback if provided, otherwise use worker-level callback effective_callback = callback or self._callback # Create a configured entity class using the factory - entity_class = self.__create_agent_entity(agent, effective_callback) + entity_class = self.__create_agent_entity(agent, effective_callback, entity_id=registration_name) # Register the entity class with the worker # The worker.add_entity method takes a class @@ -128,7 +134,7 @@ def add_agent( logger.debug( "[DurableAIAgentWorker] Successfully registered entity class %s for agent: %s", entity_registered, - agent_name, + registration_name, ) def start(self) -> None: @@ -184,10 +190,13 @@ def configure_workflow( # is shared with the Azure Functions host via plan_workflow_registration. plan = plan_workflow_registration(workflow) - # Register agent executors as durable entities. - for agent in plan.agents: - if agent.name and agent.name not in self._registered_agents: - self.add_agent(agent, callback=callback) + # Register agent executors as durable entities. Each entity is keyed by + # the executor's id (the identity the orchestrator dispatches to) so + # AgentExecutor(agent, id=...) works even when the id differs from the + # agent's name. + for agent_executor in plan.agent_executors: + if agent_executor.id not in self._registered_agents: + self.add_agent(agent_executor.agent, callback=callback, entity_id=agent_executor.id) # Register non-agent executors as durable activities. for executor in plan.activity_executors: @@ -199,7 +208,7 @@ def configure_workflow( logger.info( "[DurableAIAgentWorker] Workflow configured with %d executors (%d agents, %d activities)", len(workflow.executors), - len(plan.agents), + len(plan.agent_executors), len(plan.activity_executors), ) @@ -227,7 +236,9 @@ def workflow_orchestrator(context: OrchestrationContext, input_data: Any) -> Any if captured_workflow is None: raise RuntimeError("Workflow not configured") - initial_message = json.dumps(input_data) if isinstance(input_data, (dict, list)) else str(input_data) + # Pass the deserialized client input straight to the shared engine, which + # reconstructs the start executor's declared type (see _coerce_initial_input). + initial_message = input_data shared_state: dict[str, Any] = {} dt_ctx = DurableTaskWorkflowContext(context) @@ -244,6 +255,8 @@ def __create_agent_entity( self, agent: SupportsAgentRun, callback: AgentResponseCallbackProtocol | None = None, + *, + entity_id: str | None = None, ) -> type[DurableTaskEntityStateProvider]: """Factory function to create a DurableEntity class configured with an agent. @@ -253,11 +266,14 @@ def __create_agent_entity( Args: agent: The agent instance to wrap callback: Optional callback for agent responses + entity_id: Optional identity to register the entity under instead of + ``agent.name`` (used by workflow hosting to key entities by + executor id). Returns: A new DurableEntity subclass configured for this agent """ - agent_name = agent.name or type(agent).__name__ + agent_name = entity_id or agent.name or type(agent).__name__ entity_name = f"dafx-{agent_name}" class ConfiguredAgentEntity(DurableTaskEntityStateProvider): diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflow_orchestrator.py b/python/packages/durabletask/agent_framework_durabletask/_workflow_orchestrator.py index 432c7b744f9..983b84337d8 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_workflow_orchestrator.py +++ b/python/packages/durabletask/agent_framework_durabletask/_workflow_orchestrator.py @@ -33,6 +33,7 @@ AgentExecutorRequest, AgentExecutorResponse, AgentResponse, + Executor, Message, Workflow, WorkflowConvergenceException, @@ -481,6 +482,54 @@ def _extract_message_content(message: Any) -> str: return message_content +def _select_primary_input_type(executor: Executor) -> type | None: + """Return the executor's primary concrete declared input type, if any. + + The first declared input type that is a concrete class is used; union or + unannotated types yield ``None`` (the caller then passes the value through + unchanged). + """ + for input_type in executor.input_types: + if isinstance(input_type, type): + return input_type + return None + + +def _coerce_initial_input(workflow: Workflow, raw_value: Any) -> Any: + """Coerce the client's initial workflow input to the start executor's type. + + A durable workflow runs as a durable orchestration, so its initial payload + arrives as plain JSON via ``context.get_input()`` -- without the type markers + that inter-executor messages carry (those are reconstructed by + :func:`deserialize_value`). This single entry hop therefore needs explicit + reconstruction to mirror in-process delivery, where the start executor + receives its declared type: + + * Agent start executors only consume text, so non-text input is stringified. + * Other executors get their primary declared input type reconstructed + (``dict`` -> Pydantic/dataclass, ``str`` -> ``str``, ...) via + :func:`reconstruct_to_type`; union/unannotated types pass through unchanged. + """ + start_executor = workflow.executors.get(workflow.start_executor_id) + if start_executor is None: + return raw_value + + if isinstance(start_executor, AgentExecutor): + if isinstance(raw_value, str): + return raw_value + if isinstance(raw_value, (dict, list)): + return json.dumps(raw_value) + return str(raw_value) + + input_type = _select_primary_input_type(start_executor) + if input_type is None: + return raw_value + # The initial payload is untrusted external input (HTTP body / client input) with no + # legitimate checkpoint type markers, so neutralize any pickle-marker injection before + # it can reach deserialize_value() inside reconstruct_to_type() (avoids pickle RCE). + return reconstruct_to_type(strip_pickle_markers(raw_value), input_type) + + # ============================================================================ # HITL Response Handler Execution # ============================================================================ @@ -666,7 +715,7 @@ def run_workflow_orchestrator( List of workflow outputs collected from executor activities. """ pending_messages: dict[str, list[tuple[Any, str]]] = { - workflow.start_executor_id: [(initial_message, SOURCE_WORKFLOW_START)] + workflow.start_executor_id: [(_coerce_initial_input(workflow, initial_message), SOURCE_WORKFLOW_START)] } workflow_outputs: list[Any] = [] iteration = 0 diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflow_registration.py b/python/packages/durabletask/agent_framework_durabletask/_workflow_registration.py index 032ef022965..b88be759d52 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_workflow_registration.py +++ b/python/packages/durabletask/agent_framework_durabletask/_workflow_registration.py @@ -23,7 +23,7 @@ from dataclasses import dataclass -from agent_framework import AgentExecutor, Executor, SupportsAgentRun, Workflow +from agent_framework import AgentExecutor, Executor, Workflow from ._workflow_orchestrator import WORKFLOW_ORCHESTRATOR_NAME @@ -33,12 +33,17 @@ class WorkflowRegistrationPlan: """The durable primitives a workflow registers, independent of host. Attributes: - agents: Agents (from agent executors) to register as durable entities. + agent_executors: Agent executors to register as durable entities. The + full :class:`AgentExecutor` is carried (not just its agent) so each + host can register the entity under the executor's ``id`` — the same + identity the orchestrator dispatches to — which keeps + ``AgentExecutor(agent, id=...)`` working when the id differs from + ``agent.name``. activity_executors: Non-agent executors to register as durable activities. orchestrator_name: The orchestrator name to register and to start runs with. """ - agents: list[SupportsAgentRun] + agent_executors: list[AgentExecutor] activity_executors: list[Executor] orchestrator_name: str @@ -50,20 +55,20 @@ def plan_workflow_registration(workflow: Workflow) -> WorkflowRegistrationPlan: workflow: The MAF :class:`Workflow` to host. Returns: - A :class:`WorkflowRegistrationPlan` describing the agents (entities), - non-agent executors (activities), and the orchestrator name. + A :class:`WorkflowRegistrationPlan` describing the agent executors + (entities), non-agent executors (activities), and the orchestrator name. """ - agents: list[SupportsAgentRun] = [] + agent_executors: list[AgentExecutor] = [] activity_executors: list[Executor] = [] for executor in workflow.executors.values(): if isinstance(executor, AgentExecutor): - agents.append(executor.agent) + agent_executors.append(executor) else: activity_executors.append(executor) return WorkflowRegistrationPlan( - agents=agents, + agent_executors=agent_executors, activity_executors=activity_executors, orchestrator_name=WORKFLOW_ORCHESTRATOR_NAME, ) From 98106168a02ebf5214e3986e57b4374e62a18851 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Wed, 10 Jun 2026 12:39:37 -0500 Subject: [PATCH 08/15] fix(samples): type durable workflow start executors for reconstructed input The HITL and parallel workflow samples no longer hand-parse a JSON string. Their start executors now declare their real input type (ContentSubmission / DocumentInput), which the durable engine reconstructs from the client payload before delivery. --- .../11_workflow_parallel/function_app.py | 14 ++++---------- .../12_workflow_hitl/function_app.py | 16 +++++----------- .../durabletask/09_workflow_hitl/worker.py | 11 +---------- 3 files changed, 10 insertions(+), 31 deletions(-) diff --git a/python/samples/04-hosting/azure_functions/11_workflow_parallel/function_app.py b/python/samples/04-hosting/azure_functions/11_workflow_parallel/function_app.py index 0669d95e7b1..49ac41ffb19 100644 --- a/python/samples/04-hosting/azure_functions/11_workflow_parallel/function_app.py +++ b/python/samples/04-hosting/azure_functions/11_workflow_parallel/function_app.py @@ -26,7 +26,6 @@ - Ensure Azurite and the Durable Task Scheduler emulator are running """ -import json import logging import os from dataclasses import dataclass @@ -142,17 +141,12 @@ class FinalReport: @executor(id="input_router") -async def input_router(doc: str, ctx: WorkflowContext[DocumentInput]) -> None: - """Route input document to parallel processors. +async def input_router(document: DocumentInput, ctx: WorkflowContext[DocumentInput]) -> None: + """Route the input document to the parallel processors. - Accepts a JSON string from the HTTP request and converts to DocumentInput. + The durable engine reconstructs ``DocumentInput`` from the client's JSON + payload before delivery, mirroring in-process execution. """ - # Parse the JSON string input - data = json.loads(doc) if isinstance(doc, str) else doc - document = DocumentInput( - document_id=data.get("document_id", "unknown"), - content=data.get("content", ""), - ) logger.info("[input_router] Routing document: %s", document.document_id) await ctx.send_message(document) diff --git a/python/samples/04-hosting/azure_functions/12_workflow_hitl/function_app.py b/python/samples/04-hosting/azure_functions/12_workflow_hitl/function_app.py index e1f9389a6da..f8245d2381c 100644 --- a/python/samples/04-hosting/azure_functions/12_workflow_hitl/function_app.py +++ b/python/samples/04-hosting/azure_functions/12_workflow_hitl/function_app.py @@ -23,7 +23,6 @@ - Authentication via Azure CLI (az login) """ -import json import logging import os from dataclasses import dataclass @@ -332,19 +331,14 @@ def __init__(self): @handler async def route_input( self, - input_json: str, + submission: ContentSubmission, ctx: WorkflowContext[AgentExecutorRequest], ) -> None: - """Parse input and create agent request.""" - data = json.loads(input_json) if isinstance(input_json, str) else input_json - - submission = ContentSubmission( - content_id=data.get("content_id", "unknown"), - title=data.get("title", "Untitled"), - body=data.get("body", ""), - author=data.get("author", "Anonymous"), - ) + """Create the agent request from the submitted content. + The durable engine reconstructs this ``ContentSubmission`` from the + client's JSON payload before delivery, mirroring in-process execution. + """ # Store submission in shared state for later retrieval ctx.set_state("current_submission", submission) diff --git a/python/samples/04-hosting/durabletask/09_workflow_hitl/worker.py b/python/samples/04-hosting/durabletask/09_workflow_hitl/worker.py index 007e0224231..33e8ef54087 100644 --- a/python/samples/04-hosting/durabletask/09_workflow_hitl/worker.py +++ b/python/samples/04-hosting/durabletask/09_workflow_hitl/worker.py @@ -31,7 +31,6 @@ """ import asyncio -import json import logging import os from dataclasses import dataclass @@ -144,15 +143,7 @@ def __init__(self) -> None: super().__init__(id="input_router") @handler - async def route_input(self, input_json: str, ctx: WorkflowContext[AgentExecutorRequest]) -> None: - data = json.loads(input_json) if isinstance(input_json, str) else input_json - - submission = ContentSubmission( - content_id=data.get("content_id", "unknown"), - title=data.get("title", "Untitled"), - body=data.get("body", ""), - author=data.get("author", "Anonymous"), - ) + async def route_input(self, submission: ContentSubmission, ctx: WorkflowContext[AgentExecutorRequest]) -> None: ctx.set_state("current_submission", submission) message = ( From 81a5c740330bff067a81a0e79eed7aaf5344ff52 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Wed, 10 Jun 2026 12:39:46 -0500 Subject: [PATCH 09/15] test(durabletask): unit coverage for registration, client, worker, and input coercion Add unit tests for plan_workflow_registration, DurableWorkflowClient, the agent-executor identity registration (entity keyed by executor id), and the typed initial-input coercion including pickle-marker neutralization. --- .../packages/azurefunctions/tests/test_app.py | 24 +- .../packages/durabletask/tests/test_worker.py | 57 +++++ .../durabletask/tests/test_workflow_client.py | 213 ++++++++++++++++++ .../tests/test_workflow_input_coercion.py | 135 +++++++++++ .../tests/test_workflow_registration.py | 97 ++++++++ 5 files changed, 518 insertions(+), 8 deletions(-) create mode 100644 python/packages/durabletask/tests/test_workflow_client.py create mode 100644 python/packages/durabletask/tests/test_workflow_input_coercion.py create mode 100644 python/packages/durabletask/tests/test_workflow_registration.py diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py index 62f15582266..c31ff005a48 100644 --- a/python/packages/azurefunctions/tests/test_app.py +++ b/python/packages/azurefunctions/tests/test_app.py @@ -1341,8 +1341,8 @@ def test_init_with_workflow_stores_workflow(self) -> None: assert app.workflow is mock_workflow - def test_init_with_workflow_extracts_agents(self) -> None: - """Test that agents are extracted from workflow executors.""" + def test_init_with_workflow_registers_agent_entity_by_executor_id(self) -> None: + """Workflow agent executors are registered as entities keyed by executor id.""" from agent_framework import AgentExecutor mock_agent = Mock() @@ -1350,18 +1350,24 @@ def test_init_with_workflow_extracts_agents(self) -> None: mock_executor = Mock(spec=AgentExecutor) mock_executor.agent = mock_agent + # Executor id intentionally differs from the agent name to exercise the + # identity fix: dispatch uses the executor id, so registration must too. + mock_executor.id = "custom-executor-id" mock_workflow = Mock() - mock_workflow.executors = {"WorkflowAgent": mock_executor} + mock_workflow.executors = {"custom-executor-id": mock_executor} with ( patch.object(AgentFunctionApp, "_setup_executor_activity"), patch.object(AgentFunctionApp, "_setup_workflow_orchestration"), - patch.object(AgentFunctionApp, "_setup_agent_functions"), + patch.object(AgentFunctionApp, "_setup_agent_entity") as setup_entity, ): - app = AgentFunctionApp(workflow=mock_workflow) + AgentFunctionApp(workflow=mock_workflow) - assert "WorkflowAgent" in app.agents + setup_entity.assert_called_once() + call_args = setup_entity.call_args.args + assert call_args[0] is mock_agent + assert call_args[1] == "custom-executor-id" def test_init_with_workflow_calls_setup_methods(self) -> None: """Test that workflow setup methods are called.""" @@ -1395,8 +1401,8 @@ def test_init_without_workflow_does_not_call_workflow_setup(self) -> None: setup_exec.assert_not_called() setup_orch.assert_not_called() - def test_init_with_workflow_deduplicates_agents(self) -> None: - """Test that agents in both 'agents' and workflow are not double-registered.""" + def test_init_with_workflow_and_explicit_agent_does_not_raise(self) -> None: + """An agent passed explicitly and present in the workflow registers without error.""" from agent_framework import AgentExecutor mock_agent = Mock() @@ -1404,6 +1410,7 @@ def test_init_with_workflow_deduplicates_agents(self) -> None: mock_executor = Mock(spec=AgentExecutor) mock_executor.agent = mock_agent + mock_executor.id = "SharedAgent" mock_workflow = Mock() mock_workflow.executors = {"SharedAgent": mock_executor} @@ -1412,6 +1419,7 @@ def test_init_with_workflow_deduplicates_agents(self) -> None: patch.object(AgentFunctionApp, "_setup_executor_activity"), patch.object(AgentFunctionApp, "_setup_workflow_orchestration"), patch.object(AgentFunctionApp, "_setup_agent_functions"), + patch.object(AgentFunctionApp, "_setup_agent_entity"), ): # Same agent passed explicitly AND present in workflow — should not raise app = AgentFunctionApp(agents=[mock_agent], workflow=mock_workflow) diff --git a/python/packages/durabletask/tests/test_worker.py b/python/packages/durabletask/tests/test_worker.py index e6dabcdfdf0..315690a821e 100644 --- a/python/packages/durabletask/tests/test_worker.py +++ b/python/packages/durabletask/tests/test_worker.py @@ -164,5 +164,62 @@ def test_start_works_with_multiple_agents(self, agent_worker: DurableAIAgentWork assert len(agent_worker.registered_agent_names) == 2 +class TestDurableAIAgentWorkerWorkflow: + """Test workflow registration, including the agent-executor identity fix.""" + + def test_add_agent_with_entity_id_registers_under_override( + self, agent_worker: DurableAIAgentWorker, mock_agent: Mock + ) -> None: + """An explicit entity_id overrides the agent name as the entity identity.""" + agent_worker.add_agent(mock_agent, entity_id="node-7") + + assert "node-7" in agent_worker.registered_agent_names + assert "test_agent" not in agent_worker.registered_agent_names + + def test_configure_workflow_registers_agent_entity_by_executor_id( + self, agent_worker: DurableAIAgentWorker, mock_grpc_worker: Mock + ) -> None: + """Workflow agent executors register entities keyed by executor id. + + The orchestrator dispatches by executor id, so an + ``AgentExecutor(agent, id=...)`` whose id differs from the agent name must + still be reachable. + """ + from agent_framework import AgentExecutor + + agent = Mock() + agent.name = "Reviewer" + agent_executor = Mock(spec=AgentExecutor) + agent_executor.id = "custom-executor-id" + agent_executor.agent = agent + + workflow = Mock() + workflow.executors = {"custom-executor-id": agent_executor} + + agent_worker.configure_workflow(workflow) + + assert "custom-executor-id" in agent_worker.registered_agent_names + assert "Reviewer" not in agent_worker.registered_agent_names + mock_grpc_worker.add_orchestrator.assert_called_once() + + def test_configure_workflow_registers_non_agent_executor_as_activity( + self, agent_worker: DurableAIAgentWorker, mock_grpc_worker: Mock + ) -> None: + """Non-agent executors are registered as activities, not entities.""" + from agent_framework import Executor + + activity_executor = Mock(spec=Executor) + activity_executor.id = "router-node" + + workflow = Mock() + workflow.executors = {"router-node": activity_executor} + + agent_worker.configure_workflow(workflow) + + assert agent_worker.registered_agent_names == [] + mock_grpc_worker.add_activity.assert_called_once() + mock_grpc_worker.add_orchestrator.assert_called_once() + + if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/durabletask/tests/test_workflow_client.py b/python/packages/durabletask/tests/test_workflow_client.py new file mode 100644 index 00000000000..45e87cb65b9 --- /dev/null +++ b/python/packages/durabletask/tests/test_workflow_client.py @@ -0,0 +1,213 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for DurableWorkflowClient. + +Covers starting workflows, awaiting output (including error/timeout paths), +parsing pending human-in-the-loop (HITL) requests from custom status, and +sanitizing HITL responses before delivery. +""" + +import json +from unittest.mock import Mock + +import pytest + +from agent_framework_durabletask import DurableWorkflowClient +from agent_framework_durabletask._workflow_orchestrator import WORKFLOW_ORCHESTRATOR_NAME + + +@pytest.fixture +def mock_client() -> Mock: + """Create a mock TaskHubGrpcClient.""" + return Mock() + + +@pytest.fixture +def workflow_client(mock_client: Mock) -> DurableWorkflowClient: + """Create a DurableWorkflowClient wrapping the mock client.""" + return DurableWorkflowClient(mock_client) + + +class TestStartWorkflow: + """Test starting workflow orchestrations.""" + + def test_start_workflow_schedules_orchestrator( + self, workflow_client: DurableWorkflowClient, mock_client: Mock + ) -> None: + """start_workflow schedules the auto-registered orchestrator by name.""" + mock_client.schedule_new_orchestration.return_value = "instance-1" + + result = workflow_client.start_workflow(input="hello") + + assert result == "instance-1" + mock_client.schedule_new_orchestration.assert_called_once_with( + WORKFLOW_ORCHESTRATOR_NAME, input="hello", instance_id=None + ) + + def test_start_workflow_passes_non_string_input_unchanged( + self, workflow_client: DurableWorkflowClient, mock_client: Mock + ) -> None: + """Non-string payloads are forwarded as-is (no string coercion).""" + mock_client.schedule_new_orchestration.return_value = "instance-2" + payload = {"order_id": 42, "items": ["a", "b"]} + + workflow_client.start_workflow(input=payload) + + _, kwargs = mock_client.schedule_new_orchestration.call_args + assert kwargs["input"] == payload + + def test_start_workflow_forwards_instance_id( + self, workflow_client: DurableWorkflowClient, mock_client: Mock + ) -> None: + """An explicit instance id is forwarded to the underlying client.""" + mock_client.schedule_new_orchestration.return_value = "explicit-id" + + workflow_client.start_workflow(input="x", instance_id="explicit-id") + + _, kwargs = mock_client.schedule_new_orchestration.call_args + assert kwargs["instance_id"] == "explicit-id" + + +class TestAwaitWorkflowOutput: + """Test awaiting workflow completion and output.""" + + def test_returns_deserialized_output_on_completion( + self, workflow_client: DurableWorkflowClient, mock_client: Mock + ) -> None: + """A COMPLETED workflow returns its deserialized output.""" + metadata = Mock() + metadata.runtime_status.name = "COMPLETED" + metadata.serialized_output = json.dumps(["result"]) + mock_client.wait_for_orchestration_completion.return_value = metadata + + output = workflow_client.await_workflow_output("instance-1") + + assert output == ["result"] + + def test_returns_none_when_no_output(self, workflow_client: DurableWorkflowClient, mock_client: Mock) -> None: + """A COMPLETED workflow with no output returns None.""" + metadata = Mock() + metadata.runtime_status.name = "COMPLETED" + metadata.serialized_output = None + mock_client.wait_for_orchestration_completion.return_value = metadata + + assert workflow_client.await_workflow_output("instance-1") is None + + def test_raises_timeout_when_not_completed(self, workflow_client: DurableWorkflowClient, mock_client: Mock) -> None: + """A None metadata (no completion) raises TimeoutError.""" + mock_client.wait_for_orchestration_completion.return_value = None + + with pytest.raises(TimeoutError, match="did not complete"): + workflow_client.await_workflow_output("instance-1", timeout_seconds=5) + + def test_raises_runtime_error_on_failed_status( + self, workflow_client: DurableWorkflowClient, mock_client: Mock + ) -> None: + """A non-COMPLETED status raises RuntimeError.""" + metadata = Mock() + metadata.runtime_status.name = "FAILED" + metadata.serialized_output = "boom" + mock_client.wait_for_orchestration_completion.return_value = metadata + + with pytest.raises(RuntimeError, match="status FAILED"): + workflow_client.await_workflow_output("instance-1") + + +class TestGetPendingHitlRequests: + """Test parsing pending HITL requests from custom status.""" + + def _state_with_status(self, status: object) -> Mock: + state = Mock() + state.serialized_custom_status = json.dumps(status) if status is not None else None + return state + + def test_returns_empty_when_no_state(self, workflow_client: DurableWorkflowClient, mock_client: Mock) -> None: + """No orchestration state yields an empty list.""" + mock_client.get_orchestration_state.return_value = None + + assert workflow_client.get_pending_hitl_requests("instance-1") == [] + + def test_returns_empty_when_status_blank(self, workflow_client: DurableWorkflowClient, mock_client: Mock) -> None: + """A blank custom status yields an empty list.""" + state = Mock() + state.serialized_custom_status = "" + mock_client.get_orchestration_state.return_value = state + + assert workflow_client.get_pending_hitl_requests("instance-1") == [] + + def test_returns_empty_on_invalid_json(self, workflow_client: DurableWorkflowClient, mock_client: Mock) -> None: + """Malformed custom status JSON yields an empty list.""" + state = Mock() + state.serialized_custom_status = "{not-json" + mock_client.get_orchestration_state.return_value = state + + assert workflow_client.get_pending_hitl_requests("instance-1") == [] + + def test_parses_pending_requests(self, workflow_client: DurableWorkflowClient, mock_client: Mock) -> None: + """Pending requests are normalized into the documented shape.""" + status = { + "pending_requests": { + "req-1": { + "request_id": "req-1", + "source_executor_id": "approver", + "data": {"prompt": "approve?"}, + "request_type": "ApprovalRequest", + "response_type": "ApprovalResponse", + } + } + } + mock_client.get_orchestration_state.return_value = self._state_with_status(status) + + requests = workflow_client.get_pending_hitl_requests("instance-1") + + assert requests == [ + { + "request_id": "req-1", + "source_executor_id": "approver", + "data": {"prompt": "approve?"}, + "request_type": "ApprovalRequest", + "response_type": "ApprovalResponse", + } + ] + + def test_falls_back_to_dict_key_for_request_id( + self, workflow_client: DurableWorkflowClient, mock_client: Mock + ) -> None: + """When a request omits request_id, the dict key is used.""" + status = {"pending_requests": {"req-key": {"source_executor_id": "x"}}} + mock_client.get_orchestration_state.return_value = self._state_with_status(status) + + requests = workflow_client.get_pending_hitl_requests("instance-1") + + assert requests[0]["request_id"] == "req-key" + + def test_ignores_non_dict_entries(self, workflow_client: DurableWorkflowClient, mock_client: Mock) -> None: + """Non-dict request entries are skipped.""" + status = {"pending_requests": {"req-1": "not-a-dict"}} + mock_client.get_orchestration_state.return_value = self._state_with_status(status) + + assert workflow_client.get_pending_hitl_requests("instance-1") == [] + + def test_returns_empty_when_pending_not_dict( + self, workflow_client: DurableWorkflowClient, mock_client: Mock + ) -> None: + """A non-dict pending_requests field yields an empty list.""" + status = {"pending_requests": ["unexpected"]} + mock_client.get_orchestration_state.return_value = self._state_with_status(status) + + assert workflow_client.get_pending_hitl_requests("instance-1") == [] + + +class TestSendHitlResponse: + """Test delivering HITL responses.""" + + def test_raises_orchestration_event_with_request_id( + self, workflow_client: DurableWorkflowClient, mock_client: Mock + ) -> None: + """The response is delivered as an external event named by request id.""" + workflow_client.send_hitl_response("instance-1", "req-1", {"approved": True}) + + mock_client.raise_orchestration_event.assert_called_once() + _, kwargs = mock_client.raise_orchestration_event.call_args + assert kwargs["event_name"] == "req-1" + assert kwargs["data"] == {"approved": True} diff --git a/python/packages/durabletask/tests/test_workflow_input_coercion.py b/python/packages/durabletask/tests/test_workflow_input_coercion.py new file mode 100644 index 00000000000..1008b7edbf0 --- /dev/null +++ b/python/packages/durabletask/tests/test_workflow_input_coercion.py @@ -0,0 +1,135 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for workflow initial-input coercion (`_coerce_initial_input`). + +A durable workflow runs as a durable orchestration, so its initial payload +arrives as plain JSON (no type markers). The shared engine reconstructs the +start executor's declared input type from that JSON, mirroring in-process +delivery. These tests pin that behavior across the relevant start-executor +shapes. +""" + +import json +from dataclasses import dataclass +from unittest.mock import Mock + +from agent_framework import AgentExecutor, Executor, WorkflowContext, handler +from pydantic import BaseModel + +from agent_framework_durabletask._workflow_orchestrator import _coerce_initial_input + + +@dataclass +class _Submission: + content_id: str + title: str + + +class _SubmissionModel(BaseModel): + content_id: str + title: str + + +class _StrStart(Executor): + def __init__(self) -> None: + super().__init__(id="str_start") + + @handler + async def run(self, message: str, ctx: WorkflowContext) -> None: # pragma: no cover - never invoked + ... + + +class _DataclassStart(Executor): + def __init__(self) -> None: + super().__init__(id="dc_start") + + @handler + async def run(self, message: _Submission, ctx: WorkflowContext) -> None: # pragma: no cover - never invoked + ... + + +class _PydanticStart(Executor): + def __init__(self) -> None: + super().__init__(id="pyd_start") + + @handler + async def run(self, message: _SubmissionModel, ctx: WorkflowContext) -> None: # pragma: no cover - never invoked + ... + + +def _workflow_with(executor: Executor | Mock) -> Mock: + workflow = Mock() + workflow.executors = {executor.id: executor} + workflow.start_executor_id = executor.id + return workflow + + +class TestCoerceInitialInput: + """Test reconstruction of the initial workflow input by start-executor type.""" + + def test_str_start_passes_string_through(self) -> None: + workflow = _workflow_with(_StrStart()) + + assert _coerce_initial_input(workflow, "hello world") == "hello world" + + def test_dataclass_start_reconstructs_from_dict(self) -> None: + workflow = _workflow_with(_DataclassStart()) + + result = _coerce_initial_input(workflow, {"content_id": "x", "title": "T"}) + + assert isinstance(result, _Submission) + assert result.content_id == "x" + assert result.title == "T" + + def test_pydantic_start_reconstructs_from_dict(self) -> None: + workflow = _workflow_with(_PydanticStart()) + + result = _coerce_initial_input(workflow, {"content_id": "x", "title": "T"}) + + assert isinstance(result, _SubmissionModel) + assert result.content_id == "x" + + def test_str_start_leaves_dict_unchanged(self) -> None: + """A str-typed start executor declares text; a dict is not coerced to str.""" + workflow = _workflow_with(_StrStart()) + payload = {"content_id": "x"} + + assert _coerce_initial_input(workflow, payload) == payload + + def test_agent_start_passes_string_through(self) -> None: + agent_executor = Mock(spec=AgentExecutor) + agent_executor.id = "agent" + workflow = _workflow_with(agent_executor) + + assert _coerce_initial_input(workflow, "draft this email") == "draft this email" + + def test_agent_start_stringifies_dict(self) -> None: + """Agents only consume text, so a structured payload is serialized to text.""" + agent_executor = Mock(spec=AgentExecutor) + agent_executor.id = "agent" + workflow = _workflow_with(agent_executor) + + result = _coerce_initial_input(workflow, {"email": "hi"}) + + assert result == json.dumps({"email": "hi"}) + + def test_missing_start_executor_passes_through(self) -> None: + workflow = Mock() + workflow.executors = {} + workflow.start_executor_id = "missing" + payload = {"a": 1} + + assert _coerce_initial_input(workflow, payload) == payload + + def test_pickle_marker_injection_is_neutralized(self) -> None: + """A crafted pickle-marker payload is stripped before reconstruction (no pickle RCE). + + The initial workflow input is untrusted, so a dict carrying the checkpoint + ``__pickled__`` marker must be neutralized rather than flowing into + ``deserialize_value`` (which would ``pickle.loads`` it). + """ + workflow = _workflow_with(_DataclassStart()) + malicious = {"__pickled__": "", "content_id": "x", "title": "T"} + + # The marker-bearing dict is replaced with None, never unpickled or reconstructed. + assert _coerce_initial_input(workflow, malicious) is None diff --git a/python/packages/durabletask/tests/test_workflow_registration.py b/python/packages/durabletask/tests/test_workflow_registration.py new file mode 100644 index 00000000000..64d81be3f17 --- /dev/null +++ b/python/packages/durabletask/tests/test_workflow_registration.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for plan_workflow_registration. + +Verifies the host-agnostic decision of which executors become durable entities +(agent executors) versus durable activities (everything else), and that agent +executors are carried whole so each host can register entities under the +executor id the orchestrator dispatches to. +""" + +from unittest.mock import Mock + +from agent_framework import AgentExecutor, Executor + +from agent_framework_durabletask import WorkflowRegistrationPlan, plan_workflow_registration +from agent_framework_durabletask._workflow_orchestrator import WORKFLOW_ORCHESTRATOR_NAME + + +def _agent_executor(executor_id: str, agent_name: str) -> Mock: + agent = Mock() + agent.name = agent_name + executor = Mock(spec=AgentExecutor) + executor.id = executor_id + executor.agent = agent + return executor + + +def _activity_executor(executor_id: str) -> Mock: + executor = Mock(spec=Executor) + executor.id = executor_id + return executor + + +class TestPlanWorkflowRegistration: + """Test classification of workflow executors into durable primitives.""" + + def test_agent_executor_classified_as_entity(self) -> None: + """An AgentExecutor is carried whole in agent_executors.""" + agent_exec = _agent_executor("reviewer-node", "Reviewer") + workflow = Mock() + workflow.executors = {"reviewer-node": agent_exec} + + plan = plan_workflow_registration(workflow) + + assert plan.agent_executors == [agent_exec] + assert plan.activity_executors == [] + assert plan.orchestrator_name == WORKFLOW_ORCHESTRATOR_NAME + + def test_non_agent_executor_classified_as_activity(self) -> None: + """A plain Executor is classified as an activity.""" + activity_exec = _activity_executor("router-node") + workflow = Mock() + workflow.executors = {"router-node": activity_exec} + + plan = plan_workflow_registration(workflow) + + assert plan.agent_executors == [] + assert plan.activity_executors == [activity_exec] + + def test_mixed_executors_are_partitioned(self) -> None: + """Agent and non-agent executors are split into the correct buckets.""" + agent_exec = _agent_executor("agent-node", "Agent") + activity_exec = _activity_executor("activity-node") + workflow = Mock() + workflow.executors = {"agent-node": agent_exec, "activity-node": activity_exec} + + plan = plan_workflow_registration(workflow) + + assert plan.agent_executors == [agent_exec] + assert plan.activity_executors == [activity_exec] + + def test_agent_executor_id_is_preserved_when_distinct_from_name(self) -> None: + """The plan keeps the executor (and its id), not just the bare agent. + + This is the core of the identity fix: dispatch targets the executor id, + so registration must be able to use the id even when it differs from + ``agent.name``. + """ + agent_exec = _agent_executor("custom-executor-id", "ReusedAgentName") + workflow = Mock() + workflow.executors = {"custom-executor-id": agent_exec} + + plan = plan_workflow_registration(workflow) + + assert plan.agent_executors[0].id == "custom-executor-id" + assert plan.agent_executors[0].agent.name == "ReusedAgentName" + + def test_returns_workflow_registration_plan(self) -> None: + """The return value is a WorkflowRegistrationPlan.""" + workflow = Mock() + workflow.executors = {} + + plan = plan_workflow_registration(workflow) + + assert isinstance(plan, WorkflowRegistrationPlan) + assert plan.agent_executors == [] + assert plan.activity_executors == [] From bbd9b6f687bf535d09f2fd0bf4e8766188f2b02a Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Wed, 10 Jun 2026 12:39:57 -0500 Subject: [PATCH 10/15] test(durabletask): HITL and parallel durable workflow integration tests Add an integration test for the standalone durabletask HITL workflow sample via a new workflow_client fixture. Re-enable the Azure Functions parallel workflow test, consolidated into one end-to-end case so the work-stealing xdist scheduler cannot spawn multiple func hosts for this sample. --- .../test_11_workflow_parallel.py | 97 +++------------- .../tests/integration_tests/conftest.py | 9 +- .../test_09_dt_workflow_hitl.py | 108 ++++++++++++++++++ 3 files changed, 134 insertions(+), 80 deletions(-) create mode 100644 python/packages/durabletask/tests/integration_tests/test_09_dt_workflow_hitl.py diff --git a/python/packages/azurefunctions/tests/integration_tests/test_11_workflow_parallel.py b/python/packages/azurefunctions/tests/integration_tests/test_11_workflow_parallel.py index 65a96678a12..831ff6f4adc 100644 --- a/python/packages/azurefunctions/tests/integration_tests/test_11_workflow_parallel.py +++ b/python/packages/azurefunctions/tests/integration_tests/test_11_workflow_parallel.py @@ -42,103 +42,42 @@ def _setup(self, base_url: str, sample_helper) -> None: self.base_url = base_url self.helper = sample_helper - @pytest.mark.skip(reason="xdist distributes module tests across workers, each spawning a func process") - def test_parallel_workflow_document_analysis(self) -> None: - """Test parallel workflow with a standard document.""" + def test_parallel_workflow_end_to_end(self) -> None: + """Run the parallel workflow end-to-end: start, check status, verify completion. + + Consolidated into a single test on purpose: the work-stealing xdist scheduler + distributes tests (not modules) across workers, and the module-scoped + ``function_app_for_test`` fixture is created per worker -- so multiple tests in + this module would each spawn a separate ``func`` host for this resource-heavy + parallel sample. One test keeps it to a single host while still covering the + fan-out path end-to-end. + """ payload = { "document_id": "doc-test-001", "content": ( "The quarterly earnings report shows strong growth in our cloud services division. " "Revenue increased by 25% compared to last year, driven by enterprise adoption. " - "Customer satisfaction remains high at 92%. However, we face challenges in the " - "mobile segment where competition is intense. Overall, the outlook is positive " - "with expected continued growth in the coming quarters." + "Customer satisfaction remains high at 92%." ), } - # Start orchestration + # Start the orchestration. response = self.helper.post_json(f"{self.base_url}/api/workflow/run", payload) assert response.status_code == 202 data = response.json() - assert "instanceId" in data + instance_id = data["instanceId"] assert "statusQueryGetUri" in data - # Wait for completion - parallel workflows may take longer - status = self.helper.wait_for_orchestration_with_output( - data["statusQueryGetUri"], - max_wait=300, # 5 minutes for parallel execution - ) - assert status["runtimeStatus"] == "Completed" - assert "output" in status - - @pytest.mark.skip(reason="xdist distributes module tests across workers, each spawning a func process") - def test_parallel_workflow_short_document(self) -> None: - """Test parallel workflow with a short document.""" - payload = { - "document_id": "doc-test-002", - "content": "Quick update: Project completed successfully. Team performance exceeded expectations.", - } - - # Start orchestration - response = self.helper.post_json(f"{self.base_url}/api/workflow/run", payload) - assert response.status_code == 202 - data = response.json() - assert "instanceId" in data - assert "statusQueryGetUri" in data + # The status endpoint reflects the started instance. + status_response = self.helper.get(f"{self.base_url}/api/workflow/status/{instance_id}") + assert status_response.status_code == 200 + assert status_response.json()["instanceId"] == instance_id - # Wait for completion + # Fan-out to parallel processors and agents completes with an aggregated output. status = self.helper.wait_for_orchestration_with_output(data["statusQueryGetUri"], max_wait=300) assert status["runtimeStatus"] == "Completed" assert "output" in status - @pytest.mark.skip(reason="xdist distributes module tests across workers, each spawning a func process") - def test_parallel_workflow_technical_document(self) -> None: - """Test parallel workflow with a technical document.""" - payload = { - "document_id": "doc-test-003", - "content": ( - "The new microservices architecture has been deployed to production. " - "Key improvements include: reduced latency by 40%, improved scalability " - "to handle 10x traffic spikes, and enhanced monitoring with distributed tracing. " - "The Kubernetes cluster is now running on version 1.28 with auto-scaling enabled. " - "Next steps include implementing service mesh and improving CI/CD pipelines." - ), - } - - # Start orchestration - response = self.helper.post_json(f"{self.base_url}/api/workflow/run", payload) - assert response.status_code == 202 - data = response.json() - assert "instanceId" in data - - # Wait for completion - status = self.helper.wait_for_orchestration_with_output(data["statusQueryGetUri"], max_wait=300) - assert status["runtimeStatus"] == "Completed" - - @pytest.mark.skip(reason="xdist distributes module tests across workers, each spawning a func process") - def test_workflow_status_endpoint(self) -> None: - """Test that the workflow status endpoint works correctly.""" - payload = { - "document_id": "doc-test-004", - "content": "Brief status update for testing purposes.", - } - - # Start orchestration - response = self.helper.post_json(f"{self.base_url}/api/workflow/run", payload) - assert response.status_code == 202 - data = response.json() - instance_id = data["instanceId"] - - # Check status - status_response = self.helper.get(f"{self.base_url}/api/workflow/status/{instance_id}") - assert status_response.status_code == 200 - status = status_response.json() - assert "instanceId" in status - assert status["instanceId"] == instance_id - - # Wait for completion - self.helper.wait_for_orchestration(data["statusQueryGetUri"], max_wait=300) - if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/python/packages/durabletask/tests/integration_tests/conftest.py b/python/packages/durabletask/tests/integration_tests/conftest.py index 65202e29079..1a1cd142aca 100644 --- a/python/packages/durabletask/tests/integration_tests/conftest.py +++ b/python/packages/durabletask/tests/integration_tests/conftest.py @@ -21,7 +21,7 @@ from durabletask.azuremanaged.client import DurableTaskSchedulerClient from durabletask.client import OrchestrationStatus -from agent_framework_durabletask import DurableAIAgentClient +from agent_framework_durabletask import DurableAIAgentClient, DurableWorkflowClient # Load environment variables from .env file load_dotenv(Path(__file__).parent / ".env") @@ -492,3 +492,10 @@ def create(cls, max_poll_retries: int = 90) -> tuple[DurableTaskSchedulerClient, return create_agent_client(cls.endpoint, cls.taskhub, max_poll_retries) return AgentClientFactory + + +@pytest.fixture(scope="module") +def workflow_client(worker_process: dict[str, Any]) -> DurableWorkflowClient: + """Create a DurableWorkflowClient bound to the current sample worker's task hub.""" + dts_client = create_dts_client(worker_process["endpoint"], worker_process["taskhub"]) + return DurableWorkflowClient(dts_client) diff --git a/python/packages/durabletask/tests/integration_tests/test_09_dt_workflow_hitl.py b/python/packages/durabletask/tests/integration_tests/test_09_dt_workflow_hitl.py new file mode 100644 index 00000000000..6ff54222e79 --- /dev/null +++ b/python/packages/durabletask/tests/integration_tests/test_09_dt_workflow_hitl.py @@ -0,0 +1,108 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Integration tests for the standalone durabletask HITL workflow sample (09_workflow_hitl). + +Exercises the human-in-the-loop workflow path on a standalone durabletask worker: +- The ``InputRouter`` start executor receives a typed ``ContentSubmission`` that the + shared engine reconstructs from the client's JSON payload (no manual parsing). +- An analysis agent produces a recommendation, then the workflow pauses for human + approval via ``request_info``. +- The client retrieves the pending request, replies with ``send_hitl_response``, and + the workflow resumes to an approved/rejected outcome read via ``await_workflow_output``. +""" + +import logging +import time +from typing import Any + +import pytest + +from agent_framework_durabletask import DurableWorkflowClient + +logging.basicConfig(level=logging.WARNING) + +# Module-level markers +pytestmark = [ + pytest.mark.flaky, + pytest.mark.integration, + pytest.mark.sample("09_workflow_hitl"), + pytest.mark.integration_test, + pytest.mark.requires_dts, + pytest.mark.requires_azure_openai, +] + + +def _wait_for_hitl_request( + client: DurableWorkflowClient, instance_id: str, timeout_seconds: int = 90 +) -> list[dict[str, Any]]: + """Poll until the workflow records at least one pending HITL request.""" + deadline = time.time() + timeout_seconds + while time.time() < deadline: + pending = client.get_pending_hitl_requests(instance_id) + if pending: + return pending + time.sleep(2) + raise AssertionError(f"Timed out waiting for a HITL request on instance {instance_id}") + + +class TestStandaloneWorkflowHITL: + """Human-in-the-loop workflow execution on a standalone durabletask worker.""" + + @pytest.fixture(autouse=True) + def setup(self, workflow_client: DurableWorkflowClient) -> None: + """Bind the DurableWorkflowClient for the current sample worker.""" + self.client = workflow_client + + def _run_case(self, submission: dict[str, Any], *, approve: bool) -> Any: + """Start a moderation case, answer the HITL pause, and return the final output.""" + instance_id = self.client.start_workflow(input=submission) + + pending = _wait_for_hitl_request(self.client, instance_id) + request = pending[0] + assert request["request_id"] + assert request["source_executor_id"] + + self.client.send_hitl_response( + instance_id, + request["request_id"], + {"approved": approve, "reviewer_notes": "Looks good." if approve else "Violates content policy."}, + ) + + return self.client.await_workflow_output(instance_id, timeout_seconds=180) + + def test_hitl_workflow_approval(self) -> None: + """Appropriate content is approved after the reviewer says yes.""" + output = self._run_case( + { + "content_id": "article-001", + "title": "Introduction to AI in Healthcare", + "body": ( + "Artificial intelligence is improving healthcare by enabling faster diagnosis, " + "personalized treatment plans, and better patient outcomes." + ), + "author": "Dr. Jane Smith", + }, + approve=True, + ) + + assert output is not None + assert "APPROVED" in str(output).upper() + + def test_hitl_workflow_rejection(self) -> None: + """Spammy content is rejected after the reviewer says no.""" + output = self._run_case( + { + "content_id": "article-002", + "title": "Get Rich Quick", + "body": "Click here NOW to make $10,000 overnight! GUARANTEED! Limited time offer!", + "author": "Definitely Not Spam", + }, + approve=False, + ) + + assert output is not None + assert "REJECTED" in str(output).upper() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 9b213b0b7200646d0057824609371dd73f1a485f Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Wed, 10 Jun 2026 13:33:27 -0500 Subject: [PATCH 11/15] refactor(durabletask): group workflow modules into a _workflows subpackage Move the eight workflow modules into a private _workflows/ subpackage and drop the redundant _workflow_ prefix (orchestrator.py, registration.py, activity.py, client.py, context.py, dt_context.py, runner_context.py, serialization.py). The public API and __all__ are unchanged; only direct internal-module imports were repointed (package __init__, the worker, the azure-functions shared shim, and the affected unit tests). --- .../agent_framework_azurefunctions/_workflow.py | 6 +++--- .../agent_framework_durabletask/__init__.py | 14 +++++++------- .../agent_framework_durabletask/_worker.py | 8 ++++---- .../_workflows/__init__.py | 10 ++++++++++ .../activity.py} | 6 +++--- .../{_workflow_client.py => _workflows/client.py} | 4 ++-- .../context.py} | 0 .../dt_context.py} | 8 ++++---- .../orchestrator.py} | 4 ++-- .../registration.py} | 2 +- .../runner_context.py} | 0 .../serialization.py} | 2 +- .../durabletask/tests/test_workflow_activity.py | 2 +- .../durabletask/tests/test_workflow_client.py | 2 +- .../tests/test_workflow_input_coercion.py | 2 +- .../tests/test_workflow_registration.py | 2 +- 16 files changed, 41 insertions(+), 31 deletions(-) create mode 100644 python/packages/durabletask/agent_framework_durabletask/_workflows/__init__.py rename python/packages/durabletask/agent_framework_durabletask/{_workflow_activity.py => _workflows/activity.py} (97%) rename python/packages/durabletask/agent_framework_durabletask/{_workflow_client.py => _workflows/client.py} (98%) rename python/packages/durabletask/agent_framework_durabletask/{_workflow_context.py => _workflows/context.py} (100%) rename python/packages/durabletask/agent_framework_durabletask/{_workflow_dt_context.py => _workflows/dt_context.py} (94%) rename python/packages/durabletask/agent_framework_durabletask/{_workflow_orchestrator.py => _workflows/orchestrator.py} (99%) rename python/packages/durabletask/agent_framework_durabletask/{_workflow_registration.py => _workflows/registration.py} (97%) rename python/packages/durabletask/agent_framework_durabletask/{_workflow_runner_context.py => _workflows/runner_context.py} (100%) rename python/packages/durabletask/agent_framework_durabletask/{_workflow_serialization.py => _workflows/serialization.py} (98%) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py index a01f35729ff..f6fbe957ab0 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py @@ -4,7 +4,7 @@ This module provides the Azure Functions entry point for workflow orchestration. The actual orchestration logic lives in the shared module -``agent_framework_durabletask._workflow_orchestrator`` and is host-agnostic. +``agent_framework_durabletask._workflows.orchestrator`` and is host-agnostic. This module re-exports the public API and provides the AF-specific ``run_workflow_orchestrator`` wrapper that creates an :class:`AzureFunctionsWorkflowContext` before delegating. @@ -17,7 +17,7 @@ from typing import Any from agent_framework import Workflow -from agent_framework_durabletask._workflow_orchestrator import ( +from agent_framework_durabletask._workflows.orchestrator import ( DEFAULT_HITL_TIMEOUT_HOURS, SOURCE_HITL_RESPONSE, SOURCE_ORCHESTRATOR, @@ -31,7 +31,7 @@ execute_hitl_response_handler, route_message_through_edge_groups, ) -from agent_framework_durabletask._workflow_orchestrator import ( +from agent_framework_durabletask._workflows.orchestrator import ( run_workflow_orchestrator as _run_workflow_orchestrator_shared, ) from azure.durable_functions import DurableOrchestrationContext diff --git a/python/packages/durabletask/agent_framework_durabletask/__init__.py b/python/packages/durabletask/agent_framework_durabletask/__init__.py index 59872a2216b..cebd45ff4bc 100644 --- a/python/packages/durabletask/agent_framework_durabletask/__init__.py +++ b/python/packages/durabletask/agent_framework_durabletask/__init__.py @@ -51,13 +51,13 @@ from ._response_utils import ensure_response_format, load_agent_response from ._shim import DurableAIAgent from ._worker import DurableAIAgentWorker -from ._workflow_activity import execute_workflow_activity -from ._workflow_client import DurableWorkflowClient -from ._workflow_context import WorkflowOrchestrationContext -from ._workflow_dt_context import DurableTaskWorkflowContext -from ._workflow_orchestrator import WORKFLOW_ORCHESTRATOR_NAME, run_workflow_orchestrator -from ._workflow_registration import WorkflowRegistrationPlan, plan_workflow_registration -from ._workflow_runner_context import CapturingRunnerContext +from ._workflows.activity import execute_workflow_activity +from ._workflows.client import DurableWorkflowClient +from ._workflows.context import WorkflowOrchestrationContext +from ._workflows.dt_context import DurableTaskWorkflowContext +from ._workflows.orchestrator import WORKFLOW_ORCHESTRATOR_NAME, run_workflow_orchestrator +from ._workflows.registration import WorkflowRegistrationPlan, plan_workflow_registration +from ._workflows.runner_context import CapturingRunnerContext try: __version__ = importlib.metadata.version(__name__) diff --git a/python/packages/durabletask/agent_framework_durabletask/_worker.py b/python/packages/durabletask/agent_framework_durabletask/_worker.py index d0a46422f2e..bda2e3c5e8b 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_worker.py +++ b/python/packages/durabletask/agent_framework_durabletask/_worker.py @@ -19,10 +19,10 @@ from ._async_bridge import run_agent_coroutine from ._callbacks import AgentResponseCallbackProtocol from ._entities import AgentEntity, DurableTaskEntityStateProvider -from ._workflow_activity import execute_workflow_activity -from ._workflow_dt_context import DurableTaskWorkflowContext -from ._workflow_orchestrator import WORKFLOW_ORCHESTRATOR_NAME, run_workflow_orchestrator -from ._workflow_registration import plan_workflow_registration +from ._workflows.activity import execute_workflow_activity +from ._workflows.dt_context import DurableTaskWorkflowContext +from ._workflows.orchestrator import WORKFLOW_ORCHESTRATOR_NAME, run_workflow_orchestrator +from ._workflows.registration import plan_workflow_registration logger = logging.getLogger("agent_framework.durabletask") diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflows/__init__.py b/python/packages/durabletask/agent_framework_durabletask/_workflows/__init__.py new file mode 100644 index 00000000000..3ebffaa7e78 --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_workflows/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Durable hosting of Microsoft Agent Framework workflows. + +This subpackage turns a MAF :class:`~agent_framework.Workflow` into durable +primitives -- a single orchestrator, agent entities, and non-agent executor +activities -- that run on either a standalone Durable Task worker or Azure +Functions. The host-agnostic engine lives here; each host programs against the +:class:`~.context.WorkflowOrchestrationContext` protocol. +""" diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflow_activity.py b/python/packages/durabletask/agent_framework_durabletask/_workflows/activity.py similarity index 97% rename from python/packages/durabletask/agent_framework_durabletask/_workflow_activity.py rename to python/packages/durabletask/agent_framework_durabletask/_workflows/activity.py index bc4147255bf..0e48417ceef 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_workflow_activity.py +++ b/python/packages/durabletask/agent_framework_durabletask/_workflows/activity.py @@ -25,13 +25,13 @@ from agent_framework._workflows._runner_context import YieldOutputEventType from agent_framework._workflows._state import State -from ._workflow_orchestrator import ( +from .orchestrator import ( SOURCE_HITL_RESPONSE, SOURCE_ORCHESTRATOR, execute_hitl_response_handler, ) -from ._workflow_runner_context import CapturingRunnerContext -from ._workflow_serialization import deserialize_value, serialize_value +from .runner_context import CapturingRunnerContext +from .serialization import deserialize_value, serialize_value def execute_workflow_activity(executor: Executor, input_json: str, workflow: Workflow | None = None) -> str: diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflow_client.py b/python/packages/durabletask/agent_framework_durabletask/_workflows/client.py similarity index 98% rename from python/packages/durabletask/agent_framework_durabletask/_workflow_client.py rename to python/packages/durabletask/agent_framework_durabletask/_workflows/client.py index c2d304fb324..930301c9745 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_workflow_client.py +++ b/python/packages/durabletask/agent_framework_durabletask/_workflows/client.py @@ -15,8 +15,8 @@ from durabletask.client import TaskHubGrpcClient -from ._workflow_orchestrator import WORKFLOW_ORCHESTRATOR_NAME -from ._workflow_serialization import strip_pickle_markers +from .orchestrator import WORKFLOW_ORCHESTRATOR_NAME +from .serialization import strip_pickle_markers logger = logging.getLogger("agent_framework.durabletask") diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflow_context.py b/python/packages/durabletask/agent_framework_durabletask/_workflows/context.py similarity index 100% rename from python/packages/durabletask/agent_framework_durabletask/_workflow_context.py rename to python/packages/durabletask/agent_framework_durabletask/_workflows/context.py diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflow_dt_context.py b/python/packages/durabletask/agent_framework_durabletask/_workflows/dt_context.py similarity index 94% rename from python/packages/durabletask/agent_framework_durabletask/_workflow_dt_context.py rename to python/packages/durabletask/agent_framework_durabletask/_workflows/dt_context.py index 8d2caf0baf1..1e6923e78d3 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_workflow_dt_context.py +++ b/python/packages/durabletask/agent_framework_durabletask/_workflows/dt_context.py @@ -19,10 +19,10 @@ when_any, # pyright: ignore[reportUnknownVariableType] ) -from ._executors import OrchestrationAgentExecutor -from ._models import AgentSessionId, DurableAgentSession -from ._shim import DurableAIAgent -from ._workflow_context import WorkflowOrchestrationContext +from .._executors import OrchestrationAgentExecutor +from .._models import AgentSessionId, DurableAgentSession +from .._shim import DurableAIAgent +from .context import WorkflowOrchestrationContext logger = logging.getLogger(__name__) diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflow_orchestrator.py b/python/packages/durabletask/agent_framework_durabletask/_workflows/orchestrator.py similarity index 99% rename from python/packages/durabletask/agent_framework_durabletask/_workflow_orchestrator.py rename to python/packages/durabletask/agent_framework_durabletask/_workflows/orchestrator.py index 983b84337d8..b7abe690f77 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_workflow_orchestrator.py +++ b/python/packages/durabletask/agent_framework_durabletask/_workflows/orchestrator.py @@ -48,8 +48,8 @@ ) from agent_framework._workflows._state import State -from ._workflow_context import WorkflowOrchestrationContext -from ._workflow_serialization import ( +from .context import WorkflowOrchestrationContext +from .serialization import ( deserialize_value, reconstruct_to_type, resolve_type, diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflow_registration.py b/python/packages/durabletask/agent_framework_durabletask/_workflows/registration.py similarity index 97% rename from python/packages/durabletask/agent_framework_durabletask/_workflow_registration.py rename to python/packages/durabletask/agent_framework_durabletask/_workflows/registration.py index b88be759d52..684f25323a7 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_workflow_registration.py +++ b/python/packages/durabletask/agent_framework_durabletask/_workflows/registration.py @@ -25,7 +25,7 @@ from agent_framework import AgentExecutor, Executor, Workflow -from ._workflow_orchestrator import WORKFLOW_ORCHESTRATOR_NAME +from .orchestrator import WORKFLOW_ORCHESTRATOR_NAME @dataclass diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflow_runner_context.py b/python/packages/durabletask/agent_framework_durabletask/_workflows/runner_context.py similarity index 100% rename from python/packages/durabletask/agent_framework_durabletask/_workflow_runner_context.py rename to python/packages/durabletask/agent_framework_durabletask/_workflows/runner_context.py diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflow_serialization.py b/python/packages/durabletask/agent_framework_durabletask/_workflows/serialization.py similarity index 98% rename from python/packages/durabletask/agent_framework_durabletask/_workflow_serialization.py rename to python/packages/durabletask/agent_framework_durabletask/_workflows/serialization.py index b0d5e647875..cf64cc82013 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_workflow_serialization.py +++ b/python/packages/durabletask/agent_framework_durabletask/_workflows/serialization.py @@ -160,7 +160,7 @@ def reconstruct_to_type(value: Any, target_type: type) -> Any: # Try decoding if data has pickle markers (from checkpoint encoding). # NOTE: This function is general-purpose. Callers that handle untrusted # data (e.g. HITL responses) MUST call strip_pickle_markers() before - # passing data here. See _deserialize_hitl_response in _workflow_orchestrator.py. + # passing data here. See _deserialize_hitl_response in orchestrator.py. decoded = deserialize_value(value) if not isinstance(decoded, dict): return decoded diff --git a/python/packages/durabletask/tests/test_workflow_activity.py b/python/packages/durabletask/tests/test_workflow_activity.py index d92a0a82372..b2ff9c159bc 100644 --- a/python/packages/durabletask/tests/test_workflow_activity.py +++ b/python/packages/durabletask/tests/test_workflow_activity.py @@ -14,7 +14,7 @@ from unittest.mock import AsyncMock, Mock from agent_framework_durabletask import execute_workflow_activity -from agent_framework_durabletask._workflow_orchestrator import SOURCE_ORCHESTRATOR +from agent_framework_durabletask._workflows.orchestrator import SOURCE_ORCHESTRATOR def _make_executor(executor_id: str, mutate: Any) -> Mock: diff --git a/python/packages/durabletask/tests/test_workflow_client.py b/python/packages/durabletask/tests/test_workflow_client.py index 45e87cb65b9..9a8056b474a 100644 --- a/python/packages/durabletask/tests/test_workflow_client.py +++ b/python/packages/durabletask/tests/test_workflow_client.py @@ -13,7 +13,7 @@ import pytest from agent_framework_durabletask import DurableWorkflowClient -from agent_framework_durabletask._workflow_orchestrator import WORKFLOW_ORCHESTRATOR_NAME +from agent_framework_durabletask._workflows.orchestrator import WORKFLOW_ORCHESTRATOR_NAME @pytest.fixture diff --git a/python/packages/durabletask/tests/test_workflow_input_coercion.py b/python/packages/durabletask/tests/test_workflow_input_coercion.py index 1008b7edbf0..2b785a145d3 100644 --- a/python/packages/durabletask/tests/test_workflow_input_coercion.py +++ b/python/packages/durabletask/tests/test_workflow_input_coercion.py @@ -16,7 +16,7 @@ from agent_framework import AgentExecutor, Executor, WorkflowContext, handler from pydantic import BaseModel -from agent_framework_durabletask._workflow_orchestrator import _coerce_initial_input +from agent_framework_durabletask._workflows.orchestrator import _coerce_initial_input @dataclass diff --git a/python/packages/durabletask/tests/test_workflow_registration.py b/python/packages/durabletask/tests/test_workflow_registration.py index 64d81be3f17..f9cb9e190ee 100644 --- a/python/packages/durabletask/tests/test_workflow_registration.py +++ b/python/packages/durabletask/tests/test_workflow_registration.py @@ -13,7 +13,7 @@ from agent_framework import AgentExecutor, Executor from agent_framework_durabletask import WorkflowRegistrationPlan, plan_workflow_registration -from agent_framework_durabletask._workflow_orchestrator import WORKFLOW_ORCHESTRATOR_NAME +from agent_framework_durabletask._workflows.orchestrator import WORKFLOW_ORCHESTRATOR_NAME def _agent_executor(executor_id: str, agent_name: str) -> Mock: From 505f9951f96c82c30f77b864d239253da41fd8b9 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Wed, 10 Jun 2026 14:10:11 -0500 Subject: [PATCH 12/15] fix(durabletask): harden workflow type resolution and HITL response handling - resolve_type returns only real classes (avoids issubclass TypeError in reconstruct_to_type) - re-wait on HITL responses rejected by pickle-marker sanitization instead of dropping the request and losing the run - American spelling in strip_pickle_markers docstring - unit tests for resolve_type --- .../_workflows/orchestrator.py | 49 ++++++++++++------- .../_workflows/serialization.py | 7 ++- .../tests/test_workflow_serialization.py | 31 ++++++++++++ 3 files changed, 67 insertions(+), 20 deletions(-) create mode 100644 python/packages/durabletask/tests/test_workflow_serialization.py diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflows/orchestrator.py b/python/packages/durabletask/agent_framework_durabletask/_workflows/orchestrator.py index b7abe690f77..99b577d3f52 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_workflows/orchestrator.py +++ b/python/packages/durabletask/agent_framework_durabletask/_workflows/orchestrator.py @@ -792,14 +792,25 @@ def run_workflow_orchestrator( }) for request_id, hitl_request in list(pending_hitl_requests.items()): - logger.debug("Waiting for HITL response for request: %s", request_id) + # Re-wait until a valid response arrives (or the request times out). A + # payload rejected by sanitization (pickle/type markers) does not consume + # the request, so the caller can resubmit a corrected response instead of + # losing the entire workflow run. + while True: + logger.debug("Waiting for HITL response for request: %s", request_id) - approval_task = ctx.wait_for_external_event(request_id) - timeout_task = ctx.create_timer(ctx.current_utc_datetime + timedelta(hours=hitl_timeout_hours)) + approval_task = ctx.wait_for_external_event(request_id) + timeout_task = ctx.create_timer(ctx.current_utc_datetime + timedelta(hours=hitl_timeout_hours)) - winner = yield ctx.task_any([approval_task, timeout_task]) + winner = yield ctx.task_any([approval_task, timeout_task]) + + if winner != approval_task: + ctx.cancel_task(approval_task) + logger.warning("HITL request %s timed out after %s hours", request_id, hitl_timeout_hours) + raise TimeoutError( + f"Human-in-the-loop request '{request_id}' timed out after {hitl_timeout_hours} hours." + ) - if winner == approval_task: ctx.cancel_task(timeout_task) raw_response = ctx.get_task_result(approval_task) @@ -817,24 +828,26 @@ def run_workflow_orchestrator( except (json.JSONDecodeError, TypeError): logger.debug("Response is not JSON, keeping as string") - del pending_hitl_requests[request_id] - - # Sanitize against pickle-marker injection in case a caller - # bypassed DurableWorkflowClient.send_hitl_response and raised - # the external event directly (e.g. via the raw DTS client). - raw_response = strip_pickle_markers(raw_response) + # Sanitize against pickle-marker injection in case a caller bypassed + # DurableWorkflowClient.send_hitl_response and raised the external + # event directly (e.g. via the raw DTS client). Sanitize *before* + # consuming the request so a rejected payload can be resubmitted. + sanitized_response = strip_pickle_markers(raw_response) + if sanitized_response is None and raw_response is not None: + logger.warning( + "Rejected HITL response for request %s: payload contained " + "disallowed pickle/type markers. Awaiting a new response.", + request_id, + ) + continue + del pending_hitl_requests[request_id] _route_hitl_response( hitl_request, - raw_response, + sanitized_response, pending_messages, ) - else: - ctx.cancel_task(approval_task) - logger.warning("HITL request %s timed out after %s hours", request_id, hitl_timeout_hours) - raise TimeoutError( - f"Human-in-the-loop request '{request_id}' timed out after {hitl_timeout_hours} hours." - ) + break ctx.set_custom_status({"state": "running"}) diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflows/serialization.py b/python/packages/durabletask/agent_framework_durabletask/_workflows/serialization.py index cf64cc82013..f6e10281986 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_workflows/serialization.py +++ b/python/packages/durabletask/agent_framework_durabletask/_workflows/serialization.py @@ -47,7 +47,10 @@ def resolve_type(type_key: str) -> type | None: try: module_name, class_name = type_key.split(":", 1) module = importlib.import_module(module_name) - return getattr(module, class_name, None) + resolved = getattr(module, class_name, None) + # Only return actual classes. A non-type attribute (function, module member, + # etc.) would raise TypeError in issubclass() inside reconstruct_to_type(). + return resolved if isinstance(resolved, type) else None except Exception: logger.debug("Could not resolve type %s", type_key) return None @@ -67,7 +70,7 @@ def strip_pickle_markers(data: Any) -> Any: ``pickle.loads()`` and enable **arbitrary code execution**. This function walks the incoming data structure and replaces any ``dict`` - that contains either marker key with ``None``, neutralising the attack + that contains either marker key with ``None``, neutralizing the attack vector while leaving all other data untouched. It **must** be called on every value that originates from an untrusted diff --git a/python/packages/durabletask/tests/test_workflow_serialization.py b/python/packages/durabletask/tests/test_workflow_serialization.py new file mode 100644 index 00000000000..28d127de074 --- /dev/null +++ b/python/packages/durabletask/tests/test_workflow_serialization.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for workflow serialization helpers (`resolve_type`). + +``resolve_type`` is annotated ``type | None`` and its result flows into +``reconstruct_to_type``, which calls ``issubclass``. A non-class attribute +(function, module member, etc.) would raise ``TypeError`` there, so the +resolver must only ever return actual classes. +""" + +from collections import OrderedDict + +from agent_framework_durabletask._workflows.serialization import resolve_type + + +class TestResolveType: + """Test that resolve_type only returns real classes.""" + + def test_resolves_a_real_class(self) -> None: + assert resolve_type("collections:OrderedDict") is OrderedDict + + def test_returns_none_for_non_class_attribute(self) -> None: + # json.dumps is a function; if resolve_type returned it, issubclass() + # inside reconstruct_to_type() would raise TypeError at runtime. + assert resolve_type("json:dumps") is None + + def test_returns_none_for_unknown_attribute(self) -> None: + assert resolve_type("json:DoesNotExist") is None + + def test_returns_none_for_malformed_key(self) -> None: + assert resolve_type("not-a-valid-key") is None From 7a2eb01d1292027043b705aeef4436a150fa3149 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Wed, 10 Jun 2026 16:11:48 -0500 Subject: [PATCH 13/15] fix(durabletask): treat async edge conditions as not-matched on the synchronous host The durabletask orchestrator evaluates edge conditions synchronously and does not support async edge conditions. Such an edge is now treated as not matched (the edge is not traversed) rather than assuming a result. Adds unit coverage; full async-condition support will be handled separately. --- .../_workflows/orchestrator.py | 30 +++++++++------ .../tests/test_workflow_routing.py | 38 +++++++++++++++++++ 2 files changed, 56 insertions(+), 12 deletions(-) create mode 100644 python/packages/durabletask/tests/test_workflow_routing.py diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflows/orchestrator.py b/python/packages/durabletask/agent_framework_durabletask/_workflows/orchestrator.py index 99b577d3f52..7ad3c026bbf 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_workflows/orchestrator.py +++ b/python/packages/durabletask/agent_framework_durabletask/_workflows/orchestrator.py @@ -19,6 +19,7 @@ from __future__ import annotations +import inspect import json import logging from collections import defaultdict @@ -131,24 +132,29 @@ class PendingHITLRequest: def _evaluate_edge_condition_sync(edge: Edge, message: Any) -> bool: """Evaluate an edge's condition synchronously. - Durable orchestrators use generators, not async/await, so we cannot call - async methods like ``edge.should_route()``. + Durable orchestrators run as generators, so conditions are evaluated + synchronously here; the durabletask host does not support ``async`` edge + conditions. A condition that returns an awaitable cannot be evaluated in + this context, so the edge is treated as *not matched* (not traversed) + rather than assuming a result. """ condition = edge._condition # pyright: ignore[reportPrivateUsage] if condition is None: return True result = condition(message) - if hasattr(result, "__await__"): - import warnings - - warnings.warn( - f"Edge condition for {edge.source_id}->{edge.target_id} is async, " - "which is not supported in durable orchestrators. " - "The edge will be traversed unconditionally.", - RuntimeWarning, - stacklevel=2, + if inspect.isawaitable(result): + # Async conditions cannot be evaluated in a synchronous orchestrator. + # Close the unawaited coroutine to avoid a "never awaited" warning and + # decline to traverse the edge (treated as not matched). + if inspect.iscoroutine(result): + result.close() + logger.warning( + "Edge condition for %s->%s is async and cannot be evaluated by the durabletask host; " + "the edge is not traversed. Use a synchronous condition.", + edge.source_id, + edge.target_id, ) - return True + return False return bool(result) diff --git a/python/packages/durabletask/tests/test_workflow_routing.py b/python/packages/durabletask/tests/test_workflow_routing.py new file mode 100644 index 00000000000..079e0b69909 --- /dev/null +++ b/python/packages/durabletask/tests/test_workflow_routing.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for synchronous edge-condition evaluation on the durabletask host. + +Durable orchestrators run as generators and evaluate edge conditions +synchronously. A condition that returns an awaitable cannot be evaluated in +that context, so the edge is treated as *not matched* (not traversed). +""" + +from agent_framework._workflows._edge import Edge # pyright: ignore[reportPrivateImportUsage] + +from agent_framework_durabletask._workflows.orchestrator import _evaluate_edge_condition_sync + + +class TestEvaluateEdgeConditionSync: + """Synchronous edge-condition evaluation semantics.""" + + def test_no_condition_traverses(self) -> None: + edge = Edge("a", "b") + assert _evaluate_edge_condition_sync(edge, {"x": 1}) is True + + def test_sync_true_traverses(self) -> None: + edge = Edge("a", "b", condition=lambda m: m["ok"]) + assert _evaluate_edge_condition_sync(edge, {"ok": True}) is True + + def test_sync_false_does_not_traverse(self) -> None: + edge = Edge("a", "b", condition=lambda m: m["ok"]) + assert _evaluate_edge_condition_sync(edge, {"ok": False}) is False + + def test_async_condition_is_not_traversed(self) -> None: + # The durabletask host evaluates conditions synchronously; an async + # condition cannot be evaluated, so the edge is treated as not matched + # even though it would resolve True when awaited. + async def gate(_message: object) -> bool: + return True + + edge = Edge("a", "b", condition=gate) + assert _evaluate_edge_condition_sync(edge, {"x": 1}) is False From 4f9cf65e7b18126cc0c8a6f9a8a24fae4d3ae0ff Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Wed, 10 Jun 2026 17:50:06 -0500 Subject: [PATCH 14/15] fix(durabletask): reconstruct typed workflow outputs at the host boundary await_workflow_output and the Azure Functions status endpoint now decode the checkpoint-encoded outputs the shared activity produces, via a shared deserialize_workflow_output helper. The client returns the original objects; the AF endpoint emits clean domain JSON instead of checkpoint-marker dicts, keeping the two hosts consistent. --- .../agent_framework_azurefunctions/_app.py | 42 +++++++++++++++-- .../packages/azurefunctions/tests/test_app.py | 44 ++++++++++++++++++ .../agent_framework_durabletask/__init__.py | 2 + .../_workflows/client.py | 7 ++- .../_workflows/serialization.py | 27 +++++++++++ .../durabletask/tests/test_workflow_client.py | 25 ++++++++++ .../tests/test_workflow_serialization.py | 46 ++++++++++++++++++- 7 files changed, 185 insertions(+), 8 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index 715a57601c7..5fc0bfd50ce 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -14,7 +14,7 @@ import re import uuid from collections.abc import Callable, Mapping -from dataclasses import dataclass +from dataclasses import asdict, dataclass, is_dataclass from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, TypeVar, cast @@ -39,6 +39,7 @@ DurableAgentState, DurableAIAgent, RunRequest, + deserialize_workflow_output, execute_workflow_activity, plan_workflow_registration, ) @@ -55,6 +56,32 @@ HandlerT = TypeVar("HandlerT", bound=Callable[..., Any]) +def _json_default(obj: Any) -> Any: + """JSON fallback encoder for reconstructed workflow outputs. + + A workflow's yielded outputs are reconstructed (see ``deserialize_workflow_output``) + before they reach the HTTP response, so they may be framework models + (e.g. ``AgentResponse``), dataclasses, or other non-JSON-native objects. + Prefer the type's own serialization so the response carries clean domain + JSON, falling back to ``str`` for anything without one. + """ + to_dict = getattr(obj, "to_dict", None) + if callable(to_dict): + try: + return to_dict() + except Exception: + logger.debug("to_dict() failed while encoding %s for HTTP output", type(obj).__name__) + model_dump = getattr(obj, "model_dump", None) + if callable(model_dump): + try: + return model_dump(mode="json") + except Exception: + logger.debug("model_dump() failed while encoding %s for HTTP output", type(obj).__name__) + if is_dataclass(obj) and not isinstance(obj, type): + return asdict(obj) + return str(obj) + + @dataclass class AgentMetadata: """Metadata for a registered agent. @@ -351,12 +378,19 @@ async def get_workflow_status( if not status: return self._build_error_response("Instance not found", status_code=404) + # The workflow's yielded outputs are checkpoint-encoded by the shared + # activity (typed objects become pickle/type-marker dicts). Reconstruct + # the originals so the HTTP response carries clean domain JSON, matching + # what DurableWorkflowClient.await_workflow_output returns in-process. + # status.output is the workflow's own (trusted) orchestration result. + decoded_output = deserialize_workflow_output(status.output) if status.output is not None else None + response = { "instanceId": status.instance_id, "runtimeStatus": status.runtime_status.name if status.runtime_status else None, "customStatus": status.custom_status, - "output": status.output, - "error": status.output if status.runtime_status == df.OrchestrationRuntimeStatus.Failed else None, + "output": decoded_output, + "error": decoded_output if status.runtime_status == df.OrchestrationRuntimeStatus.Failed else None, "createdTime": status.created_time.isoformat() if status.created_time else None, "lastUpdatedTime": status.last_updated_time.isoformat() if status.last_updated_time else None, } @@ -384,7 +418,7 @@ async def get_workflow_status( response["pendingHumanInputRequests"] = pending_requests return func.HttpResponse( - json.dumps(response, default=str), + json.dumps(response, default=_json_default), status_code=200, mimetype="application/json", ) diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py index c31ff005a48..fe6a0ac462b 100644 --- a/python/packages/azurefunctions/tests/test_app.py +++ b/python/packages/azurefunctions/tests/test_app.py @@ -1462,5 +1462,49 @@ def test_build_status_url_handles_trailing_slash(self) -> None: # See packages/durabletask/tests/test_workflow_activity.py. +class TestWorkflowStatusOutputEncoding: + """The workflow status endpoint emits clean domain JSON for reconstructed outputs. + + Reconstructed outputs (see ``deserialize_workflow_output``) may be framework + models or dataclasses; ``_json_default`` converts them via their own + serialization so the HTTP body is clean domain JSON rather than opaque + checkpoint-marker dicts or repr strings. + """ + + def test_encodes_framework_model_via_to_dict(self) -> None: + from agent_framework_azurefunctions._app import _json_default + + response = AgentResponse(messages=[Message(role="assistant", contents=["hello"])]) + + encoded = _json_default(response) + + assert encoded == response.to_dict() + # The result must be JSON-serializable (no marker dicts, no objects). + assert "hello" in json.dumps(encoded) + + def test_encodes_dataclass_via_asdict(self) -> None: + from dataclasses import dataclass + + from agent_framework_azurefunctions._app import _json_default + + @dataclass + class Decision: + approved: bool + note: str + + encoded = _json_default(Decision(approved=True, note="ok")) + + assert encoded == {"approved": True, "note": "ok"} + + def test_falls_back_to_str_for_plain_objects(self) -> None: + from agent_framework_azurefunctions._app import _json_default + + class Opaque: + def __str__(self) -> str: + return "opaque-value" + + assert _json_default(Opaque()) == "opaque-value" + + if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/durabletask/agent_framework_durabletask/__init__.py b/python/packages/durabletask/agent_framework_durabletask/__init__.py index cebd45ff4bc..bcab531e51f 100644 --- a/python/packages/durabletask/agent_framework_durabletask/__init__.py +++ b/python/packages/durabletask/agent_framework_durabletask/__init__.py @@ -58,6 +58,7 @@ from ._workflows.orchestrator import WORKFLOW_ORCHESTRATOR_NAME, run_workflow_orchestrator from ._workflows.registration import WorkflowRegistrationPlan, plan_workflow_registration from ._workflows.runner_context import CapturingRunnerContext +from ._workflows.serialization import deserialize_workflow_output try: __version__ = importlib.metadata.version(__name__) @@ -117,6 +118,7 @@ "WorkflowOrchestrationContext", "WorkflowRegistrationPlan", "__version__", + "deserialize_workflow_output", "ensure_response_format", "execute_workflow_activity", "load_agent_response", diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflows/client.py b/python/packages/durabletask/agent_framework_durabletask/_workflows/client.py index 930301c9745..15d7179f4ea 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_workflows/client.py +++ b/python/packages/durabletask/agent_framework_durabletask/_workflows/client.py @@ -16,7 +16,7 @@ from durabletask.client import TaskHubGrpcClient from .orchestrator import WORKFLOW_ORCHESTRATOR_NAME -from .serialization import strip_pickle_markers +from .serialization import deserialize_workflow_output, strip_pickle_markers logger = logging.getLogger("agent_framework.durabletask") @@ -107,7 +107,10 @@ def await_workflow_output(self, instance_id: str, *, timeout_seconds: int = 300) if metadata.serialized_output is None: return None - return json.loads(metadata.serialized_output) + # The shared activity encodes each yielded output with serialize_value() + # before it reaches the orchestrator, so typed objects come back as + # checkpoint-marker dicts. Reconstruct the originals before returning. + return deserialize_workflow_output(json.loads(metadata.serialized_output)) def get_pending_hitl_requests(self, instance_id: str) -> list[dict[str, Any]]: """Return the workflow's pending human-in-the-loop (HITL) requests, if any. diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflows/serialization.py b/python/packages/durabletask/agent_framework_durabletask/_workflows/serialization.py index f6e10281986..0cd616b2cef 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_workflows/serialization.py +++ b/python/packages/durabletask/agent_framework_durabletask/_workflows/serialization.py @@ -126,6 +126,33 @@ def deserialize_value(value: Any) -> Any: return decode_checkpoint_value(value) +def deserialize_workflow_output(output: Any) -> Any: + """Reconstruct the workflow outputs produced by the shared activity. + + Each value an executor yields is encoded with :func:`serialize_value` before + it reaches the orchestrator, so typed objects (dataclasses, Pydantic models, + ``AgentResponse``, ...) are stored as checkpoint-marker dicts. This reverses + that encoding so callers receive the original objects. + + This is the single decode path shared by every host (the in-process + :class:`DurableWorkflowClient` and the Azure Functions status endpoint) so + they never diverge in how a completed workflow's output is reconstructed. + + ``output`` must originate from the workflow's own orchestration result + (trusted durable storage), never from untrusted external input. Markers in + untrusted input must be neutralized with :func:`strip_pickle_markers` first. + + Args: + output: The workflow's orchestration result, already JSON-decoded (a list + of yielded outputs or a single value). + + Returns: + The output with every checkpoint-encoded value reconstructed; primitives + and plain JSON structures pass through unchanged. + """ + return deserialize_value(output) + + # ============================================================================ # HITL Type Reconstruction # ============================================================================ diff --git a/python/packages/durabletask/tests/test_workflow_client.py b/python/packages/durabletask/tests/test_workflow_client.py index 9a8056b474a..7ba0e4e880a 100644 --- a/python/packages/durabletask/tests/test_workflow_client.py +++ b/python/packages/durabletask/tests/test_workflow_client.py @@ -8,12 +8,22 @@ """ import json +from dataclasses import dataclass from unittest.mock import Mock import pytest from agent_framework_durabletask import DurableWorkflowClient from agent_framework_durabletask._workflows.orchestrator import WORKFLOW_ORCHESTRATOR_NAME +from agent_framework_durabletask._workflows.serialization import serialize_value + + +@dataclass +class _Receipt: + """Module-level dataclass so it is picklable by serialize_value.""" + + order_id: int + total: float @pytest.fixture @@ -93,6 +103,21 @@ def test_returns_none_when_no_output(self, workflow_client: DurableWorkflowClien assert workflow_client.await_workflow_output("instance-1") is None + def test_reconstructs_typed_outputs(self, workflow_client: DurableWorkflowClient, mock_client: Mock) -> None: + """Typed outputs encoded by the activity come back as objects, not marker dicts.""" + receipt = _Receipt(order_id=7, total=19.99) + # The shared activity stores each yielded output via serialize_value(), so a + # typed object is persisted as a checkpoint-marker dict. + metadata = Mock() + metadata.runtime_status.name = "COMPLETED" + metadata.serialized_output = json.dumps([serialize_value(receipt)]) + mock_client.wait_for_orchestration_completion.return_value = metadata + + output = workflow_client.await_workflow_output("instance-1") + + assert output == [receipt] + assert isinstance(output[0], _Receipt) + def test_raises_timeout_when_not_completed(self, workflow_client: DurableWorkflowClient, mock_client: Mock) -> None: """A None metadata (no completion) raises TimeoutError.""" mock_client.wait_for_orchestration_completion.return_value = None diff --git a/python/packages/durabletask/tests/test_workflow_serialization.py b/python/packages/durabletask/tests/test_workflow_serialization.py index 28d127de074..246e2e25968 100644 --- a/python/packages/durabletask/tests/test_workflow_serialization.py +++ b/python/packages/durabletask/tests/test_workflow_serialization.py @@ -1,16 +1,34 @@ # Copyright (c) Microsoft. All rights reserved. -"""Unit tests for workflow serialization helpers (`resolve_type`). +"""Unit tests for workflow serialization helpers. ``resolve_type`` is annotated ``type | None`` and its result flows into ``reconstruct_to_type``, which calls ``issubclass``. A non-class attribute (function, module member, etc.) would raise ``TypeError`` there, so the resolver must only ever return actual classes. + +``deserialize_workflow_output`` reverses the per-output ``serialize_value`` +encoding the shared activity applies, so typed outputs are returned as the +original objects rather than checkpoint-marker dicts. """ +import json from collections import OrderedDict +from dataclasses import dataclass + +from agent_framework_durabletask._workflows.serialization import ( + deserialize_workflow_output, + resolve_type, + serialize_value, +) + -from agent_framework_durabletask._workflows.serialization import resolve_type +@dataclass +class _Decision: + """Module-level dataclass so it is picklable by serialize_value.""" + + approved: bool + note: str class TestResolveType: @@ -29,3 +47,27 @@ def test_returns_none_for_unknown_attribute(self) -> None: def test_returns_none_for_malformed_key(self) -> None: assert resolve_type("not-a-valid-key") is None + + +class TestDeserializeWorkflowOutput: + """Reconstruction of stored workflow outputs.""" + + def test_primitives_pass_through(self) -> None: + # Mirror the stored shape: a list of yielded outputs, JSON round-tripped. + stored = json.loads(json.dumps([serialize_value("hello"), serialize_value(42)])) + + assert deserialize_workflow_output(stored) == ["hello", 42] + + def test_typed_outputs_are_reconstructed(self) -> None: + # A typed object is stored as a checkpoint-marker dict; it must come back + # as the original object, not the marker dict. + decision = _Decision(approved=True, note="ok") + stored = json.loads(json.dumps([serialize_value(decision)])) + + result = deserialize_workflow_output(stored) + + assert result == [decision] + assert isinstance(result[0], _Decision) + + def test_none_passes_through(self) -> None: + assert deserialize_workflow_output(None) is None From b5516cb8f978285528465b05439646b8788ec326 Mon Sep 17 00:00:00 2001 From: Ahmed Muhsin Date: Wed, 10 Jun 2026 18:07:34 -0500 Subject: [PATCH 15/15] fix(durabletask): address review findings on workflow hosting - AF: register workflow agents through add_agent(entity_id=...) so they remain tracked in app.agents / get_agent() (restores documented behavior) while keying by the executor id the orchestrator dispatches to; mirrors DurableAIAgentWorker.add_agent. - async bridge: treat the shared loop as reusable only while its backing thread is alive, so a dead loop thread is replaced instead of hanging future.result() forever. - client: add get_runtime_status; the standalone HITL sample now stops polling and reports the real terminal state instead of a generic timeout. - tests: guard send_hitl_response pickle-marker stripping and add get_runtime_status coverage. --- .../agent_framework_azurefunctions/_app.py | 44 ++++++++++++++----- .../packages/azurefunctions/tests/test_app.py | 9 +++- .../_async_bridge.py | 42 ++++++++++++------ .../_workflows/client.py | 20 +++++++++ .../durabletask/tests/test_workflow_client.py | 36 +++++++++++++++ .../durabletask/09_workflow_hitl/client.py | 14 +++++- 6 files changed, 138 insertions(+), 27 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index 5fc0bfd50ce..e9ac2590a65 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -258,10 +258,16 @@ def __init__( logger.debug("[AgentFunctionApp] Extracting agents from workflow") plan = plan_workflow_registration(workflow) for agent_executor in plan.agent_executors: - # Register each agent executor's entity under its executor id (the - # identity the orchestrator dispatches to), so AgentExecutor(agent, - # id=...) works even when the id differs from agent.name. - self._setup_agent_entity(agent_executor.agent, agent_executor.id, self.default_callback) + # Register each workflow agent through the same surface as a + # standalone agent (so it is tracked in ``agents`` / ``get_agent``), + # but keyed by the executor id the orchestrator dispatches to, so + # AgentExecutor(agent, id=...) works when the id differs from + # agent.name. Mirrors DurableAIAgentWorker.add_agent(entity_id=...). + self.add_agent( + agent_executor.agent, + callback=self.default_callback, + entity_id=agent_executor.id, + ) for executor in plan.activity_executors: # Set up a Functions activity trigger for each non-agent executor. self._setup_executor_activity(executor.id) @@ -496,6 +502,8 @@ def add_agent( callback: AgentResponseCallbackProtocol | None = None, enable_http_endpoint: bool | None = None, enable_mcp_tool_trigger: bool | None = None, + *, + entity_id: str | None = None, ) -> None: """Add an agent to the function app after initialization. @@ -507,6 +515,11 @@ def add_agent( The app level enable_http_endpoints setting will override this setting. enable_mcp_tool_trigger: Optional flag to enable/disable MCP tool trigger for this agent. The app level enable_mcp_tool_trigger setting will override this setting. + entity_id: Optional identity to register the agent under instead of + ``agent.name``. Workflow hosting passes the executor's ``id`` so the + durable entity (and the ``agents`` / ``get_agent`` key) matches the + identity the orchestrator dispatches to. Mirrors + ``DurableAIAgentWorker.add_agent(entity_id=...)``. Raises: ValueError: If the agent doesn't have a 'name' attribute. @@ -516,8 +529,15 @@ def add_agent( if name is None: raise ValueError("Agent does not have a 'name' attribute. All agents must have a 'name' attribute.") - if name in self._agent_metadata: - logger.warning("[AgentFunctionApp] Agent '%s' is already registered, skipping duplicate.", name) + # The registration name keys the agent everywhere on this app (metadata, + # routes, entity). It defaults to the agent name but can be overridden so a + # workflow agent is keyed by its executor id. + registration_name = entity_id or name + + if registration_name in self._agent_metadata: + logger.warning( + "[AgentFunctionApp] Agent '%s' is already registered, skipping duplicate.", registration_name + ) return effective_enable_http_endpoint = ( @@ -529,19 +549,19 @@ def add_agent( else self._coerce_to_bool(enable_mcp_tool_trigger) ) - logger.debug(f"[AgentFunctionApp] Adding agent: {name}") - logger.debug(f"[AgentFunctionApp] Route: /api/agents/{name}") + logger.debug(f"[AgentFunctionApp] Adding agent: {registration_name}") + logger.debug(f"[AgentFunctionApp] Route: /api/agents/{registration_name}") logger.debug( "[AgentFunctionApp] HTTP endpoint %s for agent '%s'", "enabled" if effective_enable_http_endpoint else "disabled", - name, + registration_name, ) logger.debug( f"[AgentFunctionApp] MCP tool trigger: {'enabled' if effective_enable_mcp_endpoint else 'disabled'}" ) # Store agent metadata - self._agent_metadata[name] = AgentMetadata( + self._agent_metadata[registration_name] = AgentMetadata( agent=agent, http_endpoint_enabled=effective_enable_http_endpoint, mcp_tool_enabled=effective_enable_mcp_endpoint, @@ -550,10 +570,10 @@ def add_agent( effective_callback = callback or self.default_callback self._setup_agent_functions( - agent, name, effective_callback, effective_enable_http_endpoint, effective_enable_mcp_endpoint + agent, registration_name, effective_callback, effective_enable_http_endpoint, effective_enable_mcp_endpoint ) - logger.debug(f"[AgentFunctionApp] Agent '{name}' added successfully") + logger.debug(f"[AgentFunctionApp] Agent '{registration_name}' added successfully") def get_agent( self, diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py index fe6a0ac462b..194f827197e 100644 --- a/python/packages/azurefunctions/tests/test_app.py +++ b/python/packages/azurefunctions/tests/test_app.py @@ -1362,13 +1362,20 @@ def test_init_with_workflow_registers_agent_entity_by_executor_id(self) -> None: patch.object(AgentFunctionApp, "_setup_workflow_orchestration"), patch.object(AgentFunctionApp, "_setup_agent_entity") as setup_entity, ): - AgentFunctionApp(workflow=mock_workflow) + app = AgentFunctionApp(workflow=mock_workflow) + # The entity is registered under the executor id (the dispatch identity). setup_entity.assert_called_once() call_args = setup_entity.call_args.args assert call_args[0] is mock_agent assert call_args[1] == "custom-executor-id" + # Regression guard: the workflow agent must also be tracked on the app's + # normal registration surface, keyed by the executor id, so it appears in + # ``agents`` and is retrievable via ``get_agent`` (as the constructor documents). + assert "custom-executor-id" in app.agents + assert app.agents["custom-executor-id"] is mock_agent + def test_init_with_workflow_calls_setup_methods(self) -> None: """Test that workflow setup methods are called.""" mock_executor = Mock() diff --git a/python/packages/durabletask/agent_framework_durabletask/_async_bridge.py b/python/packages/durabletask/agent_framework_durabletask/_async_bridge.py index 72b1140d02c..2ca52c6c94e 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_async_bridge.py +++ b/python/packages/durabletask/agent_framework_durabletask/_async_bridge.py @@ -18,6 +18,7 @@ from __future__ import annotations import asyncio +import contextlib import threading from collections.abc import Coroutine from typing import Any, TypeVar @@ -30,28 +31,43 @@ def _ensure_loop() -> asyncio.AbstractEventLoop: - """Return the shared persistent event loop, starting it on first use.""" + """Return the shared persistent event loop, starting it on first use. + + The loop is only reusable when it is open *and* its backing thread is still + alive. A loop whose thread has died (e.g. during interpreter shutdown) is not + reusable: ``run_coroutine_threadsafe`` would schedule onto a loop that will + never run again and ``future.result()`` would block forever. Such a loop is + replaced with a fresh loop + thread. + """ global _loop, _thread - if _loop is not None and not _loop.is_closed(): - return _loop + loop, thread = _loop, _thread + if loop is not None and not loop.is_closed() and thread is not None and thread.is_alive(): + return loop with _lock: - if _loop is not None and not _loop.is_closed(): - return _loop + loop, thread = _loop, _thread + if loop is not None and not loop.is_closed() and thread is not None and thread.is_alive(): + return loop - loop = asyncio.new_event_loop() + # An existing loop whose thread has died is orphaned; close it best-effort + # before replacing it so it does not leak. + if loop is not None and not loop.is_closed(): + with contextlib.suppress(Exception): + loop.close() + + new_loop = asyncio.new_event_loop() def _run() -> None: - asyncio.set_event_loop(loop) - loop.run_forever() + asyncio.set_event_loop(new_loop) + new_loop.run_forever() - thread = threading.Thread(target=_run, name="dafx-agent-loop", daemon=True) - thread.start() + new_thread = threading.Thread(target=_run, name="dafx-agent-loop", daemon=True) + new_thread.start() - _loop = loop - _thread = thread - return loop + _loop = new_loop + _thread = new_thread + return new_loop def run_agent_coroutine(coro: Coroutine[Any, Any, _T]) -> _T: diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflows/client.py b/python/packages/durabletask/agent_framework_durabletask/_workflows/client.py index 15d7179f4ea..612576c715c 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_workflows/client.py +++ b/python/packages/durabletask/agent_framework_durabletask/_workflows/client.py @@ -112,6 +112,26 @@ def await_workflow_output(self, instance_id: str, *, timeout_seconds: int = 300) # checkpoint-marker dicts. Reconstruct the originals before returning. return deserialize_workflow_output(json.loads(metadata.serialized_output)) + def get_runtime_status(self, instance_id: str) -> str | None: + """Return the workflow's current runtime status name, or ``None`` if unknown. + + Lets callers distinguish a workflow that is still running or paused for + human input from one that has reached a terminal state (for example + ``COMPLETED``, ``FAILED``, or ``TERMINATED``) — useful when polling, so a + workflow that ends without pausing is not mistaken for one that never paused. + + Args: + instance_id: The instance ID returned by ``start_workflow``. + + Returns: + The runtime status name (e.g. ``"RUNNING"``, ``"COMPLETED"``), or + ``None`` if no state is available for the instance. + """ + state = self._client.get_orchestration_state(instance_id) + if state is None: + return None + return state.runtime_status.name + def get_pending_hitl_requests(self, instance_id: str) -> list[dict[str, Any]]: """Return the workflow's pending human-in-the-loop (HITL) requests, if any. diff --git a/python/packages/durabletask/tests/test_workflow_client.py b/python/packages/durabletask/tests/test_workflow_client.py index 7ba0e4e880a..434a25a7074 100644 --- a/python/packages/durabletask/tests/test_workflow_client.py +++ b/python/packages/durabletask/tests/test_workflow_client.py @@ -138,6 +138,24 @@ def test_raises_runtime_error_on_failed_status( workflow_client.await_workflow_output("instance-1") +class TestGetRuntimeStatus: + """Test reading the workflow's runtime status.""" + + def test_returns_status_name(self, workflow_client: DurableWorkflowClient, mock_client: Mock) -> None: + """The runtime status name is returned when state is available.""" + state = Mock() + state.runtime_status.name = "RUNNING" + mock_client.get_orchestration_state.return_value = state + + assert workflow_client.get_runtime_status("instance-1") == "RUNNING" + + def test_returns_none_when_no_state(self, workflow_client: DurableWorkflowClient, mock_client: Mock) -> None: + """No orchestration state yields None (status unknown).""" + mock_client.get_orchestration_state.return_value = None + + assert workflow_client.get_runtime_status("instance-1") is None + + class TestGetPendingHitlRequests: """Test parsing pending HITL requests from custom status.""" @@ -236,3 +254,21 @@ def test_raises_orchestration_event_with_request_id( _, kwargs = mock_client.raise_orchestration_event.call_args assert kwargs["event_name"] == "req-1" assert kwargs["data"] == {"approved": True} + + def test_strips_pickle_markers_before_delivery( + self, workflow_client: DurableWorkflowClient, mock_client: Mock + ) -> None: + """A crafted pickle-marker payload is neutralized before reaching the worker. + + The HITL response is sent to the worker which deserializes it, so a payload + carrying the checkpoint ``__pickled__`` marker must be stripped client-side + (regression guard for the strip_pickle_markers call in send_hitl_response). + """ + malicious = {"__pickled__": "", "approved": True} + + workflow_client.send_hitl_response("instance-1", "req-1", malicious) + + _, kwargs = mock_client.raise_orchestration_event.call_args + # The whole marker-bearing dict is neutralized (replaced with None) rather + # than forwarded, so it can never reach pickle.loads on the worker. + assert kwargs["data"] is None diff --git a/python/samples/04-hosting/durabletask/09_workflow_hitl/client.py b/python/samples/04-hosting/durabletask/09_workflow_hitl/client.py index aab91d2e991..73d06dde3a8 100644 --- a/python/samples/04-hosting/durabletask/09_workflow_hitl/client.py +++ b/python/samples/04-hosting/durabletask/09_workflow_hitl/client.py @@ -52,12 +52,24 @@ def get_client(taskhub: str | None = None, endpoint: str | None = None) -> Durab def _wait_for_hitl_request( client: DurableWorkflowClient, instance_id: str, timeout_seconds: int = 60 ) -> list[dict[str, Any]]: - """Poll until the workflow has at least one pending HITL request.""" + """Poll until the workflow has at least one pending HITL request. + + Stops early if the workflow reaches a terminal state (e.g. completed or failed) + without pausing, so a misconfiguration or early failure surfaces the real + status instead of a misleading timeout. + """ + terminal_statuses = {"COMPLETED", "FAILED", "TERMINATED"} deadline = time.time() + timeout_seconds while time.time() < deadline: pending = client.get_pending_hitl_requests(instance_id) if pending: return pending + status = client.get_runtime_status(instance_id) + if status in terminal_statuses: + raise RuntimeError( + f"Workflow instance {instance_id} reached terminal state '{status}' " + "before pausing for human input." + ) time.sleep(2) raise TimeoutError(f"Timed out waiting for a HITL request on instance {instance_id}")