diff --git a/python/packages/ag-ui/agent_framework_ag_ui/__init__.py b/python/packages/ag-ui/agent_framework_ag_ui/__init__.py index 7d5bfc951b..c787de5167 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/__init__.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/__init__.py @@ -9,6 +9,7 @@ from ._endpoint import add_agent_framework_fastapi_endpoint from ._event_converters import AGUIEventConverter from ._http_service import AGUIHttpService +from ._state import state_update from ._types import AgentState, AGUIChatOptions, AGUIRequest, PredictStateConfig, RunMetadata from ._workflow import AgentFrameworkWorkflow, WorkflowFactory @@ -34,5 +35,6 @@ "PredictStateConfig", "RunMetadata", "DEFAULT_TAGS", + "state_update", "__version__", ] diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_run_common.py b/python/packages/ag-ui/agent_framework_ag_ui/_run_common.py index 81d5fadbbe..58236cdf0e 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_run_common.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_run_common.py @@ -6,6 +6,7 @@ import json import logging +from collections.abc import Mapping from dataclasses import dataclass, field from typing import Any, cast @@ -31,6 +32,7 @@ from agent_framework import Content from ._orchestration._predictive_state import PredictiveStateHandler +from ._state import TOOL_RESULT_STATE_KEY from ._utils import generate_event_id, make_json_safe logger = logging.getLogger(__name__) @@ -233,16 +235,66 @@ def _emit_tool_call( return events +def _extract_tool_result_state(content: Content) -> dict[str, Any] | None: + """Extract a deterministic AG-UI state update from a tool-result ``Content``. + + Tools using :func:`agent_framework_ag_ui.state_update` carry the state + payload in ``additional_properties[TOOL_RESULT_STATE_KEY]`` on the inner + text item produced by ``parse_result``. We also check the outer + function_result content's ``additional_properties`` for robustness. + + If multiple items carry state, they are merged in order so later items + override earlier ones (plain ``dict.update`` semantics). + + Returns: + The merged state dict to apply, or ``None`` if no state update is + present. + """ + merged: dict[str, Any] | None = None + + outer_ap = getattr(content, "additional_properties", None) or {} + outer_state = outer_ap.get(TOOL_RESULT_STATE_KEY) + if isinstance(outer_state, dict): + merged = dict(outer_state) + + for item in content.items or (): + item_ap = getattr(item, "additional_properties", None) or {} + item_state = item_ap.get(TOOL_RESULT_STATE_KEY) + if isinstance(item_state, dict): + if merged is None: + merged = dict(item_state) + else: + merged.update(item_state) + + return merged + + def _emit_tool_result_common( call_id: str, raw_result: Any, flow: FlowState, predictive_handler: PredictiveStateHandler | None = None, + *, + state_update: Mapping[str, Any] | None = None, ) -> list[BaseEvent]: """Shared helper for emitting ToolCallEnd + ToolCallResult events and performing FlowState cleanup. Both ``_emit_tool_result`` (standard function results) and ``_emit_mcp_tool_result`` (MCP server tool results) delegate to this function. + + Args: + call_id: Tool call identifier. + raw_result: The stringified tool result content sent back to the LLM. + flow: Current ``FlowState``. + predictive_handler: Optional predictive state handler driven by + ``predict_state_config``. + state_update: Optional deterministic state snapshot produced by a tool + returning :func:`agent_framework_ag_ui.state_update`. When present, + it is merged into ``flow.current_state`` and a ``StateSnapshotEvent`` + is emitted after the ``ToolCallResult`` event. When both + ``predictive_handler`` and ``state_update`` are active, predictive + updates are applied first, then the deterministic merge, and a + single coalesced ``StateSnapshotEvent`` is emitted. """ events: list[BaseEvent] = [] @@ -271,8 +323,18 @@ def _emit_tool_result_common( if predictive_handler: predictive_handler.apply_pending_updates() - if flow.current_state: - events.append(StateSnapshotEvent(snapshot=flow.current_state)) + + if state_update: + flow.current_state.update(state_update) + logger.debug( + "Emitted deterministic tool-result StateSnapshotEvent for call_id=%s (keys=%s)", + call_id, + list(state_update.keys()), + ) + + # Emit a single coalesced snapshot when either mechanism updated state. + if (predictive_handler or state_update) and flow.current_state: + events.append(StateSnapshotEvent(snapshot=flow.current_state)) flow.tool_call_id = None flow.tool_call_name = None @@ -295,7 +357,14 @@ def _emit_tool_result( if not content.call_id: return [] raw_result = content.result if content.result is not None else "" - return _emit_tool_result_common(content.call_id, raw_result, flow, predictive_handler) + state_update = _extract_tool_result_state(content) + return _emit_tool_result_common( + content.call_id, + raw_result, + flow, + predictive_handler, + state_update=state_update, + ) def _emit_approval_request( @@ -460,7 +529,14 @@ def _emit_mcp_tool_result( logger.warning("MCP tool result content missing call_id, skipping") return [] raw_output = content.output if content.output is not None else "" - return _emit_tool_result_common(content.call_id, raw_output, flow, predictive_handler) + state_update = _extract_tool_result_state(content) + return _emit_tool_result_common( + content.call_id, + raw_output, + flow, + predictive_handler, + state_update=state_update, + ) def _close_reasoning_block(flow: FlowState) -> list[BaseEvent]: diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_state.py b/python/packages/ag-ui/agent_framework_ag_ui/_state.py new file mode 100644 index 0000000000..efce2988fa --- /dev/null +++ b/python/packages/ag-ui/agent_framework_ag_ui/_state.py @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Deterministic tool-driven AG-UI state updates. + +Tools wired into the :mod:`agent_framework_ag_ui` endpoint can push a +deterministic state update by returning :func:`state_update`. Unlike +``predict_state_config`` — which emits ``StateDeltaEvent``s optimistically from +LLM-predicted tool call arguments — ``state_update`` runs *after* the tool +executes, so the AG-UI state always reflects the tool's actual return value. + +See issue https://github.com/microsoft/agent-framework/issues/3167 for the +motivating discussion. +""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from agent_framework import Content + +__all__ = ["TOOL_RESULT_STATE_KEY", "state_update"] + + +TOOL_RESULT_STATE_KEY = "__ag_ui_tool_result_state__" +"""Reserved ``Content.additional_properties`` key used to carry a tool-driven +state snapshot from a tool return value through to the AG-UI emitter.""" + + +def state_update( + text: str = "", + *, + state: Mapping[str, Any], +) -> Content: + """Build a tool return value that deterministically updates AG-UI shared state. + + Return the result of this helper from an agent tool to push a state update + to AG-UI clients using the actual tool output, rather than LLM-predicted + tool arguments. + + When the AG-UI endpoint emits the tool result, it will: + + * Forward ``text`` to the LLM as the normal ``function_result`` content. + * Merge ``state`` into ``FlowState.current_state``. + * Emit a deterministic ``StateSnapshotEvent`` after the ``ToolCallResult`` + event so frontends observe the updated state deterministically. If + predictive state is enabled, a predictive snapshot may be emitted first. + + Example: + .. code-block:: python + + from agent_framework import tool + from agent_framework_ag_ui import state_update + + + @tool + async def get_weather(city: str) -> Content: + data = await _fetch_weather(city) + return state_update( + text=f"Weather in {city}: {data['temp']}°C {data['conditions']}", + state={"weather": {"city": city, **data}}, + ) + + Args: + text: Text passed back to the LLM as the ``function_result`` content. + Defaults to an empty string for tools whose only output is a state + update. + state: A mapping merged into the AG-UI shared state via JSON-compatible + ``dict.update`` semantics. Nested dicts are replaced, not deep-merged. + + Returns: + A ``Content`` object with ``type="text"``. The state payload rides in + ``additional_properties`` under :data:`TOOL_RESULT_STATE_KEY` and is + extracted by the AG-UI emitter. + + Raises: + TypeError: If ``state`` is not a ``Mapping``. + """ + if not isinstance(state, Mapping): + raise TypeError(f"state_update() 'state' must be a Mapping, got {type(state).__name__}") + return Content.from_text( + text, + additional_properties={TOOL_RESULT_STATE_KEY: dict(state)}, + ) diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/weather_state_agent.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/weather_state_agent.py new file mode 100644 index 0000000000..f556af3458 --- /dev/null +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/weather_state_agent.py @@ -0,0 +1,92 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Deterministic tool-driven AG-UI state example. + +This sample demonstrates how a tool can push a *deterministic* state update +to the AG-UI frontend based on its actual return value — in contrast to +``predict_state_config`` which fires optimistically from LLM-predicted tool +call arguments. See issue https://github.com/microsoft/agent-framework/issues/3167. + +The :func:`agent_framework_ag_ui.state_update` helper wraps a text result +together with a state snapshot. When a tool returns one of these, the AG-UI +endpoint merges the snapshot into the shared state and emits a +``StateSnapshotEvent`` after the tool result. +""" + +from __future__ import annotations + +from typing import Any + +from agent_framework import Agent, Content, SupportsChatGetResponse, tool +from agent_framework.ag_ui import AgentFrameworkAgent + +from agent_framework_ag_ui import state_update + +# Simulated weather database — in the issue's motivating example the tool +# would instead call a real weather API. +_WEATHER_DB: dict[str, dict[str, Any]] = { + "seattle": {"temperature": 11, "conditions": "rainy", "humidity": 75}, + "san francisco": {"temperature": 14, "conditions": "foggy", "humidity": 85}, + "new york city": {"temperature": 18, "conditions": "sunny", "humidity": 60}, + "miami": {"temperature": 29, "conditions": "hot and humid", "humidity": 90}, + "chicago": {"temperature": 9, "conditions": "windy", "humidity": 65}, +} + + +@tool +async def get_weather(location: str) -> Content: + """Fetch current weather for a location and push it into AG-UI shared state. + + Unlike ``predict_state_config`` — which derives state optimistically from + LLM-predicted tool call arguments — this tool uses ``state_update`` to + forward the *actual* fetched weather to the frontend. The ``text`` goes + back to the LLM as the normal tool result, and the ``state`` dict is merged + into the AG-UI shared state. + + Args: + location: City name to look up. + + Returns: + A :class:`Content` carrying both the LLM-visible text result and a + deterministic state snapshot. + """ + key = location.lower() + data = _WEATHER_DB.get( + key, + {"temperature": 21, "conditions": "partly cloudy", "humidity": 50}, + ) + weather_record = {"location": location, **data} + return state_update( + text=( + f"The weather in {location} is {data['conditions']} at " + f"{data['temperature']}°C with {data['humidity']}% humidity." + ), + state={"weather": weather_record}, + ) + + +def weather_state_agent(client: SupportsChatGetResponse[Any]) -> AgentFrameworkAgent: + """Create an AG-UI agent with a deterministic tool-driven state tool.""" + agent = Agent[Any]( + name="weather_state_agent", + instructions=( + "You are a weather assistant. When a user asks about the weather " + "in a city, call the get_weather tool and use its output to give a " + "friendly, concise reply. The tool also updates the shared UI state " + "so the frontend can render a weather card from the `weather` key." + ), + client=client, + tools=[get_weather], + ) + + return AgentFrameworkAgent( + agent=agent, + name="WeatherStateAgent", + description="Weather agent that deterministically updates shared state from tool results.", + state_schema={ + "weather": { + "type": "object", + "description": "Last fetched weather record", + }, + }, + ) diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/server/main.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/server/main.py index 31a7c47963..4b7d56fba5 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/server/main.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/server/main.py @@ -24,6 +24,7 @@ from ..agents.task_steps_agent import task_steps_agent_wrapped from ..agents.ui_generator_agent import ui_generator_agent from ..agents.weather_agent import weather_agent +from ..agents.weather_state_agent import weather_state_agent AnthropicClient: type[Any] | None try: @@ -141,6 +142,14 @@ path="/subgraphs", ) +# Deterministic Tool-Driven State - tool returns state_update() to push snapshot +# from actual tool output (see issue #3167). +add_agent_framework_fastapi_endpoint( + app=app, + agent=weather_state_agent(client), + path="/deterministic_state", +) + def main(): """Run the server.""" diff --git a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_deterministic_state.py b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_deterministic_state.py new file mode 100644 index 0000000000..70bc5c129b --- /dev/null +++ b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_deterministic_state.py @@ -0,0 +1,267 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Golden event-stream tests for the deterministic tool-driven state scenario. + +Covers issue https://github.com/microsoft/agent-framework/issues/3167 — a tool +returning :func:`agent_framework_ag_ui.state_update` must push a deterministic +``StateSnapshotEvent`` derived from its actual return value, orthogonal to the +optimistic ``predict_state_config`` path. These golden tests pin the user-visible +event stream so additive changes cannot silently regress it. +""" + +from __future__ import annotations + +from typing import Any + +from agent_framework import AgentResponseUpdate, Content +from conftest import StubAgent +from event_stream import EventStream + +from agent_framework_ag_ui import AgentFrameworkAgent, state_update + +STATE_SCHEMA = { + "weather": {"type": "object", "description": "Last fetched weather"}, +} + + +def _build_agent(updates: list[AgentResponseUpdate], **kwargs: Any) -> AgentFrameworkAgent: + stub = StubAgent(updates=updates) + kwargs.setdefault("state_schema", STATE_SCHEMA) + return AgentFrameworkAgent(agent=stub, **kwargs) + + +async def _run(agent: AgentFrameworkAgent, payload: dict[str, Any]) -> EventStream: + return EventStream([event async for event in agent.run(payload)]) + + +PAYLOAD: dict[str, Any] = { + "thread_id": "thread-det-state", + "run_id": "run-det-state", + "messages": [{"role": "user", "content": "What's the weather in SF?"}], + "state": {"weather": {}}, +} + + +def _tool_call(call_id: str, name: str, arguments: str) -> AgentResponseUpdate: + return AgentResponseUpdate( + contents=[Content.from_function_call(name=name, call_id=call_id, arguments=arguments)], + role="assistant", + ) + + +def _tool_result_with_state(call_id: str, text: str, state: dict[str, Any]) -> AgentResponseUpdate: + """Build a function_result update whose inner item carries a state marker. + + This mirrors what the core framework produces when a real ``@tool`` returns + :func:`state_update`: ``parse_result`` keeps the ``Content`` as-is, and + ``Content.from_function_result`` preserves its ``additional_properties`` + inside ``items``. + """ + return AgentResponseUpdate( + contents=[ + Content.from_function_result( + call_id=call_id, + result=[state_update(text=text, state=state)], + ) + ], + role="assistant", + ) + + +# ── Golden stream tests ── + + +async def test_deterministic_state_emits_snapshot_after_tool_result() -> None: + """The happy path: STATE_SNAPSHOT follows TOOL_CALL_RESULT in order.""" + updates = [ + _tool_call("call-1", "get_weather", '{"city": "SF"}'), + _tool_result_with_state( + "call-1", + text="Weather in SF: 14°C foggy", + state={"weather": {"city": "SF", "temp": 14, "conditions": "foggy"}}, + ), + AgentResponseUpdate( + contents=[Content.from_text(text="It's 14°C and foggy in SF.")], + role="assistant", + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, PAYLOAD) + + stream.assert_bookends() + stream.assert_no_run_error() + stream.assert_tool_calls_balanced() + stream.assert_text_messages_balanced() + + # Ordered subsequence: the deterministic STATE_SNAPSHOT must follow the + # TOOL_CALL_RESULT. This is the central contract for #3167. + stream.assert_ordered_types( + [ + "RUN_STARTED", + "TOOL_CALL_START", + "TOOL_CALL_ARGS", + "TOOL_CALL_END", + "TOOL_CALL_RESULT", + "STATE_SNAPSHOT", + "RUN_FINISHED", + ] + ) + + # The final STATE_SNAPSHOT must carry the tool-driven state. + snapshot = stream.snapshot() + assert snapshot["weather"] == {"city": "SF", "temp": 14, "conditions": "foggy"} + + +async def test_deterministic_state_does_not_fire_for_plain_tool_result() -> None: + """Regression guard: tools returning plain strings must NOT emit a new STATE_SNAPSHOT. + + The initial STATE_SNAPSHOT fires once from the schema + initial payload + state. A plain (non-state_update) tool result must not add another one. + """ + updates = [ + _tool_call("call-1", "get_weather", '{"city": "SF"}'), + AgentResponseUpdate( + contents=[Content.from_function_result(call_id="call-1", result="14°C foggy")], + role="assistant", + ), + AgentResponseUpdate( + contents=[Content.from_text(text="It's 14°C and foggy.")], + role="assistant", + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, PAYLOAD) + + stream.assert_bookends() + stream.assert_no_run_error() + + snapshots = stream.get("STATE_SNAPSHOT") + # Only the initial snapshot (from state_schema + payload state) should exist. + # No deterministic snapshot should have been added by the plain tool result. + assert len(snapshots) == 1, ( + f"Expected exactly 1 STATE_SNAPSHOT (initial only) for plain tool result; " + f"got {len(snapshots)}. Snapshots: {[s.snapshot for s in snapshots]}" + ) + + +async def test_deterministic_state_merges_into_initial_state() -> None: + """The tool-driven snapshot must merge into, not replace, pre-existing state keys.""" + payload = dict(PAYLOAD) + payload["state"] = {"weather": {}, "user_preferences": {"unit": "C"}} + + updates = [ + _tool_call("call-1", "get_weather", '{"city": "SF"}'), + _tool_result_with_state( + "call-1", + text="Weather: 14°C", + state={"weather": {"city": "SF", "temp": 14}}, + ), + ] + agent = _build_agent(updates, state_schema={**STATE_SCHEMA, "user_preferences": {"type": "object"}}) + stream = await _run(agent, payload) + + stream.assert_bookends() + stream.assert_no_run_error() + + final_snapshot = stream.snapshot() + assert final_snapshot["weather"] == {"city": "SF", "temp": 14} + assert final_snapshot["user_preferences"] == {"unit": "C"}, ( + "Pre-existing state keys must survive the deterministic merge" + ) + + +async def test_deterministic_state_llm_visible_text_is_clean() -> None: + """The LLM-visible TOOL_CALL_RESULT content must not leak the state marker key.""" + updates = [ + _tool_call("call-1", "get_weather", '{"city": "SF"}'), + _tool_result_with_state( + "call-1", + text="Weather in SF: 14°C foggy", + state={"weather": {"city": "SF", "temp": 14}}, + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, PAYLOAD) + + result = stream.first("TOOL_CALL_RESULT") + assert result.content == "Weather in SF: 14°C foggy" + # The marker key must never appear in the content sent back to the LLM. + assert "__ag_ui_tool_result_state__" not in result.content + assert "weather" not in result.content # not as a raw state dump + + +async def test_deterministic_state_multiple_tools_merge_in_order() -> None: + """Two state-updating tools in one run merge in order; later wins on key collisions.""" + updates = [ + _tool_call("call-a", "get_weather", '{"city": "SF"}'), + _tool_result_with_state( + "call-a", + text="First result", + state={"weather": {"city": "SF", "temp": 14}, "source": "primary"}, + ), + _tool_call("call-b", "get_weather_refined", '{"city": "SF"}'), + _tool_result_with_state( + "call-b", + text="Refined result", + state={"source": "refined"}, + ), + AgentResponseUpdate( + contents=[Content.from_text(text="Here you go.")], + role="assistant", + ), + ] + agent = _build_agent( + updates, + state_schema={**STATE_SCHEMA, "source": {"type": "string"}}, + ) + stream = await _run(agent, PAYLOAD) + + stream.assert_bookends() + stream.assert_tool_calls_balanced() + stream.assert_no_run_error() + + # Two tool-driven snapshots emitted (one per tool) plus the initial snapshot. + snapshots = stream.get("STATE_SNAPSHOT") + assert len(snapshots) >= 2, f"Expected at least 2 STATE_SNAPSHOTs; got {len(snapshots)}" + + final = stream.snapshot() + assert final["weather"] == {"city": "SF", "temp": 14} + # Later tool must override earlier tool on the shared key. + assert final["source"] == "refined" + + +async def test_deterministic_state_coexists_with_predict_state_config() -> None: + """Predictive state and deterministic state must coexist without clobbering each other.""" + predict_config = { + "draft": { + "tool": "write_draft", + "tool_argument": "body", + } + } + updates = [ + # Predictive tool: its argument "body" populates state.draft optimistically. + _tool_call("call-1", "write_draft", '{"body": "Hello world"}'), + # Then a deterministic tool result landing a different key. + _tool_result_with_state( + "call-1", + text="Draft saved", + state={"weather": {"city": "SF", "temp": 14}}, + ), + ] + agent = _build_agent( + updates, + state_schema={**STATE_SCHEMA, "draft": {"type": "string"}}, + predict_state_config=predict_config, + require_confirmation=False, + ) + payload = dict(PAYLOAD) + payload["state"] = {"weather": {}, "draft": ""} + stream = await _run(agent, payload) + + stream.assert_bookends() + stream.assert_no_run_error() + stream.assert_tool_calls_balanced() + + # The final observed state must contain both the deterministic and predictive contributions. + final = stream.snapshot() + assert final["weather"] == {"city": "SF", "temp": 14}, f"Deterministic state missing from final snapshot: {final}" diff --git a/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py index e6f58ef0fd..5ea284c68d 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py @@ -1405,3 +1405,95 @@ async def stream_fn( for content in msg.contents: if content.type == "function_result" and content.call_id == "fake_reject_001": assert False, "Fabricated rejection response leaked as function_result into LLM messages" + + +async def test_state_update_end_to_end_via_real_tool_invocation(streaming_chat_client_stub): + """End-to-end coverage for issue #3167: a real ``@tool`` returning ``state_update`` must + emit a deterministic STATE_SNAPSHOT through the full pipeline. + + This test exercises the entire chain that a user would hit in production: + ``FunctionInvocationLayer`` executes the tool, ``FunctionTool.parse_result`` + preserves the returned ``Content`` with its ``additional_properties`` marker, + ``Content.from_function_result`` carries the marker through in ``items``, + and the AG-UI emitter extracts it via ``_extract_tool_result_state`` and + emits the snapshot. A regression anywhere in that chain will fail this test. + """ + from agent_framework import tool + from agent_framework.ag_ui import AgentFrameworkAgent + + from agent_framework_ag_ui import state_update + + @tool(name="get_weather", description="Get current weather for a city.") + async def get_weather(city: str) -> Content: + return state_update( + text=f"Weather in {city}: 14°C foggy", + state={"weather": {"city": city, "temperature": 14, "conditions": "foggy"}}, + ) + + call_count = {"n": 0} + + async def stream_fn( + messages: MutableSequence[Message], options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + """First turn proposes a tool call; second turn (after tool execution) returns text.""" + call_count["n"] += 1 + if call_count["n"] == 1: + yield ChatResponseUpdate( + contents=[ + Content.from_function_call( + name="get_weather", + call_id="call-weather-1", + arguments='{"city": "SF"}', + ) + ] + ) + else: + yield ChatResponseUpdate(contents=[Content.from_text(text="It's 14°C and foggy in SF.")]) + + agent = Agent( + client=streaming_chat_client_stub(stream_fn), + name="weather_agent", + instructions="Answer weather questions.", + tools=[get_weather], + ) + wrapper = AgentFrameworkAgent( + agent=agent, + state_schema={"weather": {"type": "object"}}, + ) + + events: list[Any] = [] + async for event in wrapper.run( + { + "thread_id": "thread-weather", + "run_id": "run-weather", + "messages": [{"role": "user", "content": "What's the weather in SF?"}], + "state": {"weather": {}}, + } + ): + events.append(event) + + types = [e.type for e in events] + + # The tool call must be visible in the stream. + assert "TOOL_CALL_START" in types, f"Missing TOOL_CALL_START in: {types}" + assert "TOOL_CALL_RESULT" in types, f"Missing TOOL_CALL_RESULT in: {types}" + + # A STATE_SNAPSHOT must be emitted after the tool result. + tool_result_idx = types.index("TOOL_CALL_RESULT") + snapshot_indices_after_result = [i for i, t in enumerate(types) if t == "STATE_SNAPSHOT" and i > tool_result_idx] + assert snapshot_indices_after_result, ( + f"Expected a STATE_SNAPSHOT after TOOL_CALL_RESULT (index {tool_result_idx}); got types: {types}" + ) + + # The tool's deterministic snapshot carries the actual fetched weather data. + final_snapshot = events[snapshot_indices_after_result[-1]].snapshot + assert final_snapshot["weather"] == { + "city": "SF", + "temperature": 14, + "conditions": "foggy", + } + + # The LLM-visible tool result must carry the plain text, not the marker key. + tool_result_event = next(e for e in events if e.type == "TOOL_CALL_RESULT") + assert tool_result_event.content == "Weather in SF: 14°C foggy" + assert "__ag_ui_tool_result_state__" not in tool_result_event.content diff --git a/python/packages/ag-ui/tests/ag_ui/test_public_exports.py b/python/packages/ag-ui/tests/ag_ui/test_public_exports.py index 433935fb24..ea570f50a6 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_public_exports.py +++ b/python/packages/ag-ui/tests/ag_ui/test_public_exports.py @@ -18,7 +18,24 @@ def test_core_ag_ui_lazy_exports_include_only_stable_api() -> None: assert hasattr(ag_ui, "AgentFrameworkAgent") assert hasattr(ag_ui, "AGUIChatClient") assert hasattr(ag_ui, "add_agent_framework_fastapi_endpoint") + assert hasattr(ag_ui, "state_update") assert not hasattr(ag_ui, "WorkflowFactory") assert not hasattr(ag_ui, "AGUIRequest") assert not hasattr(ag_ui, "RunMetadata") + + +def test_agent_framework_ag_ui_exports_state_update() -> None: + """Runtime package should export the ``state_update`` helper.""" + from agent_framework_ag_ui import state_update + + assert callable(state_update) + + +def test_core_ag_ui_lazy_exports_include_event_converter_and_http_service() -> None: + """Core facade must expose AGUIEventConverter, AGUIHttpService, and __version__.""" + from agent_framework import ag_ui + + assert hasattr(ag_ui, "AGUIEventConverter") + assert hasattr(ag_ui, "AGUIHttpService") + assert hasattr(ag_ui, "__version__") diff --git a/python/packages/ag-ui/tests/ag_ui/test_run_common.py b/python/packages/ag-ui/tests/ag_ui/test_run_common.py index 526a3c33c1..27294d9171 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_run_common.py +++ b/python/packages/ag-ui/tests/ag_ui/test_run_common.py @@ -2,14 +2,20 @@ """Tests for _run_common.py edge cases.""" +from ag_ui.core import EventType from agent_framework import Content +from agent_framework_ag_ui import state_update +from agent_framework_ag_ui._orchestration._predictive_state import PredictiveStateHandler from agent_framework_ag_ui._run_common import ( FlowState, + _emit_mcp_tool_result, _emit_tool_result, _extract_resume_payload, + _extract_tool_result_state, _normalize_resume_interrupts, ) +from agent_framework_ag_ui._state import TOOL_RESULT_STATE_KEY class TestNormalizeResumeInterrupts: @@ -120,3 +126,223 @@ def test_tool_result_closes_open_text_message(self): assert "TEXT_MESSAGE_END" in event_types assert flow.message_id is None assert flow.accumulated_text == "" + + +class TestStateUpdateHelper: + """Tests for the public ``state_update`` helper.""" + + def test_builds_text_content_with_state_marker(self): + """state_update returns a text Content carrying state in additional_properties.""" + c = state_update(text="done", state={"weather": {"temp": 14}}) + assert c.type == "text" + assert c.text == "done" + assert c.additional_properties == { + TOOL_RESULT_STATE_KEY: {"weather": {"temp": 14}}, + } + + def test_empty_text_is_allowed(self): + """State-only tools can omit the text argument.""" + c = state_update(state={"steps": ["a", "b"]}) + assert c.text == "" + assert c.additional_properties[TOOL_RESULT_STATE_KEY] == {"steps": ["a", "b"]} + + def test_non_mapping_state_raises(self): + """Passing a non-mapping value for state raises TypeError.""" + import pytest + + with pytest.raises(TypeError): + state_update(text="t", state=["not", "a", "mapping"]) # type: ignore[arg-type] + + def test_state_is_copied_defensively(self): + """Mutating the caller's dict after ``state_update`` must not mutate the content.""" + caller_state = {"weather": {"temp": 14}} + c = state_update(text="ok", state=caller_state) + caller_state["weather"]["temp"] = 99 + # The top-level dict was copied, so replacing the key in caller_state + # would not affect the Content, but nested dicts share references — document + # this by asserting only the top-level copy semantics. + assert TOOL_RESULT_STATE_KEY in c.additional_properties + inner = c.additional_properties[TOOL_RESULT_STATE_KEY] + assert inner is not caller_state + + +class TestExtractToolResultState: + """Tests for ``_extract_tool_result_state``.""" + + def test_returns_none_for_plain_string_result(self): + content = Content.from_function_result(call_id="c1", result="plain") + assert _extract_tool_result_state(content) is None + + def test_extracts_state_from_inner_item(self): + tool_return = state_update(text="hi", state={"k": 1}) + content = Content.from_function_result(call_id="c1", result=[tool_return]) + assert _extract_tool_result_state(content) == {"k": 1} + + def test_extracts_state_from_outer_additional_properties(self): + """Outer function_result content can also carry state (legacy/advanced use).""" + content = Content.from_function_result( + call_id="c1", + result="hi", + additional_properties={TOOL_RESULT_STATE_KEY: {"k": 1}}, + ) + assert _extract_tool_result_state(content) == {"k": 1} + + def test_merges_multiple_items(self): + a = state_update(text="a", state={"k": 1, "shared": "from_a"}) + b = state_update(text="b", state={"shared": "from_b", "extra": True}) + content = Content.from_function_result(call_id="c1", result=[a, b]) + merged = _extract_tool_result_state(content) + assert merged == {"k": 1, "shared": "from_b", "extra": True} + + def test_ignores_non_dict_marker_value(self): + """A garbled marker value must not break extraction (defensive guard).""" + bad = Content.from_text( + "hi", + additional_properties={TOOL_RESULT_STATE_KEY: "not-a-dict"}, + ) + content = Content.from_function_result(call_id="c1", result=[bad]) + assert _extract_tool_result_state(content) is None + + +class TestEmitToolResultWithState: + """Tests for the deterministic state emission in ``_emit_tool_result``.""" + + def test_emits_state_snapshot_after_tool_call_result(self): + """Tool returning state_update produces a StateSnapshotEvent right after the result.""" + tool_return = state_update( + text="Weather: 14°C", + state={"weather": {"temp": 14, "conditions": "foggy"}}, + ) + content = Content.from_function_result(call_id="call_1", result=[tool_return]) + flow = FlowState() + + events = _emit_tool_result(content, flow) + event_types = [e.type for e in events] + + # Expect TOOL_CALL_END, TOOL_CALL_RESULT, STATE_SNAPSHOT in that order. + assert event_types[0] == EventType.TOOL_CALL_END + assert event_types[1] == EventType.TOOL_CALL_RESULT + state_idx = event_types.index(EventType.STATE_SNAPSHOT) + assert state_idx == 2 + assert events[state_idx].snapshot == {"weather": {"temp": 14, "conditions": "foggy"}} + + def test_updates_flow_current_state(self): + tool_return = state_update(text="", state={"a": 1}) + content = Content.from_function_result(call_id="c1", result=[tool_return]) + flow = FlowState(current_state={"existing": "value"}) + + _emit_tool_result(content, flow) + + # Existing keys must survive (merge semantics), new keys must be added. + assert flow.current_state == {"existing": "value", "a": 1} + + def test_merge_overrides_existing_key(self): + tool_return = state_update(text="", state={"existing": "new"}) + content = Content.from_function_result(call_id="c1", result=[tool_return]) + flow = FlowState(current_state={"existing": "old", "other": 1}) + + _emit_tool_result(content, flow) + + assert flow.current_state == {"existing": "new", "other": 1} + + def test_no_state_snapshot_when_result_has_no_state(self): + """Plain tool results must not emit a StateSnapshotEvent.""" + content = Content.from_function_result(call_id="c1", result="plain") + flow = FlowState() + + events = _emit_tool_result(content, flow) + assert all(e.type != EventType.STATE_SNAPSHOT for e in events) + + def test_tool_result_content_text_unchanged(self): + """The text sent to the LLM must not leak the state marker.""" + tool_return = state_update(text="Weather: 14°C", state={"weather": {"temp": 14}}) + content = Content.from_function_result(call_id="c1", result=[tool_return]) + flow = FlowState() + + events = _emit_tool_result(content, flow) + result_events = [e for e in events if e.type == EventType.TOOL_CALL_RESULT] + assert len(result_events) == 1 + assert result_events[0].content == "Weather: 14°C" + assert TOOL_RESULT_STATE_KEY not in result_events[0].content + + def test_coexists_with_active_predictive_state_handler(self): + """Both predictive and deterministic state produce a single coalesced snapshot. + + Predictive state (``predict_state_config``) and deterministic state + (``state_update``) are two independent mechanisms. When both are active, + a single coalesced ``StateSnapshotEvent`` is emitted containing the + merged result of both contributions. + """ + flow = FlowState(current_state={"preexisting": "value"}) + handler = PredictiveStateHandler( + predict_state_config={"draft": {"tool": "write_draft", "tool_argument": "body"}}, + current_state=flow.current_state, + ) + + tool_return = state_update(text="Draft written", state={"draft_final": True}) + content = Content.from_function_result(call_id="c1", result=[tool_return]) + + events = _emit_tool_result(content, flow, predictive_handler=handler) + + # Exactly one coalesced snapshot must be emitted containing all merged keys. + snapshots = [e for e in events if e.type == EventType.STATE_SNAPSHOT] + assert len(snapshots) == 1 + assert snapshots[0].snapshot["draft_final"] is True + assert snapshots[0].snapshot["preexisting"] == "value" + assert flow.current_state["draft_final"] is True + assert flow.current_state["preexisting"] == "value" + + def test_predictive_and_deterministic_emit_single_snapshot(self): + """When both predictive_handler and state_update are active, only one snapshot is emitted.""" + flow = FlowState(current_state={"existing": "yes"}) + handler = PredictiveStateHandler( + predict_state_config={"draft": {"tool": "write_draft", "tool_argument": "body"}}, + current_state=flow.current_state, + ) + + tool_return = state_update(text="ok", state={"new_key": 42}) + content = Content.from_function_result(call_id="c1", result=[tool_return]) + + events = _emit_tool_result(content, flow, predictive_handler=handler) + + snapshots = [e for e in events if e.type == EventType.STATE_SNAPSHOT] + assert len(snapshots) == 1, f"Expected 1 coalesced snapshot, got {len(snapshots)}" + assert snapshots[0].snapshot == {"existing": "yes", "new_key": 42} + + +class TestEmitMcpToolResultWithState: + """MCP tool results should honour the same state_update marker. + + MCP results come from an external MCP server rather than a locally + executed ``@tool`` function, so they do not flow through ``parse_result`` + and ``content.items`` is typically empty. State is instead carried on the + outer content's ``additional_properties`` (e.g. by middleware that + inspects the MCP output and attaches a marker). ``_extract_tool_result_state`` + supports both locations so this path remains usable. + """ + + def test_mcp_tool_result_emits_state_snapshot_from_additional_properties(self): + content = Content.from_mcp_server_tool_result( + call_id="mcp_1", + output="server result", + additional_properties={TOOL_RESULT_STATE_KEY: {"mcp_ok": True}}, + ) + flow = FlowState() + + events = _emit_mcp_tool_result(content, flow) + event_types = [e.type for e in events] + + assert EventType.TOOL_CALL_END in event_types + assert EventType.TOOL_CALL_RESULT in event_types + assert EventType.STATE_SNAPSHOT in event_types + assert flow.current_state == {"mcp_ok": True} + + def test_mcp_tool_result_without_state_emits_no_snapshot(self): + content = Content.from_mcp_server_tool_result( + call_id="mcp_1", + output="server result", + ) + flow = FlowState() + + events = _emit_mcp_tool_result(content, flow) + assert all(e.type != EventType.STATE_SNAPSHOT for e in events) diff --git a/python/packages/core/agent_framework/ag_ui/__init__.py b/python/packages/core/agent_framework/ag_ui/__init__.py index 8e1385a26c..91754e01b4 100644 --- a/python/packages/core/agent_framework/ag_ui/__init__.py +++ b/python/packages/core/agent_framework/ag_ui/__init__.py @@ -7,10 +7,13 @@ Supported classes and functions: - AgentFrameworkAgent +- AgentFrameworkWorkflow - AGUIChatClient - AGUIEventConverter - AGUIHttpService - add_agent_framework_fastapi_endpoint +- state_update +- __version__ """ import importlib @@ -23,6 +26,10 @@ "AgentFrameworkWorkflow", "add_agent_framework_fastapi_endpoint", "AGUIChatClient", + "AGUIEventConverter", + "AGUIHttpService", + "state_update", + "__version__", ] diff --git a/python/packages/core/agent_framework/ag_ui/__init__.pyi b/python/packages/core/agent_framework/ag_ui/__init__.pyi index 17a5b3a4db..1f6636ae81 100644 --- a/python/packages/core/agent_framework/ag_ui/__init__.pyi +++ b/python/packages/core/agent_framework/ag_ui/__init__.pyi @@ -8,6 +8,7 @@ from agent_framework_ag_ui import ( AGUIHttpService, __version__, add_agent_framework_fastapi_endpoint, + state_update, ) __all__ = [ @@ -18,4 +19,5 @@ __all__ = [ "AgentFrameworkWorkflow", "__version__", "add_agent_framework_fastapi_endpoint", + "state_update", ]