diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/__init__.py b/python/packages/declarative/agent_framework_declarative/_workflows/__init__.py index 579cabfafaf..31ee990cb69 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/__init__.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/__init__.py @@ -76,12 +76,10 @@ from ._executors_tools import ( FUNCTION_TOOL_REGISTRY_KEY, TOOL_ACTION_EXECUTORS, - TOOL_APPROVAL_STATE_KEY, BaseToolExecutor, InvokeFunctionToolExecutor, ToolApprovalRequest, ToolApprovalResponse, - ToolApprovalState, ToolInvocationResult, ) from ._factory import WorkflowFactory @@ -111,7 +109,6 @@ "HTTP_ACTION_EXECUTORS", "MCP_ACTION_EXECUTORS", "TOOL_ACTION_EXECUTORS", - "TOOL_APPROVAL_STATE_KEY", "TOOL_REGISTRY_KEY", "ActionComplete", "ActionTrigger", @@ -164,7 +161,6 @@ "SetVariableExecutor", "ToolApprovalRequest", "ToolApprovalResponse", - "ToolApprovalState", "ToolInvocationResult", "WorkflowFactory", "WorkflowState", diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py index 73b66341ea3..1b16a87277c 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py @@ -10,17 +10,11 @@ Security notes: -- The executor never echoes header VALUES (auth tokens, API keys) into the - approval request — only header NAMES are surfaced to the caller. This - matches the security posture of :mod:`._executors_http` (which never logs - request headers either) and prevents secrets from leaking through workflow - events that are typically observable to operators / UIs. -- ``_MCPToolApprovalState`` snapshots the EVALUATED values for non-secret - fields (server URL, tool name, arguments) at approval-request time so that - subsequent state mutations cannot make the executor "approve X then call - Y". Headers are stored as the raw expression strings (not evaluated values) - so secrets are not persisted in the workflow's checkpoint state. They are - re-evaluated on resume. +- Approval requests surface header NAMES only; header values are not echoed, + matching the posture of :mod:`._executors_http`. +- :class:`MCPToolApprovalRequest` carries the values the resume handler will + use; header values are re-evaluated on resume to keep secrets out of + checkpoint state. - Tool outputs flow back into agent conversations through ``conversationId`` and through Tool-role messages emitted to ``output.messages``. They share the same prompt-injection risk surface as ``HttpRequestAction``: workflow @@ -60,8 +54,6 @@ logger = logging.getLogger(__name__) -_MCP_APPROVAL_STATE_KEY = "_mcp_tool_approval_state" - # --------------------------------------------------------------------------- # Request / state types @@ -72,20 +64,16 @@ class MCPToolApprovalRequest: """Approval request emitted before invoking an MCP tool. - Mirrors :class:`agent_framework_declarative.ToolApprovalRequest` but for - MCP-style invocations. Only header NAMES are surfaced — header values are - intentionally omitted because they typically carry authentication - secrets. - Attributes: - request_id: Unique identifier for this approval request. Matches the - id workflow event-emitters use. - tool_name: Evaluated name of the tool to be invoked. + request_id: Identifier matching the framework's pending-request key. + tool_name: Evaluated tool name. server_url: Evaluated MCP server URL. - server_label: Optional human-readable label for diagnostics. - arguments: Evaluated arguments to be forwarded to the tool. - header_names: Sorted list of outbound header names (no values). Empty - when no headers are configured. + server_label: Optional human-readable label. + arguments: Evaluated tool arguments. + header_names: Outbound header names (values withheld). + connection_name: Connection identifier the invocation will use. + metadata: Internal routing data pinned at approval-request time + (e.g. ``conversation_id``) for use by the resume handler. """ request_id: str @@ -94,28 +82,8 @@ class MCPToolApprovalRequest: server_label: str | None arguments: dict[str, Any] header_names: list[str] = field(default_factory=lambda: []) - - -@dataclass -class _MCPToolApprovalState: - """Internal state saved during the approval yield for resumption. - - Stores **evaluated** values for non-secret fields to prevent - "approve X / execute Y" attacks. Stores the raw expression string for - ``headers`` so that secret values are NOT persisted in checkpoint state; - the expressions are re-evaluated against current state on resume. - """ - - server_url: str - tool_name: str - server_label: str | None - arguments: dict[str, Any] - connection_name: str | None - headers_def: Any - auto_send: bool - conversation_id_expr: str | None - output_messages_path: str | None - output_result_path: str | None + connection_name: str | None = None + metadata: dict[str, Any] = field(default_factory=lambda: {}) # --------------------------------------------------------------------------- @@ -123,21 +91,15 @@ class _MCPToolApprovalState: # --------------------------------------------------------------------------- -def _get_messages_path(state: DeclarativeWorkflowState, conversation_id_expr: str | None) -> str | None: - """Return the configured conversation messages path, if any. - - Returns ``System.conversations.{evaluated_id}.messages`` when a - ``conversation_id_expr`` is configured and evaluates to a non-empty value. - Returns ``None`` when no conversation id expression is configured or when - the expression evaluates to ``None`` or an empty string (mirrors .NET - ``GetConversationId`` behaviour). - """ - if not conversation_id_expr: +def _evaluate_conversation_id(state: DeclarativeWorkflowState, conversation_id_expr: Any) -> str | None: + """Return the evaluated ``conversationId`` string, or None when empty/unset.""" + if not isinstance(conversation_id_expr, str) or not conversation_id_expr: return None evaluated = state.eval_if_expression(conversation_id_expr) - if evaluated is None or (isinstance(evaluated, str) and not evaluated): + if evaluated is None: return None - return f"System.conversations.{evaluated}.messages" + text = str(evaluated) + return text or None def _get_output_path(action_def: Mapping[str, Any], key: str) -> str | None: @@ -260,20 +222,7 @@ async def handle_action( if require_approval: request_id = str(uuid.uuid4()) - approval_state = _MCPToolApprovalState( - server_url=server_url, - tool_name=tool_name, - server_label=server_label, - arguments=arguments, - connection_name=connection_name, - headers_def=self._action_def.get("headers"), - auto_send=auto_send, - conversation_id_expr=conversation_id_expr if isinstance(conversation_id_expr, str) else None, - output_messages_path=output_messages_path, - output_result_path=output_result_path, - ) - ctx.state.set(self._approval_key(), approval_state) - + conversation_id = _evaluate_conversation_id(state, conversation_id_expr) request = MCPToolApprovalRequest( request_id=request_id, tool_name=tool_name, @@ -281,6 +230,8 @@ async def handle_action( server_label=server_label, arguments=arguments, header_names=sorted(headers.keys()), + connection_name=connection_name, + metadata={"conversation_id": conversation_id}, ) logger.info( "%s: requesting approval for MCP tool '%s' on '%s'", @@ -289,7 +240,6 @@ async def handle_action( server_url, ) await ctx.request_info(request, ToolApprovalResponse, request_id=request_id) - # Workflow yields here — resume in handle_approval_response. return # No approval required - invoke directly. @@ -307,7 +257,7 @@ async def handle_action( state=state, result=result, auto_send=auto_send, - conversation_id_expr=conversation_id_expr if isinstance(conversation_id_expr, str) else None, + conversation_id=_evaluate_conversation_id(state, conversation_id_expr), output_messages_path=output_messages_path, output_result_path=output_result_path, ) @@ -322,54 +272,46 @@ async def handle_approval_response( response: ToolApprovalResponse, ctx: WorkflowContext[ActionComplete, str], ) -> None: - """Resume after the workflow yielded for an approval request.""" + """Resume the invocation using the values pinned on ``original_request``.""" state = self._get_state(ctx.state) - approval_key = self._approval_key() - try: - approval_state: _MCPToolApprovalState = ctx.state.get(approval_key) - except KeyError: - logger.error("%s: approval state missing for executor '%s'", self.__class__.__name__, self.id) - await ctx.send_message(ActionComplete()) - return - try: - ctx.state.delete(approval_key) - except KeyError: - logger.warning("%s: approval state already deleted for '%s'", self.__class__.__name__, self.id) + tool_name = original_request.tool_name + metadata: dict[str, Any] = getattr(original_request, "metadata", None) or {} + raw_conversation_id = metadata.get("conversation_id") + conversation_id = raw_conversation_id if isinstance(raw_conversation_id, str) and raw_conversation_id else None + + auto_send = self._get_auto_send(state) + output_messages_path = _get_output_path(self._action_def, "messages") + output_result_path = _get_output_path(self._action_def, "result") if not response.approved: logger.info( "%s: MCP tool '%s' rejected: %s", self.__class__.__name__, - approval_state.tool_name, + tool_name, response.reason, ) - self._assign_error( - state, approval_state.output_result_path, "MCP tool invocation was not approved by user." - ) + self._assign_error(state, output_result_path, "MCP tool invocation was not approved by user.") await ctx.send_message(ActionComplete()) return - # Approved — re-evaluate headers (not stored at approval time for security). - headers = self._evaluate_headers(state, approval_state.headers_def) - invocation = MCPToolInvocation( - server_url=approval_state.server_url, - tool_name=approval_state.tool_name, - server_label=approval_state.server_label, - arguments=approval_state.arguments, - headers=headers, - connection_name=approval_state.connection_name, + server_url=original_request.server_url, + tool_name=tool_name, + server_label=original_request.server_label, + arguments=original_request.arguments, + headers=self._evaluate_headers(state, self._action_def.get("headers")), + connection_name=getattr(original_request, "connection_name", None), ) result = await self._invoke_with_narrow_catch(invocation) await self._process_result( ctx=ctx, state=state, result=result, - auto_send=approval_state.auto_send, - conversation_id_expr=approval_state.conversation_id_expr, - output_messages_path=approval_state.output_messages_path, - output_result_path=approval_state.output_result_path, + auto_send=auto_send, + conversation_id=conversation_id, + output_messages_path=output_messages_path, + output_result_path=output_result_path, ) await ctx.send_message(ActionComplete()) @@ -528,7 +470,7 @@ async def _process_result( state: DeclarativeWorkflowState, result: MCPToolResult, auto_send: bool, - conversation_id_expr: str | None, + conversation_id: str | None, output_messages_path: str | None, output_result_path: str | None, ) -> None: @@ -557,14 +499,10 @@ async def _process_result( if auto_send and parsed_results: await ctx.yield_output(_format_outputs_for_send(parsed_results)) - if conversation_id_expr: - messages_path = _get_messages_path(state, conversation_id_expr) - if messages_path is not None: - # Mirrors .NET: conversation gets ASSISTANT-role message with - # the same outputs (so chat history reads it as the agent's - # contribution). - assistant_message = Message(role="assistant", contents=list(result.outputs)) - state.append(messages_path, assistant_message) + if conversation_id: + messages_path = f"System.conversations.{conversation_id}.messages" + assistant_message = Message(role="assistant", contents=list(result.outputs)) + state.append(messages_path, assistant_message) @staticmethod def _assign_error( @@ -577,9 +515,6 @@ def _assign_error( return state.set(output_result_path, f"Error: {error_message}") - def _approval_key(self) -> str: - return f"{_MCP_APPROVAL_STATE_KEY}_{self.id}" - def _parse_outputs(outputs: list[Content]) -> list[Any]: """Parse :class:`Content` outputs into Python values for ``output.result``. diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_tools.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_tools.py index b2c046a69bb..d522cf56643 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_tools.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_tools.py @@ -41,10 +41,6 @@ # at runtime are discoverable by both agent-based and function-based tool executors. FUNCTION_TOOL_REGISTRY_KEY = TOOL_REGISTRY_KEY -# State key prefix for storing approval state during yield/resume. -# The executor's ID is appended to create a per-executor key. -TOOL_APPROVAL_STATE_KEY = "_tool_approval_state" - # ============================================================================ # Request/Response Types for Approval Flow @@ -87,26 +83,6 @@ class ToolApprovalResponse: reason: str | None = None -# ============================================================================ -# State Types for Approval Flow -# ============================================================================ - - -@dataclass -class ToolApprovalState: - """State saved during approval yield for resumption. - - Stored in State under a per-executor key when requireApproval=true. - Retrieved by handle_approval_response() to continue execution. - """ - - function_name: str - arguments: dict[str, Any] - output_messages_var: str | None - output_result_var: str | None - auto_send: bool - - # ============================================================================ # Result Types # ============================================================================ @@ -501,25 +477,16 @@ async def handle_action( require_approval = self._action_def.get("requireApproval", False) if require_approval: - # Save state for resumption (keyed by executor ID to avoid collisions) - approval_state = ToolApprovalState( - function_name=function_name, - arguments=arguments, - output_messages_var=messages_var, - output_result_var=result_var, - auto_send=auto_send, - ) - approval_key = f"{TOOL_APPROVAL_STATE_KEY}_{self.id}" - ctx.state.set(approval_key, approval_state) - - # Emit approval request - workflow yields here + # Emit approval request - the request payload is the source of + # truth for resumed invocation; no side-channel state is written. + request_id = str(uuid.uuid4()) request = ToolApprovalRequest( - request_id=str(uuid.uuid4()), + request_id=request_id, function_name=function_name, arguments=arguments, ) logger.info(f"{self.__class__.__name__}: requesting approval for '{function_name}'") - await ctx.request_info(request, ToolApprovalResponse) + await ctx.request_info(request, ToolApprovalResponse, request_id=request_id) # Workflow yields - will resume in handle_approval_response return @@ -545,36 +512,16 @@ async def handle_approval_response( ) -> None: """Handle response to a ToolApprovalRequest. - Called when the workflow resumes after yielding for approval. - Either executes the tool (if approved) or stores rejection status. + Resumes after the workflow yielded for approval. The invocation + ``function_name`` and ``arguments`` are sourced from + ``original_request`` (the payload the reviewer approved); output + configuration is re-derived from the executor's action definition. """ state = self._get_state(ctx.state) - approval_key = f"{TOOL_APPROVAL_STATE_KEY}_{self.id}" - - # Retrieve saved invocation state - try: - approval_state: ToolApprovalState = ctx.state.get(approval_key) - except KeyError: - error_msg = "Approval state not found, cannot resume tool invocation" - logger.error(f"{self.__class__.__name__}: {error_msg}") - # Try to store error - get output config from action def as fallback - _, result_var, _ = self._get_output_config() - if result_var and state: - state.set(_normalize_variable_path(result_var), {"error": error_msg}) - await ctx.send_message(ActionComplete()) - return - # Clean up approval state - try: - ctx.state.delete(approval_key) - except KeyError: - logger.warning(f"{self.__class__.__name__}: approval state already deleted") - - function_name = approval_state.function_name - arguments = approval_state.arguments - messages_var = approval_state.output_messages_var - result_var = approval_state.output_result_var - auto_send = approval_state.auto_send + function_name = original_request.function_name + arguments = original_request.arguments + messages_var, result_var, auto_send = self._get_output_config() # Check if approved if not response.approved: diff --git a/python/packages/declarative/tests/test_declarative_approval_binding.py b/python/packages/declarative/tests/test_declarative_approval_binding.py new file mode 100644 index 00000000000..ba0d4108f12 --- /dev/null +++ b/python/packages/declarative/tests/test_declarative_approval_binding.py @@ -0,0 +1,528 @@ +# Copyright (c) Microsoft. All rights reserved. +# pyright: reportUnknownParameterType=false, reportUnknownArgumentType=false +# pyright: reportMissingParameterType=false, reportUnknownMemberType=false +# pyright: reportPrivateUsage=false, reportUnknownVariableType=false +# pyright: reportGeneralTypeIssues=false + +"""Regression tests pinning the approval-flow binding contract. + +The resumed invocation MUST come from the framework-delivered +``original_request`` payload (the data the reviewer approved) for both +``InvokeFunctionTool`` and ``InvokeMcpTool``. These tests verify that: + +* Invocation parameters come from ``original_request``, not from any prior + side-channel state. +* Concurrent pending approvals on the same executor do not swap. +* Pre-existing state at old approval keys is ignored entirely. +* Resume works on a freshly constructed executor (checkpoint-restore + simulation), without any prior ``ctx.state`` write. +* For MCP, ``connection_name`` is sourced from the approval payload and + ``headers`` are re-evaluated from the action definition on resume. +""" + +import sys +from dataclasses import dataclass +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +try: + import powerfx # noqa: F401 + + _powerfx_available = True +except (ImportError, RuntimeError): + _powerfx_available = False + +pytestmark = pytest.mark.skipif( + not _powerfx_available or sys.version_info >= (3, 14), + reason="PowerFx engine not available (requires dotnet runtime)", +) + +from agent_framework import Content # noqa: E402 + +from agent_framework_declarative._workflows import ( # noqa: E402 + DECLARATIVE_STATE_KEY, + ActionComplete, + InvokeFunctionToolExecutor, + MCPToolApprovalRequest, + MCPToolHandler, + MCPToolInvocation, + MCPToolResult, + ToolApprovalRequest, + ToolApprovalResponse, +) +from agent_framework_declarative._workflows._declarative_base import DeclarativeWorkflowState # noqa: E402 +from agent_framework_declarative._workflows._executors_mcp import ( # noqa: E402 + InvokeMcpToolActionExecutor, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_state() -> MagicMock: + """In-memory mock of the underlying State.""" + state = MagicMock() + state._data = {} + + def _get(key: str, default: Any = None) -> Any: + return state._data.get(key, default) + + def _set(key: str, value: Any) -> None: + state._data[key] = value + + def _has(key: str) -> bool: + return key in state._data + + def _delete(key: str) -> None: + state._data.pop(key, None) + + state.get = MagicMock(side_effect=_get) + state.set = MagicMock(side_effect=_set) + state.has = MagicMock(side_effect=_has) + state.delete = MagicMock(side_effect=_delete) + return state + + +@pytest.fixture +def mock_context(mock_state: MagicMock) -> MagicMock: + ctx = MagicMock() + ctx.state = mock_state + ctx.send_message = AsyncMock() + ctx.yield_output = AsyncMock() + ctx.request_info = AsyncMock() + return ctx + + +def _seed_state(mock_state: MagicMock) -> None: + mock_state._data[DECLARATIVE_STATE_KEY] = { + "Inputs": {}, + "Outputs": {}, + "Local": {}, + "Custom": {}, + "System": { + "ConversationId": "00000000-0000-0000-0000-000000000000", + "LastMessage": {"Text": "", "Id": ""}, + "LastMessageText": "", + "LastMessageId": "", + }, + "Agent": {}, + "Conversation": {"messages": [], "history": []}, + } + + +class _RecordingMcpHandler(MCPToolHandler): + def __init__(self, result: MCPToolResult | None = None) -> None: + self.result = result or MCPToolResult(outputs=[Content.from_text("ok")]) + self.invocations: list[MCPToolInvocation] = [] + + @property + def call_count(self) -> int: + return len(self.invocations) + + @property + def last(self) -> MCPToolInvocation | None: + return self.invocations[-1] if self.invocations else None + + async def invoke_tool(self, invocation: MCPToolInvocation) -> MCPToolResult: + self.invocations.append(invocation) + return self.result + + +# --------------------------------------------------------------------------- +# InvokeFunctionTool: approval-binding regression +# --------------------------------------------------------------------------- + + +class TestFunctionToolApprovalBinding: + def _action(self, *, fn_name: str = "my_tool") -> dict[str, Any]: + return { + "kind": "InvokeFunctionTool", + "id": "fn_action", + "functionName": fn_name, + "requireApproval": True, + "output": {"result": "Local.result"}, + } + + @pytest.mark.asyncio + async def test_request_id_matches_framework_pending_key(self, mock_state, mock_context) -> None: + """The id on the emitted ToolApprovalRequest must match the framework's pending-request key.""" + from agent_framework_declarative._workflows._declarative_base import ActionTrigger + + _seed_state(mock_state) + + def my_tool(x: int) -> int: + return x + + executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool}) + await executor.handle_action(ActionTrigger(), mock_context) + + mock_context.request_info.assert_called_once() + emitted_request = mock_context.request_info.call_args[0][0] + framework_request_id = mock_context.request_info.call_args.kwargs["request_id"] + assert isinstance(emitted_request, ToolApprovalRequest) + assert emitted_request.request_id == framework_request_id + + @pytest.mark.asyncio + async def test_resume_uses_request_payload_arguments(self, mock_state, mock_context) -> None: + _seed_state(mock_state) + call_log: list[int] = [] + + def my_tool(x: int) -> int: + call_log.append(x) + return x + + executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool}) + + request = ToolApprovalRequest(request_id="r-1", function_name="my_tool", arguments={"x": 1}) + await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context) + + assert call_log == [1] + + @pytest.mark.asyncio + async def test_concurrent_pending_approvals_do_not_swap(self, mock_state, mock_context) -> None: + """Two pending approvals, responses delivered out of order — each invocation uses its own payload.""" + _seed_state(mock_state) + call_log: list[int] = [] + + def my_tool(x: int) -> int: + call_log.append(x) + return x + + executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool}) + + request_a = ToolApprovalRequest(request_id="r-A", function_name="my_tool", arguments={"x": 1}) + request_b = ToolApprovalRequest(request_id="r-B", function_name="my_tool", arguments={"x": 999}) + + # Deliver response for B first, then for A. Each invocation must use its own payload. + await executor.handle_approval_response(request_b, ToolApprovalResponse(approved=True), mock_context) + await executor.handle_approval_response(request_a, ToolApprovalResponse(approved=True), mock_context) + + assert call_log == [999, 1] + + @pytest.mark.asyncio + async def test_resume_ignores_stale_state_at_old_approval_key(self, mock_state, mock_context) -> None: + """Pre-existing state at the OLD approval key is ignored — payload wins.""" + _seed_state(mock_state) + call_log: list[int] = [] + + def my_tool(x: int) -> int: + call_log.append(x) + return x + + executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool}) + + # Poison the old key shape (no longer read by the executor). + mock_state._data["_tool_approval_state_fn_action"] = {"function_name": "other", "arguments": {"x": 999}} + + request = ToolApprovalRequest(request_id="r-3", function_name="my_tool", arguments={"x": 7}) + await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context) + + assert call_log == [7] + # The poison was never read or deleted by the executor. + assert "_tool_approval_state_fn_action" in mock_state._data + + @pytest.mark.asyncio + async def test_fresh_executor_resume_works(self, mock_state, mock_context) -> None: + """Simulates checkpoint restore: a brand-new executor instance handles the approval response.""" + _seed_state(mock_state) + call_log: list[int] = [] + + def my_tool(x: int) -> int: + call_log.append(x) + return x + + # Pretend the executor that emitted the request is gone; a fresh one handles the response. + fresh = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool}) + + request = ToolApprovalRequest(request_id="r-4", function_name="my_tool", arguments={"x": 42}) + await fresh.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context) + + assert call_log == [42] + mock_context.send_message.assert_called_once() + sent = mock_context.send_message.call_args[0][0] + assert isinstance(sent, ActionComplete) + + @pytest.mark.asyncio + async def test_rejection_uses_request_payload_function_name(self, mock_state, mock_context) -> None: + _seed_state(mock_state) + + def my_tool(x: int) -> int: + raise AssertionError("should not be called when rejected") + + executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool}) + + request = ToolApprovalRequest(request_id="r-5", function_name="my_tool", arguments={"x": 3}) + await executor.handle_approval_response( + request, ToolApprovalResponse(approved=False, reason="not authorized"), mock_context + ) + + # The rejection message references the function name from the request payload. + local = mock_state._data[DECLARATIVE_STATE_KEY]["Local"] + assert local["result"]["rejected"] is True + assert local["result"]["reason"] == "not authorized" + + +# --------------------------------------------------------------------------- +# InvokeMcpTool: approval-binding regression +# --------------------------------------------------------------------------- + + +class TestMcpToolApprovalBinding: + def _action(self, *, headers: dict[str, Any] | None = None) -> dict[str, Any]: + action: dict[str, Any] = { + "kind": "InvokeMcpTool", + "id": "mcp_action", + "serverUrl": "https://mcp.example/api", + "toolName": "search", + "requireApproval": True, + "output": {"result": "Local.Result"}, + } + if headers is not None: + action["headers"] = headers + return action + + @pytest.mark.asyncio + async def test_request_id_matches_framework_pending_key(self, mock_state, mock_context) -> None: + """The id on the emitted MCPToolApprovalRequest must match the framework's pending-request key.""" + from agent_framework_declarative._workflows._declarative_base import ActionTrigger + + _seed_state(mock_state) + executor = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=_RecordingMcpHandler()) + await executor.handle_action(ActionTrigger(), mock_context) + + mock_context.request_info.assert_called_once() + emitted_request = mock_context.request_info.call_args[0][0] + framework_request_id = mock_context.request_info.call_args.kwargs["request_id"] + assert isinstance(emitted_request, MCPToolApprovalRequest) + assert emitted_request.request_id == framework_request_id + + @pytest.mark.asyncio + async def test_resume_uses_request_payload_fields(self, mock_state, mock_context) -> None: + _seed_state(mock_state) + handler = _RecordingMcpHandler() + executor = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=handler) + + request = MCPToolApprovalRequest( + request_id="r-1", + tool_name="search", + server_url="https://mcp.example/api", + server_label="prod", + arguments={"q": "x"}, + connection_name="conn-A", + ) + await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context) + + assert handler.call_count == 1 + inv = handler.last + assert inv is not None + assert inv.tool_name == "search" + assert inv.server_url == "https://mcp.example/api" + assert inv.server_label == "prod" + assert inv.arguments == {"q": "x"} + assert inv.connection_name == "conn-A" + + @pytest.mark.asyncio + async def test_concurrent_pending_mcp_approvals_do_not_swap(self, mock_state, mock_context) -> None: + _seed_state(mock_state) + handler = _RecordingMcpHandler() + executor = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=handler) + + request_a = MCPToolApprovalRequest( + request_id="r-A", + tool_name="search", + server_url="https://mcp.example/api", + server_label=None, + arguments={"q": "alpha"}, + connection_name="conn-A", + ) + request_b = MCPToolApprovalRequest( + request_id="r-B", + tool_name="search", + server_url="https://mcp.example/api", + server_label=None, + arguments={"q": "beta"}, + connection_name="conn-B", + ) + + await executor.handle_approval_response(request_b, ToolApprovalResponse(approved=True), mock_context) + await executor.handle_approval_response(request_a, ToolApprovalResponse(approved=True), mock_context) + + assert handler.call_count == 2 + assert handler.invocations[0].arguments == {"q": "beta"} + assert handler.invocations[0].connection_name == "conn-B" + assert handler.invocations[1].arguments == {"q": "alpha"} + assert handler.invocations[1].connection_name == "conn-A" + + @pytest.mark.asyncio + async def test_headers_reevaluated_from_action_def_on_resume(self, mock_state, mock_context) -> None: + """Headers come from the action definition (re-evaluated) so secrets are not in the payload.""" + _seed_state(mock_state) + handler = _RecordingMcpHandler() + executor = InvokeMcpToolActionExecutor( + self._action(headers={"Authorization": "Bearer tk"}), + mcp_tool_handler=handler, + ) + + request = MCPToolApprovalRequest( + request_id="r-1", + tool_name="search", + server_url="https://mcp.example/api", + server_label=None, + arguments={"q": "x"}, + connection_name=None, + ) + await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context) + + assert handler.last is not None + assert handler.last.headers == {"Authorization": "Bearer tk"} + + @pytest.mark.asyncio + async def test_mcp_resume_ignores_stale_state_at_old_approval_key(self, mock_state, mock_context) -> None: + _seed_state(mock_state) + handler = _RecordingMcpHandler() + executor = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=handler) + + mock_state._data["_mcp_tool_approval_state_mcp_action"] = {"poison": True} + + request = MCPToolApprovalRequest( + request_id="r-1", + tool_name="search", + server_url="https://mcp.example/api", + server_label=None, + arguments={"q": "real"}, + connection_name=None, + ) + await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context) + + assert handler.call_count == 1 + assert handler.last is not None + assert handler.last.arguments == {"q": "real"} + # The poison was never read or deleted by the executor. + assert "_mcp_tool_approval_state_mcp_action" in mock_state._data + + @pytest.mark.asyncio + async def test_fresh_mcp_executor_resume_works(self, mock_state, mock_context) -> None: + """Checkpoint-restore simulation: fresh executor handles the response.""" + _seed_state(mock_state) + handler = _RecordingMcpHandler() + fresh = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=handler) + + request = MCPToolApprovalRequest( + request_id="r-1", + tool_name="search", + server_url="https://mcp.example/api", + server_label=None, + arguments={"q": "fresh"}, + connection_name=None, + ) + await fresh.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context) + + assert handler.call_count == 1 + assert handler.last is not None + assert handler.last.arguments == {"q": "fresh"} + + @pytest.mark.asyncio + async def test_request_payload_carries_connection_name(self, mock_state, mock_context) -> None: + """When emitting the approval request, connection_name flows into MCPToolApprovalRequest.""" + from agent_framework_declarative._workflows._declarative_base import ActionTrigger + + _seed_state(mock_state) + action = self._action() + action["connection"] = {"name": "conn-from-action"} + executor = InvokeMcpToolActionExecutor(action, mcp_tool_handler=_RecordingMcpHandler()) + + await executor.handle_action(ActionTrigger(), mock_context) + + mock_context.request_info.assert_called_once() + request = mock_context.request_info.call_args[0][0] + assert isinstance(request, MCPToolApprovalRequest) + assert request.connection_name == "conn-from-action" + + @pytest.mark.asyncio + async def test_request_payload_pins_conversation_id(self, mock_state, mock_context) -> None: + """Evaluated ``conversationId`` is pinned in ``metadata`` at request-emit time.""" + from agent_framework_declarative._workflows._declarative_base import ActionTrigger + + _seed_state(mock_state) + state = DeclarativeWorkflowState(mock_state) + state.set("Local.targetConversation", "conv-original") + action = self._action() + action["conversationId"] = "=Local.targetConversation" + executor = InvokeMcpToolActionExecutor(action, mcp_tool_handler=_RecordingMcpHandler()) + + await executor.handle_action(ActionTrigger(), mock_context) + + mock_context.request_info.assert_called_once() + request = mock_context.request_info.call_args[0][0] + assert isinstance(request, MCPToolApprovalRequest) + assert request.metadata.get("conversation_id") == "conv-original" + + @pytest.mark.asyncio + async def test_resume_routes_output_to_pinned_conversation_not_mutated_state( + self, mock_state, mock_context + ) -> None: + """Output appends to the conversation pinned on ``original_request``, not the + current state evaluation.""" + _seed_state(mock_state) + state = DeclarativeWorkflowState(mock_state) + state.set("System.conversations.conv-original.messages", []) + state.set("System.conversations.conv-mutated.messages", []) + state.set("Local.targetConversation", "conv-mutated") + + handler = _RecordingMcpHandler(MCPToolResult(outputs=[Content.from_text("approved-output")])) + action = self._action() + action["conversationId"] = "=Local.targetConversation" + executor = InvokeMcpToolActionExecutor(action, mcp_tool_handler=handler) + + original_request = MCPToolApprovalRequest( + request_id="r-1", + tool_name="search", + server_url="https://mcp.example/api", + server_label=None, + arguments={"q": "x"}, + connection_name=None, + metadata={"conversation_id": "conv-original"}, + ) + await executor.handle_approval_response(original_request, ToolApprovalResponse(approved=True), mock_context) + + assert len(state.get("System.conversations.conv-original.messages") or []) == 1 + assert state.get("System.conversations.conv-mutated.messages") == [] + + @pytest.mark.asyncio + async def test_resume_handles_legacy_request_without_new_fields(self, mock_state, mock_context) -> None: + """Resume tolerates payloads lacking ``connection_name`` / ``metadata`` (legacy pickle shape).""" + + @dataclass + class _LegacyMCPApprovalRequest: + request_id: str + tool_name: str + server_url: str + server_label: str | None + arguments: dict[str, Any] + header_names: list[str] + + _seed_state(mock_state) + handler = _RecordingMcpHandler() + executor = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=handler) + + legacy_request = _LegacyMCPApprovalRequest( + request_id="r-1", + tool_name="search", + server_url="https://mcp.example/api", + server_label=None, + arguments={"q": "x"}, + header_names=[], + ) + await executor.handle_approval_response( + legacy_request, # type: ignore[arg-type] + ToolApprovalResponse(approved=True), + mock_context, + ) + + assert handler.call_count == 1 + assert handler.last is not None + assert handler.last.connection_name is None diff --git a/python/packages/declarative/tests/test_function_tool_executor.py b/python/packages/declarative/tests/test_function_tool_executor.py index f11b3568658..bcf04bd21d2 100644 --- a/python/packages/declarative/tests/test_function_tool_executor.py +++ b/python/packages/declarative/tests/test_function_tool_executor.py @@ -35,14 +35,12 @@ from agent_framework_declarative._workflows import ( # noqa: E402 DECLARATIVE_STATE_KEY, FUNCTION_TOOL_REGISTRY_KEY, - TOOL_APPROVAL_STATE_KEY, ActionComplete, ActionTrigger, DeclarativeWorkflowBuilder, InvokeFunctionToolExecutor, ToolApprovalRequest, ToolApprovalResponse, - ToolApprovalState, ToolInvocationResult, WorkflowFactory, ) @@ -393,21 +391,6 @@ def test_approval_response_rejected(self): assert response.approved is False assert response.reason == "Not authorized" - def test_approval_state(self): - """Test creating approval state for yield/resume.""" - state = ToolApprovalState( - function_name="delete_user", - arguments={"user_id": "123"}, - output_messages_var="Local.messages", - output_result_var="Local.result", - auto_send=True, - ) - assert state.function_name == "delete_user" - assert state.arguments == {"user_id": "123"} - assert state.output_messages_var == "Local.messages" - assert state.output_result_var == "Local.result" - assert state.auto_send is True - class TestInvokeFunctionToolEdgeCases: """Tests for edge cases and error handling.""" @@ -1075,13 +1058,6 @@ def my_tool(x: int) -> int: # Should NOT have sent ActionComplete (workflow yields) mock_context.send_message.assert_not_called() - # Approval state should be saved in state - approval_key = f"{TOOL_APPROVAL_STATE_KEY}_approval_test" - saved_state = mock_state._data[approval_key] - assert isinstance(saved_state, ToolApprovalState) - assert saved_state.function_name == "my_tool" - assert saved_state.arguments == {"x": 5} - @pytest.mark.asyncio async def test_approval_response_approved(self, mock_state, mock_context): """When approval response is approved, the tool should be invoked.""" @@ -1104,17 +1080,7 @@ def my_tool(x: int) -> int: executor = InvokeFunctionToolExecutor(action_def, tools={"my_tool": my_tool}) - # Pre-populate approval state (simulating what handle_action stores) - approval_key = f"{TOOL_APPROVAL_STATE_KEY}_approval_approved" - mock_state._data[approval_key] = ToolApprovalState( - function_name="my_tool", - arguments={"x": 7}, - output_messages_var=None, - output_result_var="Local.result", - auto_send=True, - ) - - # Simulate the response + # Simulate the response — invocation params come from original_request original_request = ToolApprovalRequest( request_id="req-123", function_name="my_tool", @@ -1124,7 +1090,7 @@ def my_tool(x: int) -> int: await executor.handle_approval_response(original_request, response, mock_context) - # Tool should have been called + # Tool should have been called with the approved arguments assert call_log == [7] # ActionComplete should have been sent @@ -1132,9 +1098,6 @@ def my_tool(x: int) -> int: sent = mock_context.send_message.call_args[0][0] assert isinstance(sent, ActionComplete) - # Approval state should be cleaned up - assert approval_key not in mock_state._data - @pytest.mark.asyncio async def test_approval_response_rejected(self, mock_state, mock_context): """When approval response is rejected, rejection status should be stored.""" @@ -1154,16 +1117,6 @@ def my_tool(x: int) -> int: executor = InvokeFunctionToolExecutor(action_def, tools={"my_tool": my_tool}) - # Pre-populate approval state - approval_key = f"{TOOL_APPROVAL_STATE_KEY}_approval_rejected" - mock_state._data[approval_key] = ToolApprovalState( - function_name="my_tool", - arguments={"x": 5}, - output_messages_var=None, - output_result_var="Local.result", - auto_send=True, - ) - original_request = ToolApprovalRequest( request_id="req-456", function_name="my_tool", @@ -1185,36 +1138,6 @@ def my_tool(x: int) -> int: assert result["reason"] == "Not authorized" assert result["approved"] is False - @pytest.mark.asyncio - async def test_approval_response_missing_state(self, mock_state, mock_context): - """When approval state is missing on resume, should log error and complete.""" - self._init_state(mock_state) - - action_def = { - "kind": "InvokeFunctionTool", - "id": "missing_state_test", - "functionName": "my_tool", - "requireApproval": True, - "output": {"result": "Local.result"}, - } - - executor = InvokeFunctionToolExecutor(action_def, tools={}) - - # Don't populate approval state - simulate missing state - original_request = ToolApprovalRequest( - request_id="req-789", - function_name="my_tool", - arguments={}, - ) - response = ToolApprovalResponse(approved=True) - - await executor.handle_approval_response(original_request, response, mock_context) - - # Should still send ActionComplete - mock_context.send_message.assert_called_once() - sent = mock_context.send_message.call_args[0][0] - assert isinstance(sent, ActionComplete) - # ============================================================================ # State registry tool lookup (lines 255-257) diff --git a/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py b/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py index fdee1f7df1d..549cdd30a70 100644 --- a/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py +++ b/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py @@ -403,7 +403,6 @@ class TestApprovalFlow: async def test_approval_required_emits_request_and_yields(self, mock_state, mock_context) -> None: # type: ignore[no-untyped-def] from agent_framework_declarative._workflows._declarative_base import ActionTrigger from agent_framework_declarative._workflows._executors_mcp import ( - _MCP_APPROVAL_STATE_KEY, InvokeMcpToolActionExecutor, MCPToolApprovalRequest, ) @@ -439,18 +438,12 @@ async def test_approval_required_emits_request_and_yields(self, mock_state, mock # Handler not invoked yet. assert handler.call_count == 0 - # Approval state stored. - approval_key = f"{_MCP_APPROVAL_STATE_KEY}_mcp_action" - assert approval_key in mock_state._data - @pytest.mark.asyncio async def test_approval_response_approved_invokes_handler(self, mock_state, mock_context) -> None: # type: ignore[no-untyped-def] from agent_framework_declarative._workflows import ActionComplete, ToolApprovalResponse from agent_framework_declarative._workflows._executors_mcp import ( - _MCP_APPROVAL_STATE_KEY, InvokeMcpToolActionExecutor, MCPToolApprovalRequest, - _MCPToolApprovalState, ) _seed_state(mock_state) @@ -458,24 +451,11 @@ async def test_approval_response_approved_invokes_handler(self, mock_state, mock executor = InvokeMcpToolActionExecutor( _action( require_approval=True, + headers={"Authorization": "Bearer tk"}, output={"result": "Local.Result"}, ), mcp_tool_handler=handler, ) - # Pre-populate approval state. - approval_key = f"{_MCP_APPROVAL_STATE_KEY}_mcp_action" - mock_state._data[approval_key] = _MCPToolApprovalState( - server_url="https://mcp.example/api", - tool_name="search", - server_label=None, - arguments={"q": "x"}, - connection_name=None, - headers_def={"Authorization": "Bearer tk"}, - auto_send=False, - conversation_id_expr=None, - output_messages_path=None, - output_result_path="Local.Result", - ) await executor.handle_approval_response( MCPToolApprovalRequest( request_id="req-1", @@ -491,10 +471,12 @@ async def test_approval_response_approved_invokes_handler(self, mock_state, mock assert handler.call_count == 1 inv = handler.last_invocation assert inv is not None - # Headers are re-evaluated from headers_def. + # Invocation fields source from the approval request payload. + assert inv.tool_name == "search" + assert inv.server_url == "https://mcp.example/api" + assert inv.arguments == {"q": "x"} + # Headers are re-evaluated from the action definition on resume. assert inv.headers == {"Authorization": "Bearer tk"} - # Approval state was cleaned up. - assert approval_key not in mock_state._data # ActionComplete was sent. mock_context.send_message.assert_called_once() sent = mock_context.send_message.call_args[0][0] @@ -504,10 +486,8 @@ async def test_approval_response_approved_invokes_handler(self, mock_state, mock async def test_approval_response_rejected_assigns_error(self, mock_state, mock_context) -> None: # type: ignore[no-untyped-def] from agent_framework_declarative._workflows import ToolApprovalResponse from agent_framework_declarative._workflows._executors_mcp import ( - _MCP_APPROVAL_STATE_KEY, InvokeMcpToolActionExecutor, MCPToolApprovalRequest, - _MCPToolApprovalState, ) _seed_state(mock_state) @@ -519,19 +499,6 @@ async def test_approval_response_rejected_assigns_error(self, mock_state, mock_c ), mcp_tool_handler=handler, ) - approval_key = f"{_MCP_APPROVAL_STATE_KEY}_mcp_action" - mock_state._data[approval_key] = _MCPToolApprovalState( - server_url="https://mcp.example/api", - tool_name="search", - server_label=None, - arguments={}, - connection_name=None, - headers_def=None, - auto_send=True, - conversation_id_expr=None, - output_messages_path=None, - output_result_path="Local.Result", - ) await executor.handle_approval_response( MCPToolApprovalRequest( request_id="req-2", diff --git a/python/samples/03-workflows/declarative/invoke_mcp_tool/main.py b/python/samples/03-workflows/declarative/invoke_mcp_tool/main.py index 85b513b5620..358ee919047 100644 --- a/python/samples/03-workflows/declarative/invoke_mcp_tool/main.py +++ b/python/samples/03-workflows/declarative/invoke_mcp_tool/main.py @@ -87,6 +87,8 @@ def _prompt_for_approval(request: MCPToolApprovalRequest) -> ToolApprovalRespons print(f" outbound header names: {', '.join(request.header_names)}") else: print(" outbound header names: (none)") + if request.connection_name: + print(f" connection: {request.connection_name}") print("-" * 60) while True: