diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index c25c2461cee..e9ac2590a65 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -14,15 +14,13 @@ import re import uuid from collections.abc import Callable, Mapping -from copy import deepcopy -from dataclasses import dataclass +from dataclasses import asdict, dataclass, is_dataclass from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, TypeVar, cast import azure.durable_functions as df import azure.functions as func -from agent_framework import AgentExecutor, SupportsAgentRun, Workflow, WorkflowEvent -from agent_framework._workflows._runner_context import YieldOutputEventType +from agent_framework import SupportsAgentRun, Workflow from agent_framework_durabletask import ( DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS, @@ -34,25 +32,23 @@ THREAD_ID_HEADER, WAIT_FOR_RESPONSE_FIELD, WAIT_FOR_RESPONSE_HEADER, + WORKFLOW_ORCHESTRATOR_NAME, AgentResponseCallbackProtocol, AgentSessionId, ApiResponseFields, DurableAgentState, DurableAIAgent, RunRequest, + deserialize_workflow_output, + execute_workflow_activity, + plan_workflow_registration, ) -from ._context import CapturingRunnerContext from ._entities import create_agent_entity from ._errors import IncomingRequestError from ._orchestration import AgentOrchestrationContextType, AgentTask, AzureFunctionsAgentExecutor -from ._serialization import deserialize_value, serialize_value, strip_pickle_markers -from ._workflow import ( - SOURCE_HITL_RESPONSE, - SOURCE_ORCHESTRATOR, - execute_hitl_response_handler, - run_workflow_orchestrator, -) +from ._serialization import strip_pickle_markers +from ._workflow import run_workflow_orchestrator logger = logging.getLogger("agent_framework.azurefunctions") @@ -60,9 +56,30 @@ HandlerT = TypeVar("HandlerT", bound=Callable[..., Any]) -def _create_state_snapshot(state: dict[str, Any]) -> dict[str, Any]: - """Create a deep copy of the deserialized state for later diffing.""" - return deepcopy(state) +def _json_default(obj: Any) -> Any: + """JSON fallback encoder for reconstructed workflow outputs. + + A workflow's yielded outputs are reconstructed (see ``deserialize_workflow_output``) + before they reach the HTTP response, so they may be framework models + (e.g. ``AgentResponse``), dataclasses, or other non-JSON-native objects. + Prefer the type's own serialization so the response carries clean domain + JSON, falling back to ``str`` for anything without one. + """ + to_dict = getattr(obj, "to_dict", None) + if callable(to_dict): + try: + return to_dict() + except Exception: + logger.debug("to_dict() failed while encoding %s for HTTP output", type(obj).__name__) + model_dump = getattr(obj, "model_dump", None) + if callable(model_dump): + try: + return model_dump(mode="json") + except Exception: + logger.debug("model_dump() failed while encoding %s for HTTP output", type(obj).__name__) + if is_dataclass(obj) and not isinstance(obj, type): + return asdict(obj) + return str(obj) @dataclass @@ -234,17 +251,26 @@ def __init__( interval = DEFAULT_POLL_INTERVAL_SECONDS self.poll_interval_seconds = interval if interval > 0 else DEFAULT_POLL_INTERVAL_SECONDS - # If workflow is provided, extract agents and set up orchestration + # If workflow is provided, extract agents and set up orchestration. + # The "what to register" decision (agent -> entity, non-agent -> activity) + # is shared with the standalone durabletask host via plan_workflow_registration. if workflow: - if agents is None: - agents = [] logger.debug("[AgentFunctionApp] Extracting agents from workflow") - for executor in workflow.executors.values(): - if isinstance(executor, AgentExecutor): - agents.append(executor.agent) - else: - # Setup individual activity for each non-agent executor - self._setup_executor_activity(executor.id) + plan = plan_workflow_registration(workflow) + for agent_executor in plan.agent_executors: + # Register each workflow agent through the same surface as a + # standalone agent (so it is tracked in ``agents`` / ``get_agent``), + # but keyed by the executor id the orchestrator dispatches to, so + # AgentExecutor(agent, id=...) works when the id differs from + # agent.name. Mirrors DurableAIAgentWorker.add_agent(entity_id=...). + self.add_agent( + agent_executor.agent, + callback=self.default_callback, + entity_id=agent_executor.id, + ) + for executor in plan.activity_executors: + # Set up a Functions activity trigger for each non-agent executor. + self._setup_executor_activity(executor.id) self._setup_workflow_orchestration() @@ -279,18 +305,9 @@ def executor_activity(inputData: str) -> str: Note: We use str type annotations instead of dict to work around Azure Functions worker type validation issues with dict[str, Any]. + The execution body is shared with the standalone durabletask host via + ``execute_workflow_activity``. """ - from agent_framework._workflows._state import State - - data_obj = json.loads(inputData) - if not isinstance(data_obj, dict): - raise ValueError("Activity inputData must decode to a JSON object") - data = cast(dict[str, Any], data_obj) - - message_data = data.get("message") - shared_state_snapshot = data.get("shared_state_snapshot", {}) - source_executor_ids = cast(list[str], data.get("source_executor_ids", [SOURCE_ORCHESTRATOR])) - if not self.workflow: raise RuntimeError("Workflow not initialized in AgentFunctionApp") @@ -298,120 +315,7 @@ def executor_activity(inputData: str) -> str: if not executor: raise ValueError(f"Unknown executor: {captured_executor_id}") - # Reconstruct message - deserialize_value restores the original typed objects - # from the encoded data (with type markers) - message = deserialize_value(message_data) - - # Check if this is a HITL response message by examining source_executor_ids - is_hitl_response = any(s.startswith(SOURCE_HITL_RESPONSE) for s in source_executor_ids) - - async def run() -> dict[str, Any]: - # Create runner context and shared state - runner_context = CapturingRunnerContext() - workflow = self.workflow - - def classify_yielded_output(executor_id: str) -> YieldOutputEventType | None: - if workflow is None: - return "output" - if workflow.is_terminal_executor(executor_id): - return "output" - if workflow.is_intermediate_executor(executor_id): - return "intermediate" - return None - - runner_context.set_yield_output_classifier(classify_yielded_output) - shared_state = State() - - # Deserialize shared state values to reconstruct dataclasses/Pydantic models - deserialized_state: dict[str, Any] = { - str(k): deserialize_value(v) for k, v in shared_state_snapshot.items() - } - original_snapshot = _create_state_snapshot(deserialized_state) - shared_state.import_state(deserialized_state) - - if is_hitl_response: - # Handle HITL response by calling the executor's @response_handler - if not isinstance(message_data, dict): - raise ValueError("HITL message payload must be a JSON object") - - await execute_hitl_response_handler( - executor=executor, - hitl_message=cast(dict[str, Any], message_data), - shared_state=shared_state, - runner_context=runner_context, - ) - else: - # Execute using the public execute() method - await executor.execute( - message=message, - source_executor_ids=source_executor_ids, - state=shared_state, - runner_context=runner_context, - ) - - # Commit pending state changes and export - shared_state.commit() - current_state = shared_state.export_state() - original_keys: set[str] = set(original_snapshot.keys()) - current_keys: set[str] = set(current_state.keys()) - - # Deleted = was in original, not in current - deletes: set[str] = original_keys - current_keys - - # Updates = keys in current that are new or have different values - updates: dict[str, Any] = {} - for key in current_keys: - if key not in original_keys or current_state[key] != original_snapshot.get(key): - updates[key] = current_state[key] - - # Drain messages and events from runner context - sent_messages = await runner_context.drain_messages() - events = await runner_context.drain_events() - - # Extract outputs from WorkflowEvent instances with type='output' - outputs: list[Any] = [] - for event in events: - if isinstance(event, WorkflowEvent) and event.type == "output": - outputs.append(serialize_value(event.data)) - - # Get pending request info events for HITL - pending_request_info_events = await runner_context.get_pending_request_info_events() - - # Serialize pending request info events for orchestrator - serialized_pending_requests: list[dict[str, Any]] = [] - for _request_id, event in pending_request_info_events.items(): - serialized_pending_requests.append({ - "request_id": event.request_id, - "source_executor_id": event.source_executor_id, - "data": serialize_value(event.data), - "request_type": f"{type(event.data).__module__}:{type(event.data).__name__}", - "response_type": f"{event.response_type.__module__}:{event.response_type.__name__}" - if event.response_type - else None, - }) - - # Serialize messages for JSON compatibility - serialized_sent_messages: list[dict[str, Any]] = [] - for _source_id, msg_list in sent_messages.items(): - for msg in msg_list: - serialized_sent_messages.append({ - "message": serialize_value(msg.data), - "target_id": msg.target_id, - "source_id": msg.source_id, - }) - - serialized_updates = {k: serialize_value(v) for k, v in updates.items()} - - return { - "sent_messages": serialized_sent_messages, - "outputs": outputs, - "shared_state_updates": serialized_updates, - "shared_state_deletes": list(deletes), - "pending_request_info_events": serialized_pending_requests, - } - - result = asyncio.run(run()) - return json.dumps(result) + return execute_workflow_activity(executor, inputData, self.workflow) # Ensure the function is registered (prevents garbage collection) _ = executor_activity @@ -427,8 +331,9 @@ def workflow_orchestrator(context: df.DurableOrchestrationContext) -> Any: # ty input_data = context.get_input() - # Ensure input is a string for the agent - initial_message = json.dumps(input_data) if isinstance(input_data, (dict, list)) else str(input_data) + # Pass the deserialized client input straight to the shared engine, which + # reconstructs the start executor's declared type (see _coerce_initial_input). + initial_message = input_data # Create local shared state dict for cross-executor state sharing shared_state: dict[str, Any] = {} @@ -448,7 +353,7 @@ async def start_workflow_orchestration( except ValueError: return self._build_error_response("Invalid JSON body") - instance_id = await client.start_new("workflow_orchestrator", client_input=req_body) + instance_id = await client.start_new(WORKFLOW_ORCHESTRATOR_NAME, client_input=req_body) base_url = self._build_base_url(req.url) status_url = f"{base_url}/api/workflow/status/{instance_id}" @@ -479,12 +384,19 @@ async def get_workflow_status( if not status: return self._build_error_response("Instance not found", status_code=404) + # The workflow's yielded outputs are checkpoint-encoded by the shared + # activity (typed objects become pickle/type-marker dicts). Reconstruct + # the originals so the HTTP response carries clean domain JSON, matching + # what DurableWorkflowClient.await_workflow_output returns in-process. + # status.output is the workflow's own (trusted) orchestration result. + decoded_output = deserialize_workflow_output(status.output) if status.output is not None else None + response = { "instanceId": status.instance_id, "runtimeStatus": status.runtime_status.name if status.runtime_status else None, "customStatus": status.custom_status, - "output": status.output, - "error": status.output if status.runtime_status == df.OrchestrationRuntimeStatus.Failed else None, + "output": decoded_output, + "error": decoded_output if status.runtime_status == df.OrchestrationRuntimeStatus.Failed else None, "createdTime": status.created_time.isoformat() if status.created_time else None, "lastUpdatedTime": status.last_updated_time.isoformat() if status.last_updated_time else None, } @@ -512,7 +424,7 @@ async def get_workflow_status( response["pendingHumanInputRequests"] = pending_requests return func.HttpResponse( - json.dumps(response, default=str), + json.dumps(response, default=_json_default), status_code=200, mimetype="application/json", ) @@ -590,6 +502,8 @@ def add_agent( callback: AgentResponseCallbackProtocol | None = None, enable_http_endpoint: bool | None = None, enable_mcp_tool_trigger: bool | None = None, + *, + entity_id: str | None = None, ) -> None: """Add an agent to the function app after initialization. @@ -601,6 +515,11 @@ def add_agent( The app level enable_http_endpoints setting will override this setting. enable_mcp_tool_trigger: Optional flag to enable/disable MCP tool trigger for this agent. The app level enable_mcp_tool_trigger setting will override this setting. + entity_id: Optional identity to register the agent under instead of + ``agent.name``. Workflow hosting passes the executor's ``id`` so the + durable entity (and the ``agents`` / ``get_agent`` key) matches the + identity the orchestrator dispatches to. Mirrors + ``DurableAIAgentWorker.add_agent(entity_id=...)``. Raises: ValueError: If the agent doesn't have a 'name' attribute. @@ -610,8 +529,15 @@ def add_agent( if name is None: raise ValueError("Agent does not have a 'name' attribute. All agents must have a 'name' attribute.") - if name in self._agent_metadata: - logger.warning("[AgentFunctionApp] Agent '%s' is already registered, skipping duplicate.", name) + # The registration name keys the agent everywhere on this app (metadata, + # routes, entity). It defaults to the agent name but can be overridden so a + # workflow agent is keyed by its executor id. + registration_name = entity_id or name + + if registration_name in self._agent_metadata: + logger.warning( + "[AgentFunctionApp] Agent '%s' is already registered, skipping duplicate.", registration_name + ) return effective_enable_http_endpoint = ( @@ -623,19 +549,19 @@ def add_agent( else self._coerce_to_bool(enable_mcp_tool_trigger) ) - logger.debug(f"[AgentFunctionApp] Adding agent: {name}") - logger.debug(f"[AgentFunctionApp] Route: /api/agents/{name}") + logger.debug(f"[AgentFunctionApp] Adding agent: {registration_name}") + logger.debug(f"[AgentFunctionApp] Route: /api/agents/{registration_name}") logger.debug( "[AgentFunctionApp] HTTP endpoint %s for agent '%s'", "enabled" if effective_enable_http_endpoint else "disabled", - name, + registration_name, ) logger.debug( f"[AgentFunctionApp] MCP tool trigger: {'enabled' if effective_enable_mcp_endpoint else 'disabled'}" ) # Store agent metadata - self._agent_metadata[name] = AgentMetadata( + self._agent_metadata[registration_name] = AgentMetadata( agent=agent, http_endpoint_enabled=effective_enable_http_endpoint, mcp_tool_enabled=effective_enable_mcp_endpoint, @@ -644,10 +570,10 @@ def add_agent( effective_callback = callback or self.default_callback self._setup_agent_functions( - agent, name, effective_callback, effective_enable_http_endpoint, effective_enable_mcp_endpoint + agent, registration_name, effective_callback, effective_enable_http_endpoint, effective_enable_mcp_endpoint ) - logger.debug(f"[AgentFunctionApp] Agent '{name}' added successfully") + logger.debug(f"[AgentFunctionApp] Agent '{registration_name}' added successfully") def get_agent( self, @@ -1092,8 +1018,6 @@ async def _get_response_from_entity( thread_id: str, ) -> dict[str, Any]: """Poll the entity state until a response is available or timeout occurs.""" - import asyncio - max_retries = self.max_poll_retries interval = self.poll_interval_seconds retry_count = 0 diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py index d734950979d..9de242efaa6 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py @@ -9,7 +9,6 @@ from __future__ import annotations -import asyncio import logging from collections.abc import Callable from typing import Any, cast @@ -20,6 +19,7 @@ AgentEntity, AgentEntityStateProviderMixin, AgentResponseCallbackProtocol, + run_agent_coroutine, ) logger = logging.getLogger("agent_framework.azurefunctions") @@ -101,23 +101,16 @@ async def _entity_coroutine(context: df.DurableEntityContext) -> None: context.set_result({"error": str(exc), "status": "error"}) def entity_function(context: df.DurableEntityContext) -> None: - """Synchronous wrapper invoked by the Durable Functions runtime.""" + """Synchronous wrapper invoked by the Durable Functions runtime. + + All agent coroutines run on a single process-wide persistent event loop + (see ``run_agent_coroutine``). This keeps async resources created by + shared agent clients/credentials bound to a live loop across every + invocation, preventing cross-loop hangs when the host dispatches + successive entity operations onto different worker threads. + """ try: - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - if loop.is_running(): - temp_loop = asyncio.new_event_loop() - try: - temp_loop.run_until_complete(_entity_coroutine(context)) - finally: - temp_loop.close() - else: - loop.run_until_complete(_entity_coroutine(context)) - + run_agent_coroutine(_entity_coroutine(context)) except Exception as exc: # pragma: no cover - defensive logging logger.error("[entity_function] Unexpected error executing entity: %s", exc, exc_info=True) context.set_result({"error": str(exc), "status": "error"}) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py index 6fbdaf44f7e..f6fbe957ab0 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py @@ -2,603 +2,60 @@ """Workflow Execution for Durable Functions. -This module provides the workflow orchestration engine that executes MAF Workflows -using Azure Durable Functions. It reuses MAF's edge group routing logic while -adapting execution to the DF generator-based model (yield instead of await). - -Key components: -- run_workflow_orchestrator: Main orchestration function for workflow execution -- route_message_through_edge_groups: Routing helper using MAF edge group APIs -- build_agent_executor_response: Helper to construct AgentExecutorResponse - -HITL (Human-in-the-Loop) Support: -- Detects pending RequestInfoEvents from executor activities -- Uses wait_for_external_event to pause for human input -- Routes responses back to executor's @response_handler methods +This module provides the Azure Functions entry point for workflow orchestration. +The actual orchestration logic lives in the shared module +``agent_framework_durabletask._workflows.orchestrator`` and is host-agnostic. +This module re-exports the public API and provides the AF-specific +``run_workflow_orchestrator`` wrapper that creates an +:class:`AzureFunctionsWorkflowContext` before delegating. """ from __future__ import annotations -import json import logging -from collections import defaultdict from collections.abc import Generator -from dataclasses import dataclass -from datetime import timedelta -from enum import Enum from typing import Any -from agent_framework import ( - AgentExecutor, - AgentExecutorRequest, - AgentExecutorResponse, - AgentResponse, - Message, - Workflow, +from agent_framework import Workflow +from agent_framework_durabletask._workflows.orchestrator import ( + DEFAULT_HITL_TIMEOUT_HOURS, + SOURCE_HITL_RESPONSE, + SOURCE_ORCHESTRATOR, + SOURCE_WORKFLOW_START, + ExecutorResult, + PendingHITLRequest, + TaskMetadata, + TaskType, + _extract_message_content, # pyright: ignore[reportPrivateUsage] + build_agent_executor_response, + execute_hitl_response_handler, + route_message_through_edge_groups, ) -from agent_framework._workflows._edge import ( - Edge, - EdgeGroup, - FanInEdgeGroup, - FanOutEdgeGroup, - SingleEdgeGroup, - SwitchCaseEdgeGroup, +from agent_framework_durabletask._workflows.orchestrator import ( + run_workflow_orchestrator as _run_workflow_orchestrator_shared, ) -from agent_framework._workflows._state import State -from agent_framework_durabletask import AgentSessionId, DurableAgentSession, DurableAIAgent from azure.durable_functions import DurableOrchestrationContext -from ._context import CapturingRunnerContext -from ._orchestration import AzureFunctionsAgentExecutor -from ._serialization import deserialize_value, reconstruct_to_type, resolve_type, serialize_value, strip_pickle_markers +from ._workflow_af_context import AzureFunctionsWorkflowContext logger = logging.getLogger(__name__) - -# ============================================================================ -# Source Marker Constants -# ============================================================================ -# These markers identify the origin of messages in the workflow orchestration. -# They are used to track message provenance and handle special cases like HITL. - -# Marker indicating the message originated from the workflow start (initial user input) -SOURCE_WORKFLOW_START = "__workflow_start__" - -# Marker indicating the message originated from the orchestrator itself -# (used as default when executor is called directly by orchestrator, not via another executor) -SOURCE_ORCHESTRATOR = "__orchestrator__" - -# Marker indicating the message is a human-in-the-loop response. -# Used as a source ID prefix. To detect HITL responses, check if any source_executor_id -# starts with this prefix. -SOURCE_HITL_RESPONSE = "__hitl_response__" - - -# ============================================================================ -# Task Types and Data Structures -# ============================================================================ - - -class TaskType(Enum): - """Type of executor task.""" - - AGENT = "agent" - ACTIVITY = "activity" - - -@dataclass -class TaskMetadata: - """Metadata for a pending task.""" - - executor_id: str - message: Any - source_executor_id: str - task_type: TaskType - remaining_messages: list[tuple[str, Any, str]] | None = None # For agents with multiple messages - - -@dataclass -class ExecutorResult: - """Result from executing an agent or activity.""" - - executor_id: str - output_message: AgentExecutorResponse | None - activity_result: dict[str, Any] | None - task_type: TaskType - - -@dataclass -class PendingHITLRequest: - """Tracks a pending Human-in-the-Loop request in the orchestrator. - - Attributes: - request_id: Unique identifier for correlation with external events - source_executor_id: The executor that called ctx.request_info() - request_data: The serialized request payload - request_type: Fully qualified type name of the request data - response_type: Fully qualified type name of expected response - """ - - request_id: str - source_executor_id: str - request_data: Any - request_type: str | None - response_type: str | None - - -# Default timeout for HITL requests (72 hours) -DEFAULT_HITL_TIMEOUT_HOURS = 72.0 - - -# ============================================================================ -# Routing Functions -# ============================================================================ - - -def _evaluate_edge_condition_sync(edge: Edge, message: Any) -> bool: - """Evaluate an edge's condition synchronously. - - This is needed because Durable Functions orchestrators use generators, - not async/await, so we cannot call async methods like edge.should_route(). - - Args: - edge: The Edge with an optional _condition callable - message: The message to evaluate against the condition - - Returns: - True if the edge should be traversed, False otherwise - """ - # Access the internal condition directly since should_route is async - condition = edge._condition # pyright: ignore[reportPrivateUsage] - if condition is None: - return True - result = condition(message) - # If the condition is async, we cannot await it in a generator context - # Log a warning and assume True (or False for safety) - if hasattr(result, "__await__"): - import warnings - - warnings.warn( - f"Edge condition for {edge.source_id}->{edge.target_id} is async, " - "which is not supported in Durable Functions orchestrators. " - "The edge will be traversed unconditionally.", - RuntimeWarning, - stacklevel=2, - ) - return True - return bool(result) - - -def route_message_through_edge_groups( - edge_groups: list[EdgeGroup], - source_id: str, - message: Any, -) -> list[str]: - """Route a message through edge groups to find target executor IDs. - - Delegates to MAF's edge group routing logic instead of manual inspection. - - Args: - edge_groups: List of EdgeGroup instances from the workflow - source_id: The ID of the source executor - message: The message to route - - Returns: - List of target executor IDs that should receive the message - """ - targets: list[str] = [] - - for group in edge_groups: - if source_id not in group.source_executor_ids: - continue - - # SwitchCaseEdgeGroup and FanOutEdgeGroup use selection_func - if isinstance(group, (SwitchCaseEdgeGroup, FanOutEdgeGroup)): - if group.selection_func is not None: - selected = group.selection_func(message, group.target_executor_ids) - targets.extend(selected) - else: - # No selection func means broadcast to all targets - targets.extend(group.target_executor_ids) - - elif isinstance(group, SingleEdgeGroup): - # SingleEdgeGroup has exactly one edge - edge = group.edges[0] - if _evaluate_edge_condition_sync(edge, message): - targets.append(edge.target_id) - - elif isinstance(group, FanInEdgeGroup): - # FanIn is handled separately in the orchestrator loop - # since it requires aggregation - pass - - else: - # Generic EdgeGroup: check each edge's condition - for edge in group.edges: - if edge.source_id == source_id and _evaluate_edge_condition_sync(edge, message): - targets.append(edge.target_id) - - return targets - - -def build_agent_executor_response( - executor_id: str, - response_text: str | None, - structured_response: dict[str, Any] | None, - previous_message: Any, -) -> AgentExecutorResponse: - """Build an AgentExecutorResponse from entity response data. - - Shared helper to construct the response object consistently. - - Args: - executor_id: The ID of the executor that produced the response - response_text: Plain text response from the agent (if any) - structured_response: Structured JSON response (if any) - previous_message: The input message that triggered this response - - Returns: - AgentExecutorResponse with reconstructed conversation - """ - final_text: str = response_text or "" - if structured_response: - final_text = json.dumps(structured_response) - - assistant_message = Message(role="assistant", contents=[final_text]) - - agent_response = AgentResponse( - messages=[assistant_message], - ) - - # Build conversation history - full_conversation: list[Message] = [] - if isinstance(previous_message, AgentExecutorResponse) and previous_message.full_conversation: - full_conversation.extend(previous_message.full_conversation) - elif isinstance(previous_message, str): - full_conversation.append(Message(role="user", contents=[previous_message])) - - full_conversation.append(assistant_message) - - return AgentExecutorResponse( - executor_id=executor_id, - agent_response=agent_response, - full_conversation=full_conversation, - ) - - -# ============================================================================ -# Task Preparation Helpers -# ============================================================================ - - -def _prepare_agent_task( - context: DurableOrchestrationContext, - executor_id: str, - message: Any, -) -> Any: - """Prepare an agent task for execution. - - Args: - context: The Durable Functions orchestration context - executor_id: The agent executor ID (agent name) - message: The input message for the agent - - Returns: - A task that can be yielded to execute the agent - """ - message_content = _extract_message_content(message) - session_id = AgentSessionId(name=executor_id, key=context.instance_id) - session = DurableAgentSession(durable_session_id=session_id) - - az_executor = AzureFunctionsAgentExecutor(context) - agent = DurableAIAgent(az_executor, executor_id) - return agent.run(message_content, session=session) - - -def _prepare_activity_task( - context: DurableOrchestrationContext, - executor_id: str, - message: Any, - source_executor_id: str, - shared_state_snapshot: dict[str, Any] | None, -) -> Any: - """Prepare an activity task for execution. - - Args: - context: The Durable Functions orchestration context - executor_id: The activity executor ID - message: The input message for the activity - source_executor_id: The ID of the executor that sent the message - shared_state_snapshot: Current shared state snapshot - - Returns: - A task that can be yielded to execute the activity - """ - activity_input = { - "executor_id": executor_id, - "message": serialize_value(message), - "shared_state_snapshot": shared_state_snapshot, - "source_executor_ids": [source_executor_id], - } - activity_input_json = json.dumps(activity_input) - # Use the prefixed activity name that matches the registered function - activity_name = f"dafx-{executor_id}" - orchestration_context: Any = context - return orchestration_context.call_activity(activity_name, activity_input_json) - - -# ============================================================================ -# Result Processing Helpers -# ============================================================================ - - -def _process_agent_response( - agent_response: AgentResponse, - executor_id: str, - message: Any, -) -> ExecutorResult: - """Process an agent response into an ExecutorResult. - - Args: - agent_response: The response from the agent - executor_id: The agent executor ID - message: The original input message - - Returns: - ExecutorResult containing the processed response - """ - response_text = agent_response.text if agent_response else None - structured_response: dict[str, Any] | None = None - - if agent_response and agent_response.value is not None: - model_dump = getattr(agent_response.value, "model_dump", None) - if callable(model_dump): - dumped = model_dump() - if isinstance(dumped, dict): - structured_response = dumped # type: ignore[assignment] - elif isinstance(agent_response.value, dict): - structured_response = agent_response.value # type: ignore[assignment] - - output_message = build_agent_executor_response( - executor_id=executor_id, - response_text=response_text, - structured_response=structured_response, - previous_message=message, - ) - - return ExecutorResult( - executor_id=executor_id, - output_message=output_message, - activity_result=None, - task_type=TaskType.AGENT, - ) - - -def _process_activity_result( - result_json: str | None, - executor_id: str, - shared_state: dict[str, Any] | None, - workflow_outputs: list[Any], -) -> ExecutorResult: - """Process an activity result and apply shared state updates. - - Args: - result_json: The JSON result from the activity - executor_id: The activity executor ID - shared_state: The shared state dict to update (mutated in place) - workflow_outputs: List to append outputs to (mutated in place) - - Returns: - ExecutorResult containing the processed result - """ - result = json.loads(result_json) if result_json else None - - # Apply shared state updates - if shared_state is not None and result: - if result.get("shared_state_updates"): - updates = result["shared_state_updates"] - logger.debug("[workflow] Applying SharedState updates from %s: %s", executor_id, updates) - shared_state.update(updates) - if result.get("shared_state_deletes"): - deletes = result["shared_state_deletes"] - logger.debug("[workflow] Applying SharedState deletes from %s: %s", executor_id, deletes) - for key in deletes: - shared_state.pop(key, None) - - # Collect outputs - if result and result.get("outputs"): - workflow_outputs.extend(result["outputs"]) - - return ExecutorResult( - executor_id=executor_id, - output_message=None, - activity_result=result, - task_type=TaskType.ACTIVITY, - ) - - -# ============================================================================ -# Routing Helpers -# ============================================================================ - - -def _route_result_messages( - result: ExecutorResult, - workflow: Workflow, - next_pending_messages: dict[str, list[tuple[Any, str]]], - fan_in_pending: dict[str, dict[str, list[tuple[Any, str]]]], -) -> None: - """Route messages from an executor result to their targets. - - Args: - result: The executor result containing messages to route - workflow: The workflow definition - next_pending_messages: Dict to accumulate next iteration's messages (mutated) - fan_in_pending: Dict tracking fan-in state (mutated) - """ - executor_id = result.executor_id - messages_to_route: list[tuple[Any, str | None]] = [] - - # Collect messages from agent response - if result.output_message: - messages_to_route.append((result.output_message, None)) - - # Collect sent_messages from activity results - if result.activity_result and result.activity_result.get("sent_messages"): - for msg_data in result.activity_result["sent_messages"]: - sent_msg = msg_data.get("message") - target_id = msg_data.get("target_id") - if sent_msg: - sent_msg = deserialize_value(sent_msg) - messages_to_route.append((sent_msg, target_id)) - - # Route each message - for msg_to_route, explicit_target in messages_to_route: - logger.debug("Routing output from %s", executor_id) - - # If explicit target specified, route directly - if explicit_target: - if explicit_target not in next_pending_messages: - next_pending_messages[explicit_target] = [] - next_pending_messages[explicit_target].append((msg_to_route, executor_id)) - logger.debug("Routed message from %s to explicit target %s", executor_id, explicit_target) - continue - - # Check for FanInEdgeGroup sources - for group in workflow.edge_groups: - if isinstance(group, FanInEdgeGroup) and executor_id in group.source_executor_ids: - fan_in_pending[group.id][executor_id].append((msg_to_route, executor_id)) - logger.debug("Accumulated message for FanIn group %s from %s", group.id, executor_id) - - # Use MAF's edge group routing for other edge types - targets = route_message_through_edge_groups(workflow.edge_groups, executor_id, msg_to_route) - - for target_id in targets: - logger.debug("Routing to %s", target_id) - if target_id not in next_pending_messages: - next_pending_messages[target_id] = [] - next_pending_messages[target_id].append((msg_to_route, executor_id)) - - -def _check_fan_in_ready( - workflow: Workflow, - fan_in_pending: dict[str, dict[str, list[tuple[Any, str]]]], - next_pending_messages: dict[str, list[tuple[Any, str]]], -) -> None: - """Check if any FanInEdgeGroups are ready and deliver their messages. - - Args: - workflow: The workflow definition - fan_in_pending: Dict tracking fan-in state (mutated - cleared when delivered) - next_pending_messages: Dict to add aggregated messages to (mutated) - """ - for group in workflow.edge_groups: - if not isinstance(group, FanInEdgeGroup): - continue - - pending_sources = fan_in_pending.get(group.id, {}) - - # Check if all sources have contributed at least one message - if not all(src in pending_sources and pending_sources[src] for src in group.source_executor_ids): - continue - - # Aggregate all messages into a single list - aggregated: list[Any] = [] - aggregated_sources: list[str] = [] - for src in group.source_executor_ids: - for msg, msg_source in pending_sources[src]: - aggregated.append(msg) - aggregated_sources.append(msg_source) - - target_id = group.target_executor_ids[0] - logger.debug("FanIn group %s ready, delivering %d messages to %s", group.id, len(aggregated), target_id) - - if target_id not in next_pending_messages: - next_pending_messages[target_id] = [] - - first_source = aggregated_sources[0] if aggregated_sources else "__fan_in__" - next_pending_messages[target_id].append((aggregated, first_source)) - - # Clear the pending sources for this group - fan_in_pending[group.id] = defaultdict(list) - - -# ============================================================================ -# HITL (Human-in-the-Loop) Helpers -# ============================================================================ - - -def _collect_hitl_requests( - result: ExecutorResult, - pending_hitl_requests: dict[str, PendingHITLRequest], -) -> None: - """Collect pending HITL requests from an activity result. - - Args: - result: The executor result that may contain pending request info events - pending_hitl_requests: Dict to accumulate pending requests (mutated) - """ - if result.activity_result and result.activity_result.get("pending_request_info_events"): - for req_data in result.activity_result["pending_request_info_events"]: - request_id = req_data.get("request_id") - if request_id: - pending_hitl_requests[request_id] = PendingHITLRequest( - request_id=request_id, - source_executor_id=req_data.get("source_executor_id", result.executor_id), - request_data=req_data.get("data"), - request_type=req_data.get("request_type"), - response_type=req_data.get("response_type"), - ) - logger.debug( - "Collected HITL request %s from executor %s", - request_id, - result.executor_id, - ) - - -def _route_hitl_response( - hitl_request: PendingHITLRequest, - raw_response: Any, - pending_messages: dict[str, list[tuple[Any, str]]], -) -> None: - """Route a HITL response back to the source executor's @response_handler. - - The response is packaged as a special HITL response message that the executor - activity can recognize and route to the appropriate @response_handler method. - - Args: - hitl_request: The original HITL request - raw_response: The raw response data from the external event - pending_messages: Dict to add the response message to (mutated) - """ - # Create a message structure that the executor can recognize - # This mimics what the InProcRunnerContext does for request_info responses - # Note: HITL origin is identified via source_executor_ids (starting with SOURCE_HITL_RESPONSE) - response_message = { - "request_id": hitl_request.request_id, - "original_request": hitl_request.request_data, - "response": raw_response, - "response_type": hitl_request.response_type, - } - - target_id = hitl_request.source_executor_id - if target_id not in pending_messages: - pending_messages[target_id] = [] - - # Use a special source ID to indicate this is a HITL response - source_id = f"{SOURCE_HITL_RESPONSE}_{hitl_request.request_id}" - pending_messages[target_id].append((response_message, source_id)) - - logger.debug( - "Routed HITL response for request %s to executor %s", - hitl_request.request_id, - target_id, - ) - - -# ============================================================================ -# Main Orchestrator -# ============================================================================ +# Re-export shared symbols for backward compatibility +__all__ = [ + "DEFAULT_HITL_TIMEOUT_HOURS", + "SOURCE_HITL_RESPONSE", + "SOURCE_ORCHESTRATOR", + "SOURCE_WORKFLOW_START", + "ExecutorResult", + "PendingHITLRequest", + "TaskMetadata", + "TaskType", + "_extract_message_content", + "build_agent_executor_response", + "execute_hitl_response_handler", + "route_message_through_edge_groups", + "run_workflow_orchestrator", +] def run_workflow_orchestrator( @@ -608,386 +65,20 @@ def run_workflow_orchestrator( shared_state: dict[str, Any] | None = None, hitl_timeout_hours: float = DEFAULT_HITL_TIMEOUT_HOURS, ) -> Generator[Any, Any, list[Any]]: - """Traverse and execute the workflow graph using Durable Functions. - - This orchestrator reuses MAF's edge group routing logic while adapting - execution to the DF generator-based model (yield instead of await). - - Supports: - - SingleEdgeGroup: Direct 1:1 routing with optional condition - - SwitchCaseEdgeGroup: First matching condition wins - - FanOutEdgeGroup: Broadcast to multiple targets - **executed in parallel** - - FanInEdgeGroup: Aggregates messages from multiple sources before delivery - - SharedState: Local shared state accessible to all executors - - HITL: Human-in-the-loop via request_info / @response_handler pattern - - Execution model: - - All pending executors (agents AND activities) run in parallel via single task_all() - - Multiple messages to the SAME agent are processed sequentially for conversation coherence - - SharedState updates are applied in order after parallel tasks complete - - HITL requests pause the orchestration until external events are received - - Args: - context: The Durable Functions orchestration context - workflow: The MAF Workflow instance to execute - initial_message: The initial message to send to the start executor - shared_state: Optional dict for cross-executor state sharing (local to orchestration) - hitl_timeout_hours: Timeout in hours for HITL requests (default: 72 hours) - - Returns: - List of workflow outputs collected from executor activities - """ - pending_messages: dict[str, list[tuple[Any, str]]] = { - workflow.start_executor_id: [(initial_message, SOURCE_WORKFLOW_START)] - } - workflow_outputs: list[Any] = [] - iteration = 0 - - # Track pending sources for FanInEdgeGroups using defaultdict for cleaner access - fan_in_pending: dict[str, dict[str, list[tuple[Any, str]]]] = { - group.id: defaultdict(list) for group in workflow.edge_groups if isinstance(group, FanInEdgeGroup) - } - - # Track pending HITL requests - pending_hitl_requests: dict[str, PendingHITLRequest] = {} - - while pending_messages and iteration < workflow.max_iterations: - logger.debug("Orchestrator iteration %d", iteration) - next_pending_messages: dict[str, list[tuple[Any, str]]] = {} - - # Phase 1: Prepare all tasks (agents and activities unified) - all_tasks, task_metadata_list, remaining_agent_messages = _prepare_all_tasks( - context, workflow, pending_messages, shared_state - ) - - # Phase 2: Execute all tasks in parallel (single task_all for true parallelism) - all_results: list[ExecutorResult] = [] - if all_tasks: - logger.debug("Executing %d tasks in parallel (agents + activities)", len(all_tasks)) - raw_results = yield context.task_all(all_tasks) - logger.debug("All %d tasks completed", len(all_tasks)) - - # Process results based on task type - for idx, raw_result in enumerate(raw_results): - metadata = task_metadata_list[idx] - if metadata.task_type == TaskType.AGENT: - result = _process_agent_response(raw_result, metadata.executor_id, metadata.message) - else: - result = _process_activity_result(raw_result, metadata.executor_id, shared_state, workflow_outputs) - all_results.append(result) - - # Phase 3: Process sequential agent messages (for same-agent conversation coherence) - for executor_id, message, _source_executor_id in remaining_agent_messages: - logger.debug("Processing sequential message for agent: %s", executor_id) - task = _prepare_agent_task(context, executor_id, message) - agent_response: AgentResponse = yield task - logger.debug("Agent %s sequential response completed", executor_id) - - result = _process_agent_response(agent_response, executor_id, message) - all_results.append(result) - - # Phase 4: Collect pending HITL requests from activity results - for result in all_results: - _collect_hitl_requests(result, pending_hitl_requests) - - # Phase 5: Route all results to next iteration - for result in all_results: - _route_result_messages(result, workflow, next_pending_messages, fan_in_pending) - - # Phase 6: Check if any FanInEdgeGroups are ready to deliver - _check_fan_in_ready(workflow, fan_in_pending, next_pending_messages) - - pending_messages = next_pending_messages + """Azure Functions wrapper around the shared workflow orchestrator. - # Phase 7: Handle HITL - if no pending work but HITL requests exist, wait for responses - if not pending_messages and pending_hitl_requests: - logger.debug("Workflow paused for HITL - %d pending requests", len(pending_hitl_requests)) - - # Update custom status to expose pending requests - context.set_custom_status({ - "state": "waiting_for_human_input", - "pending_requests": { - req_id: { - "request_id": req.request_id, - "source_executor_id": req.source_executor_id, - "data": req.request_data, - "request_type": req.request_type, - "response_type": req.response_type, - } - for req_id, req in pending_hitl_requests.items() - }, - }) - - # Wait for external events for each pending request - # Process responses one at a time to maintain ordering - for request_id, hitl_request in list(pending_hitl_requests.items()): - logger.debug("Waiting for HITL response for request: %s", request_id) - - # Create tasks for approval and timeout - approval_task = context.wait_for_external_event(request_id) - timeout_task = context.create_timer(context.current_utc_datetime + timedelta(hours=hitl_timeout_hours)) - - winner = yield context.task_any([approval_task, timeout_task]) - - if winner == approval_task: - # Cancel the timeout - timeout_task.cancel() # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] - - # Get the response - raw_response = approval_task.result - logger.debug( - "Received HITL response for request %s. Type: %s, Value: %s", - request_id, - type(raw_response).__name__, - raw_response, - ) - - # Durable Functions may return a JSON string; parse it if so - if isinstance(raw_response, str): - try: - raw_response = json.loads(raw_response) - logger.debug("Parsed JSON string response to: %s", type(raw_response).__name__) - except (json.JSONDecodeError, TypeError): - logger.debug("Response is not JSON, keeping as string") - - # Remove from pending - del pending_hitl_requests[request_id] - - # Route the response back to the source executor's @response_handler - _route_hitl_response( - hitl_request, - raw_response, - pending_messages, - ) - else: - # Timeout occurred — cancel the dangling external event listener - approval_task.cancel() # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] - logger.warning("HITL request %s timed out after %s hours", request_id, hitl_timeout_hours) - raise TimeoutError( - f"Human-in-the-loop request '{request_id}' timed out after {hitl_timeout_hours} hours." - ) - - # Clear custom status after HITL is resolved - context.set_custom_status({"state": "running"}) - - iteration += 1 - - # Durable Functions runtime extracts return value from StopIteration - return workflow_outputs # noqa: B901 - - -def _prepare_all_tasks( - context: DurableOrchestrationContext, - workflow: Workflow, - pending_messages: dict[str, list[tuple[Any, str]]], - shared_state: dict[str, Any] | None, -) -> tuple[list[Any], list[TaskMetadata], list[tuple[str, Any, str]]]: - """Prepare all pending tasks for parallel execution. - - Groups agent messages by executor ID so that only the first message per agent - runs in the parallel batch. Additional messages to the same agent are returned - for sequential processing. - - Args: - context: The Durable Functions orchestration context - workflow: The workflow definition - pending_messages: Messages pending for each executor - shared_state: Current shared state snapshot - - Returns: - Tuple of (tasks, metadata, remaining_agent_messages): - - tasks: List of tasks ready for task_all() - - metadata: TaskMetadata for each task (same order as tasks) - - remaining_agent_messages: Agent messages requiring sequential processing - """ - all_tasks: list[Any] = [] - task_metadata_list: list[TaskMetadata] = [] - remaining_agent_messages: list[tuple[str, Any, str]] = [] - - # Group agent messages by executor_id for sequential handling of same-agent messages - agent_messages_by_executor: dict[str, list[tuple[str, Any, str]]] = defaultdict(list) - - # Categorize all pending messages - for executor_id, messages_with_sources in pending_messages.items(): - executor = workflow.executors[executor_id] - is_agent = isinstance(executor, AgentExecutor) - - for message, source_executor_id in messages_with_sources: - if is_agent: - agent_messages_by_executor[executor_id].append((executor_id, message, source_executor_id)) - else: - # Activity tasks can all run in parallel - logger.debug("Preparing activity task: %s", executor_id) - task = _prepare_activity_task(context, executor_id, message, source_executor_id, shared_state) - all_tasks.append(task) - task_metadata_list.append( - TaskMetadata( - executor_id=executor_id, - message=message, - source_executor_id=source_executor_id, - task_type=TaskType.ACTIVITY, - ) - ) - - # Process agent messages: first message per agent goes to parallel batch - for executor_id, messages_list in agent_messages_by_executor.items(): - first_msg = messages_list[0] - remaining = messages_list[1:] - - logger.debug("Preparing agent task: %s", executor_id) - task = _prepare_agent_task(context, first_msg[0], first_msg[1]) - all_tasks.append(task) - task_metadata_list.append( - TaskMetadata( - executor_id=first_msg[0], - message=first_msg[1], - source_executor_id=first_msg[2], - task_type=TaskType.AGENT, - ) - ) - - # Queue remaining messages for sequential processing - remaining_agent_messages.extend(remaining) - - return all_tasks, task_metadata_list, remaining_agent_messages - - -# ============================================================================ -# Message Content Extraction -# ============================================================================ - - -def _extract_message_content(message: Any) -> str: - """Extract text content from various message types.""" - message_content = "" - if isinstance(message, AgentExecutorResponse) and message.agent_response: - if message.agent_response.text: - message_content = message.agent_response.text - elif message.agent_response.messages: - message_content = message.agent_response.messages[-1].text or "" - elif isinstance(message, AgentExecutorRequest) and message.messages: - # Extract text from the last message in the request - message_content = message.messages[-1].text or "" - elif isinstance(message, dict): - key_names = list(message.keys()) # type: ignore[union-attr] - logger.warning("Unexpected dict message in _extract_message_content. Keys: %s", key_names) # type: ignore - elif isinstance(message, str): - message_content = message - - return message_content - - -# ============================================================================ -# HITL Response Handler Execution -# ============================================================================ - - -async def execute_hitl_response_handler( - executor: Any, - hitl_message: dict[str, Any], - shared_state: State, - runner_context: CapturingRunnerContext, -) -> None: - """Execute a HITL response handler on an executor. - - This function handles the delivery of a HITL response to the executor's - @response_handler method. It: - 1. Deserializes the original request and response - 2. Finds the matching response handler based on types - 3. Creates a WorkflowContext and invokes the handler - - Args: - executor: The executor instance that has a @response_handler - hitl_message: The HITL response message containing original_request and response - shared_state: The shared state for the workflow context - runner_context: The runner context for capturing outputs - """ - from agent_framework._workflows._workflow_context import WorkflowContext - - # Extract the response data - original_request_data = hitl_message.get("original_request") - response_data = hitl_message.get("response") - response_type_str = hitl_message.get("response_type") - - # Deserialize the original request - original_request = deserialize_value(original_request_data) - - # Deserialize the response - try to match expected type - response = _deserialize_hitl_response(response_data, response_type_str) - - # Find the matching response handler - handler = executor._find_response_handler(original_request, response) # pyright: ignore[reportPrivateUsage] - - if handler is None: - logger.warning( - "No response handler found for HITL response in executor %s. Request type: %s, Response type: %s", - executor.id, - type(original_request).__name__, - type(response).__name__, - ) - return - - # Create a WorkflowContext for the handler - # Use a special source ID to indicate this is a HITL response - ctx = WorkflowContext( - executor=executor, - source_executor_ids=[SOURCE_HITL_RESPONSE], - runner_context=runner_context, - state=shared_state, - ) - - # Call the response handler - # Note: handler is already a partial with original_request bound - logger.debug( - "Invoking response handler for HITL request in executor %s", - executor.id, - ) - await handler(response, ctx) - - -def _deserialize_hitl_response(response_data: Any, response_type_str: str | None) -> Any: - """Deserialize a HITL response to its expected type. + Creates an :class:`AzureFunctionsWorkflowContext` and delegates to the + host-agnostic :func:`run_workflow_orchestrator` in the durabletask package. Args: - response_data: The raw response data (typically a dict from JSON) - response_type_str: The fully qualified type name (module:classname) + context: The Azure Functions ``DurableOrchestrationContext``. + workflow: The MAF Workflow instance to execute. + initial_message: Initial message to send to the start executor. + shared_state: Optional dict for cross-executor state sharing. + hitl_timeout_hours: Timeout in hours for HITL requests. Returns: - The deserialized response, or the original data if deserialization fails + List of workflow outputs collected from executor activities. """ - logger.debug( - "Deserializing HITL response. response_type_str=%s, response_data type=%s", - response_type_str, - type(response_data).__name__, - ) - - if response_data is None: - return None - - # Sanitize untrusted external input before deserialization. - # HITL response data originates from an HTTP POST and must not contain - # pickle/type markers that would reach pickle.loads(). - response_data = strip_pickle_markers(response_data) - if response_data is None: - return None - - # If already a primitive, return as-is - if not isinstance(response_data, dict): - logger.debug("Response data is not a dict, returning as-is: %s", type(response_data).__name__) - return response_data - - # Try to reconstruct using the type hint (Pydantic / dataclass) - if response_type_str: - response_type = resolve_type(response_type_str) - if response_type: - logger.debug("Found response type %s, attempting reconstruction", response_type) - result = reconstruct_to_type(response_data, response_type) - logger.debug("Reconstructed response type: %s", type(result).__name__) - return result - logger.warning("Could not resolve response type: %s", response_type_str) - - # No type hint available - return the sanitized dict as-is. - # We intentionally do NOT call deserialize_value() here because HITL - # response data is untrusted and must never flow into pickle.loads(). - logger.debug("No type hint; returning sanitized data as-is") - return response_data # type: ignore[reportUnknownVariableType] + af_ctx = AzureFunctionsWorkflowContext(context) + return _run_workflow_orchestrator_shared(af_ctx, workflow, initial_message, shared_state, hitl_timeout_hours) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow_af_context.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow_af_context.py new file mode 100644 index 00000000000..639c9332ce4 --- /dev/null +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow_af_context.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Azure Functions adapter for WorkflowOrchestrationContext. + +Wraps ``azure.durable_functions.DurableOrchestrationContext`` to satisfy the +:class:`~agent_framework_durabletask.WorkflowOrchestrationContext` protocol. +""" + +from __future__ import annotations + +import logging +from datetime import datetime +from typing import Any + +from agent_framework_durabletask import AgentSessionId, DurableAgentSession, DurableAIAgent +from azure.durable_functions import DurableOrchestrationContext + +from ._orchestration import AzureFunctionsAgentExecutor + +logger = logging.getLogger(__name__) + + +class AzureFunctionsWorkflowContext: + """Adapter that maps ``DurableOrchestrationContext`` to ``WorkflowOrchestrationContext``.""" + + def __init__(self, context: DurableOrchestrationContext) -> None: + self._context = context + + # -- Properties ----------------------------------------------------------- + + @property + def instance_id(self) -> str: + # Typed local (not cast): mypy sees the untyped context as Any, while + # pyright sees a concrete str - the annotation satisfies both. + instance_id: str = self._context.instance_id + return instance_id + + @property + def current_utc_datetime(self) -> datetime: + current: datetime = self._context.current_utc_datetime + return current + + # -- Agent / Activity dispatch -------------------------------------------- + + def prepare_agent_task(self, executor_id: str, message: str, orchestration_instance_id: str) -> Any: + session_id = AgentSessionId(name=executor_id, key=orchestration_instance_id) + session = DurableAgentSession(durable_session_id=session_id) + az_executor = AzureFunctionsAgentExecutor(self._context) + agent = DurableAIAgent(az_executor, executor_id) + return agent.run(message, session=session) + + def prepare_activity_task(self, activity_name: str, input_json: str) -> Any: + orchestration_context: Any = self._context + return orchestration_context.call_activity(activity_name, input_json) + + # -- Composite tasks ------------------------------------------------------ + + def task_all(self, tasks: list[Any]) -> Any: + return self._context.task_all(tasks) + + def task_any(self, tasks: list[Any]) -> Any: + return self._context.task_any(tasks) + + # -- External events / timers --------------------------------------------- + + def wait_for_external_event(self, name: str) -> Any: + return self._context.wait_for_external_event(name) + + def create_timer(self, fire_at: datetime) -> Any: + return self._context.create_timer(fire_at) + + # -- Status / utility ----------------------------------------------------- + + def set_custom_status(self, status: Any) -> None: + self._context.set_custom_status(status) + + def new_uuid(self) -> str: + new_uuid: str = self._context.new_uuid() + return new_uuid + + def cancel_task(self, task: Any) -> None: + cancel_fn = getattr(task, "cancel", None) + if callable(cancel_fn): + cancel_fn() + + def get_task_result(self, task: Any) -> Any: + return getattr(task, "result", None) diff --git a/python/packages/azurefunctions/tests/integration_tests/test_11_workflow_parallel.py b/python/packages/azurefunctions/tests/integration_tests/test_11_workflow_parallel.py index 65a96678a12..831ff6f4adc 100644 --- a/python/packages/azurefunctions/tests/integration_tests/test_11_workflow_parallel.py +++ b/python/packages/azurefunctions/tests/integration_tests/test_11_workflow_parallel.py @@ -42,103 +42,42 @@ def _setup(self, base_url: str, sample_helper) -> None: self.base_url = base_url self.helper = sample_helper - @pytest.mark.skip(reason="xdist distributes module tests across workers, each spawning a func process") - def test_parallel_workflow_document_analysis(self) -> None: - """Test parallel workflow with a standard document.""" + def test_parallel_workflow_end_to_end(self) -> None: + """Run the parallel workflow end-to-end: start, check status, verify completion. + + Consolidated into a single test on purpose: the work-stealing xdist scheduler + distributes tests (not modules) across workers, and the module-scoped + ``function_app_for_test`` fixture is created per worker -- so multiple tests in + this module would each spawn a separate ``func`` host for this resource-heavy + parallel sample. One test keeps it to a single host while still covering the + fan-out path end-to-end. + """ payload = { "document_id": "doc-test-001", "content": ( "The quarterly earnings report shows strong growth in our cloud services division. " "Revenue increased by 25% compared to last year, driven by enterprise adoption. " - "Customer satisfaction remains high at 92%. However, we face challenges in the " - "mobile segment where competition is intense. Overall, the outlook is positive " - "with expected continued growth in the coming quarters." + "Customer satisfaction remains high at 92%." ), } - # Start orchestration + # Start the orchestration. response = self.helper.post_json(f"{self.base_url}/api/workflow/run", payload) assert response.status_code == 202 data = response.json() - assert "instanceId" in data + instance_id = data["instanceId"] assert "statusQueryGetUri" in data - # Wait for completion - parallel workflows may take longer - status = self.helper.wait_for_orchestration_with_output( - data["statusQueryGetUri"], - max_wait=300, # 5 minutes for parallel execution - ) - assert status["runtimeStatus"] == "Completed" - assert "output" in status - - @pytest.mark.skip(reason="xdist distributes module tests across workers, each spawning a func process") - def test_parallel_workflow_short_document(self) -> None: - """Test parallel workflow with a short document.""" - payload = { - "document_id": "doc-test-002", - "content": "Quick update: Project completed successfully. Team performance exceeded expectations.", - } - - # Start orchestration - response = self.helper.post_json(f"{self.base_url}/api/workflow/run", payload) - assert response.status_code == 202 - data = response.json() - assert "instanceId" in data - assert "statusQueryGetUri" in data + # The status endpoint reflects the started instance. + status_response = self.helper.get(f"{self.base_url}/api/workflow/status/{instance_id}") + assert status_response.status_code == 200 + assert status_response.json()["instanceId"] == instance_id - # Wait for completion + # Fan-out to parallel processors and agents completes with an aggregated output. status = self.helper.wait_for_orchestration_with_output(data["statusQueryGetUri"], max_wait=300) assert status["runtimeStatus"] == "Completed" assert "output" in status - @pytest.mark.skip(reason="xdist distributes module tests across workers, each spawning a func process") - def test_parallel_workflow_technical_document(self) -> None: - """Test parallel workflow with a technical document.""" - payload = { - "document_id": "doc-test-003", - "content": ( - "The new microservices architecture has been deployed to production. " - "Key improvements include: reduced latency by 40%, improved scalability " - "to handle 10x traffic spikes, and enhanced monitoring with distributed tracing. " - "The Kubernetes cluster is now running on version 1.28 with auto-scaling enabled. " - "Next steps include implementing service mesh and improving CI/CD pipelines." - ), - } - - # Start orchestration - response = self.helper.post_json(f"{self.base_url}/api/workflow/run", payload) - assert response.status_code == 202 - data = response.json() - assert "instanceId" in data - - # Wait for completion - status = self.helper.wait_for_orchestration_with_output(data["statusQueryGetUri"], max_wait=300) - assert status["runtimeStatus"] == "Completed" - - @pytest.mark.skip(reason="xdist distributes module tests across workers, each spawning a func process") - def test_workflow_status_endpoint(self) -> None: - """Test that the workflow status endpoint works correctly.""" - payload = { - "document_id": "doc-test-004", - "content": "Brief status update for testing purposes.", - } - - # Start orchestration - response = self.helper.post_json(f"{self.base_url}/api/workflow/run", payload) - assert response.status_code == 202 - data = response.json() - instance_id = data["instanceId"] - - # Check status - status_response = self.helper.get(f"{self.base_url}/api/workflow/status/{instance_id}") - assert status_response.status_code == 200 - status = status_response.json() - assert "instanceId" in status - assert status["instanceId"] == instance_id - - # Wait for completion - self.helper.wait_for_orchestration(data["statusQueryGetUri"], max_wait=300) - if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py index 61518fa44b7..194f827197e 100644 --- a/python/packages/azurefunctions/tests/test_app.py +++ b/python/packages/azurefunctions/tests/test_app.py @@ -26,7 +26,6 @@ from agent_framework_azurefunctions import AgentFunctionApp from agent_framework_azurefunctions._entities import create_agent_entity -from agent_framework_azurefunctions._workflow import SOURCE_ORCHESTRATOR FuncT = TypeVar("FuncT", bound=Callable[..., Any]) @@ -1342,8 +1341,8 @@ def test_init_with_workflow_stores_workflow(self) -> None: assert app.workflow is mock_workflow - def test_init_with_workflow_extracts_agents(self) -> None: - """Test that agents are extracted from workflow executors.""" + def test_init_with_workflow_registers_agent_entity_by_executor_id(self) -> None: + """Workflow agent executors are registered as entities keyed by executor id.""" from agent_framework import AgentExecutor mock_agent = Mock() @@ -1351,18 +1350,31 @@ def test_init_with_workflow_extracts_agents(self) -> None: mock_executor = Mock(spec=AgentExecutor) mock_executor.agent = mock_agent + # Executor id intentionally differs from the agent name to exercise the + # identity fix: dispatch uses the executor id, so registration must too. + mock_executor.id = "custom-executor-id" mock_workflow = Mock() - mock_workflow.executors = {"WorkflowAgent": mock_executor} + mock_workflow.executors = {"custom-executor-id": mock_executor} with ( patch.object(AgentFunctionApp, "_setup_executor_activity"), patch.object(AgentFunctionApp, "_setup_workflow_orchestration"), - patch.object(AgentFunctionApp, "_setup_agent_functions"), + patch.object(AgentFunctionApp, "_setup_agent_entity") as setup_entity, ): app = AgentFunctionApp(workflow=mock_workflow) - assert "WorkflowAgent" in app.agents + # The entity is registered under the executor id (the dispatch identity). + setup_entity.assert_called_once() + call_args = setup_entity.call_args.args + assert call_args[0] is mock_agent + assert call_args[1] == "custom-executor-id" + + # Regression guard: the workflow agent must also be tracked on the app's + # normal registration surface, keyed by the executor id, so it appears in + # ``agents`` and is retrievable via ``get_agent`` (as the constructor documents). + assert "custom-executor-id" in app.agents + assert app.agents["custom-executor-id"] is mock_agent def test_init_with_workflow_calls_setup_methods(self) -> None: """Test that workflow setup methods are called.""" @@ -1396,8 +1408,8 @@ def test_init_without_workflow_does_not_call_workflow_setup(self) -> None: setup_exec.assert_not_called() setup_orch.assert_not_called() - def test_init_with_workflow_deduplicates_agents(self) -> None: - """Test that agents in both 'agents' and workflow are not double-registered.""" + def test_init_with_workflow_and_explicit_agent_does_not_raise(self) -> None: + """An agent passed explicitly and present in the workflow registers without error.""" from agent_framework import AgentExecutor mock_agent = Mock() @@ -1405,6 +1417,7 @@ def test_init_with_workflow_deduplicates_agents(self) -> None: mock_executor = Mock(spec=AgentExecutor) mock_executor.agent = mock_agent + mock_executor.id = "SharedAgent" mock_workflow = Mock() mock_workflow.executors = {"SharedAgent": mock_executor} @@ -1413,6 +1426,7 @@ def test_init_with_workflow_deduplicates_agents(self) -> None: patch.object(AgentFunctionApp, "_setup_executor_activity"), patch.object(AgentFunctionApp, "_setup_workflow_orchestration"), patch.object(AgentFunctionApp, "_setup_agent_functions"), + patch.object(AgentFunctionApp, "_setup_agent_entity"), ): # Same agent passed explicitly AND present in workflow — should not raise app = AgentFunctionApp(agents=[mock_agent], workflow=mock_workflow) @@ -1450,285 +1464,53 @@ def test_build_status_url_handles_trailing_slash(self) -> None: assert "instance-456" in url -def _compute_state_updates(original_snapshot: dict[str, Any], current_state: dict[str, Any]) -> dict[str, Any]: - """Compute state updates by comparing current state against the original snapshot. - - This mirrors the inlined logic in ``_app.py``'s ``executor_activity.run()``. - """ - original_keys = set(original_snapshot.keys()) - current_keys = set(current_state.keys()) - updates: dict[str, Any] = {} - for key in current_keys: - if key not in original_keys or current_state[key] != original_snapshot.get(key): - updates[key] = current_state[key] - return updates - - -class TestStateSnapshotDiff: - """Test suite for state snapshot diffing in activity execution. - - The activity executor snapshots state before execution and diffs against the - post-execution state to determine which keys were updated. These tests exercise - the production snapshot helper and the state-update diffing logic to ensure that - in-place mutations to nested objects (dicts, lists) are correctly detected as changes. - """ - - def test_nested_dict_mutation_detected_in_diff(self) -> None: - """Test that mutating values inside a nested dict appears in the diff.""" - from agent_framework._workflows._state import State - - from agent_framework_azurefunctions._app import _create_state_snapshot - - deserialized_state: dict[str, Any] = { - "Local.config": {"code": "", "enabled": False}, - "simple_key": "simple_value", - } - - original_snapshot = _create_state_snapshot(deserialized_state) - - shared_state = State() - shared_state.import_state(deserialized_state) - - config = shared_state.get("Local.config") - config["code"] = "SOMECODEXXX" - config["enabled"] = True - - shared_state.commit() - current_state = shared_state.export_state() - - updates = _compute_state_updates(original_snapshot, current_state) - - assert "Local.config" in updates - assert updates["Local.config"]["code"] == "SOMECODEXXX" - assert updates["Local.config"]["enabled"] is True - - def test_new_key_in_nested_dict_detected_in_diff(self) -> None: - """Test that adding a key to a nested dict appears in the diff.""" - from agent_framework._workflows._state import State - - from agent_framework_azurefunctions._app import _create_state_snapshot - - deserialized_state: dict[str, Any] = { - "Local.data": {"existing": "value"}, - } - - original_snapshot = _create_state_snapshot(deserialized_state) - - shared_state = State() - shared_state.import_state(deserialized_state) - - data = shared_state.get("Local.data") - data["code"] = "NEW_CODE" - - shared_state.commit() - current_state = shared_state.export_state() - - updates = _compute_state_updates(original_snapshot, current_state) - - assert "Local.data" in updates - assert updates["Local.data"]["code"] == "NEW_CODE" - - def test_nested_list_mutation_detected_in_diff(self) -> None: - """Test that appending to a nested list appears in the diff.""" - from agent_framework._workflows._state import State - - from agent_framework_azurefunctions._app import _create_state_snapshot - - deserialized_state: dict[str, Any] = { - "Local.items": [1, 2, 3], - } - - original_snapshot = _create_state_snapshot(deserialized_state) - - shared_state = State() - shared_state.import_state(deserialized_state) - - items = shared_state.get("Local.items") - items.append(4) - - shared_state.commit() - current_state = shared_state.export_state() - - updates = _compute_state_updates(original_snapshot, current_state) - - assert "Local.items" in updates - assert updates["Local.items"] == [1, 2, 3, 4] - - def test_new_top_level_key_detected_in_diff(self) -> None: - """Test that setting a new top-level key appears in the diff.""" - from agent_framework._workflows._state import State - - from agent_framework_azurefunctions._app import _create_state_snapshot - - deserialized_state: dict[str, Any] = { - "existing": "value", - } - - original_snapshot = _create_state_snapshot(deserialized_state) - - shared_state = State() - shared_state.import_state(deserialized_state) +# NOTE: State snapshot/diff tests were moved to durabletask once the activity +# execution body was extracted into the host-agnostic execute_workflow_activity. +# See packages/durabletask/tests/test_workflow_activity.py. - shared_state.set("Local.code", "SOMECODEXXX") - shared_state.commit() - current_state = shared_state.export_state() +class TestWorkflowStatusOutputEncoding: + """The workflow status endpoint emits clean domain JSON for reconstructed outputs. - updates = _compute_state_updates(original_snapshot, current_state) - - assert "Local.code" in updates - assert updates["Local.code"] == "SOMECODEXXX" - - def test_unchanged_nested_state_produces_empty_diff(self) -> None: - """Test that unmodified nested state produces no updates.""" - from agent_framework._workflows._state import State - - from agent_framework_azurefunctions._app import _create_state_snapshot - - deserialized_state: dict[str, Any] = { - "Local.config": {"code": "existing", "enabled": True}, - "simple_key": "simple_value", - } - - original_snapshot = _create_state_snapshot(deserialized_state) - - shared_state = State() - shared_state.import_state(deserialized_state) - - # No mutations performed - shared_state.commit() - current_state = shared_state.export_state() - - updates = _compute_state_updates(original_snapshot, current_state) - - assert updates == {} - - def test_shallow_copy_would_miss_nested_mutations(self) -> None: - """Regression test: a shallow copy (dict()) shares nested refs, hiding mutations. - - This reproduces the original bug from #4500 where ``dict(deserialized_state)`` - was used instead of ``copy.deepcopy()``. With a shallow copy the snapshot and - the live state share nested objects, so in-place mutations appear in both and - the diff produces an empty update set. - """ - from agent_framework._workflows._state import State - - deserialized_state: dict[str, Any] = { - "Local.config": {"code": "", "enabled": False}, - } - - # Shallow copy (the OLD, buggy behaviour) - shallow_snapshot = dict(deserialized_state) - - shared_state = State() - shared_state.import_state(deserialized_state) - - config = shared_state.get("Local.config") - config["code"] = "SOMECODEXXX" - config["enabled"] = True - - shared_state.commit() - current_state = shared_state.export_state() - - # With a shallow copy the mutation leaks into the snapshot → empty diff - updates_shallow = _compute_state_updates(shallow_snapshot, current_state) - assert updates_shallow == {}, "shallow copy should miss nested mutations (demonstrating the bug)" - - def test_create_state_snapshot_isolates_nested_objects(self) -> None: - """Verify _create_state_snapshot produces a deep copy that is mutation-proof. - - This ensures the production snapshot helper is not equivalent to ``dict()`` - and will correctly isolate nested objects so that later mutations are detected. - """ - from agent_framework_azurefunctions._app import _create_state_snapshot - - original: dict[str, Any] = { - "nested_dict": {"a": 1}, - "nested_list": [1, 2, 3], - } - - snapshot = _create_state_snapshot(original) - - # Mutate the originals in place - original["nested_dict"]["a"] = 999 - original["nested_list"].append(4) - - # Snapshot must be unaffected - assert snapshot["nested_dict"]["a"] == 1 - assert snapshot["nested_list"] == [1, 2, 3] - - def test_executor_activity_detects_nested_state_mutations(self) -> None: - """Integration test: the full activity wrapper detects nested mutations. - - This exercises the actual executor_activity function registered by - _setup_executor_activity to verify the production code path uses - _create_state_snapshot (deep copy) rather than dict() (shallow copy). - If the implementation regressed to using a shallow copy such as - ``dict(deserialized_state)``, this test would fail because in-place - mutations would leak into the snapshot and produce an empty diff. - """ - mock_executor = Mock() - mock_executor.id = "test-exec" - - async def mutate_nested_state( - message: Any, - source_executor_ids: Any, - state: Any, - runner_context: Any, - ) -> None: - config = state.get("Local.config") - config["code"] = "MUTATED" - config["enabled"] = True - state.commit() + Reconstructed outputs (see ``deserialize_workflow_output``) may be framework + models or dataclasses; ``_json_default`` converts them via their own + serialization so the HTTP body is clean domain JSON rather than opaque + checkpoint-marker dicts or repr strings. + """ - mock_executor.execute = AsyncMock(side_effect=mutate_nested_state) + def test_encodes_framework_model_via_to_dict(self) -> None: + from agent_framework_azurefunctions._app import _json_default - mock_workflow = Mock() - mock_workflow.executors = {"test-exec": mock_executor} + response = AgentResponse(messages=[Message(role="assistant", contents=["hello"])]) - # Capture the activity function by making decorators pass-through - captured_activity: dict[str, Any] = {} + encoded = _json_default(response) - def passthrough_function_name(name: str) -> Callable[[FuncT], FuncT]: - def decorator(fn: FuncT) -> FuncT: - captured_activity["fn"] = fn - return fn + assert encoded == response.to_dict() + # The result must be JSON-serializable (no marker dicts, no objects). + assert "hello" in json.dumps(encoded) - return decorator + def test_encodes_dataclass_via_asdict(self) -> None: + from dataclasses import dataclass - def passthrough_activity_trigger(input_name: str) -> Callable[[FuncT], FuncT]: - def decorator(fn: FuncT) -> FuncT: - return fn + from agent_framework_azurefunctions._app import _json_default - return decorator + @dataclass + class Decision: + approved: bool + note: str - with ( - patch.object(AgentFunctionApp, "function_name", side_effect=passthrough_function_name), - patch.object(AgentFunctionApp, "activity_trigger", side_effect=passthrough_activity_trigger), - patch.object(AgentFunctionApp, "_setup_workflow_orchestration"), - ): - AgentFunctionApp(workflow=mock_workflow) + encoded = _json_default(Decision(approved=True, note="ok")) - assert "fn" in captured_activity, "activity function was not captured" + assert encoded == {"approved": True, "note": "ok"} - # Call the activity with nested state that the executor will mutate - input_data = json.dumps({ - "message": "test", - "shared_state_snapshot": { - "Local.config": {"code": "", "enabled": False}, - }, - "source_executor_ids": [SOURCE_ORCHESTRATOR], - }) + def test_falls_back_to_str_for_plain_objects(self) -> None: + from agent_framework_azurefunctions._app import _json_default - result = json.loads(captured_activity["fn"](input_data)) + class Opaque: + def __str__(self) -> str: + return "opaque-value" - # The deep copy snapshot must detect the in-place nested mutations - assert "Local.config" in result["shared_state_updates"], ( - "nested mutation not detected — snapshot may be using shallow copy" - ) - updated_config = result["shared_state_updates"]["Local.config"] - assert updated_config["code"] == "MUTATED" - assert updated_config["enabled"] is True + assert _json_default(Opaque()) == "opaque-value" if __name__ == "__main__": diff --git a/python/packages/azurefunctions/tests/test_entities.py b/python/packages/azurefunctions/tests/test_entities.py index bfd17ba176e..cc2fc75320e 100644 --- a/python/packages/azurefunctions/tests/test_entities.py +++ b/python/packages/azurefunctions/tests/test_entities.py @@ -241,10 +241,8 @@ def test_entity_function_handles_none_input(self) -> None: # Verify the result was set (likely error result) assert mock_context.set_result.called - def test_entity_function_handles_event_loop_runtime_error(self) -> None: - """Test that the entity function handles RuntimeError from get_event_loop by creating a new loop.""" - from unittest.mock import patch - + def test_entity_function_runs_on_persistent_loop(self) -> None: + """Entity coroutines run on the shared persistent loop and set a result.""" mock_agent = Mock() mock_agent.run = AsyncMock(return_value=_agent_response("Response")) @@ -253,60 +251,42 @@ def test_entity_function_handles_event_loop_runtime_error(self) -> None: mock_context = Mock() mock_context.operation_name = "run" mock_context.entity_key = "conv-loop-test" - mock_context.get_input.return_value = {"message": "Test"} + mock_context.get_input.return_value = {"message": "Test", "correlationId": "corr-loop-test"} mock_context.get_state.return_value = None - # Simulate RuntimeError when getting event loop - with ( - patch("asyncio.get_event_loop", side_effect=RuntimeError("No event loop")), - patch("asyncio.new_event_loop") as mock_new_loop, - patch("asyncio.set_event_loop") as mock_set_loop, - ): - mock_loop = Mock() - mock_loop.is_running.return_value = False - mock_loop.run_until_complete = Mock() - mock_new_loop.return_value = mock_loop + # Execute - the persistent loop should run the coroutine and set a result. + entity_function(mock_context) - # Execute - entity_function(mock_context) + assert mock_context.set_result.called + mock_agent.run.assert_awaited() - # Verify new event loop was created - mock_new_loop.assert_called_once() - mock_set_loop.assert_called_once_with(mock_loop) + def test_entity_function_runs_across_threads_without_hang(self) -> None: + """Successive entity invocations from different threads must not hang. - def test_entity_function_handles_running_event_loop(self) -> None: - """Test that the entity function handles a running event loop by creating a temporary loop.""" - from unittest.mock import patch + This reproduces the cross-loop scenario that previously deadlocked: a + shared async resource bound to one loop being awaited from a different + worker thread. The persistent loop keeps every invocation on one loop. + """ + from concurrent.futures import ThreadPoolExecutor mock_agent = Mock() mock_agent.run = AsyncMock(return_value=_agent_response("Response")) entity_function = create_agent_entity(mock_agent) - mock_context = Mock() - mock_context.operation_name = "run" - mock_context.entity_key = "conv-running-loop" - mock_context.get_input.return_value = {"message": "Test"} - mock_context.get_state.return_value = None - - # Simulate a running event loop - mock_existing_loop = Mock() - mock_existing_loop.is_running.return_value = True - - mock_temp_loop = Mock() - mock_temp_loop.run_until_complete = Mock() - mock_temp_loop.close = Mock() + def invoke(i: int) -> bool: + ctx = Mock() + ctx.operation_name = "run" + ctx.entity_key = f"conv-{i}" + ctx.get_input.return_value = {"message": f"Test {i}", "correlationId": f"corr-{i}"} + ctx.get_state.return_value = None + entity_function(ctx) + return ctx.set_result.called - with ( - patch("asyncio.get_event_loop", return_value=mock_existing_loop), - patch("asyncio.new_event_loop", return_value=mock_temp_loop), - ): - # Execute - entity_function(mock_context) + with ThreadPoolExecutor(max_workers=8) as ex: + results = list(ex.map(invoke, range(16))) - # Verify temporary loop was created and closed - mock_temp_loop.run_until_complete.assert_called_once() - mock_temp_loop.close.assert_called_once() + assert all(results) if __name__ == "__main__": diff --git a/python/packages/core/agent_framework/azure/__init__.py b/python/packages/core/agent_framework/azure/__init__.py index 7cff0150f19..a6a1de34064 100644 --- a/python/packages/core/agent_framework/azure/__init__.py +++ b/python/packages/core/agent_framework/azure/__init__.py @@ -19,6 +19,8 @@ "DurableAIAgentClient": ("agent_framework_durabletask", "agent-framework-durabletask"), "DurableAIAgentOrchestrationContext": ("agent_framework_durabletask", "agent-framework-durabletask"), "DurableAIAgentWorker": ("agent_framework_durabletask", "agent-framework-durabletask"), + "DurableWorkflowClient": ("agent_framework_durabletask", "agent-framework-durabletask"), + "WORKFLOW_ORCHESTRATOR_NAME": ("agent_framework_durabletask", "agent-framework-durabletask"), } diff --git a/python/packages/core/agent_framework/azure/__init__.pyi b/python/packages/core/agent_framework/azure/__init__.pyi index 12527b6db31..97a44180f8f 100644 --- a/python/packages/core/agent_framework/azure/__init__.pyi +++ b/python/packages/core/agent_framework/azure/__init__.pyi @@ -10,15 +10,18 @@ from agent_framework_azure_ai_search import ( from agent_framework_azure_cosmos import CosmosHistoryProvider from agent_framework_azurefunctions import AgentFunctionApp from agent_framework_durabletask import ( + WORKFLOW_ORCHESTRATOR_NAME, AgentCallbackContext, AgentResponseCallbackProtocol, DurableAIAgent, DurableAIAgentClient, DurableAIAgentOrchestrationContext, DurableAIAgentWorker, + DurableWorkflowClient, ) __all__ = [ + "WORKFLOW_ORCHESTRATOR_NAME", "AgentCallbackContext", "AgentFunctionApp", "AgentResponseCallbackProtocol", @@ -29,4 +32,5 @@ __all__ = [ "DurableAIAgentClient", "DurableAIAgentOrchestrationContext", "DurableAIAgentWorker", + "DurableWorkflowClient", ] diff --git a/python/packages/durabletask/agent_framework_durabletask/__init__.py b/python/packages/durabletask/agent_framework_durabletask/__init__.py index a518b5ad235..bcab531e51f 100644 --- a/python/packages/durabletask/agent_framework_durabletask/__init__.py +++ b/python/packages/durabletask/agent_framework_durabletask/__init__.py @@ -4,6 +4,7 @@ import importlib.metadata +from ._async_bridge import run_agent_coroutine from ._callbacks import AgentCallbackContext, AgentResponseCallbackProtocol from ._client import DurableAIAgentClient from ._constants import ( @@ -50,6 +51,14 @@ from ._response_utils import ensure_response_format, load_agent_response from ._shim import DurableAIAgent from ._worker import DurableAIAgentWorker +from ._workflows.activity import execute_workflow_activity +from ._workflows.client import DurableWorkflowClient +from ._workflows.context import WorkflowOrchestrationContext +from ._workflows.dt_context import DurableTaskWorkflowContext +from ._workflows.orchestrator import WORKFLOW_ORCHESTRATOR_NAME, run_workflow_orchestrator +from ._workflows.registration import WorkflowRegistrationPlan, plan_workflow_registration +from ._workflows.runner_context import CapturingRunnerContext +from ._workflows.serialization import deserialize_workflow_output try: __version__ = importlib.metadata.version(__name__) @@ -67,12 +76,14 @@ "THREAD_ID_HEADER", "WAIT_FOR_RESPONSE_FIELD", "WAIT_FOR_RESPONSE_HEADER", + "WORKFLOW_ORCHESTRATOR_NAME", "AgentCallbackContext", "AgentEntity", "AgentEntityStateProviderMixin", "AgentResponseCallbackProtocol", "AgentSessionId", "ApiResponseFields", + "CapturingRunnerContext", "ContentTypes", "DurableAIAgent", "DurableAIAgentClient", @@ -101,8 +112,17 @@ "DurableAgentStateUsage", "DurableAgentStateUsageContent", "DurableStateFields", + "DurableTaskWorkflowContext", + "DurableWorkflowClient", "RunRequest", + "WorkflowOrchestrationContext", + "WorkflowRegistrationPlan", "__version__", + "deserialize_workflow_output", "ensure_response_format", + "execute_workflow_activity", "load_agent_response", + "plan_workflow_registration", + "run_agent_coroutine", + "run_workflow_orchestrator", ] diff --git a/python/packages/durabletask/agent_framework_durabletask/_async_bridge.py b/python/packages/durabletask/agent_framework_durabletask/_async_bridge.py new file mode 100644 index 00000000000..2ca52c6c94e --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_async_bridge.py @@ -0,0 +1,89 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Persistent background event loop for running agent coroutines. + +Durable entity (and agent) handlers are invoked synchronously by the host on +arbitrary worker threads. Agent clients and their async credentials create +asyncio primitives (locks, connection pools, futures) that are bound to the +event loop on which they are *first* used. Running a later invocation on a +*different* event loop causes those primitives to await futures attached to a +now-idle loop, which results in a silent, permanent hang. + +This module provides a single, process-wide persistent event loop running on a +dedicated daemon thread. All agent coroutines are submitted to this loop via +``run_coroutine_threadsafe`` so shared async resources remain valid across +invocations regardless of which worker thread the host happens to use. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import threading +from collections.abc import Coroutine +from typing import Any, TypeVar + +_T = TypeVar("_T") + +_loop: asyncio.AbstractEventLoop | None = None +_thread: threading.Thread | None = None +_lock = threading.Lock() + + +def _ensure_loop() -> asyncio.AbstractEventLoop: + """Return the shared persistent event loop, starting it on first use. + + The loop is only reusable when it is open *and* its backing thread is still + alive. A loop whose thread has died (e.g. during interpreter shutdown) is not + reusable: ``run_coroutine_threadsafe`` would schedule onto a loop that will + never run again and ``future.result()`` would block forever. Such a loop is + replaced with a fresh loop + thread. + """ + global _loop, _thread + + loop, thread = _loop, _thread + if loop is not None and not loop.is_closed() and thread is not None and thread.is_alive(): + return loop + + with _lock: + loop, thread = _loop, _thread + if loop is not None and not loop.is_closed() and thread is not None and thread.is_alive(): + return loop + + # An existing loop whose thread has died is orphaned; close it best-effort + # before replacing it so it does not leak. + if loop is not None and not loop.is_closed(): + with contextlib.suppress(Exception): + loop.close() + + new_loop = asyncio.new_event_loop() + + def _run() -> None: + asyncio.set_event_loop(new_loop) + new_loop.run_forever() + + new_thread = threading.Thread(target=_run, name="dafx-agent-loop", daemon=True) + new_thread.start() + + _loop = new_loop + _thread = new_thread + return new_loop + + +def run_agent_coroutine(coro: Coroutine[Any, Any, _T]) -> _T: + """Run a coroutine on the shared persistent event loop and return its result. + + The calling (worker) thread blocks until the coroutine completes. Because + every agent coroutine runs on the same loop, async resources created by + shared agent clients/credentials (locks, connection pools) remain bound to a + live loop across all invocations, preventing cross-loop hangs. + + Args: + coro: The coroutine to execute. + + Returns: + The coroutine's result. + """ + loop = _ensure_loop() + future = asyncio.run_coroutine_threadsafe(coro, loop) + return future.result() diff --git a/python/packages/durabletask/agent_framework_durabletask/_worker.py b/python/packages/durabletask/agent_framework_durabletask/_worker.py index 728ae17629a..bda2e3c5e8b 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_worker.py +++ b/python/packages/durabletask/agent_framework_durabletask/_worker.py @@ -3,29 +3,46 @@ """Worker wrapper for Durable Task Agent Framework. This module provides the DurableAIAgentWorker class that wraps a durabletask worker -and enables registration of agents as durable entities. +and enables registration of agents as durable entities, and optionally workflows +as durable orchestrations with automatically generated activity functions. """ from __future__ import annotations -import asyncio import logging from typing import Any -from agent_framework import SupportsAgentRun +from agent_framework import SupportsAgentRun, Workflow +from durabletask.task import ActivityContext, OrchestrationContext from durabletask.worker import TaskHubGrpcWorker +from ._async_bridge import run_agent_coroutine from ._callbacks import AgentResponseCallbackProtocol from ._entities import AgentEntity, DurableTaskEntityStateProvider +from ._workflows.activity import execute_workflow_activity +from ._workflows.dt_context import DurableTaskWorkflowContext +from ._workflows.orchestrator import WORKFLOW_ORCHESTRATOR_NAME, run_workflow_orchestrator +from ._workflows.registration import plan_workflow_registration logger = logging.getLogger("agent_framework.durabletask") class DurableAIAgentWorker: - """Wrapper for durabletask worker that enables agent registration. + """Wrapper for a durabletask worker that hosts agents and workflows. - This class wraps an existing TaskHubGrpcWorker instance and provides - a convenient interface for registering agents as durable entities. + This class wraps an existing TaskHubGrpcWorker instance and is the single + host-side registration surface for a worker process. It supports two + complementary kinds of work: + + - **Agents** via :meth:`add_agent`, which registers each agent as a durable entity. + - **Workflows** via :meth:`configure_workflow`, which registers a MAF + ``Workflow`` (its agent executors as entities, its non-agent executors as + activities, and the workflow orchestrator). + + A single worker process commonly hosts both, so registration is intentionally + aggregated on one object rather than split per kind. (On the *client* side the + surfaces are split into :class:`DurableAIAgentClient` and ``DurableWorkflowClient``, + because a caller invokes one or the other.) Example: ```python @@ -40,7 +57,7 @@ class DurableAIAgentWorker: # Wrap it with the agent worker agent_worker = DurableAIAgentWorker(worker) - # Register agents + # Register agents (or call configure_workflow(workflow) to host a workflow) client = OpenAIChatCompletionClient() my_agent = Agent(client=client, name="assistant") agent_worker.add_agent(my_agent) @@ -64,43 +81,51 @@ def __init__( self._worker = worker self._callback = callback self._registered_agents: dict[str, SupportsAgentRun] = {} + self._workflow: Workflow | None = None logger.debug("[DurableAIAgentWorker] Initialized with worker type: %s", type(worker).__name__) def add_agent( self, agent: SupportsAgentRun, callback: AgentResponseCallbackProtocol | None = None, + *, + entity_id: str | None = None, ) -> None: """Register an agent with the worker. This method creates a durable entity class for the agent and registers it with the underlying durabletask worker. The entity will be accessible - by the name "dafx-{agent_name}". + by the name "dafx-{entity_id or agent_name}". Args: agent: The agent to register (must have a name) callback: Optional callback for this specific agent (overrides worker-level callback) + entity_id: Optional identity to register the entity under instead of + ``agent.name``. Workflow hosting passes the executor's ``id`` so the + entity matches the identity the orchestrator dispatches to. Raises: ValueError: If the agent doesn't have a name or is already registered """ - agent_name = agent.name - if not agent_name: + registration_name = entity_id or agent.name + if not registration_name: raise ValueError("Agent must have a name to be registered") - if agent_name in self._registered_agents: - raise ValueError(f"Agent '{agent_name}' is already registered") + if registration_name in self._registered_agents: + raise ValueError(f"Agent '{registration_name}' is already registered") - logger.info("[DurableAIAgentWorker] Registering agent: %s as entity: dafx-%s", agent_name, agent_name) + logger.info( + "[DurableAIAgentWorker] Registering agent: %s as entity: dafx-%s", registration_name, registration_name + ) # Store the agent reference - self._registered_agents[agent_name] = agent + self._registered_agents[registration_name] = agent # Use agent-specific callback if provided, otherwise use worker-level callback effective_callback = callback or self._callback # Create a configured entity class using the factory - entity_class = self.__create_agent_entity(agent, effective_callback) + entity_class = self.__create_agent_entity(agent, effective_callback, entity_id=registration_name) # Register the entity class with the worker # The worker.add_entity method takes a class @@ -109,7 +134,7 @@ def add_agent( logger.debug( "[DurableAIAgentWorker] Successfully registered entity class %s for agent: %s", entity_registered, - agent_name, + registration_name, ) def start(self) -> None: @@ -140,10 +165,98 @@ def registered_agent_names(self) -> list[str]: """ return list(self._registered_agents.keys()) + # ----------------------------------------------------------------- + # Workflow support + # ----------------------------------------------------------------- + + def configure_workflow( + self, + workflow: Workflow, + callback: AgentResponseCallbackProtocol | None = None, + ) -> None: + """Register a :class:`Workflow` for automatic orchestration. + + This extracts agents from the workflow and registers them as durable + entities, registers non-agent executors as activities, and creates an + orchestrator function that drives the workflow graph. + + Args: + workflow: The MAF :class:`Workflow` to register. + callback: Optional callback for agent response notifications. + """ + self._workflow = workflow + + # The "what to register" decision (agent -> entity, non-agent -> activity) + # is shared with the Azure Functions host via plan_workflow_registration. + plan = plan_workflow_registration(workflow) + + # Register agent executors as durable entities. Each entity is keyed by + # the executor's id (the identity the orchestrator dispatches to) so + # AgentExecutor(agent, id=...) works even when the id differs from the + # agent's name. + for agent_executor in plan.agent_executors: + if agent_executor.id not in self._registered_agents: + self.add_agent(agent_executor.agent, callback=callback, entity_id=agent_executor.id) + + # Register non-agent executors as durable activities. + for executor in plan.activity_executors: + self._register_executor_activity(executor) + + # Register the workflow orchestrator. + self._register_workflow_orchestrator() + + logger.info( + "[DurableAIAgentWorker] Workflow configured with %d executors (%d agents, %d activities)", + len(workflow.executors), + len(plan.agent_executors), + len(plan.activity_executors), + ) + + def _register_executor_activity(self, executor: Any) -> None: + """Register a non-agent executor as a durabletask activity.""" + captured_executor = executor + captured_workflow = self._workflow + activity_name = f"dafx-{executor.id}" + + def executor_activity(ctx: ActivityContext, input_data: str) -> str: + return execute_workflow_activity(captured_executor, input_data, captured_workflow) + + # Give the function the expected name for registration + executor_activity.__name__ = activity_name + executor_activity.__qualname__ = activity_name + + self._worker.add_activity(executor_activity) # type: ignore[arg-type] + logger.debug("[DurableAIAgentWorker] Registered activity: %s", activity_name) + + def _register_workflow_orchestrator(self) -> None: + """Register the workflow orchestrator function with the worker.""" + captured_workflow = self._workflow + + def workflow_orchestrator(context: OrchestrationContext, input_data: Any) -> Any: # type: ignore[type-arg] + if captured_workflow is None: + raise RuntimeError("Workflow not configured") + + # Pass the deserialized client input straight to the shared engine, which + # reconstructs the start executor's declared type (see _coerce_initial_input). + initial_message = input_data + shared_state: dict[str, Any] = {} + + dt_ctx = DurableTaskWorkflowContext(context) + outputs = yield from run_workflow_orchestrator(dt_ctx, captured_workflow, initial_message, shared_state) + return outputs # noqa: B901 + + workflow_orchestrator.__name__ = WORKFLOW_ORCHESTRATOR_NAME + workflow_orchestrator.__qualname__ = WORKFLOW_ORCHESTRATOR_NAME + + self._worker.add_orchestrator(workflow_orchestrator) # type: ignore[arg-type] + logger.debug("[DurableAIAgentWorker] Registered workflow orchestrator") + def __create_agent_entity( self, agent: SupportsAgentRun, callback: AgentResponseCallbackProtocol | None = None, + *, + entity_id: str | None = None, ) -> type[DurableTaskEntityStateProvider]: """Factory function to create a DurableEntity class configured with an agent. @@ -153,11 +266,14 @@ def __create_agent_entity( Args: agent: The agent instance to wrap callback: Optional callback for agent responses + entity_id: Optional identity to register the entity under instead of + ``agent.name`` (used by workflow hosting to key entities by + executor id). Returns: A new DurableEntity subclass configured for this agent """ - agent_name = agent.name or type(agent).__name__ + agent_name = entity_id or agent.name or type(agent).__name__ entity_name = f"dafx-{agent_name}" class ConfiguredAgentEntity(DurableTaskEntityStateProvider): @@ -187,7 +303,10 @@ def run(self, request: Any) -> Any: AgentResponse as dict """ logger.debug("[ConfiguredAgentEntity.run] Executing agent: %s", agent_name) - response = asyncio.run(self._agent_entity.run(request)) + # Run on the shared persistent loop so async resources created by + # shared agent clients/credentials stay bound to a live loop across + # successive entity invocations (avoids cross-loop hangs). + response = run_agent_coroutine(self._agent_entity.run(request)) return response.to_dict() def reset(self) -> None: diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflows/__init__.py b/python/packages/durabletask/agent_framework_durabletask/_workflows/__init__.py new file mode 100644 index 00000000000..3ebffaa7e78 --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_workflows/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Durable hosting of Microsoft Agent Framework workflows. + +This subpackage turns a MAF :class:`~agent_framework.Workflow` into durable +primitives -- a single orchestrator, agent entities, and non-agent executor +activities -- that run on either a standalone Durable Task worker or Azure +Functions. The host-agnostic engine lives here; each host programs against the +:class:`~.context.WorkflowOrchestrationContext` protocol. +""" diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflows/activity.py b/python/packages/durabletask/agent_framework_durabletask/_workflows/activity.py new file mode 100644 index 00000000000..0e48417ceef --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_workflows/activity.py @@ -0,0 +1,181 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Host-agnostic execution of non-agent workflow executors as durable activities. + +When a MAF :class:`Workflow` runs as a durable orchestration, each non-agent +executor is dispatched as a durable *activity*. The activity body is identical +regardless of host (Azure Functions or a standalone durabletask worker): it +deserializes the activity input, runs the executor (or a human-in-the-loop +response handler), diffs the shared state, and serializes the executor's +outputs, sent messages, shared-state changes, and any pending HITL requests back +to the orchestrator. + +This module provides that shared body as :func:`execute_workflow_activity` so +both host adapters call one implementation instead of duplicating it. +""" + +from __future__ import annotations + +import asyncio +import json +from copy import deepcopy +from typing import Any, cast + +from agent_framework import Executor, Workflow, WorkflowEvent +from agent_framework._workflows._runner_context import YieldOutputEventType +from agent_framework._workflows._state import State + +from .orchestrator import ( + SOURCE_HITL_RESPONSE, + SOURCE_ORCHESTRATOR, + execute_hitl_response_handler, +) +from .runner_context import CapturingRunnerContext +from .serialization import deserialize_value, serialize_value + + +def execute_workflow_activity(executor: Executor, input_json: str, workflow: Workflow | None = None) -> str: + """Execute a single non-agent workflow executor and return its serialized result. + + This is the host-agnostic activity body shared by the Azure Functions and + standalone durabletask workflow hosts. + + Args: + executor: The non-agent executor instance to run. + input_json: JSON-encoded activity input with keys ``message``, + ``shared_state_snapshot``, and ``source_executor_ids``. + workflow: The owning workflow, used to classify the executor's + ``yield_output`` payloads as final ``output`` vs ``intermediate``. + When omitted, all yielded outputs are treated as final outputs. + + Returns: + A JSON string with keys ``sent_messages``, ``outputs``, + ``shared_state_updates``, ``shared_state_deletes``, and + ``pending_request_info_events``. + + Raises: + ValueError: If the input does not decode to a JSON object, or a HITL + message payload is not a JSON object. + """ + data_obj = json.loads(input_json) + if not isinstance(data_obj, dict): + raise ValueError("Activity input must decode to a JSON object") + data = cast(dict[str, Any], data_obj) + + message_data = data.get("message") + # The orchestrator may pass null for these when shared state / sources are + # omitted, so coerce None to the appropriate empty default. + shared_state_snapshot: dict[str, Any] = data.get("shared_state_snapshot") or {} + source_executor_ids = cast(list[str], data.get("source_executor_ids") or [SOURCE_ORCHESTRATOR]) + + # Reconstruct the message - deserialize_value restores the original typed + # objects from the encoded data (with type markers). + message = deserialize_value(message_data) + + # A HITL response is identified by a source id starting with the HITL prefix. + is_hitl_response = any(s.startswith(SOURCE_HITL_RESPONSE) for s in source_executor_ids) + + def classify_yielded_output(executor_id: str) -> YieldOutputEventType | None: + # Mirror the core runner's classification so intermediate executors' + # yields are not surfaced as final workflow outputs. + if workflow is None: + return "output" + if workflow.is_terminal_executor(executor_id): + return "output" + if workflow.is_intermediate_executor(executor_id): + return "intermediate" + return None + + async def _run() -> dict[str, Any]: + runner_context = CapturingRunnerContext() + runner_context.set_yield_output_classifier(classify_yielded_output) + shared_state = State() + + # Deserialize shared state values to reconstruct dataclasses / Pydantic models. + deserialized_state: dict[str, Any] = { + str(k): deserialize_value(v) for k, v in shared_state_snapshot.items() + } + # Snapshot the deserialized (in-memory) state for diffing. State.export_state() + # returns the in-memory committed objects, so the snapshot must hold objects + # too (deepcopy) - comparing against a serialized snapshot would mark every + # key as changed. + original_snapshot = deepcopy(deserialized_state) + shared_state.import_state(deserialized_state) + + if is_hitl_response: + if not isinstance(message_data, dict): + raise ValueError("HITL message payload must be a JSON object") + await execute_hitl_response_handler( + executor=executor, + hitl_message=cast(dict[str, Any], message_data), + shared_state=shared_state, + runner_context=runner_context, + ) + else: + await executor.execute( + message=message, + source_executor_ids=source_executor_ids, + state=shared_state, + runner_context=runner_context, + ) + + # Commit pending state changes and compute the diff vs the original snapshot. + shared_state.commit() + current_state = shared_state.export_state() + original_keys: set[str] = set(original_snapshot.keys()) + current_keys: set[str] = set(current_state.keys()) + + # Deleted = was in original, not in current. + deletes: set[str] = original_keys - current_keys + + # Updates = keys that are new or whose value changed. + updates: dict[str, Any] = {} + for key in current_keys: + if key not in original_keys or current_state[key] != original_snapshot.get(key): + updates[key] = current_state[key] + + sent_messages = await runner_context.drain_messages() + events = await runner_context.drain_events() + + # Extract outputs from WorkflowEvent instances with type='output'. + outputs: list[Any] = [] + for event in events: + if isinstance(event, WorkflowEvent) and event.type == "output": + outputs.append(serialize_value(event.data)) + + # Serialize pending HITL request info events for the orchestrator. + pending_request_info_events = await runner_context.get_pending_request_info_events() + serialized_pending_requests: list[dict[str, Any]] = [] + for _request_id, event in pending_request_info_events.items(): + serialized_pending_requests.append({ + "request_id": event.request_id, + "source_executor_id": event.source_executor_id, + "data": serialize_value(event.data), + "request_type": f"{type(event.data).__module__}:{type(event.data).__name__}", + "response_type": f"{event.response_type.__module__}:{event.response_type.__name__}" + if event.response_type + else None, + }) + + # Serialize sent messages for JSON compatibility. + serialized_sent_messages: list[dict[str, Any]] = [] + for _source_id, msg_list in sent_messages.items(): + for msg in msg_list: + serialized_sent_messages.append({ + "message": serialize_value(msg.data), + "target_id": msg.target_id, + "source_id": msg.source_id, + }) + + serialized_updates = {k: serialize_value(v) for k, v in updates.items()} + + return { + "sent_messages": serialized_sent_messages, + "outputs": outputs, + "shared_state_updates": serialized_updates, + "shared_state_deletes": list(deletes), + "pending_request_info_events": serialized_pending_requests, + } + + result = asyncio.run(_run()) + return json.dumps(result) diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflows/client.py b/python/packages/durabletask/agent_framework_durabletask/_workflows/client.py new file mode 100644 index 00000000000..612576c715c --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_workflows/client.py @@ -0,0 +1,202 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Workflow client wrapper for Durable Task Agent Framework. + +This module provides :class:`DurableWorkflowClient` for external clients to start, +await, and drive (including human-in-the-loop) workflows registered on a worker via +``DurableAIAgentWorker.configure_workflow``. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, cast + +from durabletask.client import TaskHubGrpcClient + +from .orchestrator import WORKFLOW_ORCHESTRATOR_NAME +from .serialization import deserialize_workflow_output, strip_pickle_markers + +logger = logging.getLogger("agent_framework.durabletask") + + +class DurableWorkflowClient: + """Client wrapper for starting and driving durable workflows externally. + + This class wraps a durabletask ``TaskHubGrpcClient`` and provides a convenient + interface for the workflow registered by ``DurableAIAgentWorker.configure_workflow``: + starting it, awaiting its output, and responding to human-in-the-loop (HITL) pauses. + + For interacting with individual durable *agents*, use + :class:`~agent_framework_durabletask.DurableAIAgentClient` instead. Both wrap the + same underlying ``TaskHubGrpcClient``, so an application that needs both can + construct both over one client. + + Example: + ```python + from durabletask.azuremanaged.client import DurableTaskSchedulerClient + from agent_framework.azure import DurableWorkflowClient + + # Create the underlying client + client = DurableTaskSchedulerClient(host_address="localhost:8080", taskhub="default") + + # Wrap it with the workflow client + workflow_client = DurableWorkflowClient(client) + + # Start a workflow and wait for its output + instance_id = workflow_client.start_workflow(input="some input") + output = workflow_client.await_workflow_output(instance_id) + print(output) + ``` + """ + + def __init__(self, client: TaskHubGrpcClient): + """Initialize the workflow client wrapper. + + Args: + client: The durabletask client instance to wrap. + """ + self._client = client + logger.debug("[DurableWorkflowClient] Initialized with client type: %s", type(client).__name__) + + def start_workflow(self, input: Any = None, *, instance_id: str | None = None) -> str: + """Start the workflow orchestration registered by ``configure_workflow``. + + This schedules the orchestrator that ``DurableAIAgentWorker.configure_workflow`` + auto-registers, so callers do not need to know its internal name. + + Args: + input: The initial message/payload for the workflow. + instance_id: Optional explicit orchestration instance ID. If omitted, one + is generated. + + Returns: + The orchestration instance ID, for use with ``await_workflow_output``. + """ + new_instance_id = self._client.schedule_new_orchestration( + WORKFLOW_ORCHESTRATOR_NAME, + input=input, + instance_id=instance_id, + ) + logger.debug("[DurableWorkflowClient] Started workflow instance: %s", new_instance_id) + return new_instance_id + + def await_workflow_output(self, instance_id: str, *, timeout_seconds: int = 300) -> Any: + """Wait for a workflow orchestration to complete and return its output. + + Args: + instance_id: The instance ID returned by ``start_workflow``. + timeout_seconds: Maximum time, in seconds, to wait for completion. + + Returns: + The deserialized workflow output (typically a list of yielded outputs), + or ``None`` if the workflow produced no output. + + Raises: + TimeoutError: If the workflow does not complete within ``timeout_seconds``. + RuntimeError: If the workflow completes with a non-successful status. + """ + metadata = self._client.wait_for_orchestration_completion(instance_id, timeout=timeout_seconds) + if metadata is None: + raise TimeoutError(f"Workflow '{instance_id}' did not complete within {timeout_seconds}s") + + status = metadata.runtime_status.name + if status != "COMPLETED": + raise RuntimeError(f"Workflow '{instance_id}' ended with status {status}: {metadata.serialized_output}") + + if metadata.serialized_output is None: + return None + # The shared activity encodes each yielded output with serialize_value() + # before it reaches the orchestrator, so typed objects come back as + # checkpoint-marker dicts. Reconstruct the originals before returning. + return deserialize_workflow_output(json.loads(metadata.serialized_output)) + + def get_runtime_status(self, instance_id: str) -> str | None: + """Return the workflow's current runtime status name, or ``None`` if unknown. + + Lets callers distinguish a workflow that is still running or paused for + human input from one that has reached a terminal state (for example + ``COMPLETED``, ``FAILED``, or ``TERMINATED``) — useful when polling, so a + workflow that ends without pausing is not mistaken for one that never paused. + + Args: + instance_id: The instance ID returned by ``start_workflow``. + + Returns: + The runtime status name (e.g. ``"RUNNING"``, ``"COMPLETED"``), or + ``None`` if no state is available for the instance. + """ + state = self._client.get_orchestration_state(instance_id) + if state is None: + return None + return state.runtime_status.name + + def get_pending_hitl_requests(self, instance_id: str) -> list[dict[str, Any]]: + """Return the workflow's pending human-in-the-loop (HITL) requests, if any. + + While a workflow is paused awaiting human input, the orchestrator records the + open requests in its custom status. This method reads and normalizes that + status so callers do not need to know its internal schema. + + Args: + instance_id: The workflow instance ID returned by ``start_workflow``. + + Returns: + A list of pending requests. Each entry contains ``request_id``, + ``source_executor_id``, ``data``, ``request_type``, and ``response_type``. + Empty if the workflow is not currently waiting for human input. + """ + state = self._client.get_orchestration_state(instance_id) + if state is None or not state.serialized_custom_status: + return [] + + try: + custom_status = json.loads(state.serialized_custom_status) + except (json.JSONDecodeError, TypeError): + return [] + + if not isinstance(custom_status, dict): + return [] + status_dict = cast(dict[str, Any], custom_status) + + pending = status_dict.get("pending_requests") + if not isinstance(pending, dict): + return [] + pending_dict = cast(dict[str, Any], pending) + + requests: list[dict[str, Any]] = [] + for request_id, req_data in pending_dict.items(): + if not isinstance(req_data, dict): + continue + req = cast(dict[str, Any], req_data) + requests.append({ + "request_id": req.get("request_id", request_id), + "source_executor_id": req.get("source_executor_id"), + "data": req.get("data"), + "request_type": req.get("request_type"), + "response_type": req.get("response_type"), + }) + return requests + + def send_hitl_response(self, instance_id: str, request_id: str, response: Any) -> None: + """Send a response to a pending HITL request, resuming the workflow. + + The orchestrator correlates the response by using ``request_id`` as the + external-event name, so callers do not need to know that convention. + + Args: + instance_id: The workflow instance ID. + request_id: The pending request's ID (from ``get_pending_hitl_requests``). + response: The response payload (e.g. a dict matching the expected + response type the executor's ``@response_handler`` expects). + + Note: + The payload is sanitized with ``strip_pickle_markers`` before delivery to + neutralize pickle-marker injection, since the worker deserializes it. + """ + safe_response = strip_pickle_markers(response) + self._client.raise_orchestration_event(instance_id, event_name=request_id, data=safe_response) + logger.debug( + "[DurableWorkflowClient] Sent HITL response for request %s on instance %s", request_id, instance_id + ) diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflows/context.py b/python/packages/durabletask/agent_framework_durabletask/_workflows/context.py new file mode 100644 index 00000000000..3d31cf90e27 --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_workflows/context.py @@ -0,0 +1,143 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Protocol definition for workflow orchestration contexts. + +This module defines the ``WorkflowOrchestrationContext`` protocol that abstracts +the differences between Azure Functions' ``DurableOrchestrationContext`` and the +standalone ``durabletask.task.OrchestrationContext``. The shared workflow +orchestrator (:func:`run_workflow_orchestrator`) programs against this protocol +so that the same orchestration logic works on any host. + +Each host provides a thin adapter that maps its native context to this protocol: + +- ``DurableTaskWorkflowContext`` (this package) — wraps ``OrchestrationContext`` +- ``AzureFunctionsWorkflowContext`` (azurefunctions package) — wraps + ``DurableOrchestrationContext`` +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any, Protocol, runtime_checkable + + +@runtime_checkable +class WorkflowOrchestrationContext(Protocol): + """Host-agnostic interface for workflow orchestration primitives. + + All methods that return yieldable tasks return ``Any`` because the concrete + task types differ between hosting SDKs (``TaskBase`` for Azure Functions, + ``Task[T]`` for durabletask). The generator-based orchestrator simply + yields these opaque objects back to the hosting framework. + """ + + @property + def instance_id(self) -> str: + """The unique ID of the current orchestration instance.""" + ... + + @property + def current_utc_datetime(self) -> datetime: + """The current replay-safe UTC datetime.""" + ... + + def prepare_agent_task(self, executor_id: str, message: str, orchestration_instance_id: str) -> Any: + """Create a yieldable task that runs an agent executor. + + Args: + executor_id: Agent name / executor ID. + message: The text message to send to the agent. + orchestration_instance_id: Instance ID used as the entity session key. + + Returns: + A yieldable task whose result is an ``AgentResponse``. + """ + ... + + def prepare_activity_task(self, activity_name: str, input_json: str) -> Any: + """Create a yieldable task that runs an activity executor. + + Args: + activity_name: The registered activity function name. + input_json: JSON-serialized activity input. + + Returns: + A yieldable task whose result is a JSON string. + """ + ... + + def task_all(self, tasks: list[Any]) -> Any: + """Create a yieldable composite task that completes when *all* tasks complete. + + Args: + tasks: List of yieldable tasks. + + Returns: + A yieldable task whose result is a list of individual results. + """ + ... + + def task_any(self, tasks: list[Any]) -> Any: + """Create a yieldable composite task that completes when *any* task completes. + + Args: + tasks: List of yieldable tasks. + + Returns: + A yieldable task whose result is the winning task. + """ + ... + + def wait_for_external_event(self, name: str) -> Any: + """Create a yieldable task that waits for a named external event. + + Args: + name: Event name to wait for. + + Returns: + A yieldable task whose result is the event payload. + """ + ... + + def create_timer(self, fire_at: datetime) -> Any: + """Create a yieldable timer task. + + Args: + fire_at: UTC datetime when the timer should fire. + + Returns: + A yieldable timer task. + """ + ... + + def set_custom_status(self, status: Any) -> None: + """Set the orchestration's custom status (visible to external clients). + + Args: + status: JSON-serializable status object. + """ + ... + + def new_uuid(self) -> str: + """Generate a replay-safe UUID.""" + ... + + def cancel_task(self, task: Any) -> None: + """Best-effort cancellation of a pending task. + + Args: + task: The task to cancel. If the underlying SDK does not support + cancellation this is a no-op. + """ + ... + + def get_task_result(self, task: Any) -> Any: + """Extract the result from a completed task. + + Args: + task: A completed task object. + + Returns: + The result value. + """ + ... diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflows/dt_context.py b/python/packages/durabletask/agent_framework_durabletask/_workflows/dt_context.py new file mode 100644 index 00000000000..1e6923e78d3 --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_workflows/dt_context.py @@ -0,0 +1,96 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""DurableTask SDK adapter for WorkflowOrchestrationContext. + +Wraps ``durabletask.task.OrchestrationContext`` to satisfy the +:class:`WorkflowOrchestrationContext` protocol. +""" + +from __future__ import annotations + +import logging +from datetime import datetime +from typing import Any, cast + +from durabletask.task import ( + OrchestrationContext, + Task, + when_all, + when_any, # pyright: ignore[reportUnknownVariableType] +) + +from .._executors import OrchestrationAgentExecutor +from .._models import AgentSessionId, DurableAgentSession +from .._shim import DurableAIAgent +from .context import WorkflowOrchestrationContext + +logger = logging.getLogger(__name__) + + +class DurableTaskWorkflowContext: + """Adapter that maps ``OrchestrationContext`` to :class:`WorkflowOrchestrationContext`.""" + + def __init__(self, context: OrchestrationContext) -> None: + self._context = context + self._executor = OrchestrationAgentExecutor(context) + + # -- Properties ----------------------------------------------------------- + + @property + def instance_id(self) -> str: + return self._context.instance_id + + @property + def current_utc_datetime(self) -> datetime: + return self._context.current_utc_datetime + + # -- Agent / Activity dispatch -------------------------------------------- + + def prepare_agent_task(self, executor_id: str, message: str, orchestration_instance_id: str) -> Any: + session_id = AgentSessionId(name=executor_id, key=orchestration_instance_id) + session = DurableAgentSession(durable_session_id=session_id) + agent = DurableAIAgent(self._executor, executor_id) + return agent.run(message, session=session) + + def prepare_activity_task(self, activity_name: str, input_json: str) -> Any: + return cast(Any, self._context.call_activity(activity_name, input=input_json)) + + # -- Composite tasks ------------------------------------------------------ + + def task_all(self, tasks: list[Any]) -> Any: + return when_all(tasks) + + def task_any(self, tasks: list[Any]) -> Any: + return when_any(tasks) + + # -- External events / timers --------------------------------------------- + + def wait_for_external_event(self, name: str) -> Any: + return cast(Any, self._context).wait_for_external_event(name) + + def create_timer(self, fire_at: datetime) -> Any: + return cast(Any, self._context).create_timer(fire_at) + + # -- Status / utility ----------------------------------------------------- + + def set_custom_status(self, status: Any) -> None: + self._context.set_custom_status(status) + + def new_uuid(self) -> str: + return self._context.new_uuid() + + def cancel_task(self, task: Any) -> None: + # durabletask Task doesn't expose cancel(); this is a best-effort no-op. + cancel_fn = getattr(task, "cancel", None) + if callable(cancel_fn): + cancel_fn() + + def get_task_result(self, task: Any) -> Any: + if isinstance(task, Task): + return cast(Any, task.get_result()) + return getattr(task, "result", None) + + +# Ensure the adapter satisfies the protocol. Validated statically by the type +# checker (and at every ``run_workflow_orchestrator`` call site) with no runtime cost. +_protocol_check: type[WorkflowOrchestrationContext] = DurableTaskWorkflowContext diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflows/orchestrator.py b/python/packages/durabletask/agent_framework_durabletask/_workflows/orchestrator.py new file mode 100644 index 00000000000..7ad3c026bbf --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_workflows/orchestrator.py @@ -0,0 +1,869 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Host-agnostic workflow orchestration engine. + +This module provides the shared workflow orchestration logic that executes MAF +Workflows as durable task orchestrations. It programs against the +:class:`WorkflowOrchestrationContext` protocol so that the same code runs on +both Azure Functions and standalone durabletask hosts. + +Key components: + +* :func:`run_workflow_orchestrator` — main generator-based orchestrator +* Routing helpers (edge groups, fan-in, HITL) +* Result processing helpers + +All host-specific task creation (agent dispatch, activity dispatch, task_all / +task_any) is delegated to the ``WorkflowOrchestrationContext`` adapter. +""" + +from __future__ import annotations + +import inspect +import json +import logging +from collections import defaultdict +from collections.abc import Generator +from dataclasses import dataclass +from datetime import timedelta +from enum import Enum +from typing import Any + +from agent_framework import ( + AgentExecutor, + AgentExecutorRequest, + AgentExecutorResponse, + AgentResponse, + Executor, + Message, + Workflow, + WorkflowConvergenceException, +) +from agent_framework._workflows._edge import ( + Edge, + EdgeGroup, + FanInEdgeGroup, + FanOutEdgeGroup, + SingleEdgeGroup, + SwitchCaseEdgeGroup, +) +from agent_framework._workflows._state import State + +from .context import WorkflowOrchestrationContext +from .serialization import ( + deserialize_value, + reconstruct_to_type, + resolve_type, + serialize_value, + strip_pickle_markers, +) + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Source Marker Constants +# ============================================================================ + +SOURCE_WORKFLOW_START = "__workflow_start__" +SOURCE_ORCHESTRATOR = "__orchestrator__" +SOURCE_HITL_RESPONSE = "__hitl_response__" + +# Name of the auto-generated orchestrator registered by +# ``DurableAIAgentWorker.configure_workflow`` (and the Azure Functions host). +# Standalone clients start a configured workflow by scheduling an orchestration +# with this name, e.g. +# ``client.schedule_new_orchestration(WORKFLOW_ORCHESTRATOR_NAME, input=...)``. +WORKFLOW_ORCHESTRATOR_NAME = "workflow_orchestrator" + + +# ============================================================================ +# Task Types and Data Structures +# ============================================================================ + + +class TaskType(Enum): + """Type of executor task.""" + + AGENT = "agent" + ACTIVITY = "activity" + + +@dataclass +class TaskMetadata: + """Metadata for a pending task.""" + + executor_id: str + message: Any + source_executor_id: str + task_type: TaskType + remaining_messages: list[tuple[str, Any, str]] | None = None + + +@dataclass +class ExecutorResult: + """Result from executing an agent or activity.""" + + executor_id: str + output_message: AgentExecutorResponse | None + activity_result: dict[str, Any] | None + task_type: TaskType + + +@dataclass +class PendingHITLRequest: + """Tracks a pending Human-in-the-Loop request.""" + + request_id: str + source_executor_id: str + request_data: Any + request_type: str | None + response_type: str | None + + +DEFAULT_HITL_TIMEOUT_HOURS = 72.0 + + +# ============================================================================ +# Routing Functions +# ============================================================================ + + +def _evaluate_edge_condition_sync(edge: Edge, message: Any) -> bool: + """Evaluate an edge's condition synchronously. + + Durable orchestrators run as generators, so conditions are evaluated + synchronously here; the durabletask host does not support ``async`` edge + conditions. A condition that returns an awaitable cannot be evaluated in + this context, so the edge is treated as *not matched* (not traversed) + rather than assuming a result. + """ + condition = edge._condition # pyright: ignore[reportPrivateUsage] + if condition is None: + return True + result = condition(message) + if inspect.isawaitable(result): + # Async conditions cannot be evaluated in a synchronous orchestrator. + # Close the unawaited coroutine to avoid a "never awaited" warning and + # decline to traverse the edge (treated as not matched). + if inspect.iscoroutine(result): + result.close() + logger.warning( + "Edge condition for %s->%s is async and cannot be evaluated by the durabletask host; " + "the edge is not traversed. Use a synchronous condition.", + edge.source_id, + edge.target_id, + ) + return False + return bool(result) + + +def route_message_through_edge_groups( + edge_groups: list[EdgeGroup], + source_id: str, + message: Any, +) -> list[str]: + """Route a message through edge groups to find target executor IDs.""" + targets: list[str] = [] + + for group in edge_groups: + if source_id not in group.source_executor_ids: + continue + + if isinstance(group, (SwitchCaseEdgeGroup, FanOutEdgeGroup)): + if group.selection_func is not None: + selected = group.selection_func(message, group.target_executor_ids) + targets.extend(selected) + else: + targets.extend(group.target_executor_ids) + + elif isinstance(group, SingleEdgeGroup): + edge = group.edges[0] + if _evaluate_edge_condition_sync(edge, message): + targets.append(edge.target_id) + + elif isinstance(group, FanInEdgeGroup): + pass # Handled separately in the orchestrator loop + + else: + for edge in group.edges: + if edge.source_id == source_id and _evaluate_edge_condition_sync(edge, message): + targets.append(edge.target_id) + + return targets + + +def build_agent_executor_response( + executor_id: str, + response_text: str | None, + structured_response: dict[str, Any] | None, + previous_message: Any, +) -> AgentExecutorResponse: + """Build an AgentExecutorResponse from entity response data.""" + final_text: str = response_text or "" + if structured_response: + final_text = json.dumps(structured_response) + + assistant_message = Message(role="assistant", contents=[final_text]) + agent_response = AgentResponse(messages=[assistant_message]) + + full_conversation: list[Message] = [] + if isinstance(previous_message, AgentExecutorResponse) and previous_message.full_conversation: + full_conversation.extend(previous_message.full_conversation) + elif isinstance(previous_message, str): + full_conversation.append(Message(role="user", contents=[previous_message])) + full_conversation.append(assistant_message) + + return AgentExecutorResponse( + executor_id=executor_id, + agent_response=agent_response, + full_conversation=full_conversation, + ) + + +# ============================================================================ +# Task Preparation Helpers +# ============================================================================ + + +def _prepare_agent_task( + ctx: WorkflowOrchestrationContext, + executor_id: str, + message: Any, +) -> Any: + """Prepare an agent task for execution via the context adapter.""" + message_content = _extract_message_content(message) + return ctx.prepare_agent_task(executor_id, message_content, ctx.instance_id) + + +def _prepare_activity_task( + ctx: WorkflowOrchestrationContext, + executor_id: str, + message: Any, + source_executor_id: str, + shared_state_snapshot: dict[str, Any] | None, +) -> Any: + """Prepare an activity task for execution via the context adapter.""" + activity_input = { + "executor_id": executor_id, + "message": serialize_value(message), + "shared_state_snapshot": shared_state_snapshot, + "source_executor_ids": [source_executor_id], + } + activity_input_json = json.dumps(activity_input) + activity_name = f"dafx-{executor_id}" + return ctx.prepare_activity_task(activity_name, activity_input_json) + + +# ============================================================================ +# Result Processing Helpers +# ============================================================================ + + +def _process_agent_response( + agent_response: AgentResponse, + executor_id: str, + message: Any, +) -> ExecutorResult: + """Process an agent response into an ExecutorResult.""" + response_text = agent_response.text if agent_response else None + structured_response: dict[str, Any] | None = None + + if agent_response and agent_response.value is not None: + model_dump = getattr(agent_response.value, "model_dump", None) + if callable(model_dump): + dumped = model_dump() + if isinstance(dumped, dict): + structured_response = dumped # type: ignore[assignment] + elif isinstance(agent_response.value, dict): + structured_response = agent_response.value # type: ignore[assignment] + + output_message = build_agent_executor_response( + executor_id=executor_id, + response_text=response_text, + structured_response=structured_response, + previous_message=message, + ) + + return ExecutorResult( + executor_id=executor_id, + output_message=output_message, + activity_result=None, + task_type=TaskType.AGENT, + ) + + +def _process_activity_result( + result_json: str | None, + executor_id: str, + shared_state: dict[str, Any] | None, + workflow_outputs: list[Any], +) -> ExecutorResult: + """Process an activity result and apply shared state updates.""" + result = json.loads(result_json) if result_json else None + + if shared_state is not None and result: + if result.get("shared_state_updates"): + updates = result["shared_state_updates"] + logger.debug("[workflow] Applying SharedState updates from %s: %s", executor_id, updates) + shared_state.update(updates) + if result.get("shared_state_deletes"): + deletes = result["shared_state_deletes"] + logger.debug("[workflow] Applying SharedState deletes from %s: %s", executor_id, deletes) + for key in deletes: + shared_state.pop(key, None) + + if result and result.get("outputs"): + workflow_outputs.extend(result["outputs"]) + + return ExecutorResult( + executor_id=executor_id, + output_message=None, + activity_result=result, + task_type=TaskType.ACTIVITY, + ) + + +# ============================================================================ +# Routing Helpers +# ============================================================================ + + +def _route_result_messages( + result: ExecutorResult, + workflow: Workflow, + next_pending_messages: dict[str, list[tuple[Any, str]]], + fan_in_pending: dict[str, dict[str, list[tuple[Any, str]]]], +) -> None: + """Route messages from an executor result to their targets.""" + executor_id = result.executor_id + messages_to_route: list[tuple[Any, str | None]] = [] + + if result.output_message: + messages_to_route.append((result.output_message, None)) + + if result.activity_result and result.activity_result.get("sent_messages"): + for msg_data in result.activity_result["sent_messages"]: + sent_msg = msg_data.get("message") + target_id = msg_data.get("target_id") + # Use an explicit None check so legitimately falsy payloads + # (empty string, 0, False) are still routed. + if sent_msg is not None: + sent_msg = deserialize_value(sent_msg) + messages_to_route.append((sent_msg, target_id)) + + for msg_to_route, explicit_target in messages_to_route: + logger.debug("Routing output from %s", executor_id) + + if explicit_target: + if explicit_target not in next_pending_messages: + next_pending_messages[explicit_target] = [] + next_pending_messages[explicit_target].append((msg_to_route, executor_id)) + logger.debug("Routed message from %s to explicit target %s", executor_id, explicit_target) + continue + + for group in workflow.edge_groups: + if isinstance(group, FanInEdgeGroup) and executor_id in group.source_executor_ids: + fan_in_pending[group.id][executor_id].append((msg_to_route, executor_id)) + logger.debug("Accumulated message for FanIn group %s from %s", group.id, executor_id) + + targets = route_message_through_edge_groups(workflow.edge_groups, executor_id, msg_to_route) + + for target_id in targets: + logger.debug("Routing to %s", target_id) + if target_id not in next_pending_messages: + next_pending_messages[target_id] = [] + next_pending_messages[target_id].append((msg_to_route, executor_id)) + + +def _check_fan_in_ready( + workflow: Workflow, + fan_in_pending: dict[str, dict[str, list[tuple[Any, str]]]], + next_pending_messages: dict[str, list[tuple[Any, str]]], +) -> None: + """Check if any FanInEdgeGroups are ready and deliver their messages.""" + for group in workflow.edge_groups: + if not isinstance(group, FanInEdgeGroup): + continue + + pending_sources = fan_in_pending.get(group.id, {}) + + if not all(src in pending_sources and pending_sources[src] for src in group.source_executor_ids): + continue + + aggregated: list[Any] = [] + aggregated_sources: list[str] = [] + for src in group.source_executor_ids: + for msg, msg_source in pending_sources[src]: + aggregated.append(msg) + aggregated_sources.append(msg_source) + + target_id = group.target_executor_ids[0] + logger.debug("FanIn group %s ready, delivering %d messages to %s", group.id, len(aggregated), target_id) + + if target_id not in next_pending_messages: + next_pending_messages[target_id] = [] + + first_source = aggregated_sources[0] if aggregated_sources else "__fan_in__" + next_pending_messages[target_id].append((aggregated, first_source)) + + fan_in_pending[group.id] = defaultdict(list) + + +# ============================================================================ +# HITL Helpers +# ============================================================================ + + +def _collect_hitl_requests( + result: ExecutorResult, + pending_hitl_requests: dict[str, PendingHITLRequest], +) -> None: + """Collect pending HITL requests from an activity result.""" + if result.activity_result and result.activity_result.get("pending_request_info_events"): + for req_data in result.activity_result["pending_request_info_events"]: + request_id = req_data.get("request_id") + if request_id: + pending_hitl_requests[request_id] = PendingHITLRequest( + request_id=request_id, + source_executor_id=req_data.get("source_executor_id", result.executor_id), + request_data=req_data.get("data"), + request_type=req_data.get("request_type"), + response_type=req_data.get("response_type"), + ) + logger.debug( + "Collected HITL request %s from executor %s", + request_id, + result.executor_id, + ) + + +def _route_hitl_response( + hitl_request: PendingHITLRequest, + raw_response: Any, + pending_messages: dict[str, list[tuple[Any, str]]], +) -> None: + """Route a HITL response back to the source executor's @response_handler.""" + response_message = { + "request_id": hitl_request.request_id, + "original_request": hitl_request.request_data, + "response": raw_response, + "response_type": hitl_request.response_type, + } + + target_id = hitl_request.source_executor_id + if target_id not in pending_messages: + pending_messages[target_id] = [] + + source_id = f"{SOURCE_HITL_RESPONSE}_{hitl_request.request_id}" + pending_messages[target_id].append((response_message, source_id)) + + logger.debug( + "Routed HITL response for request %s to executor %s", + hitl_request.request_id, + target_id, + ) + + +# ============================================================================ +# Message Content Extraction +# ============================================================================ + + +def _extract_message_content(message: Any) -> str: + """Extract text content from various message types.""" + message_content = "" + if isinstance(message, AgentExecutorResponse) and message.agent_response: + if message.agent_response.text: + message_content = message.agent_response.text + elif message.agent_response.messages: + message_content = message.agent_response.messages[-1].text or "" + elif isinstance(message, AgentExecutorRequest) and message.messages: + message_content = message.messages[-1].text or "" + elif isinstance(message, dict): + key_names = list(message.keys()) # type: ignore[union-attr] + logger.warning("Unexpected dict message in _extract_message_content. Keys: %s", key_names) # type: ignore + elif isinstance(message, str): + message_content = message + return message_content + + +def _select_primary_input_type(executor: Executor) -> type | None: + """Return the executor's primary concrete declared input type, if any. + + The first declared input type that is a concrete class is used; union or + unannotated types yield ``None`` (the caller then passes the value through + unchanged). + """ + for input_type in executor.input_types: + if isinstance(input_type, type): + return input_type + return None + + +def _coerce_initial_input(workflow: Workflow, raw_value: Any) -> Any: + """Coerce the client's initial workflow input to the start executor's type. + + A durable workflow runs as a durable orchestration, so its initial payload + arrives as plain JSON via ``context.get_input()`` -- without the type markers + that inter-executor messages carry (those are reconstructed by + :func:`deserialize_value`). This single entry hop therefore needs explicit + reconstruction to mirror in-process delivery, where the start executor + receives its declared type: + + * Agent start executors only consume text, so non-text input is stringified. + * Other executors get their primary declared input type reconstructed + (``dict`` -> Pydantic/dataclass, ``str`` -> ``str``, ...) via + :func:`reconstruct_to_type`; union/unannotated types pass through unchanged. + """ + start_executor = workflow.executors.get(workflow.start_executor_id) + if start_executor is None: + return raw_value + + if isinstance(start_executor, AgentExecutor): + if isinstance(raw_value, str): + return raw_value + if isinstance(raw_value, (dict, list)): + return json.dumps(raw_value) + return str(raw_value) + + input_type = _select_primary_input_type(start_executor) + if input_type is None: + return raw_value + # The initial payload is untrusted external input (HTTP body / client input) with no + # legitimate checkpoint type markers, so neutralize any pickle-marker injection before + # it can reach deserialize_value() inside reconstruct_to_type() (avoids pickle RCE). + return reconstruct_to_type(strip_pickle_markers(raw_value), input_type) + + +# ============================================================================ +# HITL Response Handler Execution +# ============================================================================ + + +async def execute_hitl_response_handler( + executor: Any, + hitl_message: dict[str, Any], + shared_state: State, + runner_context: Any, +) -> None: + """Execute a HITL response handler on an executor. + + Args: + executor: The executor instance that has a @response_handler. + hitl_message: The HITL response message dict. + shared_state: The shared state for the workflow context. + runner_context: The runner context for capturing outputs. + """ + from agent_framework._workflows._workflow_context import WorkflowContext + + original_request_data = hitl_message.get("original_request") + response_data = hitl_message.get("response") + response_type_str = hitl_message.get("response_type") + + original_request = deserialize_value(original_request_data) + response = _deserialize_hitl_response(response_data, response_type_str) + + handler = executor._find_response_handler(original_request, response) # pyright: ignore[reportPrivateUsage] + + if handler is None: + logger.warning( + "No response handler found for HITL response in executor %s. Request type: %s, Response type: %s", + executor.id, + type(original_request).__name__, + type(response).__name__, + ) + return + + ctx = WorkflowContext( + executor=executor, + source_executor_ids=[SOURCE_HITL_RESPONSE], + runner_context=runner_context, + state=shared_state, + ) + + logger.debug( + "Invoking response handler for HITL request in executor %s", + executor.id, + ) + await handler(response, ctx) + + +def _deserialize_hitl_response(response_data: Any, response_type_str: str | None) -> Any: + """Deserialize a HITL response to its expected type.""" + logger.debug( + "Deserializing HITL response. response_type_str=%s, response_data type=%s", + response_type_str, + type(response_data).__name__, + ) + + if response_data is None: + return None + + response_data = strip_pickle_markers(response_data) + if response_data is None: + return None + + if not isinstance(response_data, dict): + logger.debug("Response data is not a dict, returning as-is: %s", type(response_data).__name__) + return response_data + + if response_type_str: + response_type = resolve_type(response_type_str) + if response_type: + logger.debug("Found response type %s, attempting reconstruction", response_type) + result = reconstruct_to_type(response_data, response_type) + logger.debug("Reconstructed response type: %s", type(result).__name__) + return result + logger.warning("Could not resolve response type: %s", response_type_str) + + logger.debug("No type hint; returning sanitized data as-is") + return response_data # type: ignore[reportUnknownVariableType] + + +# ============================================================================ +# Task Preparation (All Tasks) +# ============================================================================ + + +def _prepare_all_tasks( + ctx: WorkflowOrchestrationContext, + workflow: Workflow, + pending_messages: dict[str, list[tuple[Any, str]]], + shared_state: dict[str, Any] | None, +) -> tuple[list[Any], list[TaskMetadata], list[tuple[str, Any, str]]]: + """Prepare all pending tasks for parallel execution. + + Groups agent messages by executor ID so that only the first message per agent + runs in the parallel batch. Additional messages to the same agent are returned + for sequential processing. + """ + all_tasks: list[Any] = [] + task_metadata_list: list[TaskMetadata] = [] + remaining_agent_messages: list[tuple[str, Any, str]] = [] + + agent_messages_by_executor: dict[str, list[tuple[str, Any, str]]] = defaultdict(list) + + for executor_id, messages_with_sources in pending_messages.items(): + executor = workflow.executors[executor_id] + is_agent = isinstance(executor, AgentExecutor) + + for message, source_executor_id in messages_with_sources: + if is_agent: + agent_messages_by_executor[executor_id].append((executor_id, message, source_executor_id)) + else: + logger.debug("Preparing activity task: %s", executor_id) + task = _prepare_activity_task(ctx, executor_id, message, source_executor_id, shared_state) + all_tasks.append(task) + task_metadata_list.append( + TaskMetadata( + executor_id=executor_id, + message=message, + source_executor_id=source_executor_id, + task_type=TaskType.ACTIVITY, + ) + ) + + for executor_id, messages_list in agent_messages_by_executor.items(): + first_msg = messages_list[0] + remaining = messages_list[1:] + + logger.debug("Preparing agent task: %s", executor_id) + task = _prepare_agent_task(ctx, first_msg[0], first_msg[1]) + all_tasks.append(task) + task_metadata_list.append( + TaskMetadata( + executor_id=first_msg[0], + message=first_msg[1], + source_executor_id=first_msg[2], + task_type=TaskType.AGENT, + ) + ) + + remaining_agent_messages.extend(remaining) + + return all_tasks, task_metadata_list, remaining_agent_messages + + +# ============================================================================ +# Main Orchestrator +# ============================================================================ + + +def run_workflow_orchestrator( + ctx: WorkflowOrchestrationContext, + workflow: Workflow, + initial_message: Any, + shared_state: dict[str, Any] | None = None, + hitl_timeout_hours: float = DEFAULT_HITL_TIMEOUT_HOURS, +) -> Generator[Any, Any, list[Any]]: + """Traverse and execute the workflow graph as a durable orchestration. + + This is a generator-based orchestrator that works with any host by + programming against the :class:`WorkflowOrchestrationContext` protocol. + + Supports: + - SingleEdgeGroup: Direct 1:1 routing with optional condition + - SwitchCaseEdgeGroup: First matching condition wins + - FanOutEdgeGroup: Broadcast to multiple targets (parallel execution) + - FanInEdgeGroup: Aggregates messages from multiple sources + - SharedState: Cross-executor state sharing (local to orchestration) + - HITL: Human-in-the-loop via request_info / @response_handler + + Args: + ctx: Host-specific orchestration context adapter. + workflow: The MAF Workflow instance to execute. + initial_message: Initial message to send to the start executor. + shared_state: Optional dict for cross-executor state sharing. + hitl_timeout_hours: Timeout in hours for HITL requests. + + Returns: + List of workflow outputs collected from executor activities. + """ + pending_messages: dict[str, list[tuple[Any, str]]] = { + workflow.start_executor_id: [(_coerce_initial_input(workflow, initial_message), SOURCE_WORKFLOW_START)] + } + workflow_outputs: list[Any] = [] + iteration = 0 + + fan_in_pending: dict[str, dict[str, list[tuple[Any, str]]]] = { + group.id: defaultdict(list) for group in workflow.edge_groups if isinstance(group, FanInEdgeGroup) + } + + pending_hitl_requests: dict[str, PendingHITLRequest] = {} + + while pending_messages and iteration < workflow.max_iterations: + logger.debug("Orchestrator iteration %d", iteration) + next_pending_messages: dict[str, list[tuple[Any, str]]] = {} + + # Phase 1: Prepare all tasks + all_tasks, task_metadata_list, remaining_agent_messages = _prepare_all_tasks( + ctx, workflow, pending_messages, shared_state + ) + + # Phase 2: Execute all tasks in parallel + all_results: list[ExecutorResult] = [] + if all_tasks: + logger.debug("Executing %d tasks in parallel (agents + activities)", len(all_tasks)) + raw_results = yield ctx.task_all(all_tasks) + logger.debug("All %d tasks completed", len(all_tasks)) + + for idx, raw_result in enumerate(raw_results): + metadata = task_metadata_list[idx] + if metadata.task_type == TaskType.AGENT: + result = _process_agent_response(raw_result, metadata.executor_id, metadata.message) + else: + result = _process_activity_result(raw_result, metadata.executor_id, shared_state, workflow_outputs) + all_results.append(result) + + # Phase 3: Process sequential agent messages + for executor_id, message, _source_executor_id in remaining_agent_messages: + logger.debug("Processing sequential message for agent: %s", executor_id) + task = _prepare_agent_task(ctx, executor_id, message) + agent_response: AgentResponse = yield task + logger.debug("Agent %s sequential response completed", executor_id) + + result = _process_agent_response(agent_response, executor_id, message) + all_results.append(result) + + # Phase 4: Collect HITL requests + for result in all_results: + _collect_hitl_requests(result, pending_hitl_requests) + + # Phase 5: Route results + for result in all_results: + _route_result_messages(result, workflow, next_pending_messages, fan_in_pending) + + # Phase 6: Check fan-in readiness + _check_fan_in_ready(workflow, fan_in_pending, next_pending_messages) + + pending_messages = next_pending_messages + + # Phase 7: HITL wait + if not pending_messages and pending_hitl_requests: + logger.debug("Workflow paused for HITL - %d pending requests", len(pending_hitl_requests)) + + ctx.set_custom_status({ + "state": "waiting_for_human_input", + "pending_requests": { + req_id: { + "request_id": req.request_id, + "source_executor_id": req.source_executor_id, + "data": req.request_data, + "request_type": req.request_type, + "response_type": req.response_type, + } + for req_id, req in pending_hitl_requests.items() + }, + }) + + for request_id, hitl_request in list(pending_hitl_requests.items()): + # Re-wait until a valid response arrives (or the request times out). A + # payload rejected by sanitization (pickle/type markers) does not consume + # the request, so the caller can resubmit a corrected response instead of + # losing the entire workflow run. + while True: + logger.debug("Waiting for HITL response for request: %s", request_id) + + approval_task = ctx.wait_for_external_event(request_id) + timeout_task = ctx.create_timer(ctx.current_utc_datetime + timedelta(hours=hitl_timeout_hours)) + + winner = yield ctx.task_any([approval_task, timeout_task]) + + if winner != approval_task: + ctx.cancel_task(approval_task) + logger.warning("HITL request %s timed out after %s hours", request_id, hitl_timeout_hours) + raise TimeoutError( + f"Human-in-the-loop request '{request_id}' timed out after {hitl_timeout_hours} hours." + ) + + ctx.cancel_task(timeout_task) + + raw_response = ctx.get_task_result(approval_task) + logger.debug( + "Received HITL response for request %s. Type: %s, Value: %s", + request_id, + type(raw_response).__name__, + raw_response, + ) + + if isinstance(raw_response, str): + try: + raw_response = json.loads(raw_response) + logger.debug("Parsed JSON string response to: %s", type(raw_response).__name__) + except (json.JSONDecodeError, TypeError): + logger.debug("Response is not JSON, keeping as string") + + # Sanitize against pickle-marker injection in case a caller bypassed + # DurableWorkflowClient.send_hitl_response and raised the external + # event directly (e.g. via the raw DTS client). Sanitize *before* + # consuming the request so a rejected payload can be resubmitted. + sanitized_response = strip_pickle_markers(raw_response) + if sanitized_response is None and raw_response is not None: + logger.warning( + "Rejected HITL response for request %s: payload contained " + "disallowed pickle/type markers. Awaiting a new response.", + request_id, + ) + continue + + del pending_hitl_requests[request_id] + _route_hitl_response( + hitl_request, + sanitized_response, + pending_messages, + ) + break + + ctx.set_custom_status({"state": "running"}) + + iteration += 1 + + # Match the core WorkflowRunner: if the loop stopped because max_iterations + # was reached while messages are still pending, the workflow did not converge. + if pending_messages: + raise WorkflowConvergenceException( + f"Workflow did not converge after {workflow.max_iterations} iterations." + ) + + return workflow_outputs # noqa: B901 diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflows/registration.py b/python/packages/durabletask/agent_framework_durabletask/_workflows/registration.py new file mode 100644 index 00000000000..684f25323a7 --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_workflows/registration.py @@ -0,0 +1,74 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Host-agnostic plan for registering a MAF Workflow as a durable orchestration. + +A MAF :class:`Workflow` is hosted by turning each graph node into a durable +primitive: + +- each :class:`AgentExecutor` becomes a durable **entity**, and +- each other :class:`Executor` becomes a durable **activity**, + +driven by a single workflow **orchestrator**. + +The *decision* of which executor maps to which primitive is identical on every +host (Azure Functions or a standalone durabletask worker); only the *mechanism* +for registering them differs (Functions trigger decorators vs. +``worker.add_*``). :func:`plan_workflow_registration` captures the shared +decision so each host applies one consistent plan with its own registration +mechanism — analogous to .NET's shared ``DurableWorkflowOptions`` feeding +host-specific trigger generation. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +from agent_framework import AgentExecutor, Executor, Workflow + +from .orchestrator import WORKFLOW_ORCHESTRATOR_NAME + + +@dataclass +class WorkflowRegistrationPlan: + """The durable primitives a workflow registers, independent of host. + + Attributes: + agent_executors: Agent executors to register as durable entities. The + full :class:`AgentExecutor` is carried (not just its agent) so each + host can register the entity under the executor's ``id`` — the same + identity the orchestrator dispatches to — which keeps + ``AgentExecutor(agent, id=...)`` working when the id differs from + ``agent.name``. + activity_executors: Non-agent executors to register as durable activities. + orchestrator_name: The orchestrator name to register and to start runs with. + """ + + agent_executors: list[AgentExecutor] + activity_executors: list[Executor] + orchestrator_name: str + + +def plan_workflow_registration(workflow: Workflow) -> WorkflowRegistrationPlan: + """Classify a workflow's executors into the durable primitives to register. + + Args: + workflow: The MAF :class:`Workflow` to host. + + Returns: + A :class:`WorkflowRegistrationPlan` describing the agent executors + (entities), non-agent executors (activities), and the orchestrator name. + """ + agent_executors: list[AgentExecutor] = [] + activity_executors: list[Executor] = [] + + for executor in workflow.executors.values(): + if isinstance(executor, AgentExecutor): + agent_executors.append(executor) + else: + activity_executors.append(executor) + + return WorkflowRegistrationPlan( + agent_executors=agent_executors, + activity_executors=activity_executors, + orchestrator_name=WORKFLOW_ORCHESTRATOR_NAME, + ) diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflows/runner_context.py b/python/packages/durabletask/agent_framework_durabletask/_workflows/runner_context.py new file mode 100644 index 00000000000..1851339bb37 --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_workflows/runner_context.py @@ -0,0 +1,147 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Runner context for activity execution within durable orchestrations. + +This module provides the :class:`CapturingRunnerContext` class that captures +messages and events produced during executor execution within activities. +It is host-agnostic and works on any durable task host. +""" + +from __future__ import annotations + +import asyncio +from copy import copy +from typing import Any + +from agent_framework import ( + CheckpointStorage, + RunnerContext, + WorkflowCheckpoint, + WorkflowEvent, + WorkflowMessage, +) +from agent_framework._workflows._runner_context import YieldOutputClassifier, YieldOutputEventType +from agent_framework._workflows._state import State + + +class CapturingRunnerContext(RunnerContext): + """A RunnerContext that captures messages and events for durable activities. + + This context captures all messages and events produced during execution + without requiring durable entity storage, allowing the results to be + returned to the orchestrator. + + Checkpointing is not supported — the orchestrator manages state. + """ + + def __init__(self) -> None: + self._messages: dict[str, list[WorkflowMessage]] = {} + self._event_queue: asyncio.Queue[WorkflowEvent] = asyncio.Queue() + self._pending_request_info_events: dict[str, WorkflowEvent[Any]] = {} + self._workflow_id: str | None = None + self._streaming: bool = False + self._yield_output_classifier: YieldOutputClassifier = lambda _executor_id: "output" + + # -- Messaging ------------------------------------------------------------ + + async def send_message(self, message: WorkflowMessage) -> None: + self._messages.setdefault(message.source_id, []) + self._messages[message.source_id].append(message) + + async def drain_messages(self) -> dict[str, list[WorkflowMessage]]: + messages = copy(self._messages) + self._messages.clear() + return messages + + async def has_messages(self) -> bool: + return bool(self._messages) + + # -- Events --------------------------------------------------------------- + + async def add_event(self, event: WorkflowEvent) -> None: + await self._event_queue.put(event) + + async def drain_events(self) -> list[WorkflowEvent]: + events: list[WorkflowEvent] = [] + while True: + try: + events.append(self._event_queue.get_nowait()) + except asyncio.QueueEmpty: + break + return events + + async def has_events(self) -> bool: + return not self._event_queue.empty() + + async def next_event(self) -> WorkflowEvent: + return await self._event_queue.get() + + # -- Checkpointing (not supported) ---------------------------------------- + + def has_checkpointing(self) -> bool: + return False + + def set_runtime_checkpoint_storage(self, storage: CheckpointStorage) -> None: + pass + + def clear_runtime_checkpoint_storage(self) -> None: + pass + + async def create_checkpoint( + self, + workflow_name: str, + graph_signature_hash: str, + state: State, + previous_checkpoint_id: str | None, + iteration_count: int, + metadata: dict[str, Any] | None = None, + ) -> str: + raise NotImplementedError("Checkpointing is not supported in activity context") + + async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None: + raise NotImplementedError("Checkpointing is not supported in activity context") + + async def apply_checkpoint(self, checkpoint: WorkflowCheckpoint) -> None: + raise NotImplementedError("Checkpointing is not supported in activity context") + + # -- Workflow configuration ----------------------------------------------- + + def set_workflow_id(self, workflow_id: str) -> None: + self._workflow_id = workflow_id + + def reset_for_new_run(self) -> None: + self._messages.clear() + self._event_queue = asyncio.Queue() + self._pending_request_info_events.clear() + self._streaming = False + + def set_streaming(self, streaming: bool) -> None: + self._streaming = streaming + + def is_streaming(self) -> bool: + return self._streaming + + # -- Yield-output classification ------------------------------------------- + + def set_yield_output_classifier(self, classifier: YieldOutputClassifier) -> None: + """Set the classifier used by ``WorkflowContext.yield_output()``.""" + self._yield_output_classifier = classifier + + def classify_yielded_output(self, executor_id: str) -> YieldOutputEventType | None: + """Classify an executor's yield_output payload as output, intermediate, or hidden.""" + return self._yield_output_classifier(executor_id) + + # -- Request Info Events -------------------------------------------------- + + async def add_request_info_event(self, event: WorkflowEvent[Any]) -> None: + self._pending_request_info_events[event.request_id] = event + await self.add_event(event) + + async def send_request_info_response(self, request_id: str, response: Any) -> None: + raise NotImplementedError( + "send_request_info_response is not supported in activity context. " + "Human-in-the-loop scenarios should be handled at the orchestrator level." + ) + + async def get_pending_request_info_events(self) -> dict[str, WorkflowEvent[Any]]: + return dict(self._pending_request_info_events) diff --git a/python/packages/durabletask/agent_framework_durabletask/_workflows/serialization.py b/python/packages/durabletask/agent_framework_durabletask/_workflows/serialization.py new file mode 100644 index 00000000000..0cd616b2cef --- /dev/null +++ b/python/packages/durabletask/agent_framework_durabletask/_workflows/serialization.py @@ -0,0 +1,213 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Serialization utilities for workflow execution. + +This module provides thin wrappers around the core checkpoint encoding system +(encode_checkpoint_value / decode_checkpoint_value) from agent_framework._workflows. + +The core checkpoint encoding uses pickle + base64 for type-safe roundtripping of +arbitrary Python objects (dataclasses, Pydantic models, Message, etc.) while +keeping JSON-native types (str, int, float, bool, None) as-is. + +This module adds: +- serialize_value / deserialize_value: convenience aliases for encode/decode +- reconstruct_to_type: for HITL responses where external data (without type markers) + needs to be reconstructed to a known type +- resolve_type: resolves 'module:class' type keys to Python types +""" + +from __future__ import annotations + +import importlib +import logging +from contextlib import suppress +from dataclasses import is_dataclass +from typing import Any, cast + +from agent_framework._workflows._checkpoint_encoding import ( + _PICKLE_MARKER, # pyright: ignore[reportPrivateUsage] + _TYPE_MARKER, # pyright: ignore[reportPrivateUsage] + decode_checkpoint_value, + encode_checkpoint_value, +) +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + + +def resolve_type(type_key: str) -> type | None: + """Resolve a 'module:class' type key to its Python type. + + Args: + type_key: Fully qualified type reference in 'module_name:class_name' format. + + Returns: + The resolved type, or None if resolution fails. + """ + try: + module_name, class_name = type_key.split(":", 1) + module = importlib.import_module(module_name) + resolved = getattr(module, class_name, None) + # Only return actual classes. A non-type attribute (function, module member, + # etc.) would raise TypeError in issubclass() inside reconstruct_to_type(). + return resolved if isinstance(resolved, type) else None + except Exception: + logger.debug("Could not resolve type %s", type_key) + return None + + +# ============================================================================ +# Pickle marker sanitization (security) +# ============================================================================ + + +def strip_pickle_markers(data: Any) -> Any: + """Recursively strip pickle/type markers from untrusted data. + + The core checkpoint encoding uses ``__pickled__`` and ``__type__`` markers to + roundtrip arbitrary Python objects via *pickle*. If an attacker crafts an + HTTP payload that contains these markers, the data would flow into + ``pickle.loads()`` and enable **arbitrary code execution**. + + This function walks the incoming data structure and replaces any ``dict`` + that contains either marker key with ``None``, neutralizing the attack + vector while leaving all other data untouched. + + It **must** be called on every value that originates from an untrusted + source (e.g. ``req.get_json()``) *before* the value is passed to + ``deserialize_value`` / ``decode_checkpoint_value``. + """ + if isinstance(data, dict): + if _PICKLE_MARKER in data or _TYPE_MARKER in data: + logger.debug("Stripped pickle/type markers from untrusted input.") + return None + typed_dict = cast(dict[str, Any], data) + return {k: strip_pickle_markers(v) for k, v in typed_dict.items()} + + if isinstance(data, list): + typed_list = cast(list[Any], data) # type: ignore[redundant-cast] + return [strip_pickle_markers(item) for item in typed_list] + + return data + + +# ============================================================================ +# Serialize / Deserialize +# ============================================================================ + + +def serialize_value(value: Any) -> Any: + """Serialize a value for JSON-compatible cross-activity communication. + + Delegates to core checkpoint encoding which uses pickle + base64 for + non-JSON-native types (dataclasses, Pydantic models, Message, etc.). + + Args: + value: Any Python value (primitive, dataclass, Pydantic model, Message, etc.) + + Returns: + A JSON-serializable representation with embedded type metadata for reconstruction. + """ + return encode_checkpoint_value(value) + + +def deserialize_value(value: Any) -> Any: + """Deserialize a value previously serialized with serialize_value(). + + Delegates to core checkpoint decoding which unpickles base64-encoded values + and verifies type integrity. + + Args: + value: The serialized data (dict with pickle markers, list, or primitive) + + Returns: + Reconstructed typed object if type metadata found, otherwise original value. + """ + return decode_checkpoint_value(value) + + +def deserialize_workflow_output(output: Any) -> Any: + """Reconstruct the workflow outputs produced by the shared activity. + + Each value an executor yields is encoded with :func:`serialize_value` before + it reaches the orchestrator, so typed objects (dataclasses, Pydantic models, + ``AgentResponse``, ...) are stored as checkpoint-marker dicts. This reverses + that encoding so callers receive the original objects. + + This is the single decode path shared by every host (the in-process + :class:`DurableWorkflowClient` and the Azure Functions status endpoint) so + they never diverge in how a completed workflow's output is reconstructed. + + ``output`` must originate from the workflow's own orchestration result + (trusted durable storage), never from untrusted external input. Markers in + untrusted input must be neutralized with :func:`strip_pickle_markers` first. + + Args: + output: The workflow's orchestration result, already JSON-decoded (a list + of yielded outputs or a single value). + + Returns: + The output with every checkpoint-encoded value reconstructed; primitives + and plain JSON structures pass through unchanged. + """ + return deserialize_value(output) + + +# ============================================================================ +# HITL Type Reconstruction +# ============================================================================ + + +def reconstruct_to_type(value: Any, target_type: type) -> Any: + """Reconstruct a value to a known target type. + + Used for HITL responses where external data (without checkpoint type markers) + needs to be reconstructed to a specific type determined by the response_type hint. + + Tries strategies in order: + 1. Return as-is if already the correct type + 2. deserialize_value (for data with any type markers) + 3. Pydantic model_validate (for Pydantic models) + 4. Dataclass constructor (for dataclasses) + + Args: + value: The value to reconstruct (typically a dict from JSON) + target_type: The expected type to reconstruct to + + Returns: + Reconstructed value if possible, otherwise the original value + """ + if value is None: + return None + + with suppress(TypeError): + if isinstance(value, target_type): + return value + + if not isinstance(value, dict): + return value + + # Try decoding if data has pickle markers (from checkpoint encoding). + # NOTE: This function is general-purpose. Callers that handle untrusted + # data (e.g. HITL responses) MUST call strip_pickle_markers() before + # passing data here. See _deserialize_hitl_response in orchestrator.py. + decoded = deserialize_value(value) + if not isinstance(decoded, dict): + return decoded + + # Try Pydantic model validation (for unmarked dicts, e.g., external HITL data) + if issubclass(target_type, BaseModel): + try: + return target_type.model_validate(value) + except Exception: + logger.debug("Could not validate Pydantic model %s", target_type) + return value # type: ignore[return-value] + + # Try dataclass construction (for unmarked dicts, e.g., external HITL data) + if is_dataclass(target_type) and isinstance(target_type, type): # type: ignore + try: + return target_type(**value) + except Exception: + logger.debug("Could not construct dataclass %s", target_type) + + return value # type: ignore[return-value] diff --git a/python/packages/durabletask/tests/integration_tests/conftest.py b/python/packages/durabletask/tests/integration_tests/conftest.py index 65202e29079..1a1cd142aca 100644 --- a/python/packages/durabletask/tests/integration_tests/conftest.py +++ b/python/packages/durabletask/tests/integration_tests/conftest.py @@ -21,7 +21,7 @@ from durabletask.azuremanaged.client import DurableTaskSchedulerClient from durabletask.client import OrchestrationStatus -from agent_framework_durabletask import DurableAIAgentClient +from agent_framework_durabletask import DurableAIAgentClient, DurableWorkflowClient # Load environment variables from .env file load_dotenv(Path(__file__).parent / ".env") @@ -492,3 +492,10 @@ def create(cls, max_poll_retries: int = 90) -> tuple[DurableTaskSchedulerClient, return create_agent_client(cls.endpoint, cls.taskhub, max_poll_retries) return AgentClientFactory + + +@pytest.fixture(scope="module") +def workflow_client(worker_process: dict[str, Any]) -> DurableWorkflowClient: + """Create a DurableWorkflowClient bound to the current sample worker's task hub.""" + dts_client = create_dts_client(worker_process["endpoint"], worker_process["taskhub"]) + return DurableWorkflowClient(dts_client) diff --git a/python/packages/durabletask/tests/integration_tests/test_08_dt_workflow.py b/python/packages/durabletask/tests/integration_tests/test_08_dt_workflow.py new file mode 100644 index 00000000000..edf271757eb --- /dev/null +++ b/python/packages/durabletask/tests/integration_tests/test_08_dt_workflow.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Integration tests for the standalone durabletask workflow sample (08_workflow). + +Exercises the standalone (non-Azure-Functions) workflow path: +- ``DurableAIAgentWorker.configure_workflow`` auto-registers the agent entities, + non-agent executor activities, and the workflow orchestrator. +- A client starts the workflow by scheduling ``WORKFLOW_ORCHESTRATOR_NAME``. +- Conditional routing sends spam to a non-agent handler and legitimate email + through a second agent and a sender executor. +""" + +import logging + +import pytest +from durabletask.client import OrchestrationStatus + +from agent_framework_durabletask import WORKFLOW_ORCHESTRATOR_NAME + +logging.basicConfig(level=logging.WARNING) + +# Module-level markers +pytestmark = [ + pytest.mark.flaky, + pytest.mark.integration, + pytest.mark.sample("08_workflow"), + pytest.mark.integration_test, + pytest.mark.requires_dts, +] + + +class TestStandaloneWorkflow: + """Standalone (non-Azure-Functions) workflow execution on a durabletask worker.""" + + @pytest.fixture(autouse=True) + def setup(self, agent_client_factory: type, orchestration_helper) -> None: + """Provide a DTS client and orchestration helper for each test.""" + self.dts_client, self.agent_client = agent_client_factory.create() + self.orch_helper = orchestration_helper + + def test_legitimate_email_drafts_response(self) -> None: + """A legitimate email routes through the email agent and is 'sent'.""" + instance_id = self.dts_client.schedule_new_orchestration( + orchestrator=WORKFLOW_ORCHESTRATOR_NAME, + input=( + "Hi team, just a reminder about our sprint planning meeting tomorrow at 10 AM. " + "Please review the agenda in Jira." + ), + ) + + metadata, output = self.orch_helper.wait_for_orchestration_with_output( + instance_id=instance_id, + timeout=180.0, + ) + + assert metadata.runtime_status == OrchestrationStatus.COMPLETED + assert output is not None + assert "Email sent" in str(output) + + def test_spam_email_handled(self) -> None: + """A spam email routes to the non-agent spam handler.""" + instance_id = self.dts_client.schedule_new_orchestration( + orchestrator=WORKFLOW_ORCHESTRATOR_NAME, + input="URGENT! You've won $1,000,000! Click here now to claim your prize! Limited time offer!", + ) + + metadata, output = self.orch_helper.wait_for_orchestration_with_output( + instance_id=instance_id, + timeout=180.0, + ) + + assert metadata.runtime_status == OrchestrationStatus.COMPLETED + assert output is not None + assert "spam" in str(output).lower() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/python/packages/durabletask/tests/integration_tests/test_09_dt_workflow_hitl.py b/python/packages/durabletask/tests/integration_tests/test_09_dt_workflow_hitl.py new file mode 100644 index 00000000000..6ff54222e79 --- /dev/null +++ b/python/packages/durabletask/tests/integration_tests/test_09_dt_workflow_hitl.py @@ -0,0 +1,108 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Integration tests for the standalone durabletask HITL workflow sample (09_workflow_hitl). + +Exercises the human-in-the-loop workflow path on a standalone durabletask worker: +- The ``InputRouter`` start executor receives a typed ``ContentSubmission`` that the + shared engine reconstructs from the client's JSON payload (no manual parsing). +- An analysis agent produces a recommendation, then the workflow pauses for human + approval via ``request_info``. +- The client retrieves the pending request, replies with ``send_hitl_response``, and + the workflow resumes to an approved/rejected outcome read via ``await_workflow_output``. +""" + +import logging +import time +from typing import Any + +import pytest + +from agent_framework_durabletask import DurableWorkflowClient + +logging.basicConfig(level=logging.WARNING) + +# Module-level markers +pytestmark = [ + pytest.mark.flaky, + pytest.mark.integration, + pytest.mark.sample("09_workflow_hitl"), + pytest.mark.integration_test, + pytest.mark.requires_dts, + pytest.mark.requires_azure_openai, +] + + +def _wait_for_hitl_request( + client: DurableWorkflowClient, instance_id: str, timeout_seconds: int = 90 +) -> list[dict[str, Any]]: + """Poll until the workflow records at least one pending HITL request.""" + deadline = time.time() + timeout_seconds + while time.time() < deadline: + pending = client.get_pending_hitl_requests(instance_id) + if pending: + return pending + time.sleep(2) + raise AssertionError(f"Timed out waiting for a HITL request on instance {instance_id}") + + +class TestStandaloneWorkflowHITL: + """Human-in-the-loop workflow execution on a standalone durabletask worker.""" + + @pytest.fixture(autouse=True) + def setup(self, workflow_client: DurableWorkflowClient) -> None: + """Bind the DurableWorkflowClient for the current sample worker.""" + self.client = workflow_client + + def _run_case(self, submission: dict[str, Any], *, approve: bool) -> Any: + """Start a moderation case, answer the HITL pause, and return the final output.""" + instance_id = self.client.start_workflow(input=submission) + + pending = _wait_for_hitl_request(self.client, instance_id) + request = pending[0] + assert request["request_id"] + assert request["source_executor_id"] + + self.client.send_hitl_response( + instance_id, + request["request_id"], + {"approved": approve, "reviewer_notes": "Looks good." if approve else "Violates content policy."}, + ) + + return self.client.await_workflow_output(instance_id, timeout_seconds=180) + + def test_hitl_workflow_approval(self) -> None: + """Appropriate content is approved after the reviewer says yes.""" + output = self._run_case( + { + "content_id": "article-001", + "title": "Introduction to AI in Healthcare", + "body": ( + "Artificial intelligence is improving healthcare by enabling faster diagnosis, " + "personalized treatment plans, and better patient outcomes." + ), + "author": "Dr. Jane Smith", + }, + approve=True, + ) + + assert output is not None + assert "APPROVED" in str(output).upper() + + def test_hitl_workflow_rejection(self) -> None: + """Spammy content is rejected after the reviewer says no.""" + output = self._run_case( + { + "content_id": "article-002", + "title": "Get Rich Quick", + "body": "Click here NOW to make $10,000 overnight! GUARANTEED! Limited time offer!", + "author": "Definitely Not Spam", + }, + approve=False, + ) + + assert output is not None + assert "REJECTED" in str(output).upper() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/python/packages/durabletask/tests/test_worker.py b/python/packages/durabletask/tests/test_worker.py index e6dabcdfdf0..315690a821e 100644 --- a/python/packages/durabletask/tests/test_worker.py +++ b/python/packages/durabletask/tests/test_worker.py @@ -164,5 +164,62 @@ def test_start_works_with_multiple_agents(self, agent_worker: DurableAIAgentWork assert len(agent_worker.registered_agent_names) == 2 +class TestDurableAIAgentWorkerWorkflow: + """Test workflow registration, including the agent-executor identity fix.""" + + def test_add_agent_with_entity_id_registers_under_override( + self, agent_worker: DurableAIAgentWorker, mock_agent: Mock + ) -> None: + """An explicit entity_id overrides the agent name as the entity identity.""" + agent_worker.add_agent(mock_agent, entity_id="node-7") + + assert "node-7" in agent_worker.registered_agent_names + assert "test_agent" not in agent_worker.registered_agent_names + + def test_configure_workflow_registers_agent_entity_by_executor_id( + self, agent_worker: DurableAIAgentWorker, mock_grpc_worker: Mock + ) -> None: + """Workflow agent executors register entities keyed by executor id. + + The orchestrator dispatches by executor id, so an + ``AgentExecutor(agent, id=...)`` whose id differs from the agent name must + still be reachable. + """ + from agent_framework import AgentExecutor + + agent = Mock() + agent.name = "Reviewer" + agent_executor = Mock(spec=AgentExecutor) + agent_executor.id = "custom-executor-id" + agent_executor.agent = agent + + workflow = Mock() + workflow.executors = {"custom-executor-id": agent_executor} + + agent_worker.configure_workflow(workflow) + + assert "custom-executor-id" in agent_worker.registered_agent_names + assert "Reviewer" not in agent_worker.registered_agent_names + mock_grpc_worker.add_orchestrator.assert_called_once() + + def test_configure_workflow_registers_non_agent_executor_as_activity( + self, agent_worker: DurableAIAgentWorker, mock_grpc_worker: Mock + ) -> None: + """Non-agent executors are registered as activities, not entities.""" + from agent_framework import Executor + + activity_executor = Mock(spec=Executor) + activity_executor.id = "router-node" + + workflow = Mock() + workflow.executors = {"router-node": activity_executor} + + agent_worker.configure_workflow(workflow) + + assert agent_worker.registered_agent_names == [] + mock_grpc_worker.add_activity.assert_called_once() + mock_grpc_worker.add_orchestrator.assert_called_once() + + if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/durabletask/tests/test_workflow_activity.py b/python/packages/durabletask/tests/test_workflow_activity.py new file mode 100644 index 00000000000..b2ff9c159bc --- /dev/null +++ b/python/packages/durabletask/tests/test_workflow_activity.py @@ -0,0 +1,123 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for execute_workflow_activity (shared non-agent executor activity body). + +These tests exercise the host-agnostic activity execution shared by the Azure +Functions and standalone durabletask workflow hosts. In particular they protect +the state snapshot/diff semantics: the snapshot must be a *deep* copy so that +in-place mutations to nested objects (dicts, lists) are correctly detected as +updates (regression guard for the shallow-copy bug, #4500). +""" + +import json +from typing import Any +from unittest.mock import AsyncMock, Mock + +from agent_framework_durabletask import execute_workflow_activity +from agent_framework_durabletask._workflows.orchestrator import SOURCE_ORCHESTRATOR + + +def _make_executor(executor_id: str, mutate: Any) -> Mock: + """Build a mock non-agent executor whose execute() mutates shared state.""" + executor = Mock() + executor.id = executor_id + executor.execute = AsyncMock(side_effect=mutate) + return executor + + +def _run(executor: Mock, snapshot: dict[str, Any]) -> dict[str, Any]: + """Invoke execute_workflow_activity and return the parsed result dict.""" + input_data = json.dumps({ + "message": "test", + "shared_state_snapshot": snapshot, + "source_executor_ids": [SOURCE_ORCHESTRATOR], + }) + return json.loads(execute_workflow_activity(executor, input_data)) + + +class TestExecuteWorkflowActivityStateDiff: + """State snapshot/diff behavior of the shared workflow activity body.""" + + def test_nested_dict_mutation_detected(self) -> None: + """In-place mutation of a nested dict is reported as an update.""" + + async def mutate(message: Any, source_executor_ids: Any, state: Any, runner_context: Any) -> None: + config = state.get("Local.config") + config["code"] = "SOMECODEXXX" + config["enabled"] = True + state.commit() + + executor = _make_executor("test-exec", mutate) + result = _run(executor, {"Local.config": {"code": "", "enabled": False}, "simple_key": "simple_value"}) + + updates = result["shared_state_updates"] + assert "Local.config" in updates, "nested mutation not detected — snapshot may be a shallow copy" + assert updates["Local.config"]["code"] == "SOMECODEXXX" + assert updates["Local.config"]["enabled"] is True + + def test_new_key_in_nested_dict_detected(self) -> None: + """Adding a key to a nested dict is reported as an update.""" + + async def mutate(message: Any, source_executor_ids: Any, state: Any, runner_context: Any) -> None: + state.get("Local.data")["code"] = "NEW_CODE" + state.commit() + + executor = _make_executor("test-exec", mutate) + result = _run(executor, {"Local.data": {"existing": "value"}}) + + assert result["shared_state_updates"]["Local.data"]["code"] == "NEW_CODE" + + def test_nested_list_mutation_detected(self) -> None: + """Appending to a nested list is reported as an update.""" + + async def mutate(message: Any, source_executor_ids: Any, state: Any, runner_context: Any) -> None: + state.get("Local.items").append(4) + state.commit() + + executor = _make_executor("test-exec", mutate) + result = _run(executor, {"Local.items": [1, 2, 3]}) + + assert result["shared_state_updates"]["Local.items"] == [1, 2, 3, 4] + + def test_new_top_level_key_detected(self) -> None: + """Setting a new top-level key is reported as an update.""" + + async def mutate(message: Any, source_executor_ids: Any, state: Any, runner_context: Any) -> None: + state.set("Local.code", "SOMECODEXXX") + state.commit() + + executor = _make_executor("test-exec", mutate) + result = _run(executor, {"existing": "value"}) + + assert result["shared_state_updates"]["Local.code"] == "SOMECODEXXX" + + def test_unchanged_state_produces_empty_diff(self) -> None: + """Unmodified state produces no updates.""" + + async def mutate(message: Any, source_executor_ids: Any, state: Any, runner_context: Any) -> None: + # No mutations performed. + state.commit() + + executor = _make_executor("test-exec", mutate) + result = _run(executor, {"Local.config": {"code": "existing", "enabled": True}, "simple_key": "v"}) + + assert result["shared_state_updates"] == {} + + def test_deleted_key_reported(self) -> None: + """A key removed during execution is reported as a delete.""" + + async def mutate(message: Any, source_executor_ids: Any, state: Any, runner_context: Any) -> None: + state.delete("to_remove") + state.commit() + + executor = _make_executor("test-exec", mutate) + result = _run(executor, {"to_remove": "value", "keep": "value"}) + + assert "to_remove" in result["shared_state_deletes"] + assert "keep" not in result["shared_state_deletes"] + + +if __name__ == "__main__": + import pytest + + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/durabletask/tests/test_workflow_client.py b/python/packages/durabletask/tests/test_workflow_client.py new file mode 100644 index 00000000000..434a25a7074 --- /dev/null +++ b/python/packages/durabletask/tests/test_workflow_client.py @@ -0,0 +1,274 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for DurableWorkflowClient. + +Covers starting workflows, awaiting output (including error/timeout paths), +parsing pending human-in-the-loop (HITL) requests from custom status, and +sanitizing HITL responses before delivery. +""" + +import json +from dataclasses import dataclass +from unittest.mock import Mock + +import pytest + +from agent_framework_durabletask import DurableWorkflowClient +from agent_framework_durabletask._workflows.orchestrator import WORKFLOW_ORCHESTRATOR_NAME +from agent_framework_durabletask._workflows.serialization import serialize_value + + +@dataclass +class _Receipt: + """Module-level dataclass so it is picklable by serialize_value.""" + + order_id: int + total: float + + +@pytest.fixture +def mock_client() -> Mock: + """Create a mock TaskHubGrpcClient.""" + return Mock() + + +@pytest.fixture +def workflow_client(mock_client: Mock) -> DurableWorkflowClient: + """Create a DurableWorkflowClient wrapping the mock client.""" + return DurableWorkflowClient(mock_client) + + +class TestStartWorkflow: + """Test starting workflow orchestrations.""" + + def test_start_workflow_schedules_orchestrator( + self, workflow_client: DurableWorkflowClient, mock_client: Mock + ) -> None: + """start_workflow schedules the auto-registered orchestrator by name.""" + mock_client.schedule_new_orchestration.return_value = "instance-1" + + result = workflow_client.start_workflow(input="hello") + + assert result == "instance-1" + mock_client.schedule_new_orchestration.assert_called_once_with( + WORKFLOW_ORCHESTRATOR_NAME, input="hello", instance_id=None + ) + + def test_start_workflow_passes_non_string_input_unchanged( + self, workflow_client: DurableWorkflowClient, mock_client: Mock + ) -> None: + """Non-string payloads are forwarded as-is (no string coercion).""" + mock_client.schedule_new_orchestration.return_value = "instance-2" + payload = {"order_id": 42, "items": ["a", "b"]} + + workflow_client.start_workflow(input=payload) + + _, kwargs = mock_client.schedule_new_orchestration.call_args + assert kwargs["input"] == payload + + def test_start_workflow_forwards_instance_id( + self, workflow_client: DurableWorkflowClient, mock_client: Mock + ) -> None: + """An explicit instance id is forwarded to the underlying client.""" + mock_client.schedule_new_orchestration.return_value = "explicit-id" + + workflow_client.start_workflow(input="x", instance_id="explicit-id") + + _, kwargs = mock_client.schedule_new_orchestration.call_args + assert kwargs["instance_id"] == "explicit-id" + + +class TestAwaitWorkflowOutput: + """Test awaiting workflow completion and output.""" + + def test_returns_deserialized_output_on_completion( + self, workflow_client: DurableWorkflowClient, mock_client: Mock + ) -> None: + """A COMPLETED workflow returns its deserialized output.""" + metadata = Mock() + metadata.runtime_status.name = "COMPLETED" + metadata.serialized_output = json.dumps(["result"]) + mock_client.wait_for_orchestration_completion.return_value = metadata + + output = workflow_client.await_workflow_output("instance-1") + + assert output == ["result"] + + def test_returns_none_when_no_output(self, workflow_client: DurableWorkflowClient, mock_client: Mock) -> None: + """A COMPLETED workflow with no output returns None.""" + metadata = Mock() + metadata.runtime_status.name = "COMPLETED" + metadata.serialized_output = None + mock_client.wait_for_orchestration_completion.return_value = metadata + + assert workflow_client.await_workflow_output("instance-1") is None + + def test_reconstructs_typed_outputs(self, workflow_client: DurableWorkflowClient, mock_client: Mock) -> None: + """Typed outputs encoded by the activity come back as objects, not marker dicts.""" + receipt = _Receipt(order_id=7, total=19.99) + # The shared activity stores each yielded output via serialize_value(), so a + # typed object is persisted as a checkpoint-marker dict. + metadata = Mock() + metadata.runtime_status.name = "COMPLETED" + metadata.serialized_output = json.dumps([serialize_value(receipt)]) + mock_client.wait_for_orchestration_completion.return_value = metadata + + output = workflow_client.await_workflow_output("instance-1") + + assert output == [receipt] + assert isinstance(output[0], _Receipt) + + def test_raises_timeout_when_not_completed(self, workflow_client: DurableWorkflowClient, mock_client: Mock) -> None: + """A None metadata (no completion) raises TimeoutError.""" + mock_client.wait_for_orchestration_completion.return_value = None + + with pytest.raises(TimeoutError, match="did not complete"): + workflow_client.await_workflow_output("instance-1", timeout_seconds=5) + + def test_raises_runtime_error_on_failed_status( + self, workflow_client: DurableWorkflowClient, mock_client: Mock + ) -> None: + """A non-COMPLETED status raises RuntimeError.""" + metadata = Mock() + metadata.runtime_status.name = "FAILED" + metadata.serialized_output = "boom" + mock_client.wait_for_orchestration_completion.return_value = metadata + + with pytest.raises(RuntimeError, match="status FAILED"): + workflow_client.await_workflow_output("instance-1") + + +class TestGetRuntimeStatus: + """Test reading the workflow's runtime status.""" + + def test_returns_status_name(self, workflow_client: DurableWorkflowClient, mock_client: Mock) -> None: + """The runtime status name is returned when state is available.""" + state = Mock() + state.runtime_status.name = "RUNNING" + mock_client.get_orchestration_state.return_value = state + + assert workflow_client.get_runtime_status("instance-1") == "RUNNING" + + def test_returns_none_when_no_state(self, workflow_client: DurableWorkflowClient, mock_client: Mock) -> None: + """No orchestration state yields None (status unknown).""" + mock_client.get_orchestration_state.return_value = None + + assert workflow_client.get_runtime_status("instance-1") is None + + +class TestGetPendingHitlRequests: + """Test parsing pending HITL requests from custom status.""" + + def _state_with_status(self, status: object) -> Mock: + state = Mock() + state.serialized_custom_status = json.dumps(status) if status is not None else None + return state + + def test_returns_empty_when_no_state(self, workflow_client: DurableWorkflowClient, mock_client: Mock) -> None: + """No orchestration state yields an empty list.""" + mock_client.get_orchestration_state.return_value = None + + assert workflow_client.get_pending_hitl_requests("instance-1") == [] + + def test_returns_empty_when_status_blank(self, workflow_client: DurableWorkflowClient, mock_client: Mock) -> None: + """A blank custom status yields an empty list.""" + state = Mock() + state.serialized_custom_status = "" + mock_client.get_orchestration_state.return_value = state + + assert workflow_client.get_pending_hitl_requests("instance-1") == [] + + def test_returns_empty_on_invalid_json(self, workflow_client: DurableWorkflowClient, mock_client: Mock) -> None: + """Malformed custom status JSON yields an empty list.""" + state = Mock() + state.serialized_custom_status = "{not-json" + mock_client.get_orchestration_state.return_value = state + + assert workflow_client.get_pending_hitl_requests("instance-1") == [] + + def test_parses_pending_requests(self, workflow_client: DurableWorkflowClient, mock_client: Mock) -> None: + """Pending requests are normalized into the documented shape.""" + status = { + "pending_requests": { + "req-1": { + "request_id": "req-1", + "source_executor_id": "approver", + "data": {"prompt": "approve?"}, + "request_type": "ApprovalRequest", + "response_type": "ApprovalResponse", + } + } + } + mock_client.get_orchestration_state.return_value = self._state_with_status(status) + + requests = workflow_client.get_pending_hitl_requests("instance-1") + + assert requests == [ + { + "request_id": "req-1", + "source_executor_id": "approver", + "data": {"prompt": "approve?"}, + "request_type": "ApprovalRequest", + "response_type": "ApprovalResponse", + } + ] + + def test_falls_back_to_dict_key_for_request_id( + self, workflow_client: DurableWorkflowClient, mock_client: Mock + ) -> None: + """When a request omits request_id, the dict key is used.""" + status = {"pending_requests": {"req-key": {"source_executor_id": "x"}}} + mock_client.get_orchestration_state.return_value = self._state_with_status(status) + + requests = workflow_client.get_pending_hitl_requests("instance-1") + + assert requests[0]["request_id"] == "req-key" + + def test_ignores_non_dict_entries(self, workflow_client: DurableWorkflowClient, mock_client: Mock) -> None: + """Non-dict request entries are skipped.""" + status = {"pending_requests": {"req-1": "not-a-dict"}} + mock_client.get_orchestration_state.return_value = self._state_with_status(status) + + assert workflow_client.get_pending_hitl_requests("instance-1") == [] + + def test_returns_empty_when_pending_not_dict( + self, workflow_client: DurableWorkflowClient, mock_client: Mock + ) -> None: + """A non-dict pending_requests field yields an empty list.""" + status = {"pending_requests": ["unexpected"]} + mock_client.get_orchestration_state.return_value = self._state_with_status(status) + + assert workflow_client.get_pending_hitl_requests("instance-1") == [] + + +class TestSendHitlResponse: + """Test delivering HITL responses.""" + + def test_raises_orchestration_event_with_request_id( + self, workflow_client: DurableWorkflowClient, mock_client: Mock + ) -> None: + """The response is delivered as an external event named by request id.""" + workflow_client.send_hitl_response("instance-1", "req-1", {"approved": True}) + + mock_client.raise_orchestration_event.assert_called_once() + _, kwargs = mock_client.raise_orchestration_event.call_args + assert kwargs["event_name"] == "req-1" + assert kwargs["data"] == {"approved": True} + + def test_strips_pickle_markers_before_delivery( + self, workflow_client: DurableWorkflowClient, mock_client: Mock + ) -> None: + """A crafted pickle-marker payload is neutralized before reaching the worker. + + The HITL response is sent to the worker which deserializes it, so a payload + carrying the checkpoint ``__pickled__`` marker must be stripped client-side + (regression guard for the strip_pickle_markers call in send_hitl_response). + """ + malicious = {"__pickled__": "", "approved": True} + + workflow_client.send_hitl_response("instance-1", "req-1", malicious) + + _, kwargs = mock_client.raise_orchestration_event.call_args + # The whole marker-bearing dict is neutralized (replaced with None) rather + # than forwarded, so it can never reach pickle.loads on the worker. + assert kwargs["data"] is None diff --git a/python/packages/durabletask/tests/test_workflow_input_coercion.py b/python/packages/durabletask/tests/test_workflow_input_coercion.py new file mode 100644 index 00000000000..2b785a145d3 --- /dev/null +++ b/python/packages/durabletask/tests/test_workflow_input_coercion.py @@ -0,0 +1,135 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for workflow initial-input coercion (`_coerce_initial_input`). + +A durable workflow runs as a durable orchestration, so its initial payload +arrives as plain JSON (no type markers). The shared engine reconstructs the +start executor's declared input type from that JSON, mirroring in-process +delivery. These tests pin that behavior across the relevant start-executor +shapes. +""" + +import json +from dataclasses import dataclass +from unittest.mock import Mock + +from agent_framework import AgentExecutor, Executor, WorkflowContext, handler +from pydantic import BaseModel + +from agent_framework_durabletask._workflows.orchestrator import _coerce_initial_input + + +@dataclass +class _Submission: + content_id: str + title: str + + +class _SubmissionModel(BaseModel): + content_id: str + title: str + + +class _StrStart(Executor): + def __init__(self) -> None: + super().__init__(id="str_start") + + @handler + async def run(self, message: str, ctx: WorkflowContext) -> None: # pragma: no cover - never invoked + ... + + +class _DataclassStart(Executor): + def __init__(self) -> None: + super().__init__(id="dc_start") + + @handler + async def run(self, message: _Submission, ctx: WorkflowContext) -> None: # pragma: no cover - never invoked + ... + + +class _PydanticStart(Executor): + def __init__(self) -> None: + super().__init__(id="pyd_start") + + @handler + async def run(self, message: _SubmissionModel, ctx: WorkflowContext) -> None: # pragma: no cover - never invoked + ... + + +def _workflow_with(executor: Executor | Mock) -> Mock: + workflow = Mock() + workflow.executors = {executor.id: executor} + workflow.start_executor_id = executor.id + return workflow + + +class TestCoerceInitialInput: + """Test reconstruction of the initial workflow input by start-executor type.""" + + def test_str_start_passes_string_through(self) -> None: + workflow = _workflow_with(_StrStart()) + + assert _coerce_initial_input(workflow, "hello world") == "hello world" + + def test_dataclass_start_reconstructs_from_dict(self) -> None: + workflow = _workflow_with(_DataclassStart()) + + result = _coerce_initial_input(workflow, {"content_id": "x", "title": "T"}) + + assert isinstance(result, _Submission) + assert result.content_id == "x" + assert result.title == "T" + + def test_pydantic_start_reconstructs_from_dict(self) -> None: + workflow = _workflow_with(_PydanticStart()) + + result = _coerce_initial_input(workflow, {"content_id": "x", "title": "T"}) + + assert isinstance(result, _SubmissionModel) + assert result.content_id == "x" + + def test_str_start_leaves_dict_unchanged(self) -> None: + """A str-typed start executor declares text; a dict is not coerced to str.""" + workflow = _workflow_with(_StrStart()) + payload = {"content_id": "x"} + + assert _coerce_initial_input(workflow, payload) == payload + + def test_agent_start_passes_string_through(self) -> None: + agent_executor = Mock(spec=AgentExecutor) + agent_executor.id = "agent" + workflow = _workflow_with(agent_executor) + + assert _coerce_initial_input(workflow, "draft this email") == "draft this email" + + def test_agent_start_stringifies_dict(self) -> None: + """Agents only consume text, so a structured payload is serialized to text.""" + agent_executor = Mock(spec=AgentExecutor) + agent_executor.id = "agent" + workflow = _workflow_with(agent_executor) + + result = _coerce_initial_input(workflow, {"email": "hi"}) + + assert result == json.dumps({"email": "hi"}) + + def test_missing_start_executor_passes_through(self) -> None: + workflow = Mock() + workflow.executors = {} + workflow.start_executor_id = "missing" + payload = {"a": 1} + + assert _coerce_initial_input(workflow, payload) == payload + + def test_pickle_marker_injection_is_neutralized(self) -> None: + """A crafted pickle-marker payload is stripped before reconstruction (no pickle RCE). + + The initial workflow input is untrusted, so a dict carrying the checkpoint + ``__pickled__`` marker must be neutralized rather than flowing into + ``deserialize_value`` (which would ``pickle.loads`` it). + """ + workflow = _workflow_with(_DataclassStart()) + malicious = {"__pickled__": "", "content_id": "x", "title": "T"} + + # The marker-bearing dict is replaced with None, never unpickled or reconstructed. + assert _coerce_initial_input(workflow, malicious) is None diff --git a/python/packages/durabletask/tests/test_workflow_registration.py b/python/packages/durabletask/tests/test_workflow_registration.py new file mode 100644 index 00000000000..f9cb9e190ee --- /dev/null +++ b/python/packages/durabletask/tests/test_workflow_registration.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for plan_workflow_registration. + +Verifies the host-agnostic decision of which executors become durable entities +(agent executors) versus durable activities (everything else), and that agent +executors are carried whole so each host can register entities under the +executor id the orchestrator dispatches to. +""" + +from unittest.mock import Mock + +from agent_framework import AgentExecutor, Executor + +from agent_framework_durabletask import WorkflowRegistrationPlan, plan_workflow_registration +from agent_framework_durabletask._workflows.orchestrator import WORKFLOW_ORCHESTRATOR_NAME + + +def _agent_executor(executor_id: str, agent_name: str) -> Mock: + agent = Mock() + agent.name = agent_name + executor = Mock(spec=AgentExecutor) + executor.id = executor_id + executor.agent = agent + return executor + + +def _activity_executor(executor_id: str) -> Mock: + executor = Mock(spec=Executor) + executor.id = executor_id + return executor + + +class TestPlanWorkflowRegistration: + """Test classification of workflow executors into durable primitives.""" + + def test_agent_executor_classified_as_entity(self) -> None: + """An AgentExecutor is carried whole in agent_executors.""" + agent_exec = _agent_executor("reviewer-node", "Reviewer") + workflow = Mock() + workflow.executors = {"reviewer-node": agent_exec} + + plan = plan_workflow_registration(workflow) + + assert plan.agent_executors == [agent_exec] + assert plan.activity_executors == [] + assert plan.orchestrator_name == WORKFLOW_ORCHESTRATOR_NAME + + def test_non_agent_executor_classified_as_activity(self) -> None: + """A plain Executor is classified as an activity.""" + activity_exec = _activity_executor("router-node") + workflow = Mock() + workflow.executors = {"router-node": activity_exec} + + plan = plan_workflow_registration(workflow) + + assert plan.agent_executors == [] + assert plan.activity_executors == [activity_exec] + + def test_mixed_executors_are_partitioned(self) -> None: + """Agent and non-agent executors are split into the correct buckets.""" + agent_exec = _agent_executor("agent-node", "Agent") + activity_exec = _activity_executor("activity-node") + workflow = Mock() + workflow.executors = {"agent-node": agent_exec, "activity-node": activity_exec} + + plan = plan_workflow_registration(workflow) + + assert plan.agent_executors == [agent_exec] + assert plan.activity_executors == [activity_exec] + + def test_agent_executor_id_is_preserved_when_distinct_from_name(self) -> None: + """The plan keeps the executor (and its id), not just the bare agent. + + This is the core of the identity fix: dispatch targets the executor id, + so registration must be able to use the id even when it differs from + ``agent.name``. + """ + agent_exec = _agent_executor("custom-executor-id", "ReusedAgentName") + workflow = Mock() + workflow.executors = {"custom-executor-id": agent_exec} + + plan = plan_workflow_registration(workflow) + + assert plan.agent_executors[0].id == "custom-executor-id" + assert plan.agent_executors[0].agent.name == "ReusedAgentName" + + def test_returns_workflow_registration_plan(self) -> None: + """The return value is a WorkflowRegistrationPlan.""" + workflow = Mock() + workflow.executors = {} + + plan = plan_workflow_registration(workflow) + + assert isinstance(plan, WorkflowRegistrationPlan) + assert plan.agent_executors == [] + assert plan.activity_executors == [] diff --git a/python/packages/durabletask/tests/test_workflow_routing.py b/python/packages/durabletask/tests/test_workflow_routing.py new file mode 100644 index 00000000000..079e0b69909 --- /dev/null +++ b/python/packages/durabletask/tests/test_workflow_routing.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for synchronous edge-condition evaluation on the durabletask host. + +Durable orchestrators run as generators and evaluate edge conditions +synchronously. A condition that returns an awaitable cannot be evaluated in +that context, so the edge is treated as *not matched* (not traversed). +""" + +from agent_framework._workflows._edge import Edge # pyright: ignore[reportPrivateImportUsage] + +from agent_framework_durabletask._workflows.orchestrator import _evaluate_edge_condition_sync + + +class TestEvaluateEdgeConditionSync: + """Synchronous edge-condition evaluation semantics.""" + + def test_no_condition_traverses(self) -> None: + edge = Edge("a", "b") + assert _evaluate_edge_condition_sync(edge, {"x": 1}) is True + + def test_sync_true_traverses(self) -> None: + edge = Edge("a", "b", condition=lambda m: m["ok"]) + assert _evaluate_edge_condition_sync(edge, {"ok": True}) is True + + def test_sync_false_does_not_traverse(self) -> None: + edge = Edge("a", "b", condition=lambda m: m["ok"]) + assert _evaluate_edge_condition_sync(edge, {"ok": False}) is False + + def test_async_condition_is_not_traversed(self) -> None: + # The durabletask host evaluates conditions synchronously; an async + # condition cannot be evaluated, so the edge is treated as not matched + # even though it would resolve True when awaited. + async def gate(_message: object) -> bool: + return True + + edge = Edge("a", "b", condition=gate) + assert _evaluate_edge_condition_sync(edge, {"x": 1}) is False diff --git a/python/packages/durabletask/tests/test_workflow_serialization.py b/python/packages/durabletask/tests/test_workflow_serialization.py new file mode 100644 index 00000000000..246e2e25968 --- /dev/null +++ b/python/packages/durabletask/tests/test_workflow_serialization.py @@ -0,0 +1,73 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for workflow serialization helpers. + +``resolve_type`` is annotated ``type | None`` and its result flows into +``reconstruct_to_type``, which calls ``issubclass``. A non-class attribute +(function, module member, etc.) would raise ``TypeError`` there, so the +resolver must only ever return actual classes. + +``deserialize_workflow_output`` reverses the per-output ``serialize_value`` +encoding the shared activity applies, so typed outputs are returned as the +original objects rather than checkpoint-marker dicts. +""" + +import json +from collections import OrderedDict +from dataclasses import dataclass + +from agent_framework_durabletask._workflows.serialization import ( + deserialize_workflow_output, + resolve_type, + serialize_value, +) + + +@dataclass +class _Decision: + """Module-level dataclass so it is picklable by serialize_value.""" + + approved: bool + note: str + + +class TestResolveType: + """Test that resolve_type only returns real classes.""" + + def test_resolves_a_real_class(self) -> None: + assert resolve_type("collections:OrderedDict") is OrderedDict + + def test_returns_none_for_non_class_attribute(self) -> None: + # json.dumps is a function; if resolve_type returned it, issubclass() + # inside reconstruct_to_type() would raise TypeError at runtime. + assert resolve_type("json:dumps") is None + + def test_returns_none_for_unknown_attribute(self) -> None: + assert resolve_type("json:DoesNotExist") is None + + def test_returns_none_for_malformed_key(self) -> None: + assert resolve_type("not-a-valid-key") is None + + +class TestDeserializeWorkflowOutput: + """Reconstruction of stored workflow outputs.""" + + def test_primitives_pass_through(self) -> None: + # Mirror the stored shape: a list of yielded outputs, JSON round-tripped. + stored = json.loads(json.dumps([serialize_value("hello"), serialize_value(42)])) + + assert deserialize_workflow_output(stored) == ["hello", 42] + + def test_typed_outputs_are_reconstructed(self) -> None: + # A typed object is stored as a checkpoint-marker dict; it must come back + # as the original object, not the marker dict. + decision = _Decision(approved=True, note="ok") + stored = json.loads(json.dumps([serialize_value(decision)])) + + result = deserialize_workflow_output(stored) + + assert result == [decision] + assert isinstance(result[0], _Decision) + + def test_none_passes_through(self) -> None: + assert deserialize_workflow_output(None) is None diff --git a/python/samples/04-hosting/azure_functions/11_workflow_parallel/function_app.py b/python/samples/04-hosting/azure_functions/11_workflow_parallel/function_app.py index 0669d95e7b1..49ac41ffb19 100644 --- a/python/samples/04-hosting/azure_functions/11_workflow_parallel/function_app.py +++ b/python/samples/04-hosting/azure_functions/11_workflow_parallel/function_app.py @@ -26,7 +26,6 @@ - Ensure Azurite and the Durable Task Scheduler emulator are running """ -import json import logging import os from dataclasses import dataclass @@ -142,17 +141,12 @@ class FinalReport: @executor(id="input_router") -async def input_router(doc: str, ctx: WorkflowContext[DocumentInput]) -> None: - """Route input document to parallel processors. +async def input_router(document: DocumentInput, ctx: WorkflowContext[DocumentInput]) -> None: + """Route the input document to the parallel processors. - Accepts a JSON string from the HTTP request and converts to DocumentInput. + The durable engine reconstructs ``DocumentInput`` from the client's JSON + payload before delivery, mirroring in-process execution. """ - # Parse the JSON string input - data = json.loads(doc) if isinstance(doc, str) else doc - document = DocumentInput( - document_id=data.get("document_id", "unknown"), - content=data.get("content", ""), - ) logger.info("[input_router] Routing document: %s", document.document_id) await ctx.send_message(document) diff --git a/python/samples/04-hosting/azure_functions/12_workflow_hitl/function_app.py b/python/samples/04-hosting/azure_functions/12_workflow_hitl/function_app.py index e1f9389a6da..f8245d2381c 100644 --- a/python/samples/04-hosting/azure_functions/12_workflow_hitl/function_app.py +++ b/python/samples/04-hosting/azure_functions/12_workflow_hitl/function_app.py @@ -23,7 +23,6 @@ - Authentication via Azure CLI (az login) """ -import json import logging import os from dataclasses import dataclass @@ -332,19 +331,14 @@ def __init__(self): @handler async def route_input( self, - input_json: str, + submission: ContentSubmission, ctx: WorkflowContext[AgentExecutorRequest], ) -> None: - """Parse input and create agent request.""" - data = json.loads(input_json) if isinstance(input_json, str) else input_json - - submission = ContentSubmission( - content_id=data.get("content_id", "unknown"), - title=data.get("title", "Untitled"), - body=data.get("body", ""), - author=data.get("author", "Anonymous"), - ) + """Create the agent request from the submitted content. + The durable engine reconstructs this ``ContentSubmission`` from the + client's JSON payload before delivery, mirroring in-process execution. + """ # Store submission in shared state for later retrieval ctx.set_state("current_submission", submission) diff --git a/python/samples/04-hosting/durabletask/08_workflow/README.md b/python/samples/04-hosting/durabletask/08_workflow/README.md new file mode 100644 index 00000000000..0e032693776 --- /dev/null +++ b/python/samples/04-hosting/durabletask/08_workflow/README.md @@ -0,0 +1,57 @@ +# Workflow on a Standalone Durable Task Worker + +This sample demonstrates running an agent-framework `Workflow` as a durable +orchestration on a **standalone Durable Task worker** — no Azure Functions +required. It is the durabletask counterpart to the Azure Functions workflow +samples (`samples/04-hosting/azure_functions/10_workflow_no_shared_state`). + +## Key Concepts Demonstrated + +- Hosting a MAF `Workflow` outside Azure Functions via + `DurableAIAgentWorker.configure_workflow(workflow)`, which auto-registers: + - a durable **entity** for each agent executor, + - a durable **activity** for each non-agent executor, and + - the **workflow orchestrator** (registered as `WORKFLOW_ORCHESTRATOR_NAME`). +- Conditional routing with `add_switch_case_edge_group` (spam vs. legitimate email). +- Mixing AI agents with non-agent executors in one workflow graph. +- Starting the workflow from a client with + `DurableWorkflowClient.start_workflow(input=...)` and reading its result with + `await_workflow_output(instance_id)`. + +## Environment Setup + +See the [README.md](../README.md) in the parent directory for environment setup. + +This sample uses Azure AI Foundry credentials: + +- `FOUNDRY_PROJECT_ENDPOINT` +- `FOUNDRY_MODEL` + +It also needs a Durable Task Scheduler. For local development, start the +emulator (defaults to `http://localhost:8080`): + +```bash +docker run -d -p 8080:8080 -p 8082:8082 mcr.microsoft.com/dts/dts-emulator:latest +``` + +## Running the Sample + +Start the worker in one terminal: + +```bash +cd samples/04-hosting/durabletask/08_workflow +python worker.py +``` + +In a second terminal, run the client: + +```bash +python client.py +``` + +The client runs two cases: + +- **Legitimate email** → `SpamDetectionAgent` → `EmailAssistantAgent` → + `email_sender` → `"Email sent: ..."`. +- **Spam email** → `SpamDetectionAgent` → `spam_handler` → + `"Email marked as spam: ..."`. diff --git a/python/samples/04-hosting/durabletask/08_workflow/client.py b/python/samples/04-hosting/durabletask/08_workflow/client.py new file mode 100644 index 00000000000..7f8a40ff841 --- /dev/null +++ b/python/samples/04-hosting/durabletask/08_workflow/client.py @@ -0,0 +1,73 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Client that starts the standalone workflow orchestration and prints the result. + +The worker (``worker.py``) must be running first. The workflow is started via +``DurableWorkflowClient.start_workflow`` - which schedules the orchestrator that +``DurableAIAgentWorker.configure_workflow`` auto-registers, so the caller never +needs to know its internal name. + +Prerequisites: +- ``worker.py`` running and connected to the same Durable Task Scheduler. +- A Durable Task Scheduler reachable at ``ENDPOINT`` (default ``http://localhost:8080``). +""" + +import asyncio +import logging +import os + +from agent_framework.azure import DurableWorkflowClient +from azure.identity import AzureCliCredential +from dotenv import load_dotenv +from durabletask.azuremanaged.client import DurableTaskSchedulerClient + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def get_client(taskhub: str | None = None, endpoint: str | None = None) -> DurableTaskSchedulerClient: + """Create a configured DurableTaskSchedulerClient.""" + taskhub_name = taskhub or os.getenv("TASKHUB", "default") + endpoint_url = endpoint or os.getenv("ENDPOINT", "http://localhost:8080") + + credential = None if endpoint_url == "http://localhost:8080" else AzureCliCredential() + + return DurableTaskSchedulerClient( + host_address=endpoint_url, + secure_channel=endpoint_url != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential, + ) + + +def run_workflow(client: DurableWorkflowClient, email_content: str) -> None: + """Start the workflow with an email and wait for the result.""" + instance_id = client.start_workflow(input=email_content) + logger.info("Started workflow instance: %s", instance_id) + + output = client.await_workflow_output(instance_id) + logger.info("Workflow output: %s", output) + + +async def main() -> None: + """Run the workflow against a legitimate email and a spam email.""" + client = DurableWorkflowClient(get_client()) + + logger.info("TEST 1: Legitimate email") + run_workflow( + client, + "Hi team, just a reminder about our sprint planning meeting tomorrow at 10 AM. " + "Please review the agenda in Jira.", + ) + + logger.info("TEST 2: Spam email") + run_workflow( + client, + "URGENT! You've won $1,000,000! Click here now to claim your prize! Limited time offer!", + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/04-hosting/durabletask/08_workflow/worker.py b/python/samples/04-hosting/durabletask/08_workflow/worker.py new file mode 100644 index 00000000000..48cd2c58b2e --- /dev/null +++ b/python/samples/04-hosting/durabletask/08_workflow/worker.py @@ -0,0 +1,214 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Worker that hosts a MAF Workflow as a durable orchestration (no Azure Functions). + +This sample shows how to run an agent-framework ``Workflow`` on a standalone +Durable Task worker using ``DurableAIAgentWorker.configure_workflow``. The worker +auto-registers: + +- a durable entity for each agent executor, +- a durable activity for each non-agent executor, and +- the workflow orchestrator (named ``WORKFLOW_ORCHESTRATOR_NAME``). + +The workflow classifies an email and conditionally routes it: spam is handled by +a non-agent executor, while legitimate email is drafted by a second agent and +"sent" by another non-agent executor. + +Prerequisites: +- Set ``FOUNDRY_PROJECT_ENDPOINT`` and ``FOUNDRY_MODEL``. +- Sign in with Azure CLI (``az login``) for ``AzureCliCredential``. +- Start a Durable Task Scheduler (e.g. the DTS emulator on ``localhost:8080``). + +Run the worker (this process), then run ``client.py`` in another process. +""" + +import asyncio +import logging +import os +from typing import Any + +from agent_framework import ( + Agent, + AgentExecutorResponse, + Case, + Default, + Executor, + Workflow, + WorkflowBuilder, + WorkflowContext, + handler, +) +from agent_framework.azure import DurableAIAgentWorker +from agent_framework.foundry import FoundryChatClient +from azure.identity import AzureCliCredential +from azure.identity.aio import AzureCliCredential as AsyncAzureCliCredential +from dotenv import load_dotenv +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker +from pydantic import BaseModel, ValidationError +from typing_extensions import Never + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +SPAM_AGENT_NAME = "SpamDetectionAgent" +EMAIL_AGENT_NAME = "EmailAssistantAgent" + +SPAM_DETECTION_INSTRUCTIONS = ( + "You are a spam detection assistant that identifies spam emails. " + "Return JSON with fields is_spam (bool) and reason (string)." +) +EMAIL_ASSISTANT_INSTRUCTIONS = ( + "You are an email assistant that drafts professional replies to legitimate emails. " + "Return JSON with a single field 'response' containing the drafted reply." +) + + +class SpamDetectionResult(BaseModel): + """Structured output from the spam detection agent.""" + + is_spam: bool + reason: str + + +class EmailResponse(BaseModel): + """Structured output from the email assistant agent.""" + + response: str + + +class SpamHandlerExecutor(Executor): + """Non-agent executor that finalizes spam emails.""" + + @handler + async def handle_spam_result( + self, agent_response: AgentExecutorResponse, ctx: WorkflowContext[Never, str] + ) -> None: + text = agent_response.agent_response.text + try: + result = SpamDetectionResult.model_validate_json(text) + reason = result.reason + except ValidationError: + reason = "Invalid JSON from agent" + await ctx.yield_output(f"Email marked as spam: {reason}") + + +class EmailSenderExecutor(Executor): + """Non-agent executor that 'sends' the drafted reply.""" + + @handler + async def handle_email_response( + self, agent_response: AgentExecutorResponse, ctx: WorkflowContext[Never, str] + ) -> None: + text = agent_response.agent_response.text + try: + email = EmailResponse.model_validate_json(text) + reply = email.response + except ValidationError: + reply = "Error generating response." + await ctx.yield_output(f"Email sent: {reply}") + + +def is_spam_detected(message: Any) -> bool: + """Routing condition: True when the spam agent flagged the email as spam.""" + if not isinstance(message, AgentExecutorResponse): + return False + try: + return SpamDetectionResult.model_validate_json(message.agent_response.text).is_spam + except Exception: + return False + + +def _create_chat_client() -> FoundryChatClient: + """Create an Azure AI Foundry chat client using AzureCliCredential.""" + return FoundryChatClient( + project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"], + model=os.environ["FOUNDRY_MODEL"], + credential=AsyncAzureCliCredential(), + ) + + +def create_workflow() -> Workflow: + """Build the conditional spam-detection workflow.""" + chat_client = _create_chat_client() + + spam_agent = Agent( + client=chat_client, + name=SPAM_AGENT_NAME, + instructions=SPAM_DETECTION_INSTRUCTIONS, + default_options={"response_format": SpamDetectionResult}, + ) + email_agent = Agent( + client=chat_client, + name=EMAIL_AGENT_NAME, + instructions=EMAIL_ASSISTANT_INSTRUCTIONS, + default_options={"response_format": EmailResponse}, + ) + + spam_handler = SpamHandlerExecutor(id="spam_handler") + email_sender = EmailSenderExecutor(id="email_sender") + + return ( + WorkflowBuilder(start_executor=spam_agent) + .add_switch_case_edge_group( + spam_agent, + [ + Case(condition=is_spam_detected, target=spam_handler), + Default(target=email_agent), + ], + ) + .add_edge(email_agent, email_sender) + .build() + ) + + +def get_worker( + taskhub: str | None = None, endpoint: str | None = None, log_handler: logging.Handler | None = None +) -> DurableTaskSchedulerWorker: + """Create a configured DurableTaskSchedulerWorker.""" + taskhub_name = taskhub or os.getenv("TASKHUB", "default") + endpoint_url = endpoint or os.getenv("ENDPOINT", "http://localhost:8080") + + credential = None if endpoint_url == "http://localhost:8080" else AzureCliCredential() + + return DurableTaskSchedulerWorker( + host_address=endpoint_url, + secure_channel=endpoint_url != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential, + log_handler=log_handler, + ) + + +def setup_worker(worker: DurableTaskSchedulerWorker) -> DurableAIAgentWorker: + """Register the workflow (agents + activities + orchestrator) on the worker.""" + agent_worker = DurableAIAgentWorker(worker) + + workflow = create_workflow() + # One call wires up: agent entities, non-agent executor activities, and the + # workflow orchestrator (registered as WORKFLOW_ORCHESTRATOR_NAME). + agent_worker.configure_workflow(workflow) + logger.info("✓ Configured workflow with %d executors", len(workflow.executors)) + + return agent_worker + + +async def main() -> None: + """Start the worker and block until interrupted.""" + worker = get_worker() + setup_worker(worker) + + logger.info("Worker is ready and listening for work items. Press Ctrl+C to stop.") + try: + worker.start() + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + logger.info("Worker shutdown initiated") + + logger.info("Worker stopped") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/04-hosting/durabletask/09_workflow_hitl/README.md b/python/samples/04-hosting/durabletask/09_workflow_hitl/README.md new file mode 100644 index 00000000000..653535bc8cd --- /dev/null +++ b/python/samples/04-hosting/durabletask/09_workflow_hitl/README.md @@ -0,0 +1,78 @@ +# Human-in-the-Loop Workflow on a Standalone Durable Task Worker + +This sample demonstrates a Human-in-the-Loop (HITL) agent-framework `Workflow` +running as a durable orchestration on a **standalone Durable Task worker** — no +Azure Functions required. It is the durabletask counterpart to the Azure +Functions sample `samples/04-hosting/azure_functions/12_workflow_hitl`. + +## Key Concepts Demonstrated + +- Pausing a workflow for human input with MAF's `ctx.request_info()` / + `@response_handler` pattern, hosted on a standalone worker via + `DurableAIAgentWorker.configure_workflow(workflow)`. +- Discovering pending HITL requests from a client with + `DurableWorkflowClient.get_pending_hitl_requests(instance_id)`. +- Resuming the workflow by sending a decision with + `DurableWorkflowClient.send_hitl_response(instance_id, request_id, response)`. +- Reading the final result with `DurableWorkflowClient.await_workflow_output(instance_id)`. + +The workflow is a content-moderation pipeline: + +``` +input_router -> ContentAnalyzerAgent -> content_analyzer_executor + -> human_review_executor (HITL pause) -> publish_executor +``` + +## How HITL Works Here + +The HITL mechanism is host-agnostic — the same shared workflow orchestrator +drives it on both Azure Functions and a standalone worker: + +1. `human_review_executor` calls `ctx.request_info(...)`, which pauses the + workflow. The orchestrator records the open request in its **custom status** + and waits for an external event named by the request's `request_id`. +2. The client reads the custom status via `get_pending_hitl_requests` and sends + a response via `send_hitl_response`, which raises that external event. +3. The orchestrator routes the response back to the executor's + `@response_handler`, and the workflow resumes. + +`send_hitl_response` sanitizes the payload (neutralizing pickle-marker +injection) before delivery, since the worker deserializes it. + +## Environment Setup + +See the [README.md](../README.md) in the parent directory for environment setup. + +This sample uses Azure AI Foundry credentials: + +- `FOUNDRY_PROJECT_ENDPOINT` +- `FOUNDRY_MODEL` + +It also needs a Durable Task Scheduler. For local development, start the +emulator (defaults to `http://localhost:8080`): + +```bash +docker run -d -p 8080:8080 -p 8082:8082 mcr.microsoft.com/dts/dts-emulator:latest +``` + +## Running the Sample + +Start the worker in one terminal: + +```bash +cd samples/04-hosting/durabletask/09_workflow_hitl +python worker.py +``` + +In a second terminal, run the client: + +```bash +python client.py +``` + +The client runs two cases: + +- **Appropriate content** → analyzed → HITL pause → client **approves** → + `"Content '...' has been APPROVED and published."` +- **Spammy content** → analyzed → HITL pause → client **rejects** → + `"Content '...' has been REJECTED."` diff --git a/python/samples/04-hosting/durabletask/09_workflow_hitl/client.py b/python/samples/04-hosting/durabletask/09_workflow_hitl/client.py new file mode 100644 index 00000000000..73d06dde3a8 --- /dev/null +++ b/python/samples/04-hosting/durabletask/09_workflow_hitl/client.py @@ -0,0 +1,130 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Client that drives the standalone HITL workflow to completion. + +The worker (``worker.py``) must be running first. This client: + +1. Starts the workflow with ``DurableWorkflowClient.start_workflow``. +2. Polls ``get_pending_hitl_requests`` until the workflow pauses for human input. +3. Sends a decision with ``send_hitl_response`` (the request_id correlates the + response back to the paused executor). +4. Reads the final output with ``await_workflow_output``. + +It runs two cases: appropriate content (approved) and spammy content (rejected). + +Prerequisites: +- ``worker.py`` running and connected to the same Durable Task Scheduler. +- A Durable Task Scheduler reachable at ``ENDPOINT`` (default ``http://localhost:8080``). +""" + +import asyncio +import logging +import os +import time +from typing import Any + +from agent_framework.azure import DurableWorkflowClient +from azure.identity import AzureCliCredential +from dotenv import load_dotenv +from durabletask.azuremanaged.client import DurableTaskSchedulerClient + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def get_client(taskhub: str | None = None, endpoint: str | None = None) -> DurableTaskSchedulerClient: + """Create a configured DurableTaskSchedulerClient.""" + taskhub_name = taskhub or os.getenv("TASKHUB", "default") + endpoint_url = endpoint or os.getenv("ENDPOINT", "http://localhost:8080") + + credential = None if endpoint_url == "http://localhost:8080" else AzureCliCredential() + + return DurableTaskSchedulerClient( + host_address=endpoint_url, + secure_channel=endpoint_url != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential, + ) + + +def _wait_for_hitl_request( + client: DurableWorkflowClient, instance_id: str, timeout_seconds: int = 60 +) -> list[dict[str, Any]]: + """Poll until the workflow has at least one pending HITL request. + + Stops early if the workflow reaches a terminal state (e.g. completed or failed) + without pausing, so a misconfiguration or early failure surfaces the real + status instead of a misleading timeout. + """ + terminal_statuses = {"COMPLETED", "FAILED", "TERMINATED"} + deadline = time.time() + timeout_seconds + while time.time() < deadline: + pending = client.get_pending_hitl_requests(instance_id) + if pending: + return pending + status = client.get_runtime_status(instance_id) + if status in terminal_statuses: + raise RuntimeError( + f"Workflow instance {instance_id} reached terminal state '{status}' " + "before pausing for human input." + ) + time.sleep(2) + raise TimeoutError(f"Timed out waiting for a HITL request on instance {instance_id}") + + +def run_case(client: DurableWorkflowClient, submission: dict[str, Any], *, approve: bool) -> None: + """Run one moderation case: start, respond to the HITL pause, print the result.""" + instance_id = client.start_workflow(input=submission) + logger.info("Started workflow instance: %s", instance_id) + + pending = _wait_for_hitl_request(client, instance_id) + request = pending[0] + logger.info("Pending HITL request %s from %s", request["request_id"], request["source_executor_id"]) + + decision = { + "approved": approve, + "reviewer_notes": "Looks good." if approve else "Violates content policy.", + } + client.send_hitl_response(instance_id, request["request_id"], decision) + logger.info("Sent decision: approved=%s", approve) + + output = client.await_workflow_output(instance_id) + logger.info("Workflow output: %s", output) + + +async def main() -> None: + """Run an approved case and a rejected case.""" + client = DurableWorkflowClient(get_client()) + + logger.info("CASE 1: Appropriate content (will approve)") + run_case( + client, + { + "content_id": "article-001", + "title": "Introduction to AI in Healthcare", + "body": ( + "Artificial intelligence is improving healthcare by enabling faster diagnosis, " + "personalized treatment plans, and better patient outcomes." + ), + "author": "Dr. Jane Smith", + }, + approve=True, + ) + + logger.info("CASE 2: Spammy content (will reject)") + run_case( + client, + { + "content_id": "article-002", + "title": "Get Rich Quick", + "body": "Click here NOW to make $10,000 overnight! GUARANTEED! Limited time offer!", + "author": "Definitely Not Spam", + }, + approve=False, + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/04-hosting/durabletask/09_workflow_hitl/worker.py b/python/samples/04-hosting/durabletask/09_workflow_hitl/worker.py new file mode 100644 index 00000000000..33e8ef54087 --- /dev/null +++ b/python/samples/04-hosting/durabletask/09_workflow_hitl/worker.py @@ -0,0 +1,343 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Worker that hosts a Human-in-the-Loop (HITL) MAF Workflow on a standalone worker. + +This sample is the durabletask counterpart to the Azure Functions +``12_workflow_hitl`` sample. It runs an agent-framework ``Workflow`` that pauses +for human approval using MAF's ``ctx.request_info`` / ``@response_handler`` +pattern, hosted on a standalone Durable Task worker (no Azure Functions). + +``DurableAIAgentWorker.configure_workflow`` auto-registers: + +- a durable entity for each agent executor, +- a durable activity for each non-agent executor, and +- the workflow orchestrator (named ``WORKFLOW_ORCHESTRATOR_NAME``). + +When the workflow calls ``ctx.request_info``, the orchestrator pauses and records +the open request in its custom status. An external client discovers the request +(``DurableWorkflowClient.get_pending_hitl_requests``) and resumes the workflow by +sending a response (``DurableWorkflowClient.send_hitl_response``). + +The workflow is a content-moderation pipeline: +``input_router`` -> ``ContentAnalyzerAgent`` -> ``content_analyzer_executor`` +-> ``human_review_executor`` (HITL pause) -> ``publish_executor``. + +Prerequisites: +- Set ``FOUNDRY_PROJECT_ENDPOINT`` and ``FOUNDRY_MODEL``. +- Sign in with Azure CLI (``az login``) for ``AzureCliCredential``. +- Start a Durable Task Scheduler (e.g. the DTS emulator on ``localhost:8080``). + +Run the worker (this process), then run ``client.py`` in another process. +""" + +import asyncio +import logging +import os +from dataclasses import dataclass + +from agent_framework import ( + Agent, + AgentExecutorRequest, + AgentExecutorResponse, + Executor, + Message, + Workflow, + WorkflowBuilder, + WorkflowContext, + handler, + response_handler, +) +from agent_framework.azure import DurableAIAgentWorker +from agent_framework.foundry import FoundryChatClient +from azure.identity import AzureCliCredential +from azure.identity.aio import AzureCliCredential as AsyncAzureCliCredential +from dotenv import load_dotenv +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker +from pydantic import BaseModel, ValidationError +from typing_extensions import Never + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +CONTENT_ANALYZER_AGENT_NAME = "ContentAnalyzerAgent" + +CONTENT_ANALYZER_INSTRUCTIONS = ( + "You are a content moderation assistant that analyzes user-submitted content for policy compliance. " + "Evaluate appropriateness, assign a risk level ('low', 'medium', 'high'), list any concerns, and give a " + "brief recommendation for human reviewers. Return JSON with fields is_appropriate (bool), risk_level (str), " + "concerns (list of str), and recommendation (str)." +) + + +# ============================================================================ +# Data Models +# ============================================================================ + + +class ContentAnalysisResult(BaseModel): + """Structured output from the content analysis agent.""" + + is_appropriate: bool + risk_level: str + concerns: list[str] + recommendation: str + + +@dataclass +class ContentSubmission: + """Content submitted for moderation.""" + + content_id: str + title: str + body: str + author: str + + +@dataclass +class AnalysisWithSubmission: + """Combines the AI analysis with the original submission for downstream processing.""" + + submission: ContentSubmission + analysis: ContentAnalysisResult + + +@dataclass +class HumanApprovalRequest: + """Request sent to a human reviewer. Surfaced to clients via the orchestration status.""" + + content_id: str + title: str + body: str + author: str + ai_analysis: ContentAnalysisResult + prompt: str + + +class HumanApprovalResponse(BaseModel): + """Response the external client sends back via the HITL response endpoint/method.""" + + approved: bool + reviewer_notes: str = "" + + +@dataclass +class ModerationResult: + """Final result of the moderation workflow.""" + + content_id: str + status: str + reviewer_notes: str + + +# ============================================================================ +# Executors +# ============================================================================ + + +class InputRouterExecutor(Executor): + """Parses the incoming submission and routes it to the analysis agent.""" + + def __init__(self) -> None: + super().__init__(id="input_router") + + @handler + async def route_input(self, submission: ContentSubmission, ctx: WorkflowContext[AgentExecutorRequest]) -> None: + ctx.set_state("current_submission", submission) + + message = ( + f"Please analyze the following content for policy compliance:\n\n" + f"Title: {submission.title}\n" + f"Author: {submission.author}\n" + f"Content:\n{submission.body}" + ) + await ctx.send_message( + AgentExecutorRequest(messages=[Message(role="user", contents=[message])], should_respond=True) + ) + + +class ContentAnalyzerExecutor(Executor): + """Parses the AI agent's response and forwards it with the original submission.""" + + def __init__(self) -> None: + super().__init__(id="content_analyzer_executor") + + @handler + async def handle_analysis( + self, response: AgentExecutorResponse, ctx: WorkflowContext[AnalysisWithSubmission] + ) -> None: + try: + analysis = ContentAnalysisResult.model_validate_json(response.agent_response.text) + except ValidationError: + analysis = ContentAnalysisResult( + is_appropriate=False, + risk_level="high", + concerns=["Agent execution failed or yielded invalid JSON."], + recommendation="Manual review required", + ) + + submission: ContentSubmission = ctx.get_state("current_submission") + await ctx.send_message(AnalysisWithSubmission(submission=submission, analysis=analysis)) + + +class HumanReviewExecutor(Executor): + """Requests human approval using MAF's request_info / response_handler pattern.""" + + def __init__(self) -> None: + super().__init__(id="human_review_executor") + + @handler + async def request_review(self, data: AnalysisWithSubmission, ctx: WorkflowContext) -> None: + submission = data.submission + analysis = data.analysis + + prompt = ( + f"Please review the following content for publication:\n\n" + f"Title: {submission.title}\n" + f"Author: {submission.author}\n" + f"Content: {submission.body}\n\n" + f"AI Analysis:\n" + f"- Appropriate: {analysis.is_appropriate}\n" + f"- Risk Level: {analysis.risk_level}\n" + f"- Concerns: {', '.join(analysis.concerns) if analysis.concerns else 'None'}\n" + f"- Recommendation: {analysis.recommendation}\n\n" + f"Please approve or reject this content." + ) + approval_request = HumanApprovalRequest( + content_id=submission.content_id, + title=submission.title, + body=submission.body, + author=submission.author, + ai_analysis=analysis, + prompt=prompt, + ) + + # Pause the workflow and wait for a human response. + await ctx.request_info(request_data=approval_request, response_type=HumanApprovalResponse) + + @response_handler + async def handle_approval_response( + self, + original_request: HumanApprovalRequest, + response: HumanApprovalResponse, + ctx: WorkflowContext[ModerationResult], + ) -> None: + logger.info( + "Human review received for content %s: approved=%s", + original_request.content_id, + response.approved, + ) + await ctx.send_message( + ModerationResult( + content_id=original_request.content_id, + status="approved" if response.approved else "rejected", + reviewer_notes=response.reviewer_notes, + ) + ) + + +class PublishExecutor(Executor): + """Finalizes publication or rejection of the content.""" + + def __init__(self) -> None: + super().__init__(id="publish_executor") + + @handler + async def handle_result(self, result: ModerationResult, ctx: WorkflowContext[Never, str]) -> None: + if result.status == "approved": + message = ( + f"Content '{result.content_id}' has been APPROVED and published. " + f"Reviewer notes: {result.reviewer_notes or 'None'}" + ) + else: + message = ( + f"Content '{result.content_id}' has been REJECTED. " + f"Reviewer notes: {result.reviewer_notes or 'None'}" + ) + logger.info(message) + await ctx.yield_output(message) + + +def _create_chat_client() -> FoundryChatClient: + """Create an Azure AI Foundry chat client using AzureCliCredential.""" + return FoundryChatClient( + project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"], + model=os.environ["FOUNDRY_MODEL"], + credential=AsyncAzureCliCredential(), + ) + + +def create_workflow() -> Workflow: + """Build the content-moderation workflow with a human-in-the-loop pause.""" + chat_client = _create_chat_client() + + content_analyzer_agent = Agent( + client=chat_client, + name=CONTENT_ANALYZER_AGENT_NAME, + instructions=CONTENT_ANALYZER_INSTRUCTIONS, + default_options={"response_format": ContentAnalysisResult}, + ) + + input_router = InputRouterExecutor() + content_analyzer_executor = ContentAnalyzerExecutor() + human_review_executor = HumanReviewExecutor() + publish_executor = PublishExecutor() + + return ( + WorkflowBuilder(start_executor=input_router) + .add_edge(input_router, content_analyzer_agent) + .add_edge(content_analyzer_agent, content_analyzer_executor) + .add_edge(content_analyzer_executor, human_review_executor) + .add_edge(human_review_executor, publish_executor) + .build() + ) + + +def get_worker( + taskhub: str | None = None, endpoint: str | None = None, log_handler: logging.Handler | None = None +) -> DurableTaskSchedulerWorker: + """Create a configured DurableTaskSchedulerWorker.""" + taskhub_name = taskhub or os.getenv("TASKHUB", "default") + endpoint_url = endpoint or os.getenv("ENDPOINT", "http://localhost:8080") + + credential = None if endpoint_url == "http://localhost:8080" else AzureCliCredential() + + return DurableTaskSchedulerWorker( + host_address=endpoint_url, + secure_channel=endpoint_url != "http://localhost:8080", + taskhub=taskhub_name, + token_credential=credential, + log_handler=log_handler, + ) + + +def setup_worker(worker: DurableTaskSchedulerWorker) -> DurableAIAgentWorker: + """Register the workflow (agents + activities + orchestrator) on the worker.""" + agent_worker = DurableAIAgentWorker(worker) + + workflow = create_workflow() + agent_worker.configure_workflow(workflow) + logger.info("✓ Configured HITL workflow with %d executors", len(workflow.executors)) + + return agent_worker + + +async def main() -> None: + """Start the worker and block until interrupted.""" + worker = get_worker() + setup_worker(worker) + + logger.info("Worker is ready and listening for work items. Press Ctrl+C to stop.") + try: + worker.start() + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + logger.info("Worker shutdown initiated") + + logger.info("Worker stopped") + + +if __name__ == "__main__": + asyncio.run(main())