From 564259a4aae76ca836fb4a2835e94ea23c0d322b Mon Sep 17 00:00:00 2001 From: Peter Ibekwe Date: Wed, 10 Jun 2026 14:13:52 -0700 Subject: [PATCH 1/8] Fix declarative object parsing bug --- .../_workflows/_declarative_base.py | 38 ++- .../test_declarative_state_path_safety.py | 251 ++++++++++++++++++ .../declarative/tests/test_graph_coverage.py | 4 +- 3 files changed, 282 insertions(+), 11 deletions(-) create mode 100644 python/packages/declarative/tests/test_declarative_state_path_safety.py diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py index e6fc0a820d5..02771426102 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py @@ -63,6 +63,9 @@ _ENV_REFERENCE_RE = re.compile(r"\bEnv\.([A-Za-z_][A-Za-z0-9_]*)") +# Allowed identifier shape for object-attribute steps in declarative state paths +# (matches PowerFx / Copilot Studio identifier rules). +_SAFE_PATH_SEGMENT_RE = re.compile(r"^[A-Za-z][A-Za-z0-9_]*$") @dataclass(frozen=True) class DeclarativeEnvConfig: @@ -331,16 +334,21 @@ def set_state_data(self, data: DeclarativeStateData) -> None: def get(self, path: str, default: Any = None) -> Any: """Get a value from the state using a dot-notated path. + Dict-keyed segments may use arbitrary string keys (e.g. UUIDs in + ``System.conversations..messages``). Segments that would resolve + via object-attribute access must be valid declarative identifiers + (``[A-Za-z][A-Za-z0-9_]*``); other shapes return ``default``. + Args: path: Dot-notated path like 'Local.results' or 'Workflow.Inputs.query' default: Default value if path doesn't exist Returns: - The value at the path, or default if not found + The value at the path, or default if not found or unreachable. """ state_data = self.get_state_data() parts = path.split(".") - if not parts: + if not parts or any(not p for p in parts): return default namespace = parts[0] @@ -377,10 +385,19 @@ def get(self, path: str, default: Any = None) -> Any: obj = obj.get(part, default) # type: ignore[union-attr] if obj is default: return default - elif hasattr(obj, part): # type: ignore[arg-type] - obj = getattr(obj, part) # type: ignore[arg-type] else: - return default + # Attribute access is only allowed for safe declarative identifiers. + if not _SAFE_PATH_SEGMENT_RE.match(part): + logger.warning( + "DeclarativeWorkflowState.get: rejecting attribute segment %r in path %r", + part, + path, + ) + return default + if hasattr(obj, part): # type: ignore[arg-type] + obj = getattr(obj, part) # type: ignore[arg-type] + else: + return default return obj # type: ignore[return-value] @@ -392,7 +409,7 @@ def set(self, path: str, value: Any) -> None: value: The value to set Raises: - ValueError: If attempting to set Workflow.Inputs (which is read-only) + ValueError: If attempting to set Workflow.Inputs (which is read-only). """ state_data = self.get_state_data() parts = path.split(".") @@ -692,7 +709,7 @@ def _preprocess_custom_functions(self, formula: str) -> str: if isinstance(replacement, str): if len(replacement) > MAX_INLINE_LENGTH: # Store long strings in a temp variable to avoid PowerFx expression limit - temp_var_name = f"_TempMessageText{temp_var_counter}" + temp_var_name = f"TempMessageText{temp_var_counter}" temp_var_counter += 1 self.set(f"Local.{temp_var_name}", replacement) replacement_str = f"Local.{temp_var_name}" @@ -849,6 +866,9 @@ def eval_if_expression(self, value: Any) -> Any: def interpolate_string(self, text: str) -> str: """Interpolate {Variable.Path} references in a string. + Matched path segments must be valid declarative identifiers + (``[A-Za-z][A-Za-z0-9_]*``); other braced tokens are left as-is. + This handles template-style variable substitution like: - "Created ticket #{Local.TicketParameters.TicketId}" - "Routing to {Local.RoutingParameters.TeamName}" @@ -866,8 +886,8 @@ def replace_var(match: re.Match[str]) -> str: value = self.get(var_path) return str(value) if value is not None else "" - # Match {Variable.Path} patterns - pattern = r"\{([A-Za-z][A-Za-z0-9_.]*)\}" + # Match {Variable.Path} patterns where each segment is a declarative identifier. + pattern = r"\{([A-Za-z][A-Za-z0-9_]*(?:\.[A-Za-z][A-Za-z0-9_]*)*)\}" # Replace all matches result = text diff --git a/python/packages/declarative/tests/test_declarative_state_path_safety.py b/python/packages/declarative/tests/test_declarative_state_path_safety.py new file mode 100644 index 00000000000..6760ab47f1b --- /dev/null +++ b/python/packages/declarative/tests/test_declarative_state_path_safety.py @@ -0,0 +1,251 @@ +# Copyright (c) Microsoft. All rights reserved. +# pyright: reportUnknownParameterType=false, reportUnknownArgumentType=false +# pyright: reportMissingParameterType=false, reportUnknownMemberType=false +# pyright: reportPrivateUsage=false, reportUnknownVariableType=false +# pyright: reportGeneralTypeIssues=false + +"""Path-segment validation tests for DeclarativeWorkflowState. + +Path segments handed to ``get``/``set``/``append`` and ``{Variable.Path}`` +placeholders in ``interpolate_string`` must be valid declarative identifiers +(``[A-Za-z][A-Za-z0-9_]*``). These tests pin that contract and the related +behavior across legitimate and invalid inputs. +""" + +import logging +from dataclasses import dataclass +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from agent_framework_declarative._workflows import DeclarativeWorkflowState + +try: + import powerfx # noqa: F401 + + _powerfx_available = True +except (ImportError, RuntimeError): + _powerfx_available = False + +_requires_powerfx = pytest.mark.skipif(not _powerfx_available, reason="PowerFx engine not available") + + +@pytest.fixture +def mock_state() -> MagicMock: + """In-memory mock for the underlying State.""" + ms = MagicMock() + ms._data = {} + + def get(key: str, default: Any = None) -> Any: + return ms._data.get(key, default) + + def set_(key: str, value: Any) -> None: + ms._data[key] = value + + def has(key: str) -> bool: + return key in ms._data + + def delete(key: str) -> None: + ms._data.pop(key, None) + + ms.get = MagicMock(side_effect=get) + ms.set = MagicMock(side_effect=set_) + ms.has = MagicMock(side_effect=has) + ms.delete = MagicMock(side_effect=delete) + return ms + + +@pytest.fixture +def state(mock_state: MagicMock) -> DeclarativeWorkflowState: + s = DeclarativeWorkflowState(mock_state) + s.initialize() + return s + + +@dataclass +class _PlainObj: + """Non-dict object so ``get`` falls through to attribute access.""" + + text: str = "hi" + + +# --------------------------------------------------------------------------- +# get(): invalid paths return default +# --------------------------------------------------------------------------- + + +class TestGetRejectsInvalidPaths: + def test_rejects_dunder_segment_via_attribute_access(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.obj", _PlainObj()) + assert state.get("Local.obj.__class__") is None + assert state.get("Local.obj.__class__", default="DEF") == "DEF" + + def test_rejects_full_env_exfil_chain(self, state: DeclarativeWorkflowState, monkeypatch) -> None: + sentinel = "agent-framework-path-safety-sentinel" + monkeypatch.setenv("AF_PATH_SAFETY_SENTINEL", sentinel) + state.set("Local.obj", _PlainObj()) + + result = state.get("Local.obj.__class__.__init__.__globals__.os.environ") + + assert result is None + assert sentinel not in str(result) + + def test_rejects_leading_underscore_via_attribute_access(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.obj", _PlainObj()) + assert state.get("Local.obj._private") is None + + def test_rejects_invalid_chars_via_attribute_access(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.obj", _PlainObj()) + assert state.get("Local.obj.text bar") is None + assert state.get("Local.obj.text-bar") is None + + def test_rejects_empty_path_and_empty_segments(self, state: DeclarativeWorkflowState) -> None: + assert state.get("") is None + assert state.get(".") is None + assert state.get("Local.") is None + assert state.get(".Local") is None + + def test_warning_logged_on_rejected_attribute_segment( + self, + state: DeclarativeWorkflowState, + caplog: pytest.LogCaptureFixture, + ) -> None: + state.set("Local.obj", _PlainObj()) + with caplog.at_level(logging.WARNING, logger="agent_framework_declarative._workflows._declarative_base"): + state.get("Local.obj.__class__") + assert any("rejecting attribute segment" in r.message for r in caplog.records) + + def test_dict_keyed_dunder_is_not_attribute_access(self, state: DeclarativeWorkflowState) -> None: + """A literal dunder dict key is harmless because dict lookup never reaches getattr.""" + state.set("Local.bag", {"__class__": "harmless-string"}) + assert state.get("Local.bag.__class__") == "harmless-string" + + +# --------------------------------------------------------------------------- +# get(): legitimate paths continue to work +# --------------------------------------------------------------------------- + + +class TestGetAllowsValidPaths: + def test_underscore_inside_identifier(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.user_input", "ok") + assert state.get("Local.user_input") == "ok" + + def test_mixed_case_identifiers(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.UserInput", "u1") + state.set("Local.userInput", "u2") + assert state.get("Local.UserInput") == "u1" + assert state.get("Local.userInput") == "u2" + + def test_object_attribute_traversal_still_works(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.msg", _PlainObj(text="hello")) + assert state.get("Local.msg.text") == "hello" + + def test_nested_dict_traversal_still_works(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.params", {"team": {"name": "alpha"}}) + assert state.get("Local.params.team.name") == "alpha" + + def test_uuid_and_hyphenated_dict_keys_are_allowed(self, state: DeclarativeWorkflowState) -> None: + """Conversation-id style paths use arbitrary dict keys (UUIDs / hyphens).""" + conv_id = "eb815014-06f1-4db6-b7c1-304ea135424f" + state.set(f"System.conversations.{conv_id}.messages", ["m1", "m2"]) + assert state.get(f"System.conversations.{conv_id}.messages") == ["m1", "m2"] + + +# --------------------------------------------------------------------------- +# set() / append(): dict-keyed operations accept arbitrary string keys +# --------------------------------------------------------------------------- + + +class TestSetAndAppend: + def test_set_allows_underscore_inside_identifier(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.user_input", "ok") + assert state.get("Local.user_input") == "ok" + + def test_set_allows_uuid_and_hyphenated_dict_keys(self, state: DeclarativeWorkflowState) -> None: + conv_id = "conv-test-1" + state.set(f"System.conversations.{conv_id}.messages", []) + assert state.get(f"System.conversations.{conv_id}.messages") == [] + + def test_append_allows_uuid_and_hyphenated_dict_keys(self, state: DeclarativeWorkflowState) -> None: + conv_id = "conv-42" + state.append(f"System.conversations.{conv_id}.messages", {"role": "user", "text": "hi"}) + msgs = state.get(f"System.conversations.{conv_id}.messages") + assert msgs == [{"role": "user", "text": "hi"}] + + def test_workflow_inputs_still_read_only(self, state: DeclarativeWorkflowState) -> None: + with pytest.raises(ValueError, match="read-only"): + state.set("Workflow.Inputs.x", 1) + + +# --------------------------------------------------------------------------- +# interpolate_string(): invalid placeholders left intact, valid ones resolved +# --------------------------------------------------------------------------- + + +class TestInterpolateString: + def test_ignores_dunder_payload(self, state: DeclarativeWorkflowState, monkeypatch) -> None: + sentinel = "agent-framework-interp-sentinel" + monkeypatch.setenv("AF_INTERP_SENTINEL", sentinel) + state.set("Local.obj", _PlainObj()) + + out = state.interpolate_string("X={Local.obj.__class__.__init__.__globals__.os.environ}") + + assert sentinel not in out + assert "{Local.obj.__class__" in out # placeholder left as literal text + + def test_ignores_leading_underscore_segment(self, state: DeclarativeWorkflowState) -> None: + out = state.interpolate_string("v={Local._private}") + assert out == "v={Local._private}" + + def test_allows_underscore_inside_identifier(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.user_input", "hello") + assert state.interpolate_string("v={Local.user_input}") == "v=hello" + + def test_resolves_nested_dict_path(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.params", {"team": "alpha"}) + assert state.interpolate_string("team={Local.params.team}") == "team=alpha" + + def test_end_to_end_send_activity_literal_placeholder( + self, + state: DeclarativeWorkflowState, + monkeypatch, + ) -> None: + """Mirror the SendActivity flow: eval_if_expression then interpolate_string.""" + sentinel = "agent-framework-e2e-sentinel" + monkeypatch.setenv("AF_E2E_SENTINEL", sentinel) + state.set("Local.toolResult", _PlainObj()) + + payload = "{Local.toolResult.__class__.__init__.__globals__.os.environ}" + evaluated = state.eval_if_expression(payload) + rendered = state.interpolate_string(evaluated) if isinstance(evaluated, str) else str(evaluated) + + assert rendered == payload + assert sentinel not in rendered + + +# --------------------------------------------------------------------------- +# Regressions: PowerFx and internal temp-variable handling still work +# --------------------------------------------------------------------------- + + +@_requires_powerfx +class TestPowerFxStillWorks: + def test_simple_powerfx_expression_evaluates(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.x", 6) + state.set("Local.y", 7) + assert state.eval("=Local.x * Local.y") == 42 + + def test_internal_temp_message_text_still_works(self, state: DeclarativeWorkflowState) -> None: + """Long MessageText() results stored in TempMessageText{n} still round-trip.""" + long_text = "A" * 600 + state.set( + "Local.Messages", + [{"text": long_text, "contents": [{"type": "text", "text": long_text}]}], + ) + + result = state.eval("=Upper(MessageText(Local.Messages))") + assert result == "A" * 600 + + assert state.get("Local.TempMessageText0") == long_text diff --git a/python/packages/declarative/tests/test_graph_coverage.py b/python/packages/declarative/tests/test_graph_coverage.py index f114c8f0ae9..47742f2f69f 100644 --- a/python/packages/declarative/tests/test_graph_coverage.py +++ b/python/packages/declarative/tests/test_graph_coverage.py @@ -2761,7 +2761,7 @@ async def test_short_message_text_embedded_inline(self, mock_state): assert result == "HELLO WORLD" # No temp variable should be created for short strings - temp_var = state.get("Local._TempMessageText0") + temp_var = state.get("Local.TempMessageText0") assert temp_var is None async def test_long_message_text_stored_in_temp_variable(self, mock_state): @@ -2778,7 +2778,7 @@ async def test_long_message_text_stored_in_temp_variable(self, mock_state): assert result == "A" * 600 # Upper on 'A' is still 'A' # A temp variable should have been created - temp_var = state.get("Local._TempMessageText0") + temp_var = state.get("Local.TempMessageText0") assert temp_var == long_text async def test_find_with_long_message_text(self, mock_state): From 3498f9dc66c74f386e0dca8376bc942ce9a575c5 Mon Sep 17 00:00:00 2001 From: Peter Ibekwe Date: Wed, 10 Jun 2026 15:43:36 -0700 Subject: [PATCH 2/8] Remove unnecessary comment --- .../agent_framework_declarative/_workflows/_declarative_base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py index 02771426102..eaa97ba0790 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py @@ -64,7 +64,6 @@ _ENV_REFERENCE_RE = re.compile(r"\bEnv\.([A-Za-z_][A-Za-z0-9_]*)") # Allowed identifier shape for object-attribute steps in declarative state paths -# (matches PowerFx / Copilot Studio identifier rules). _SAFE_PATH_SEGMENT_RE = re.compile(r"^[A-Za-z][A-Za-z0-9_]*$") @dataclass(frozen=True) From c0ea099bd8a816c22de514df90d961c9cc0354f0 Mon Sep 17 00:00:00 2001 From: Peter Ibekwe Date: Wed, 10 Jun 2026 17:43:42 -0700 Subject: [PATCH 3/8] Address PR comments --- .../_workflows/_declarative_base.py | 18 +++++-- .../test_declarative_state_path_safety.py | 54 +++++++++++++++++-- 2 files changed, 66 insertions(+), 6 deletions(-) diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py index eaa97ba0790..ae2ed94f442 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py @@ -66,6 +66,7 @@ # Allowed identifier shape for object-attribute steps in declarative state paths _SAFE_PATH_SEGMENT_RE = re.compile(r"^[A-Za-z][A-Za-z0-9_]*$") + @dataclass(frozen=True) class DeclarativeEnvConfig: """Configuration that populates the PowerFx ``Env`` symbol for a workflow. @@ -408,12 +409,14 @@ def set(self, path: str, value: Any) -> None: value: The value to set Raises: - ValueError: If attempting to set Workflow.Inputs (which is read-only). + ValueError: If ``path`` is empty or contains empty segments + (e.g. ``"Local."``, ``"Local..foo"``), or if attempting to set + ``Workflow.Inputs`` (which is read-only). """ state_data = self.get_state_data() parts = path.split(".") - if not parts: - return + if not parts or any(not p for p in parts): + raise ValueError(f"Invalid path {path!r}: empty segments are not allowed") namespace = parts[0] remaining = parts[1:] @@ -469,7 +472,16 @@ def append(self, path: str, value: Any) -> None: Args: path: Dot-notated path to a list value: The value to append + + Raises: + ValueError: If ``path`` is empty or contains empty segments + (e.g. ``"Local."``, ``"Local..foo"``), or if the existing + value at ``path`` is not a list. """ + parts = path.split(".") + if not parts or any(not p for p in parts): + raise ValueError(f"Invalid path {path!r}: empty segments are not allowed") + existing = self.get(path) if existing is None: self.set(path, [value]) diff --git a/python/packages/declarative/tests/test_declarative_state_path_safety.py b/python/packages/declarative/tests/test_declarative_state_path_safety.py index 6760ab47f1b..783607a87de 100644 --- a/python/packages/declarative/tests/test_declarative_state_path_safety.py +++ b/python/packages/declarative/tests/test_declarative_state_path_safety.py @@ -7,9 +7,20 @@ """Path-segment validation tests for DeclarativeWorkflowState. Path segments handed to ``get``/``set``/``append`` and ``{Variable.Path}`` -placeholders in ``interpolate_string`` must be valid declarative identifiers -(``[A-Za-z][A-Za-z0-9_]*``). These tests pin that contract and the related -behavior across legitimate and invalid inputs. +placeholders in ``interpolate_string`` are subject to three distinct rules +that this module pins: + +- **Empty segments** (e.g. ``""``, ``"Local."``, ``"Local..foo"``) are rejected + by all of ``get``/``set``/``append`` and ``interpolate_string``. ``get`` and + ``interpolate_string`` return their default / leave the placeholder literal; + ``set`` and ``append`` raise ``ValueError``. +- **Object-attribute segments** — segments that ``get`` would resolve via + ``getattr`` because the parent is a non-dict object — must match the safe + identifier shape ``[A-Za-z][A-Za-z0-9_]*``. Other shapes are rejected with a + warning log and the default is returned. +- **Dict-keyed segments** — segments that resolve via dict lookup because the + parent is a ``dict`` — may use arbitrary non-empty string keys (e.g. UUIDs + or hyphenated identifiers like ``System.conversations..messages``). """ import logging @@ -179,6 +190,43 @@ def test_workflow_inputs_still_read_only(self, state: DeclarativeWorkflowState) state.set("Workflow.Inputs.x", 1) +# --------------------------------------------------------------------------- +# set() / append(): malformed paths (empty segments) raise ValueError +# --------------------------------------------------------------------------- + + +class TestSetRejectsInvalidPaths: + @pytest.mark.parametrize("bad_path", ["", "Local.", "Local..foo", ".Local"]) + def test_set_rejects_empty_segment(self, state: DeclarativeWorkflowState, bad_path: str) -> None: + with pytest.raises(ValueError, match="empty segments are not allowed"): + state.set(bad_path, "x") + + @pytest.mark.parametrize("bad_path", ["", "Local.", "Local..foo", ".Local"]) + def test_append_rejects_empty_segment(self, state: DeclarativeWorkflowState, bad_path: str) -> None: + with pytest.raises(ValueError, match="empty segments are not allowed"): + state.append(bad_path, "x") + + def test_set_rejection_makes_no_partial_write(self, state: DeclarativeWorkflowState) -> None: + """Rejected set() must not create an unreachable entry in the state.""" + state.set("Local.user_input", "pre") + with pytest.raises(ValueError): + state.set("Local.", "leak") + local = state.get_state_data().get("Local", {}) + assert "" not in local + assert local == {"user_input": "pre"} + assert state.get("Local.") is None + assert state.get("Local.user_input") == "pre" + + def test_append_rejection_makes_no_partial_write(self, state: DeclarativeWorkflowState) -> None: + """Rejected append() must not create an unreachable entry in the state.""" + state.set("Local.items", ["a"]) + with pytest.raises(ValueError): + state.append("Local.", "leak") + local = state.get_state_data().get("Local", {}) + assert "" not in local + assert local == {"items": ["a"]} + + # --------------------------------------------------------------------------- # interpolate_string(): invalid placeholders left intact, valid ones resolved # --------------------------------------------------------------------------- From 4b0aeb76a5e4d613f340bf7006016c466eb20184 Mon Sep 17 00:00:00 2001 From: Peter Ibekwe Date: Thu, 11 Jun 2026 10:35:34 -0700 Subject: [PATCH 4/8] Address PR comments. --- .../_workflows/_declarative_base.py | 144 +++++++++++------- .../test_declarative_state_path_safety.py | 89 +++++++++-- .../declarative/tests/test_graph_coverage.py | 10 +- 3 files changed, 169 insertions(+), 74 deletions(-) diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py index ae2ed94f442..8451ab65fa8 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py @@ -269,6 +269,9 @@ class DeclarativeWorkflowState: - Conversation: Conversation history """ + # Sentinel marking "no prior value" for temporary-key bookkeeping. + _MISSING: Any = object() + def __init__(self, state: State, env_config: DeclarativeEnvConfig | None = None): """Initialize with a State instance. @@ -492,6 +495,15 @@ def append(self, path: str, value: Any) -> None: else: raise ValueError(f"Cannot append to non-list at path '{path}'") + def _clear_local_path(self, name: str) -> None: + """Remove ``name`` from the ``Local`` namespace, if present.""" + state_data = self.get_state_data() + local = cast(dict[str, Any], state_data.get("Local")) + if local is None or name not in local: + return + local.pop(name, None) + self.set_state_data(state_data) + def eval(self, expression: str) -> Any: """Evaluate a PowerFx expression with the current state. @@ -532,53 +544,64 @@ def eval(self, expression: str) -> Any: return result # Pre-process nested custom functions (e.g., Upper(MessageText(...))) - # Replace them with their evaluated results before sending to PowerFx - formula = self._preprocess_custom_functions(formula) - - if Engine is None: - raise RuntimeError( - f"PowerFx is not available (dotnet runtime not installed). " - f"Expression '={formula[:80]}' cannot be evaluated. " - f"Install dotnet and the powerfx package for full PowerFx support." - ) - - symbols = self._to_powerfx_symbols() - # Use setlocale(category) query form so we can restore the exact prior value. - # getlocale() returns a normalized tuple and is not always a lossless - # round-trip for setlocale across platforms/locales. - original_numeric_locale = locale.setlocale(locale.LC_NUMERIC) + # and run PowerFx. The finally below restores any temporary state + # written during preprocessing, regardless of where execution exits. + temp_writes: list[tuple[str, Any]] = [] + try: - for locale_candidate in _POWERFX_NUMERIC_LOCALE_CANDIDATES: - try: - locale.setlocale(locale.LC_NUMERIC, locale_candidate) - break - except locale.Error: - continue + formula = self._preprocess_custom_functions(formula, temp_writes) - engine = Engine() - try: - from System.Globalization import ( # pyright: ignore[reportMissingImports] - CultureInfo, # pyright: ignore[reportUnknownVariableType] + if Engine is None: + raise RuntimeError( + f"PowerFx is not available (dotnet runtime not installed). " + f"Expression '={formula[:80]}' cannot be evaluated. " + f"Install dotnet and the powerfx package for full PowerFx support." ) - except ImportError: - return engine.eval(formula, symbols=symbols, locale=_POWERFX_EVAL_LOCALE) - original_culture = cast(Any, CultureInfo.CurrentCulture) # pyright: ignore[reportUnknownMemberType] + symbols = self._to_powerfx_symbols() + # Use setlocale(category) query form so we can restore the exact prior value. + # getlocale() returns a normalized tuple and is not always a lossless + # round-trip for setlocale across platforms/locales. + original_numeric_locale = locale.setlocale(locale.LC_NUMERIC) try: - CultureInfo.CurrentCulture = CultureInfo(_POWERFX_EVAL_LOCALE) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] - return engine.eval(formula, symbols=symbols, locale=_POWERFX_EVAL_LOCALE) + for locale_candidate in _POWERFX_NUMERIC_LOCALE_CANDIDATES: + try: + locale.setlocale(locale.LC_NUMERIC, locale_candidate) + break + except locale.Error: + continue + + engine = Engine() + try: + from System.Globalization import ( # pyright: ignore[reportMissingImports] + CultureInfo, # pyright: ignore[reportUnknownVariableType] + ) + except ImportError: + return engine.eval(formula, symbols=symbols, locale=_POWERFX_EVAL_LOCALE) + + original_culture = cast(Any, CultureInfo.CurrentCulture) # pyright: ignore[reportUnknownMemberType] + try: + CultureInfo.CurrentCulture = CultureInfo(_POWERFX_EVAL_LOCALE) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + return engine.eval(formula, symbols=symbols, locale=_POWERFX_EVAL_LOCALE) + finally: + CultureInfo.CurrentCulture = original_culture # pyright: ignore[reportUnknownMemberType] + except ValueError as e: + error_msg = str(e) + # Handle undefined variable errors gracefully by returning None + # This matches the behavior of the legacy fallback parser + if "isn't recognized" in error_msg or "Name isn't valid" in error_msg: + logger.debug(f"PowerFx: undefined variable in expression '{formula}', returning None") + return None + raise finally: - CultureInfo.CurrentCulture = original_culture # pyright: ignore[reportUnknownMemberType] - except ValueError as e: - error_msg = str(e) - # Handle undefined variable errors gracefully by returning None - # This matches the behavior of the legacy fallback parser - if "isn't recognized" in error_msg or "Name isn't valid" in error_msg: - logger.debug(f"PowerFx: undefined variable in expression '{formula}', returning None") - return None - raise + locale.setlocale(locale.LC_NUMERIC, original_numeric_locale) finally: - locale.setlocale(locale.LC_NUMERIC, original_numeric_locale) + # Restore each temporary key to its prior value (or remove it). + for path, previous in reversed(temp_writes): + if previous is self._MISSING: + self._clear_local_path(path.removeprefix("Local.")) + else: + self.set(path, previous) def _eval_custom_function(self, formula: str) -> Any | None: """Handle custom functions not supported by the Python PowerFx library. @@ -637,7 +660,7 @@ def _eval_custom_function(self, formula: str) -> Any | None: return None - def _preprocess_custom_functions(self, formula: str) -> str: + def _preprocess_custom_functions(self, formula: str, temp_writes: list[tuple[str, Any]]) -> str: """Pre-process custom functions nested inside other PowerFx functions. Custom functions like MessageText() are not supported by the PowerFx engine. @@ -652,9 +675,14 @@ def _preprocess_custom_functions(self, formula: str) -> str: Args: formula: The PowerFx formula to pre-process + temp_writes: Caller-owned list. Each write to a temporary key + appends a ``(path, previous_value)`` entry where + ``previous_value`` is the value at ``path`` before the write + or :attr:`_MISSING` if none. The caller must restore every + entry, including when this method raises mid-write. Returns: - The formula with custom function calls replaced by their evaluated results + The rewritten formula. """ import re @@ -663,7 +691,6 @@ def _preprocess_custom_functions(self, formula: str) -> str: # We use 500 to leave room for the rest of the expression around the replaced value. MAX_INLINE_LENGTH = 500 - # Counter for generating unique temp variable names temp_var_counter = 0 # Custom functions that need pre-processing: (regex pattern, handler) @@ -719,11 +746,14 @@ def _preprocess_custom_functions(self, formula: str) -> str: # Replace in formula if isinstance(replacement, str): if len(replacement) > MAX_INLINE_LENGTH: - # Store long strings in a temp variable to avoid PowerFx expression limit - temp_var_name = f"TempMessageText{temp_var_counter}" + # Store long results in an underscore-prefixed temp key; + # record the prior value so eval() can restore it. + temp_var_name = f"_TempMessageText{temp_var_counter}" temp_var_counter += 1 - self.set(f"Local.{temp_var_name}", replacement) - replacement_str = f"Local.{temp_var_name}" + temp_var_path = f"Local.{temp_var_name}" + temp_writes.append((temp_var_path, self.get(temp_var_path, default=self._MISSING))) + self.set(temp_var_path, replacement) + replacement_str = temp_var_path logger.debug( f"Stored long MessageText result ({len(replacement)} chars) " f"in temp variable {temp_var_name}" @@ -875,14 +905,13 @@ def eval_if_expression(self, value: Any) -> Any: return value def interpolate_string(self, text: str) -> str: - """Interpolate {Variable.Path} references in a string. - - Matched path segments must be valid declarative identifiers - (``[A-Za-z][A-Za-z0-9_]*``); other braced tokens are left as-is. + """Interpolate ``{Variable.Path}`` references in a string. - This handles template-style variable substitution like: - - "Created ticket #{Local.TicketParameters.TicketId}" - - "Routing to {Local.RoutingParameters.TeamName}" + Captures brace-delimited tokens whose root segment is an identifier + (``[A-Za-z][A-Za-z0-9_]*``) followed by zero or more ``.`` separated + dict-key segments. Resolution is delegated to :meth:`get`; unresolved + tokens are replaced with the empty string. Tokens that do not look + like state paths (e.g. ``{foo-bar}``, ``{Ctrl+C}``) are left literal. Args: text: Text that may contain {Variable.Path} references @@ -897,10 +926,11 @@ def replace_var(match: re.Match[str]) -> str: value = self.get(var_path) return str(value) if value is not None else "" - # Match {Variable.Path} patterns where each segment is a declarative identifier. - pattern = r"\{([A-Za-z][A-Za-z0-9_]*(?:\.[A-Za-z][A-Za-z0-9_]*)*)\}" + # Root segment must be an identifier; follow-on segments accept any + # non-empty dict-key (e.g. ``_id``, ``1``, UUIDs). ``get()`` enforces + # per-segment safety on attribute traversal. + pattern = r"\{([A-Za-z][A-Za-z0-9_]*(?:\.[^{}\s.]+)*)\}" - # Replace all matches result = text for match in re.finditer(pattern, text): replacement = replace_var(match) diff --git a/python/packages/declarative/tests/test_declarative_state_path_safety.py b/python/packages/declarative/tests/test_declarative_state_path_safety.py index 783607a87de..2446fc3cf4d 100644 --- a/python/packages/declarative/tests/test_declarative_state_path_safety.py +++ b/python/packages/declarative/tests/test_declarative_state_path_safety.py @@ -210,7 +210,7 @@ def test_set_rejection_makes_no_partial_write(self, state: DeclarativeWorkflowSt """Rejected set() must not create an unreachable entry in the state.""" state.set("Local.user_input", "pre") with pytest.raises(ValueError): - state.set("Local.", "leak") + state.set("Local.", "value") local = state.get_state_data().get("Local", {}) assert "" not in local assert local == {"user_input": "pre"} @@ -221,14 +221,14 @@ def test_append_rejection_makes_no_partial_write(self, state: DeclarativeWorkflo """Rejected append() must not create an unreachable entry in the state.""" state.set("Local.items", ["a"]) with pytest.raises(ValueError): - state.append("Local.", "leak") + state.append("Local.", "value") local = state.get_state_data().get("Local", {}) assert "" not in local assert local == {"items": ["a"]} # --------------------------------------------------------------------------- -# interpolate_string(): invalid placeholders left intact, valid ones resolved +# interpolate_string(): permissive matcher; get() enforces safety # --------------------------------------------------------------------------- @@ -241,11 +241,17 @@ def test_ignores_dunder_payload(self, state: DeclarativeWorkflowState, monkeypat out = state.interpolate_string("X={Local.obj.__class__.__init__.__globals__.os.environ}") assert sentinel not in out - assert "{Local.obj.__class__" in out # placeholder left as literal text + assert out == "X=" - def test_ignores_leading_underscore_segment(self, state: DeclarativeWorkflowState) -> None: - out = state.interpolate_string("v={Local._private}") - assert out == "v={Local._private}" + def test_unknown_path_reduces_to_empty(self, state: DeclarativeWorkflowState) -> None: + assert state.interpolate_string("v={Local._private}") == "v=" + + @pytest.mark.parametrize( + "literal", + ["{foo-bar}", "{Ctrl+C}", "{not:a:path}", "{Local.}", "{}"], + ) + def test_non_state_braced_tokens_left_literal(self, state: DeclarativeWorkflowState, literal: str) -> None: + assert state.interpolate_string(f"v={literal}") == f"v={literal}" def test_allows_underscore_inside_identifier(self, state: DeclarativeWorkflowState) -> None: state.set("Local.user_input", "hello") @@ -255,12 +261,29 @@ def test_resolves_nested_dict_path(self, state: DeclarativeWorkflowState) -> Non state.set("Local.params", {"team": "alpha"}) assert state.interpolate_string("team={Local.params.team}") == "team=alpha" - def test_end_to_end_send_activity_literal_placeholder( + @pytest.mark.parametrize( + ("key", "value"), + [ + ("_id", "abc123"), + ("1", "one"), + ("2025", "year-bucket"), + ], + ) + def test_resolves_dict_keyed_segments(self, state: DeclarativeWorkflowState, key: str, value: str) -> None: + state.set("Local.bag", {key: value}) + assert state.interpolate_string(f"v={{Local.bag.{key}}}") == f"v={value}" + + def test_resolves_uuid_conversation_key(self, state: DeclarativeWorkflowState) -> None: + conv_id = "eb815014-06f1-4db6-b7c1-304ea135424f" + state.set(f"System.conversations.{conv_id}.messages", ["hello"]) + out = state.interpolate_string(f"m={{System.conversations.{conv_id}.messages}}") + assert out == "m=['hello']" + + def test_end_to_end_send_activity_payload_neutralized( self, state: DeclarativeWorkflowState, monkeypatch, ) -> None: - """Mirror the SendActivity flow: eval_if_expression then interpolate_string.""" sentinel = "agent-framework-e2e-sentinel" monkeypatch.setenv("AF_E2E_SENTINEL", sentinel) state.set("Local.toolResult", _PlainObj()) @@ -269,8 +292,8 @@ def test_end_to_end_send_activity_literal_placeholder( evaluated = state.eval_if_expression(payload) rendered = state.interpolate_string(evaluated) if isinstance(evaluated, str) else str(evaluated) - assert rendered == payload assert sentinel not in rendered + assert rendered == "" # --------------------------------------------------------------------------- @@ -286,7 +309,7 @@ def test_simple_powerfx_expression_evaluates(self, state: DeclarativeWorkflowSta assert state.eval("=Local.x * Local.y") == 42 def test_internal_temp_message_text_still_works(self, state: DeclarativeWorkflowState) -> None: - """Long MessageText() results stored in TempMessageText{n} still round-trip.""" + """Long MessageText() results round-trip and the temp key is removed after eval.""" long_text = "A" * 600 state.set( "Local.Messages", @@ -296,4 +319,46 @@ def test_internal_temp_message_text_still_works(self, state: DeclarativeWorkflow result = state.eval("=Upper(MessageText(Local.Messages))") assert result == "A" * 600 - assert state.get("Local.TempMessageText0") == long_text + local = state.get_state_data().get("Local", {}) + remaining = sorted(k for k in local if k.startswith("_TempMessageText")) + assert not remaining, f"Temporary keys remain in Local: {remaining}" + + def test_message_text_eval_preserves_user_temp_value(self, state: DeclarativeWorkflowState) -> None: + """User state at the temp key path survives a long MessageText eval.""" + long_text = "A" * 600 + state.set("Local._TempMessageText0", "user-important-value") + state.set( + "Local.Messages", + [{"text": long_text, "contents": [{"type": "text", "text": long_text}]}], + ) + + result = state.eval("=Upper(MessageText(Local.Messages))") + assert result == "A" * 600 + assert state.get("Local._TempMessageText0") == "user-important-value" + + def test_message_text_eval_cleans_up_on_powerfx_failure( + self, + state: DeclarativeWorkflowState, + monkeypatch, + ) -> None: + """Temp key is removed even when PowerFx evaluation raises.""" + from agent_framework_declarative._workflows import _declarative_base as base + + class _FailingEngine: + def eval(self, *args: Any, **kwargs: Any) -> Any: + raise RuntimeError("boom") + + monkeypatch.setattr(base, "Engine", _FailingEngine) + + long_text = "A" * 600 + state.set( + "Local.Messages", + [{"text": long_text, "contents": [{"type": "text", "text": long_text}]}], + ) + + with pytest.raises(RuntimeError, match="boom"): + state.eval("=Upper(MessageText(Local.Messages))") + + local = state.get_state_data().get("Local", {}) + remaining = sorted(k for k in local if k.startswith("_TempMessageText")) + assert not remaining, f"Temporary keys remain in Local after PowerFx failure: {remaining}" diff --git a/python/packages/declarative/tests/test_graph_coverage.py b/python/packages/declarative/tests/test_graph_coverage.py index 47742f2f69f..bc20a27f09b 100644 --- a/python/packages/declarative/tests/test_graph_coverage.py +++ b/python/packages/declarative/tests/test_graph_coverage.py @@ -2761,11 +2761,11 @@ async def test_short_message_text_embedded_inline(self, mock_state): assert result == "HELLO WORLD" # No temp variable should be created for short strings - temp_var = state.get("Local.TempMessageText0") + temp_var = state.get("Local._TempMessageText0") assert temp_var is None async def test_long_message_text_stored_in_temp_variable(self, mock_state): - """Test that long MessageText results are stored in temp variables.""" + """Long MessageText results round-trip and the temp key is removed after eval.""" state = DeclarativeWorkflowState(mock_state) state.initialize() @@ -2777,9 +2777,9 @@ async def test_long_message_text_stored_in_temp_variable(self, mock_state): result = state.eval("=Upper(MessageText(Local.Messages))") assert result == "A" * 600 # Upper on 'A' is still 'A' - # A temp variable should have been created - temp_var = state.get("Local.TempMessageText0") - assert temp_var == long_text + local = state.get_state_data().get("Local", {}) + remaining = sorted(k for k in local if k.startswith("_TempMessageText")) + assert not remaining, f"Temporary keys remain in Local: {remaining}" async def test_find_with_long_message_text(self, mock_state): """Test Find function works with long MessageText stored in temp variable.""" From d83bc07582caf8977f13fce823bc1bc166b2b1fa Mon Sep 17 00:00:00 2001 From: Peter Ibekwe Date: Thu, 11 Jun 2026 10:51:06 -0700 Subject: [PATCH 5/8] Fix CI failures. --- .../agent_framework_declarative/_workflows/_declarative_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py index 8451ab65fa8..6a035a448a2 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py @@ -498,7 +498,7 @@ def append(self, path: str, value: Any) -> None: def _clear_local_path(self, name: str) -> None: """Remove ``name`` from the ``Local`` namespace, if present.""" state_data = self.get_state_data() - local = cast(dict[str, Any], state_data.get("Local")) + local = state_data.get("Local") if local is None or name not in local: return local.pop(name, None) From 0ade298cc4fb3931e5a00e2417349c14710e2db9 Mon Sep 17 00:00:00 2001 From: Peter Ibekwe Date: Thu, 11 Jun 2026 13:16:36 -0700 Subject: [PATCH 6/8] declarative action approval bugfix --- .../_workflows/__init__.py | 4 - .../_workflows/_executors_mcp.py | 116 ++--- .../_workflows/_executors_tools.py | 77 +-- .../test_declarative_approval_binding.py | 441 ++++++++++++++++++ .../tests/test_function_tool_executor.py | 81 +--- .../tests/test_invoke_mcp_tool_executor.py | 45 +- .../declarative/invoke_mcp_tool/main.py | 2 + 7 files changed, 505 insertions(+), 261 deletions(-) create mode 100644 python/packages/declarative/tests/test_declarative_approval_binding.py diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/__init__.py b/python/packages/declarative/agent_framework_declarative/_workflows/__init__.py index 579cabfafaf..31ee990cb69 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/__init__.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/__init__.py @@ -76,12 +76,10 @@ from ._executors_tools import ( FUNCTION_TOOL_REGISTRY_KEY, TOOL_ACTION_EXECUTORS, - TOOL_APPROVAL_STATE_KEY, BaseToolExecutor, InvokeFunctionToolExecutor, ToolApprovalRequest, ToolApprovalResponse, - ToolApprovalState, ToolInvocationResult, ) from ._factory import WorkflowFactory @@ -111,7 +109,6 @@ "HTTP_ACTION_EXECUTORS", "MCP_ACTION_EXECUTORS", "TOOL_ACTION_EXECUTORS", - "TOOL_APPROVAL_STATE_KEY", "TOOL_REGISTRY_KEY", "ActionComplete", "ActionTrigger", @@ -164,7 +161,6 @@ "SetVariableExecutor", "ToolApprovalRequest", "ToolApprovalResponse", - "ToolApprovalState", "ToolInvocationResult", "WorkflowFactory", "WorkflowState", diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py index 73b66341ea3..9a7d9187046 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py @@ -15,12 +15,11 @@ matches the security posture of :mod:`._executors_http` (which never logs request headers either) and prevents secrets from leaking through workflow events that are typically observable to operators / UIs. -- ``_MCPToolApprovalState`` snapshots the EVALUATED values for non-secret - fields (server URL, tool name, arguments) at approval-request time so that - subsequent state mutations cannot make the executor "approve X then call - Y". Headers are stored as the raw expression strings (not evaluated values) - so secrets are not persisted in the workflow's checkpoint state. They are - re-evaluated on resume. +- The :class:`MCPToolApprovalRequest` payload is the source of truth for the + resumed invocation: ``tool_name``, ``server_url``, ``server_label``, + ``arguments``, and ``connection_name`` come from the request the reviewer + approved. Headers are re-evaluated from the action definition on resume so + that secret values are not persisted in the workflow's checkpoint state. - Tool outputs flow back into agent conversations through ``conversationId`` and through Tool-role messages emitted to ``output.messages``. They share the same prompt-injection risk surface as ``HttpRequestAction``: workflow @@ -60,8 +59,6 @@ logger = logging.getLogger(__name__) -_MCP_APPROVAL_STATE_KEY = "_mcp_tool_approval_state" - # --------------------------------------------------------------------------- # Request / state types @@ -86,6 +83,9 @@ class MCPToolApprovalRequest: arguments: Evaluated arguments to be forwarded to the tool. header_names: Sorted list of outbound header names (no values). Empty when no headers are configured. + connection_name: Optional connection identifier the invocation will + use. Surfaced so the reviewer can see which connection is bound + to the approved call. """ request_id: str @@ -94,28 +94,7 @@ class MCPToolApprovalRequest: server_label: str | None arguments: dict[str, Any] header_names: list[str] = field(default_factory=lambda: []) - - -@dataclass -class _MCPToolApprovalState: - """Internal state saved during the approval yield for resumption. - - Stores **evaluated** values for non-secret fields to prevent - "approve X / execute Y" attacks. Stores the raw expression string for - ``headers`` so that secret values are NOT persisted in checkpoint state; - the expressions are re-evaluated against current state on resume. - """ - - server_url: str - tool_name: str - server_label: str | None - arguments: dict[str, Any] - connection_name: str | None - headers_def: Any - auto_send: bool - conversation_id_expr: str | None - output_messages_path: str | None - output_result_path: str | None + connection_name: str | None = None # --------------------------------------------------------------------------- @@ -260,20 +239,6 @@ async def handle_action( if require_approval: request_id = str(uuid.uuid4()) - approval_state = _MCPToolApprovalState( - server_url=server_url, - tool_name=tool_name, - server_label=server_label, - arguments=arguments, - connection_name=connection_name, - headers_def=self._action_def.get("headers"), - auto_send=auto_send, - conversation_id_expr=conversation_id_expr if isinstance(conversation_id_expr, str) else None, - output_messages_path=output_messages_path, - output_result_path=output_result_path, - ) - ctx.state.set(self._approval_key(), approval_state) - request = MCPToolApprovalRequest( request_id=request_id, tool_name=tool_name, @@ -281,6 +246,7 @@ async def handle_action( server_label=server_label, arguments=arguments, header_names=sorted(headers.keys()), + connection_name=connection_name, ) logger.info( "%s: requesting approval for MCP tool '%s' on '%s'", @@ -322,54 +288,59 @@ async def handle_approval_response( response: ToolApprovalResponse, ctx: WorkflowContext[ActionComplete, str], ) -> None: - """Resume after the workflow yielded for an approval request.""" + """Resume after the workflow yielded for an approval request. + + Invocation fields (``tool_name``, ``server_url``, ``server_label``, + ``arguments``, ``connection_name``) are sourced from + ``original_request``. Output configuration is re-derived from the + action definition; header values are re-evaluated from the action + definition so secrets remain out of checkpoint state. + """ state = self._get_state(ctx.state) - approval_key = self._approval_key() - try: - approval_state: _MCPToolApprovalState = ctx.state.get(approval_key) - except KeyError: - logger.error("%s: approval state missing for executor '%s'", self.__class__.__name__, self.id) - await ctx.send_message(ActionComplete()) - return - try: - ctx.state.delete(approval_key) - except KeyError: - logger.warning("%s: approval state already deleted for '%s'", self.__class__.__name__, self.id) + tool_name = original_request.tool_name + server_url = original_request.server_url + server_label = original_request.server_label + arguments = original_request.arguments + connection_name = original_request.connection_name + + auto_send = self._get_auto_send(state) + conversation_id_value = self._action_def.get("conversationId") + conversation_id_expr = conversation_id_value if isinstance(conversation_id_value, str) else None + output_messages_path = _get_output_path(self._action_def, "messages") + output_result_path = _get_output_path(self._action_def, "result") if not response.approved: logger.info( "%s: MCP tool '%s' rejected: %s", self.__class__.__name__, - approval_state.tool_name, + tool_name, response.reason, ) - self._assign_error( - state, approval_state.output_result_path, "MCP tool invocation was not approved by user." - ) + self._assign_error(state, output_result_path, "MCP tool invocation was not approved by user.") await ctx.send_message(ActionComplete()) return - # Approved — re-evaluate headers (not stored at approval time for security). - headers = self._evaluate_headers(state, approval_state.headers_def) + # Approved — re-evaluate headers (not surfaced at approval time for security). + headers = self._evaluate_headers(state, self._action_def.get("headers")) invocation = MCPToolInvocation( - server_url=approval_state.server_url, - tool_name=approval_state.tool_name, - server_label=approval_state.server_label, - arguments=approval_state.arguments, + server_url=server_url, + tool_name=tool_name, + server_label=server_label, + arguments=arguments, headers=headers, - connection_name=approval_state.connection_name, + connection_name=connection_name, ) result = await self._invoke_with_narrow_catch(invocation) await self._process_result( ctx=ctx, state=state, result=result, - auto_send=approval_state.auto_send, - conversation_id_expr=approval_state.conversation_id_expr, - output_messages_path=approval_state.output_messages_path, - output_result_path=approval_state.output_result_path, + auto_send=auto_send, + conversation_id_expr=conversation_id_expr, + output_messages_path=output_messages_path, + output_result_path=output_result_path, ) await ctx.send_message(ActionComplete()) @@ -577,9 +548,6 @@ def _assign_error( return state.set(output_result_path, f"Error: {error_message}") - def _approval_key(self) -> str: - return f"{_MCP_APPROVAL_STATE_KEY}_{self.id}" - def _parse_outputs(outputs: list[Content]) -> list[Any]: """Parse :class:`Content` outputs into Python values for ``output.result``. diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_tools.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_tools.py index b2c046a69bb..d522cf56643 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_tools.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_tools.py @@ -41,10 +41,6 @@ # at runtime are discoverable by both agent-based and function-based tool executors. FUNCTION_TOOL_REGISTRY_KEY = TOOL_REGISTRY_KEY -# State key prefix for storing approval state during yield/resume. -# The executor's ID is appended to create a per-executor key. -TOOL_APPROVAL_STATE_KEY = "_tool_approval_state" - # ============================================================================ # Request/Response Types for Approval Flow @@ -87,26 +83,6 @@ class ToolApprovalResponse: reason: str | None = None -# ============================================================================ -# State Types for Approval Flow -# ============================================================================ - - -@dataclass -class ToolApprovalState: - """State saved during approval yield for resumption. - - Stored in State under a per-executor key when requireApproval=true. - Retrieved by handle_approval_response() to continue execution. - """ - - function_name: str - arguments: dict[str, Any] - output_messages_var: str | None - output_result_var: str | None - auto_send: bool - - # ============================================================================ # Result Types # ============================================================================ @@ -501,25 +477,16 @@ async def handle_action( require_approval = self._action_def.get("requireApproval", False) if require_approval: - # Save state for resumption (keyed by executor ID to avoid collisions) - approval_state = ToolApprovalState( - function_name=function_name, - arguments=arguments, - output_messages_var=messages_var, - output_result_var=result_var, - auto_send=auto_send, - ) - approval_key = f"{TOOL_APPROVAL_STATE_KEY}_{self.id}" - ctx.state.set(approval_key, approval_state) - - # Emit approval request - workflow yields here + # Emit approval request - the request payload is the source of + # truth for resumed invocation; no side-channel state is written. + request_id = str(uuid.uuid4()) request = ToolApprovalRequest( - request_id=str(uuid.uuid4()), + request_id=request_id, function_name=function_name, arguments=arguments, ) logger.info(f"{self.__class__.__name__}: requesting approval for '{function_name}'") - await ctx.request_info(request, ToolApprovalResponse) + await ctx.request_info(request, ToolApprovalResponse, request_id=request_id) # Workflow yields - will resume in handle_approval_response return @@ -545,36 +512,16 @@ async def handle_approval_response( ) -> None: """Handle response to a ToolApprovalRequest. - Called when the workflow resumes after yielding for approval. - Either executes the tool (if approved) or stores rejection status. + Resumes after the workflow yielded for approval. The invocation + ``function_name`` and ``arguments`` are sourced from + ``original_request`` (the payload the reviewer approved); output + configuration is re-derived from the executor's action definition. """ state = self._get_state(ctx.state) - approval_key = f"{TOOL_APPROVAL_STATE_KEY}_{self.id}" - - # Retrieve saved invocation state - try: - approval_state: ToolApprovalState = ctx.state.get(approval_key) - except KeyError: - error_msg = "Approval state not found, cannot resume tool invocation" - logger.error(f"{self.__class__.__name__}: {error_msg}") - # Try to store error - get output config from action def as fallback - _, result_var, _ = self._get_output_config() - if result_var and state: - state.set(_normalize_variable_path(result_var), {"error": error_msg}) - await ctx.send_message(ActionComplete()) - return - # Clean up approval state - try: - ctx.state.delete(approval_key) - except KeyError: - logger.warning(f"{self.__class__.__name__}: approval state already deleted") - - function_name = approval_state.function_name - arguments = approval_state.arguments - messages_var = approval_state.output_messages_var - result_var = approval_state.output_result_var - auto_send = approval_state.auto_send + function_name = original_request.function_name + arguments = original_request.arguments + messages_var, result_var, auto_send = self._get_output_config() # Check if approved if not response.approved: diff --git a/python/packages/declarative/tests/test_declarative_approval_binding.py b/python/packages/declarative/tests/test_declarative_approval_binding.py new file mode 100644 index 00000000000..f96b16ee8a2 --- /dev/null +++ b/python/packages/declarative/tests/test_declarative_approval_binding.py @@ -0,0 +1,441 @@ +# Copyright (c) Microsoft. All rights reserved. +# pyright: reportUnknownParameterType=false, reportUnknownArgumentType=false +# pyright: reportMissingParameterType=false, reportUnknownMemberType=false +# pyright: reportPrivateUsage=false, reportUnknownVariableType=false +# pyright: reportGeneralTypeIssues=false + +"""Regression tests pinning the approval-flow binding contract. + +The resumed invocation MUST come from the framework-delivered +``original_request`` payload (the data the reviewer approved) for both +``InvokeFunctionTool`` and ``InvokeMcpTool``. These tests verify that: + +* Invocation parameters come from ``original_request``, not from any prior + side-channel state. +* Concurrent pending approvals on the same executor do not swap. +* Pre-existing state at old approval keys is ignored entirely. +* Resume works on a freshly constructed executor (checkpoint-restore + simulation), without any prior ``ctx.state`` write. +* For MCP, ``connection_name`` is sourced from the approval payload and + ``headers`` are re-evaluated from the action definition on resume. +""" + +import sys +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +try: + import powerfx # noqa: F401 + + _powerfx_available = True +except (ImportError, RuntimeError): + _powerfx_available = False + +pytestmark = pytest.mark.skipif( + not _powerfx_available or sys.version_info >= (3, 14), + reason="PowerFx engine not available (requires dotnet runtime)", +) + +from agent_framework import Content # noqa: E402 + +from agent_framework_declarative._workflows import ( # noqa: E402 + DECLARATIVE_STATE_KEY, + ActionComplete, + InvokeFunctionToolExecutor, + MCPToolApprovalRequest, + MCPToolHandler, + MCPToolInvocation, + MCPToolResult, + ToolApprovalRequest, + ToolApprovalResponse, +) +from agent_framework_declarative._workflows._executors_mcp import ( # noqa: E402 + InvokeMcpToolActionExecutor, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_state() -> MagicMock: + """In-memory mock of the underlying State.""" + state = MagicMock() + state._data = {} + + def _get(key: str, default: Any = None) -> Any: + return state._data.get(key, default) + + def _set(key: str, value: Any) -> None: + state._data[key] = value + + def _has(key: str) -> bool: + return key in state._data + + def _delete(key: str) -> None: + state._data.pop(key, None) + + state.get = MagicMock(side_effect=_get) + state.set = MagicMock(side_effect=_set) + state.has = MagicMock(side_effect=_has) + state.delete = MagicMock(side_effect=_delete) + return state + + +@pytest.fixture +def mock_context(mock_state: MagicMock) -> MagicMock: + ctx = MagicMock() + ctx.state = mock_state + ctx.send_message = AsyncMock() + ctx.yield_output = AsyncMock() + ctx.request_info = AsyncMock() + return ctx + + +def _seed_state(mock_state: MagicMock) -> None: + mock_state._data[DECLARATIVE_STATE_KEY] = { + "Inputs": {}, + "Outputs": {}, + "Local": {}, + "Custom": {}, + "System": { + "ConversationId": "00000000-0000-0000-0000-000000000000", + "LastMessage": {"Text": "", "Id": ""}, + "LastMessageText": "", + "LastMessageId": "", + }, + "Agent": {}, + "Conversation": {"messages": [], "history": []}, + } + + +class _RecordingMcpHandler(MCPToolHandler): + def __init__(self, result: MCPToolResult | None = None) -> None: + self.result = result or MCPToolResult(outputs=[Content.from_text("ok")]) + self.invocations: list[MCPToolInvocation] = [] + + @property + def call_count(self) -> int: + return len(self.invocations) + + @property + def last(self) -> MCPToolInvocation | None: + return self.invocations[-1] if self.invocations else None + + async def invoke_tool(self, invocation: MCPToolInvocation) -> MCPToolResult: + self.invocations.append(invocation) + return self.result + + +# --------------------------------------------------------------------------- +# InvokeFunctionTool: approval-binding regression +# --------------------------------------------------------------------------- + + +class TestFunctionToolApprovalBinding: + def _action(self, *, fn_name: str = "my_tool") -> dict[str, Any]: + return { + "kind": "InvokeFunctionTool", + "id": "fn_action", + "functionName": fn_name, + "requireApproval": True, + "output": {"result": "Local.result"}, + } + + @pytest.mark.asyncio + async def test_request_id_matches_framework_pending_key(self, mock_state, mock_context) -> None: + """The id on the emitted ToolApprovalRequest must match the framework's pending-request key.""" + from agent_framework_declarative._workflows._declarative_base import ActionTrigger + + _seed_state(mock_state) + + def my_tool(x: int) -> int: + return x + + executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool}) + await executor.handle_action(ActionTrigger(), mock_context) + + mock_context.request_info.assert_called_once() + emitted_request = mock_context.request_info.call_args[0][0] + framework_request_id = mock_context.request_info.call_args.kwargs["request_id"] + assert isinstance(emitted_request, ToolApprovalRequest) + assert emitted_request.request_id == framework_request_id + + @pytest.mark.asyncio + async def test_resume_uses_request_payload_arguments(self, mock_state, mock_context) -> None: + _seed_state(mock_state) + call_log: list[int] = [] + + def my_tool(x: int) -> int: + call_log.append(x) + return x + + executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool}) + + request = ToolApprovalRequest(request_id="r-1", function_name="my_tool", arguments={"x": 1}) + await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context) + + assert call_log == [1] + + @pytest.mark.asyncio + async def test_concurrent_pending_approvals_do_not_swap(self, mock_state, mock_context) -> None: + """Two pending approvals, responses delivered out of order — each invocation uses its own payload.""" + _seed_state(mock_state) + call_log: list[int] = [] + + def my_tool(x: int) -> int: + call_log.append(x) + return x + + executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool}) + + request_a = ToolApprovalRequest(request_id="r-A", function_name="my_tool", arguments={"x": 1}) + request_b = ToolApprovalRequest(request_id="r-B", function_name="my_tool", arguments={"x": 999}) + + # Deliver response for B first, then for A. Each invocation must use its own payload. + await executor.handle_approval_response(request_b, ToolApprovalResponse(approved=True), mock_context) + await executor.handle_approval_response(request_a, ToolApprovalResponse(approved=True), mock_context) + + assert call_log == [999, 1] + + @pytest.mark.asyncio + async def test_resume_ignores_stale_state_at_old_approval_key(self, mock_state, mock_context) -> None: + """Pre-existing state at the OLD approval key is ignored — payload wins.""" + _seed_state(mock_state) + call_log: list[int] = [] + + def my_tool(x: int) -> int: + call_log.append(x) + return x + + executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool}) + + # Poison the old key shape (no longer read by the executor). + mock_state._data["_tool_approval_state_fn_action"] = {"function_name": "other", "arguments": {"x": 999}} + + request = ToolApprovalRequest(request_id="r-3", function_name="my_tool", arguments={"x": 7}) + await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context) + + assert call_log == [7] + # The poison was never read or deleted by the executor. + assert "_tool_approval_state_fn_action" in mock_state._data + + @pytest.mark.asyncio + async def test_fresh_executor_resume_works(self, mock_state, mock_context) -> None: + """Simulates checkpoint restore: a brand-new executor instance handles the approval response.""" + _seed_state(mock_state) + call_log: list[int] = [] + + def my_tool(x: int) -> int: + call_log.append(x) + return x + + # Pretend the executor that emitted the request is gone; a fresh one handles the response. + fresh = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool}) + + request = ToolApprovalRequest(request_id="r-4", function_name="my_tool", arguments={"x": 42}) + await fresh.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context) + + assert call_log == [42] + mock_context.send_message.assert_called_once() + sent = mock_context.send_message.call_args[0][0] + assert isinstance(sent, ActionComplete) + + @pytest.mark.asyncio + async def test_rejection_uses_request_payload_function_name(self, mock_state, mock_context) -> None: + _seed_state(mock_state) + + def my_tool(x: int) -> int: + raise AssertionError("should not be called when rejected") + + executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool}) + + request = ToolApprovalRequest(request_id="r-5", function_name="my_tool", arguments={"x": 3}) + await executor.handle_approval_response( + request, ToolApprovalResponse(approved=False, reason="not authorized"), mock_context + ) + + # The rejection message references the function name from the request payload. + local = mock_state._data[DECLARATIVE_STATE_KEY]["Local"] + assert local["result"]["rejected"] is True + assert local["result"]["reason"] == "not authorized" + + +# --------------------------------------------------------------------------- +# InvokeMcpTool: approval-binding regression +# --------------------------------------------------------------------------- + + +class TestMcpToolApprovalBinding: + def _action(self, *, headers: dict[str, Any] | None = None) -> dict[str, Any]: + action: dict[str, Any] = { + "kind": "InvokeMcpTool", + "id": "mcp_action", + "serverUrl": "https://mcp.example/api", + "toolName": "search", + "requireApproval": True, + "output": {"result": "Local.Result"}, + } + if headers is not None: + action["headers"] = headers + return action + + @pytest.mark.asyncio + async def test_request_id_matches_framework_pending_key(self, mock_state, mock_context) -> None: + """The id on the emitted MCPToolApprovalRequest must match the framework's pending-request key.""" + from agent_framework_declarative._workflows._declarative_base import ActionTrigger + + _seed_state(mock_state) + executor = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=_RecordingMcpHandler()) + await executor.handle_action(ActionTrigger(), mock_context) + + mock_context.request_info.assert_called_once() + emitted_request = mock_context.request_info.call_args[0][0] + framework_request_id = mock_context.request_info.call_args.kwargs["request_id"] + assert isinstance(emitted_request, MCPToolApprovalRequest) + assert emitted_request.request_id == framework_request_id + + @pytest.mark.asyncio + async def test_resume_uses_request_payload_fields(self, mock_state, mock_context) -> None: + _seed_state(mock_state) + handler = _RecordingMcpHandler() + executor = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=handler) + + request = MCPToolApprovalRequest( + request_id="r-1", + tool_name="search", + server_url="https://mcp.example/api", + server_label="prod", + arguments={"q": "x"}, + connection_name="conn-A", + ) + await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context) + + assert handler.call_count == 1 + inv = handler.last + assert inv is not None + assert inv.tool_name == "search" + assert inv.server_url == "https://mcp.example/api" + assert inv.server_label == "prod" + assert inv.arguments == {"q": "x"} + assert inv.connection_name == "conn-A" + + @pytest.mark.asyncio + async def test_concurrent_pending_mcp_approvals_do_not_swap(self, mock_state, mock_context) -> None: + _seed_state(mock_state) + handler = _RecordingMcpHandler() + executor = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=handler) + + request_a = MCPToolApprovalRequest( + request_id="r-A", + tool_name="search", + server_url="https://mcp.example/api", + server_label=None, + arguments={"q": "alpha"}, + connection_name="conn-A", + ) + request_b = MCPToolApprovalRequest( + request_id="r-B", + tool_name="search", + server_url="https://mcp.example/api", + server_label=None, + arguments={"q": "beta"}, + connection_name="conn-B", + ) + + await executor.handle_approval_response(request_b, ToolApprovalResponse(approved=True), mock_context) + await executor.handle_approval_response(request_a, ToolApprovalResponse(approved=True), mock_context) + + assert handler.call_count == 2 + assert handler.invocations[0].arguments == {"q": "beta"} + assert handler.invocations[0].connection_name == "conn-B" + assert handler.invocations[1].arguments == {"q": "alpha"} + assert handler.invocations[1].connection_name == "conn-A" + + @pytest.mark.asyncio + async def test_headers_reevaluated_from_action_def_on_resume(self, mock_state, mock_context) -> None: + """Headers come from the action definition (re-evaluated) so secrets are not in the payload.""" + _seed_state(mock_state) + handler = _RecordingMcpHandler() + executor = InvokeMcpToolActionExecutor( + self._action(headers={"Authorization": "Bearer tk"}), + mcp_tool_handler=handler, + ) + + request = MCPToolApprovalRequest( + request_id="r-1", + tool_name="search", + server_url="https://mcp.example/api", + server_label=None, + arguments={"q": "x"}, + connection_name=None, + ) + await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context) + + assert handler.last is not None + assert handler.last.headers == {"Authorization": "Bearer tk"} + + @pytest.mark.asyncio + async def test_mcp_resume_ignores_stale_state_at_old_approval_key(self, mock_state, mock_context) -> None: + _seed_state(mock_state) + handler = _RecordingMcpHandler() + executor = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=handler) + + mock_state._data["_mcp_tool_approval_state_mcp_action"] = {"poison": True} + + request = MCPToolApprovalRequest( + request_id="r-1", + tool_name="search", + server_url="https://mcp.example/api", + server_label=None, + arguments={"q": "real"}, + connection_name=None, + ) + await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context) + + assert handler.call_count == 1 + assert handler.last is not None + assert handler.last.arguments == {"q": "real"} + # The poison was never read or deleted by the executor. + assert "_mcp_tool_approval_state_mcp_action" in mock_state._data + + @pytest.mark.asyncio + async def test_fresh_mcp_executor_resume_works(self, mock_state, mock_context) -> None: + """Checkpoint-restore simulation: fresh executor handles the response.""" + _seed_state(mock_state) + handler = _RecordingMcpHandler() + fresh = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=handler) + + request = MCPToolApprovalRequest( + request_id="r-1", + tool_name="search", + server_url="https://mcp.example/api", + server_label=None, + arguments={"q": "fresh"}, + connection_name=None, + ) + await fresh.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context) + + assert handler.call_count == 1 + assert handler.last is not None + assert handler.last.arguments == {"q": "fresh"} + + @pytest.mark.asyncio + async def test_request_payload_carries_connection_name(self, mock_state, mock_context) -> None: + """When emitting the approval request, connection_name flows into MCPToolApprovalRequest.""" + from agent_framework_declarative._workflows._declarative_base import ActionTrigger + + _seed_state(mock_state) + action = self._action() + action["connection"] = {"name": "conn-from-action"} + executor = InvokeMcpToolActionExecutor(action, mcp_tool_handler=_RecordingMcpHandler()) + + await executor.handle_action(ActionTrigger(), mock_context) + + mock_context.request_info.assert_called_once() + request = mock_context.request_info.call_args[0][0] + assert isinstance(request, MCPToolApprovalRequest) + assert request.connection_name == "conn-from-action" diff --git a/python/packages/declarative/tests/test_function_tool_executor.py b/python/packages/declarative/tests/test_function_tool_executor.py index f11b3568658..bcf04bd21d2 100644 --- a/python/packages/declarative/tests/test_function_tool_executor.py +++ b/python/packages/declarative/tests/test_function_tool_executor.py @@ -35,14 +35,12 @@ from agent_framework_declarative._workflows import ( # noqa: E402 DECLARATIVE_STATE_KEY, FUNCTION_TOOL_REGISTRY_KEY, - TOOL_APPROVAL_STATE_KEY, ActionComplete, ActionTrigger, DeclarativeWorkflowBuilder, InvokeFunctionToolExecutor, ToolApprovalRequest, ToolApprovalResponse, - ToolApprovalState, ToolInvocationResult, WorkflowFactory, ) @@ -393,21 +391,6 @@ def test_approval_response_rejected(self): assert response.approved is False assert response.reason == "Not authorized" - def test_approval_state(self): - """Test creating approval state for yield/resume.""" - state = ToolApprovalState( - function_name="delete_user", - arguments={"user_id": "123"}, - output_messages_var="Local.messages", - output_result_var="Local.result", - auto_send=True, - ) - assert state.function_name == "delete_user" - assert state.arguments == {"user_id": "123"} - assert state.output_messages_var == "Local.messages" - assert state.output_result_var == "Local.result" - assert state.auto_send is True - class TestInvokeFunctionToolEdgeCases: """Tests for edge cases and error handling.""" @@ -1075,13 +1058,6 @@ def my_tool(x: int) -> int: # Should NOT have sent ActionComplete (workflow yields) mock_context.send_message.assert_not_called() - # Approval state should be saved in state - approval_key = f"{TOOL_APPROVAL_STATE_KEY}_approval_test" - saved_state = mock_state._data[approval_key] - assert isinstance(saved_state, ToolApprovalState) - assert saved_state.function_name == "my_tool" - assert saved_state.arguments == {"x": 5} - @pytest.mark.asyncio async def test_approval_response_approved(self, mock_state, mock_context): """When approval response is approved, the tool should be invoked.""" @@ -1104,17 +1080,7 @@ def my_tool(x: int) -> int: executor = InvokeFunctionToolExecutor(action_def, tools={"my_tool": my_tool}) - # Pre-populate approval state (simulating what handle_action stores) - approval_key = f"{TOOL_APPROVAL_STATE_KEY}_approval_approved" - mock_state._data[approval_key] = ToolApprovalState( - function_name="my_tool", - arguments={"x": 7}, - output_messages_var=None, - output_result_var="Local.result", - auto_send=True, - ) - - # Simulate the response + # Simulate the response — invocation params come from original_request original_request = ToolApprovalRequest( request_id="req-123", function_name="my_tool", @@ -1124,7 +1090,7 @@ def my_tool(x: int) -> int: await executor.handle_approval_response(original_request, response, mock_context) - # Tool should have been called + # Tool should have been called with the approved arguments assert call_log == [7] # ActionComplete should have been sent @@ -1132,9 +1098,6 @@ def my_tool(x: int) -> int: sent = mock_context.send_message.call_args[0][0] assert isinstance(sent, ActionComplete) - # Approval state should be cleaned up - assert approval_key not in mock_state._data - @pytest.mark.asyncio async def test_approval_response_rejected(self, mock_state, mock_context): """When approval response is rejected, rejection status should be stored.""" @@ -1154,16 +1117,6 @@ def my_tool(x: int) -> int: executor = InvokeFunctionToolExecutor(action_def, tools={"my_tool": my_tool}) - # Pre-populate approval state - approval_key = f"{TOOL_APPROVAL_STATE_KEY}_approval_rejected" - mock_state._data[approval_key] = ToolApprovalState( - function_name="my_tool", - arguments={"x": 5}, - output_messages_var=None, - output_result_var="Local.result", - auto_send=True, - ) - original_request = ToolApprovalRequest( request_id="req-456", function_name="my_tool", @@ -1185,36 +1138,6 @@ def my_tool(x: int) -> int: assert result["reason"] == "Not authorized" assert result["approved"] is False - @pytest.mark.asyncio - async def test_approval_response_missing_state(self, mock_state, mock_context): - """When approval state is missing on resume, should log error and complete.""" - self._init_state(mock_state) - - action_def = { - "kind": "InvokeFunctionTool", - "id": "missing_state_test", - "functionName": "my_tool", - "requireApproval": True, - "output": {"result": "Local.result"}, - } - - executor = InvokeFunctionToolExecutor(action_def, tools={}) - - # Don't populate approval state - simulate missing state - original_request = ToolApprovalRequest( - request_id="req-789", - function_name="my_tool", - arguments={}, - ) - response = ToolApprovalResponse(approved=True) - - await executor.handle_approval_response(original_request, response, mock_context) - - # Should still send ActionComplete - mock_context.send_message.assert_called_once() - sent = mock_context.send_message.call_args[0][0] - assert isinstance(sent, ActionComplete) - # ============================================================================ # State registry tool lookup (lines 255-257) diff --git a/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py b/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py index fdee1f7df1d..549cdd30a70 100644 --- a/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py +++ b/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py @@ -403,7 +403,6 @@ class TestApprovalFlow: async def test_approval_required_emits_request_and_yields(self, mock_state, mock_context) -> None: # type: ignore[no-untyped-def] from agent_framework_declarative._workflows._declarative_base import ActionTrigger from agent_framework_declarative._workflows._executors_mcp import ( - _MCP_APPROVAL_STATE_KEY, InvokeMcpToolActionExecutor, MCPToolApprovalRequest, ) @@ -439,18 +438,12 @@ async def test_approval_required_emits_request_and_yields(self, mock_state, mock # Handler not invoked yet. assert handler.call_count == 0 - # Approval state stored. - approval_key = f"{_MCP_APPROVAL_STATE_KEY}_mcp_action" - assert approval_key in mock_state._data - @pytest.mark.asyncio async def test_approval_response_approved_invokes_handler(self, mock_state, mock_context) -> None: # type: ignore[no-untyped-def] from agent_framework_declarative._workflows import ActionComplete, ToolApprovalResponse from agent_framework_declarative._workflows._executors_mcp import ( - _MCP_APPROVAL_STATE_KEY, InvokeMcpToolActionExecutor, MCPToolApprovalRequest, - _MCPToolApprovalState, ) _seed_state(mock_state) @@ -458,24 +451,11 @@ async def test_approval_response_approved_invokes_handler(self, mock_state, mock executor = InvokeMcpToolActionExecutor( _action( require_approval=True, + headers={"Authorization": "Bearer tk"}, output={"result": "Local.Result"}, ), mcp_tool_handler=handler, ) - # Pre-populate approval state. - approval_key = f"{_MCP_APPROVAL_STATE_KEY}_mcp_action" - mock_state._data[approval_key] = _MCPToolApprovalState( - server_url="https://mcp.example/api", - tool_name="search", - server_label=None, - arguments={"q": "x"}, - connection_name=None, - headers_def={"Authorization": "Bearer tk"}, - auto_send=False, - conversation_id_expr=None, - output_messages_path=None, - output_result_path="Local.Result", - ) await executor.handle_approval_response( MCPToolApprovalRequest( request_id="req-1", @@ -491,10 +471,12 @@ async def test_approval_response_approved_invokes_handler(self, mock_state, mock assert handler.call_count == 1 inv = handler.last_invocation assert inv is not None - # Headers are re-evaluated from headers_def. + # Invocation fields source from the approval request payload. + assert inv.tool_name == "search" + assert inv.server_url == "https://mcp.example/api" + assert inv.arguments == {"q": "x"} + # Headers are re-evaluated from the action definition on resume. assert inv.headers == {"Authorization": "Bearer tk"} - # Approval state was cleaned up. - assert approval_key not in mock_state._data # ActionComplete was sent. mock_context.send_message.assert_called_once() sent = mock_context.send_message.call_args[0][0] @@ -504,10 +486,8 @@ async def test_approval_response_approved_invokes_handler(self, mock_state, mock async def test_approval_response_rejected_assigns_error(self, mock_state, mock_context) -> None: # type: ignore[no-untyped-def] from agent_framework_declarative._workflows import ToolApprovalResponse from agent_framework_declarative._workflows._executors_mcp import ( - _MCP_APPROVAL_STATE_KEY, InvokeMcpToolActionExecutor, MCPToolApprovalRequest, - _MCPToolApprovalState, ) _seed_state(mock_state) @@ -519,19 +499,6 @@ async def test_approval_response_rejected_assigns_error(self, mock_state, mock_c ), mcp_tool_handler=handler, ) - approval_key = f"{_MCP_APPROVAL_STATE_KEY}_mcp_action" - mock_state._data[approval_key] = _MCPToolApprovalState( - server_url="https://mcp.example/api", - tool_name="search", - server_label=None, - arguments={}, - connection_name=None, - headers_def=None, - auto_send=True, - conversation_id_expr=None, - output_messages_path=None, - output_result_path="Local.Result", - ) await executor.handle_approval_response( MCPToolApprovalRequest( request_id="req-2", diff --git a/python/samples/03-workflows/declarative/invoke_mcp_tool/main.py b/python/samples/03-workflows/declarative/invoke_mcp_tool/main.py index 85b513b5620..358ee919047 100644 --- a/python/samples/03-workflows/declarative/invoke_mcp_tool/main.py +++ b/python/samples/03-workflows/declarative/invoke_mcp_tool/main.py @@ -87,6 +87,8 @@ def _prompt_for_approval(request: MCPToolApprovalRequest) -> ToolApprovalRespons print(f" outbound header names: {', '.join(request.header_names)}") else: print(" outbound header names: (none)") + if request.connection_name: + print(f" connection: {request.connection_name}") print("-" * 60) while True: From 56ccbc86a853adf35c0ed78d92210ce79e0a41cd Mon Sep 17 00:00:00 2001 From: Peter Ibekwe Date: Thu, 11 Jun 2026 22:02:58 -0700 Subject: [PATCH 7/8] Address PR comments --- .../_workflows/_executors_mcp.py | 95 +++++++------------ .../test_declarative_approval_binding.py | 87 +++++++++++++++++ 2 files changed, 121 insertions(+), 61 deletions(-) diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py index 9a7d9187046..af43518165a 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py @@ -10,16 +10,11 @@ Security notes: -- The executor never echoes header VALUES (auth tokens, API keys) into the - approval request — only header NAMES are surfaced to the caller. This - matches the security posture of :mod:`._executors_http` (which never logs - request headers either) and prevents secrets from leaking through workflow - events that are typically observable to operators / UIs. -- The :class:`MCPToolApprovalRequest` payload is the source of truth for the - resumed invocation: ``tool_name``, ``server_url``, ``server_label``, - ``arguments``, and ``connection_name`` come from the request the reviewer - approved. Headers are re-evaluated from the action definition on resume so - that secret values are not persisted in the workflow's checkpoint state. +- Approval requests surface header NAMES only; header values are not echoed, + matching the posture of :mod:`._executors_http`. +- :class:`MCPToolApprovalRequest` carries the values the resume handler will + use; header values are re-evaluated on resume to keep secrets out of + checkpoint state. - Tool outputs flow back into agent conversations through ``conversationId`` and through Tool-role messages emitted to ``output.messages``. They share the same prompt-injection risk surface as ``HttpRequestAction``: workflow @@ -69,23 +64,16 @@ class MCPToolApprovalRequest: """Approval request emitted before invoking an MCP tool. - Mirrors :class:`agent_framework_declarative.ToolApprovalRequest` but for - MCP-style invocations. Only header NAMES are surfaced — header values are - intentionally omitted because they typically carry authentication - secrets. - Attributes: - request_id: Unique identifier for this approval request. Matches the - id workflow event-emitters use. - tool_name: Evaluated name of the tool to be invoked. + request_id: Identifier matching the framework's pending-request key. + tool_name: Evaluated tool name. server_url: Evaluated MCP server URL. - server_label: Optional human-readable label for diagnostics. - arguments: Evaluated arguments to be forwarded to the tool. - header_names: Sorted list of outbound header names (no values). Empty - when no headers are configured. - connection_name: Optional connection identifier the invocation will - use. Surfaced so the reviewer can see which connection is bound - to the approved call. + server_label: Optional human-readable label. + arguments: Evaluated tool arguments. + header_names: Outbound header names (values withheld). + connection_name: Connection identifier the invocation will use. + metadata: Internal routing data pinned at approval-request time + (e.g. ``conversation_id``) for use by the resume handler. """ request_id: str @@ -95,6 +83,7 @@ class MCPToolApprovalRequest: arguments: dict[str, Any] header_names: list[str] = field(default_factory=lambda: []) connection_name: str | None = None + metadata: dict[str, Any] = field(default_factory=lambda: {}) # --------------------------------------------------------------------------- @@ -102,21 +91,15 @@ class MCPToolApprovalRequest: # --------------------------------------------------------------------------- -def _get_messages_path(state: DeclarativeWorkflowState, conversation_id_expr: str | None) -> str | None: - """Return the configured conversation messages path, if any. - - Returns ``System.conversations.{evaluated_id}.messages`` when a - ``conversation_id_expr`` is configured and evaluates to a non-empty value. - Returns ``None`` when no conversation id expression is configured or when - the expression evaluates to ``None`` or an empty string (mirrors .NET - ``GetConversationId`` behaviour). - """ - if not conversation_id_expr: +def _evaluate_conversation_id(state: DeclarativeWorkflowState, conversation_id_expr: Any) -> str | None: + """Return the evaluated ``conversationId`` string, or None when empty/unset.""" + if not isinstance(conversation_id_expr, str) or not conversation_id_expr: return None evaluated = state.eval_if_expression(conversation_id_expr) - if evaluated is None or (isinstance(evaluated, str) and not evaluated): + if evaluated is None: return None - return f"System.conversations.{evaluated}.messages" + text = str(evaluated) + return text or None def _get_output_path(action_def: Mapping[str, Any], key: str) -> str | None: @@ -239,6 +222,7 @@ async def handle_action( if require_approval: request_id = str(uuid.uuid4()) + conversation_id = _evaluate_conversation_id(state, conversation_id_expr) request = MCPToolApprovalRequest( request_id=request_id, tool_name=tool_name, @@ -247,6 +231,7 @@ async def handle_action( arguments=arguments, header_names=sorted(headers.keys()), connection_name=connection_name, + metadata={"conversation_id": conversation_id}, ) logger.info( "%s: requesting approval for MCP tool '%s' on '%s'", @@ -255,7 +240,6 @@ async def handle_action( server_url, ) await ctx.request_info(request, ToolApprovalResponse, request_id=request_id) - # Workflow yields here — resume in handle_approval_response. return # No approval required - invoke directly. @@ -273,7 +257,7 @@ async def handle_action( state=state, result=result, auto_send=auto_send, - conversation_id_expr=conversation_id_expr if isinstance(conversation_id_expr, str) else None, + conversation_id=_evaluate_conversation_id(state, conversation_id_expr), output_messages_path=output_messages_path, output_result_path=output_result_path, ) @@ -288,25 +272,19 @@ async def handle_approval_response( response: ToolApprovalResponse, ctx: WorkflowContext[ActionComplete, str], ) -> None: - """Resume after the workflow yielded for an approval request. - - Invocation fields (``tool_name``, ``server_url``, ``server_label``, - ``arguments``, ``connection_name``) are sourced from - ``original_request``. Output configuration is re-derived from the - action definition; header values are re-evaluated from the action - definition so secrets remain out of checkpoint state. - """ + """Resume the invocation using the values pinned on ``original_request``.""" state = self._get_state(ctx.state) tool_name = original_request.tool_name server_url = original_request.server_url server_label = original_request.server_label arguments = original_request.arguments - connection_name = original_request.connection_name + connection_name = getattr(original_request, "connection_name", None) + metadata: dict[str, Any] = getattr(original_request, "metadata", None) or {} + raw_conversation_id = metadata.get("conversation_id") + conversation_id = str(raw_conversation_id) if isinstance(raw_conversation_id, str) and raw_conversation_id else None auto_send = self._get_auto_send(state) - conversation_id_value = self._action_def.get("conversationId") - conversation_id_expr = conversation_id_value if isinstance(conversation_id_value, str) else None output_messages_path = _get_output_path(self._action_def, "messages") output_result_path = _get_output_path(self._action_def, "result") @@ -321,7 +299,6 @@ async def handle_approval_response( await ctx.send_message(ActionComplete()) return - # Approved — re-evaluate headers (not surfaced at approval time for security). headers = self._evaluate_headers(state, self._action_def.get("headers")) invocation = MCPToolInvocation( @@ -338,7 +315,7 @@ async def handle_approval_response( state=state, result=result, auto_send=auto_send, - conversation_id_expr=conversation_id_expr, + conversation_id=conversation_id, output_messages_path=output_messages_path, output_result_path=output_result_path, ) @@ -499,7 +476,7 @@ async def _process_result( state: DeclarativeWorkflowState, result: MCPToolResult, auto_send: bool, - conversation_id_expr: str | None, + conversation_id: str | None, output_messages_path: str | None, output_result_path: str | None, ) -> None: @@ -528,14 +505,10 @@ async def _process_result( if auto_send and parsed_results: await ctx.yield_output(_format_outputs_for_send(parsed_results)) - if conversation_id_expr: - messages_path = _get_messages_path(state, conversation_id_expr) - if messages_path is not None: - # Mirrors .NET: conversation gets ASSISTANT-role message with - # the same outputs (so chat history reads it as the agent's - # contribution). - assistant_message = Message(role="assistant", contents=list(result.outputs)) - state.append(messages_path, assistant_message) + if conversation_id: + messages_path = f"System.conversations.{conversation_id}.messages" + assistant_message = Message(role="assistant", contents=list(result.outputs)) + state.append(messages_path, assistant_message) @staticmethod def _assign_error( diff --git a/python/packages/declarative/tests/test_declarative_approval_binding.py b/python/packages/declarative/tests/test_declarative_approval_binding.py index f96b16ee8a2..ba0d4108f12 100644 --- a/python/packages/declarative/tests/test_declarative_approval_binding.py +++ b/python/packages/declarative/tests/test_declarative_approval_binding.py @@ -21,6 +21,7 @@ """ import sys +from dataclasses import dataclass from typing import Any from unittest.mock import AsyncMock, MagicMock @@ -51,6 +52,7 @@ ToolApprovalRequest, ToolApprovalResponse, ) +from agent_framework_declarative._workflows._declarative_base import DeclarativeWorkflowState # noqa: E402 from agent_framework_declarative._workflows._executors_mcp import ( # noqa: E402 InvokeMcpToolActionExecutor, ) @@ -439,3 +441,88 @@ async def test_request_payload_carries_connection_name(self, mock_state, mock_co request = mock_context.request_info.call_args[0][0] assert isinstance(request, MCPToolApprovalRequest) assert request.connection_name == "conn-from-action" + + @pytest.mark.asyncio + async def test_request_payload_pins_conversation_id(self, mock_state, mock_context) -> None: + """Evaluated ``conversationId`` is pinned in ``metadata`` at request-emit time.""" + from agent_framework_declarative._workflows._declarative_base import ActionTrigger + + _seed_state(mock_state) + state = DeclarativeWorkflowState(mock_state) + state.set("Local.targetConversation", "conv-original") + action = self._action() + action["conversationId"] = "=Local.targetConversation" + executor = InvokeMcpToolActionExecutor(action, mcp_tool_handler=_RecordingMcpHandler()) + + await executor.handle_action(ActionTrigger(), mock_context) + + mock_context.request_info.assert_called_once() + request = mock_context.request_info.call_args[0][0] + assert isinstance(request, MCPToolApprovalRequest) + assert request.metadata.get("conversation_id") == "conv-original" + + @pytest.mark.asyncio + async def test_resume_routes_output_to_pinned_conversation_not_mutated_state( + self, mock_state, mock_context + ) -> None: + """Output appends to the conversation pinned on ``original_request``, not the + current state evaluation.""" + _seed_state(mock_state) + state = DeclarativeWorkflowState(mock_state) + state.set("System.conversations.conv-original.messages", []) + state.set("System.conversations.conv-mutated.messages", []) + state.set("Local.targetConversation", "conv-mutated") + + handler = _RecordingMcpHandler(MCPToolResult(outputs=[Content.from_text("approved-output")])) + action = self._action() + action["conversationId"] = "=Local.targetConversation" + executor = InvokeMcpToolActionExecutor(action, mcp_tool_handler=handler) + + original_request = MCPToolApprovalRequest( + request_id="r-1", + tool_name="search", + server_url="https://mcp.example/api", + server_label=None, + arguments={"q": "x"}, + connection_name=None, + metadata={"conversation_id": "conv-original"}, + ) + await executor.handle_approval_response(original_request, ToolApprovalResponse(approved=True), mock_context) + + assert len(state.get("System.conversations.conv-original.messages") or []) == 1 + assert state.get("System.conversations.conv-mutated.messages") == [] + + @pytest.mark.asyncio + async def test_resume_handles_legacy_request_without_new_fields(self, mock_state, mock_context) -> None: + """Resume tolerates payloads lacking ``connection_name`` / ``metadata`` (legacy pickle shape).""" + + @dataclass + class _LegacyMCPApprovalRequest: + request_id: str + tool_name: str + server_url: str + server_label: str | None + arguments: dict[str, Any] + header_names: list[str] + + _seed_state(mock_state) + handler = _RecordingMcpHandler() + executor = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=handler) + + legacy_request = _LegacyMCPApprovalRequest( + request_id="r-1", + tool_name="search", + server_url="https://mcp.example/api", + server_label=None, + arguments={"q": "x"}, + header_names=[], + ) + await executor.handle_approval_response( + legacy_request, # type: ignore[arg-type] + ToolApprovalResponse(approved=True), + mock_context, + ) + + assert handler.call_count == 1 + assert handler.last is not None + assert handler.last.connection_name is None From 81ea2f67035b17291ad90c91885bc5ca94c12653 Mon Sep 17 00:00:00 2001 From: Peter Ibekwe Date: Fri, 12 Jun 2026 08:55:36 -0700 Subject: [PATCH 8/8] Inlined single use variables. --- .../_workflows/_executors_mcp.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py index af43518165a..1b16a87277c 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py @@ -276,13 +276,9 @@ async def handle_approval_response( state = self._get_state(ctx.state) tool_name = original_request.tool_name - server_url = original_request.server_url - server_label = original_request.server_label - arguments = original_request.arguments - connection_name = getattr(original_request, "connection_name", None) metadata: dict[str, Any] = getattr(original_request, "metadata", None) or {} raw_conversation_id = metadata.get("conversation_id") - conversation_id = str(raw_conversation_id) if isinstance(raw_conversation_id, str) and raw_conversation_id else None + conversation_id = raw_conversation_id if isinstance(raw_conversation_id, str) and raw_conversation_id else None auto_send = self._get_auto_send(state) output_messages_path = _get_output_path(self._action_def, "messages") @@ -299,15 +295,13 @@ async def handle_approval_response( await ctx.send_message(ActionComplete()) return - headers = self._evaluate_headers(state, self._action_def.get("headers")) - invocation = MCPToolInvocation( - server_url=server_url, + server_url=original_request.server_url, tool_name=tool_name, - server_label=server_label, - arguments=arguments, - headers=headers, - connection_name=connection_name, + server_label=original_request.server_label, + arguments=original_request.arguments, + headers=self._evaluate_headers(state, self._action_def.get("headers")), + connection_name=getattr(original_request, "connection_name", None), ) result = await self._invoke_with_narrow_catch(invocation) await self._process_result(