diff --git a/.github/.linkspector.yml b/.github/.linkspector.yml index 22fe804a09f..15ca806f41a 100644 --- a/.github/.linkspector.yml +++ b/.github/.linkspector.yml @@ -20,7 +20,10 @@ ignorePatterns: - pattern: "https://your-resource.openai.azure.com/" - pattern: "http://host.docker.internal" - pattern: "https://openai.github.io/openai-agents-js/openai/agents/classes/" - - pattern: "https:\/\/dotnet.microsoft.com\/download" + # dotnet.microsoft.com bot-blocks CI link checkers with intermittent 403s on any + # path (including localized variants like /en-us/download/...), so ignore the + # whole domain rather than just /download. + - pattern: "https:\/\/dotnet.microsoft.com" - pattern: "https://github.com/Rel1cx/eslint-react" # excludedDirs: # Folders which include links to localhost, since it's not ignored with regular expressions diff --git a/python/packages/ag-ui/AGENTS.md b/python/packages/ag-ui/AGENTS.md index 9139c9bbd57..656a3fa77f1 100644 --- a/python/packages/ag-ui/AGENTS.md +++ b/python/packages/ag-ui/AGENTS.md @@ -10,10 +10,12 @@ AG-UI protocol integration for building agent UIs with the AG-UI standard. - **`AGUIHttpService`** - HTTP service for AG-UI endpoints - **`AGUIEventConverter`** - Converts between Agent Framework and AG-UI events - **`add_agent_framework_fastapi_endpoint()`** - Add AG-UI endpoint to FastAPI app (`SupportsAgentRun` or `Workflow`) +- **`InMemoryAGUIThreadSnapshotStore`** - Memory-only latest AG-UI Thread Snapshot store for local development, demos, and tests ## Types - **`AGUIRequest`** / **`AGUIChatOptions`** - Request types +- **`AGUIThreadSnapshot`** / **`AGUIThreadSnapshotStore`** - Replayable thread snapshot model and scoped async store protocol - **`availableInterrupts` / `resume`** - Optional interrupt configuration and continuation payloads - **`AgentState`** / **`RunMetadata`** - State management types - **`PredictStateConfig`** - Configuration for state prediction diff --git a/python/packages/ag-ui/README.md b/python/packages/ag-ui/README.md index 6874c4d31ec..0119aa81889 100644 --- a/python/packages/ag-ui/README.md +++ b/python/packages/ag-ui/README.md @@ -198,6 +198,71 @@ The `dependencies` parameter accepts any FastAPI dependency, enabling integratio For a complete authentication example, see [getting_started/server.py](getting_started/server.py). +## AG-UI Thread Snapshots + +AG-UI Thread Snapshot persistence is opt-in and disabled by default. Existing endpoints keep their current behavior +unless you provide a `snapshot_store`. + +Thread snapshots let an AG-UI frontend recover replayable UI state after a refresh. When snapshot persistence is +enabled, the endpoint stores the latest replayable snapshot for an AG-UI Thread within an application-defined +Snapshot Scope. A Hydrate Request is an AG-UI request with a known `threadId`, `messages: []`, and no `resume` +payload. Hydration replays the stored Shared State, message snapshot, and interruption metadata when available, +then finishes without invoking the wrapped agent or workflow. + +Use the built-in in-memory store for local development, demos, and tests: + +```python +from fastapi import FastAPI + +from agent_framework.ag_ui import InMemoryAGUIThreadSnapshotStore, add_agent_framework_fastapi_endpoint + +app = FastAPI() +agent = ... +snapshot_store = InMemoryAGUIThreadSnapshotStore(max_snapshots=500) + + +def resolve_snapshot_scope(request): + # Local demo scope. Production apps should derive the scope from authenticated user or tenant context. + del request + return "local-demo" + + +add_agent_framework_fastapi_endpoint( + app, + agent, + "/", + snapshot_store=snapshot_store, + snapshot_scope_resolver=resolve_snapshot_scope, +) +``` + +A frontend can then hydrate the latest stored snapshot for the scoped thread: + +```json +{ + "threadId": "thread-1", + "messages": [] +} +``` + +Endpoint configuration requires `snapshot_scope_resolver` whenever a snapshot store is configured, including when +the store is already set on a pre-wrapped `AgentFrameworkAgent` or `AgentFrameworkWorkflow`. The resolver returns +the application-defined Snapshot Scope used with the AG-UI Thread id as the storage key. + +AG-UI Thread ids identify AG-UI Threads; they do not authorize snapshot access. Do not treat a thread id as a bearer +credential or tenant boundary. Production applications must authenticate and authorize every AG-UI endpoint request +and choose a Snapshot Scope that represents the app's real access boundary, such as an authenticated user, tenant, +or workspace. Do not rely on untrusted client-provided fields by themselves to choose that boundary. + +Stored snapshots are untrusted application data with confidentiality impact. They may contain sensitive user text, +model output, tool results, function arguments, UI payloads, Shared State, and interruption data. The built-in +`InMemoryAGUIThreadSnapshotStore` is in-memory only, process-local, bounded, latest-only, and not durable production +storage. It is cleared on process restart and is not shared across workers. + +No file-backed AG-UI snapshot store is provided by the package. Applications that need durable persistence should +provide an app-owned implementation of the `AGUIThreadSnapshotStore` protocol and own storage hardening, including +encryption, access control, retention, audit, data residency, and deletion behavior. + ## Architecture The package uses a clean, orchestrator-based architecture: diff --git a/python/packages/ag-ui/agent_framework_ag_ui/__init__.py b/python/packages/ag-ui/agent_framework_ag_ui/__init__.py index c787de5167c..9be38154a33 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/__init__.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/__init__.py @@ -9,6 +9,15 @@ from ._endpoint import add_agent_framework_fastapi_endpoint from ._event_converters import AGUIEventConverter from ._http_service import AGUIHttpService +from ._snapshots import ( + DEFAULT_MAX_THREAD_SNAPSHOTS, + AGUIThreadID, + AGUIThreadSnapshot, + AGUIThreadSnapshotStore, + InMemoryAGUIThreadSnapshotStore, + SnapshotScope, + SnapshotScopeResolver, +) from ._state import state_update from ._types import AgentState, AGUIChatOptions, AGUIRequest, PredictStateConfig, RunMetadata from ._workflow import AgentFrameworkWorkflow, WorkflowFactory @@ -31,9 +40,16 @@ "AGUIEventConverter", "AGUIHttpService", "AGUIRequest", + "AGUIThreadID", + "AGUIThreadSnapshot", + "AGUIThreadSnapshotStore", "AgentState", + "InMemoryAGUIThreadSnapshotStore", "PredictStateConfig", "RunMetadata", + "SnapshotScope", + "SnapshotScopeResolver", + "DEFAULT_MAX_THREAD_SNAPSHOTS", "DEFAULT_TAGS", "state_update", "__version__", diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_agent.py b/python/packages/ag-ui/agent_framework_ag_ui/_agent.py index ecde5a67e17..17050f78b73 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_agent.py @@ -10,6 +10,7 @@ from agent_framework import SupportsAgentRun from ._agent_run import PendingApprovalEntry, run_agent_stream +from ._snapshots import AGUIThreadSnapshotStore class AgentConfig: @@ -21,6 +22,7 @@ def __init__( predict_state_config: dict[str, dict[str, str]] | None = None, use_service_session: bool = False, require_confirmation: bool = True, + snapshot_store: AGUIThreadSnapshotStore | None = None, ): """Initialize agent configuration. @@ -29,11 +31,14 @@ def __init__( predict_state_config: Configuration for predictive state updates use_service_session: Whether the agent session is service-managed require_confirmation: Whether predictive updates require user confirmation before applying + snapshot_store: Optional AG-UI Thread Snapshot store. Snapshot persistence remains inactive unless + endpoint setup also provides an explicit Snapshot Scope resolver. """ self.state_schema = self._normalize_state_schema(state_schema) self.predict_state_config = predict_state_config or {} self.use_service_session = use_service_session self.require_confirmation = require_confirmation + self.snapshot_store = snapshot_store @staticmethod def _normalize_state_schema(state_schema: Any | None) -> dict[str, Any]: @@ -79,6 +84,7 @@ def __init__( predict_state_config: dict[str, dict[str, str]] | None = None, require_confirmation: bool = True, use_service_session: bool = False, + snapshot_store: AGUIThreadSnapshotStore | None = None, ): """Initialize the AG-UI compatible agent wrapper. @@ -90,6 +96,8 @@ def __init__( predict_state_config: Configuration for predictive state updates require_confirmation: Whether predictive updates require user confirmation before applying use_service_session: Whether the agent session is service-managed + snapshot_store: Optional AG-UI Thread Snapshot store. Snapshot persistence remains inactive unless + endpoint setup also provides an explicit Snapshot Scope resolver. """ self.agent = agent self.name = name or getattr(agent, "name", "agent") @@ -100,6 +108,7 @@ def __init__( predict_state_config=predict_state_config, use_service_session=use_service_session, require_confirmation=require_confirmation, + snapshot_store=snapshot_store, ) # Server-side registry of pending approval requests. @@ -110,6 +119,11 @@ def __init__( self._pending_approvals: OrderedDict[str, PendingApprovalEntry] = OrderedDict() self._pending_approvals_max_size: int = 10_000 + @property + def snapshot_store(self) -> AGUIThreadSnapshotStore | None: + """Configured AG-UI Thread Snapshot store, if any.""" + return self.config.snapshot_store + async def run( self, input_data: dict[str, Any], diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py index 38578f1bf2b..30596ed408a 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py @@ -4,6 +4,7 @@ from __future__ import annotations # noqa: I001 +import copy import json import logging import uuid @@ -52,9 +53,11 @@ _extract_tool_result_display, # type: ignore _has_only_tool_calls, # type: ignore _normalize_resume_interrupts, # type: ignore + _reconstruct_messages_from_thread_snapshot, # type: ignore _resolve_ui_payload, # type: ignore _stringify_tool_result, # type: ignore ) +from ._snapshots import AGUIThreadSnapshot, _DEFAULT_STATE_INPUT_KEY, _SNAPSHOT_SCOPE_INPUT_KEY from ._utils import ( canonical_function_arguments, convert_agui_tools_to_agent_framework, @@ -748,6 +751,85 @@ def _build_messages_snapshot( return MessagesSnapshotEvent(messages=all_messages) # type: ignore[arg-type] +def _event_messages_to_snapshot_dicts(messages: list[Any]) -> list[dict[str, Any]]: + """Convert AG-UI message event models back to plain snapshot dictionaries.""" + safe_messages = make_json_safe(messages) + if not isinstance(safe_messages, list): + return [] + return [cast(dict[str, Any], message) for message in safe_messages if isinstance(message, dict)] + + +def _text_events_to_snapshot_messages(events: list[BaseEvent]) -> list[dict[str, Any]]: + """Convert streamed text-message events into snapshot message dictionaries.""" + messages: list[dict[str, Any]] = [] + messages_by_id: dict[str, dict[str, Any]] = {} + for event in events: + if isinstance(event, TextMessageStartEvent): + message: dict[str, Any] = {"id": event.message_id, "role": event.role, "content": ""} + messages.append(message) + messages_by_id[event.message_id] = message + elif isinstance(event, TextMessageContentEvent): + open_message = messages_by_id.get(event.message_id) + if open_message is not None: + open_message["content"] = f"{open_message['content']}{event.delta}" + return [message for message in messages if message.get("content")] + + +async def _hydrate_thread_snapshot( + *, + config: AgentConfig, + scope: str, + thread_id: str, + run_id: str, +) -> AsyncGenerator[BaseEvent]: + """Replay the latest stored AG-UI Thread Snapshot without invoking the agent.""" + yield RunStartedEvent(run_id=run_id, thread_id=thread_id) + if config.snapshot_store is None: + yield _build_run_finished_event(run_id=run_id, thread_id=thread_id) + return + + snapshot = await config.snapshot_store.get(scope=scope, thread_id=thread_id) + if snapshot is None: + yield _build_run_finished_event(run_id=run_id, thread_id=thread_id) + return + + if snapshot.state is not None: + yield StateSnapshotEvent(snapshot=snapshot.state) + if snapshot.messages: + yield MessagesSnapshotEvent(messages=snapshot.messages) # type: ignore[arg-type] + yield _build_run_finished_event(run_id=run_id, thread_id=thread_id, interrupts=snapshot.interrupt) + + +async def _save_thread_snapshot( + *, + config: AgentConfig, + scope: str | None, + thread_id: str, + messages: list[dict[str, Any]], + state: dict[str, Any] | None, + interrupt: list[dict[str, Any]] | None, +) -> None: + """Save the latest replayable AG-UI Thread Snapshot when persistence is configured.""" + if config.snapshot_store is None or scope is None: + return + + try: + await config.snapshot_store.save( + scope=scope, + thread_id=thread_id, + snapshot=AGUIThreadSnapshot(messages=messages, state=state, interrupt=interrupt), + ) + except Exception: + # The run itself already streamed successfully; a transient store failure + # must not surface as RUN_ERROR for a completed run. The previous snapshot + # stays available for hydration. + logger.exception( + "Failed to save AG-UI Thread Snapshot for scope=%s thread_id=%s; keeping previous snapshot.", + scope, + thread_id, + ) + + async def run_agent_stream( input_data: dict[str, Any], agent: SupportsAgentRun, @@ -774,15 +856,53 @@ async def run_agent_stream( # Parse IDs thread_id = input_data.get("thread_id") or input_data.get("threadId") or str(uuid.uuid4()) run_id = input_data.get("run_id") or input_data.get("runId") or str(uuid.uuid4()) - - # Initialize flow state with schema defaults - flow = FlowState() - if input_data.get("state"): - flow.current_state = dict(input_data["state"]) + snapshot_scope = cast(str | None, input_data.get(_SNAPSHOT_SCOPE_INPUT_KEY)) state_schema = cast(dict[str, Any], getattr(config, "state_schema", {}) or {}) predict_state_config = cast(dict[str, dict[str, str]], getattr(config, "predict_state_config", {}) or {}) + # Normalize messages + available_interrupts = input_data.get("available_interrupts") or input_data.get("availableInterrupts") + raw_messages: list[dict[str, Any]] = input_data.get("messages", []) or [] + resume_payload = _extract_resume_payload(input_data) + if config.snapshot_store is not None and snapshot_scope is not None and not raw_messages and resume_payload is None: + async for event in _hydrate_thread_snapshot( + config=config, + scope=snapshot_scope, + thread_id=thread_id, + run_id=run_id, + ): + yield event + return + + stored_snapshot: AGUIThreadSnapshot | None = None + if config.snapshot_store is not None and snapshot_scope is not None: + stored_snapshot = await config.snapshot_store.get(scope=snapshot_scope, thread_id=thread_id) + if stored_snapshot is not None and resume_payload is None: + raw_messages = _reconstruct_messages_from_thread_snapshot( + stored_messages=stored_snapshot.messages, + incoming_messages=raw_messages, + stored_interrupt=stored_snapshot.interrupt, + ) + + # Initialize flow state with stored state plus request-provided overrides. + flow = FlowState() + request_state = input_data.get("state") + if stored_snapshot is not None and stored_snapshot.state is not None: + flow.current_state = dict(stored_snapshot.state) + if isinstance(request_state, dict): + flow.current_state.update(request_state) + elif isinstance(request_state, dict): + flow.current_state = dict(request_state) + + # Apply endpoint-deferred defaults only for keys missing from both the stored + # snapshot state and the request state, so defaults never reset persisted state. + deferred_default_state = cast(dict[str, Any] | None, input_data.get(_DEFAULT_STATE_INPUT_KEY)) + if deferred_default_state: + for key, value in deferred_default_state.items(): + if key not in flow.current_state: + flow.current_state[key] = copy.deepcopy(value) + # Apply schema defaults for missing state keys if state_schema: for key, schema in state_schema.items(): @@ -801,10 +921,7 @@ async def run_agent_stream( current_state=flow.current_state, ) - # Normalize messages - available_interrupts = input_data.get("available_interrupts") or input_data.get("availableInterrupts") - raw_messages = list(cast(list[dict[str, Any]], input_data.get("messages", []) or [])) - resume_messages = _resume_to_tool_messages(_extract_resume_payload(input_data)) + resume_messages = _resume_to_tool_messages(resume_payload) if available_interrupts: logger.debug("Received available interrupts metadata: %s", available_interrupts) if resume_messages: @@ -892,8 +1009,24 @@ async def run_agent_stream( # Emit approved state snapshot before confirmation message if approved_state_snapshot_emitted: yield StateSnapshotEvent(snapshot=flow.current_state) - for event in _handle_step_based_approval(messages): + confirmation_events = _handle_step_based_approval(messages) + for event in confirmation_events: yield event + # Persist the completed confirmation turn with interrupt=None so hydration + # does not replay the stale pending interrupt after the user responded. + persisted_messages = snapshot_messages + _text_events_to_snapshot_messages(confirmation_events) + if resume_payload is not None and stored_snapshot is not None: + # Resume requests carry only the synthesized interrupt response, so prepend + # the stored thread history to avoid persisting a truncated thread. + persisted_messages = [copy.deepcopy(message) for message in stored_snapshot.messages] + persisted_messages + await _save_thread_snapshot( + config=config, + scope=snapshot_scope, + thread_id=thread_id, + messages=persisted_messages, + state=cast(dict[str, Any], make_json_safe(flow.current_state)) if flow.current_state else None, + interrupt=None, + ) yield _build_run_finished_event(run_id=run_id, thread_id=thread_id) return @@ -905,6 +1038,9 @@ async def run_agent_stream( # Stream from agent - emit RunStarted after first update to get service IDs run_started_emitted = False all_updates: list[Any] = [] # Collect for structured output processing + latest_state_snapshot: dict[str, Any] | None = ( + cast(dict[str, Any], make_json_safe(flow.current_state)) if flow.current_state else None + ) response_stream = agent.run(messages, stream=True, **run_kwargs) stream = await _normalize_response_stream(response_stream) async for update in stream: @@ -934,6 +1070,7 @@ async def run_agent_stream( yield CustomEvent(name="PredictState", value=predict_state_value) # Emit initial state snapshot only if we have both state_schema and state if state_schema and flow.current_state: + latest_state_snapshot = cast(dict[str, Any], make_json_safe(flow.current_state)) yield StateSnapshotEvent(snapshot=flow.current_state) run_started_emitted = True @@ -975,6 +1112,8 @@ async def run_agent_stream( skip_text, config.require_confirmation, ): + if isinstance(event, StateSnapshotEvent): + latest_state_snapshot = cast(dict[str, Any], make_json_safe(event.snapshot)) yield event # Stop if waiting for approval @@ -1019,6 +1158,7 @@ async def run_agent_stream( if state_updates: flow.current_state.update(state_updates) + latest_state_snapshot = cast(dict[str, Any], make_json_safe(flow.current_state)) yield StateSnapshotEvent(snapshot=flow.current_state) logger.info(f"Emitted StateSnapshotEvent with updates: {list(state_updates.keys())}") @@ -1056,6 +1196,7 @@ async def run_agent_stream( if result: state_key, state_value = result flow.current_state[state_key] = state_value + latest_state_snapshot = cast(dict[str, Any], make_json_safe(flow.current_state)) yield StateSnapshotEvent(snapshot=flow.current_state) except json.JSONDecodeError: # Ignore malformed JSON in tool arguments for predictive state; @@ -1136,7 +1277,12 @@ async def run_agent_stream( should_emit_snapshot = ( flow.pending_tool_calls or flow.tool_results or flow.accumulated_text or flow.reasoning_messages ) + latest_messages_snapshot = snapshot_messages if should_emit_snapshot: + # Always fold this turn's output into the persisted snapshot, even when the + # outbound MESSAGES_SNAPSHOT event is suppressed for predictive tools. + snapshot_event = _build_messages_snapshot(flow, snapshot_messages) + latest_messages_snapshot = _event_messages_to_snapshot_dicts(list(snapshot_event.messages)) # Check if we should suppress for predictive tool last_tool_name = None if flow.tool_results: @@ -1146,8 +1292,21 @@ async def run_agent_stream( if not _should_suppress_intermediate_snapshot( last_tool_name, predict_state_config, config.require_confirmation ): - yield _build_messages_snapshot(flow, snapshot_messages) + yield snapshot_event # Always emit RunFinished - confirm_changes tool call is complete (Start -> Args -> End) # The UI will show confirmation dialog and send a new request when user responds + persisted_messages = latest_messages_snapshot + if resume_payload is not None and stored_snapshot is not None: + # Resume requests carry only the synthesized interrupt response, so prepend + # the stored thread history to avoid persisting a truncated thread. + persisted_messages = [copy.deepcopy(message) for message in stored_snapshot.messages] + persisted_messages + await _save_thread_snapshot( + config=config, + scope=snapshot_scope, + thread_id=thread_id, + messages=persisted_messages, + state=latest_state_snapshot, + interrupt=flow.interrupts or None, + ) yield _build_run_finished_event(run_id=run_id, thread_id=thread_id, interrupts=flow.interrupts) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py b/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py index d80ecea7a1e..1d04964ce67 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py @@ -7,6 +7,7 @@ import copy import logging from collections.abc import AsyncGenerator, Sequence +from inspect import isawaitable from typing import Any from ag_ui.core import RunErrorEvent @@ -17,12 +18,58 @@ from fastapi.responses import StreamingResponse from ._agent import AgentFrameworkAgent +from ._snapshots import ( + _DEFAULT_STATE_INPUT_KEY, + _SNAPSHOT_SCOPE_INPUT_KEY, + AGUIThreadSnapshotStore, + SnapshotScopeResolver, +) from ._types import AGUIRequest from ._workflow import AgentFrameworkWorkflow logger = logging.getLogger(__name__) +def _get_snapshot_store( + protocol_runner: AgentFrameworkAgent | AgentFrameworkWorkflow, +) -> AGUIThreadSnapshotStore | None: + if isinstance(protocol_runner, AgentFrameworkAgent): + return protocol_runner.config.snapshot_store + return protocol_runner.snapshot_store + + +def _set_snapshot_store( + protocol_runner: AgentFrameworkAgent | AgentFrameworkWorkflow, + snapshot_store: AGUIThreadSnapshotStore, +) -> None: + if isinstance(protocol_runner, AgentFrameworkAgent): + protocol_runner.config.snapshot_store = snapshot_store + return + protocol_runner.snapshot_store = snapshot_store + + +def _configure_snapshot_persistence( + protocol_runner: AgentFrameworkAgent | AgentFrameworkWorkflow, + *, + snapshot_store: AGUIThreadSnapshotStore | None, + snapshot_scope_resolver: SnapshotScopeResolver | None, +) -> None: + existing_snapshot_store = _get_snapshot_store(protocol_runner) + if snapshot_store is not None: + if existing_snapshot_store is not None and existing_snapshot_store is not snapshot_store: + raise ValueError("snapshot_store is already configured on the AG-UI runner.") + if existing_snapshot_store is None: + _set_snapshot_store(protocol_runner, snapshot_store) + existing_snapshot_store = snapshot_store + + if existing_snapshot_store is not None and snapshot_scope_resolver is None: + raise ValueError( + "snapshot_scope_resolver is required when snapshot_store is configured. " + "AG-UI Thread ids identify threads but do not authorize snapshot access; " + "provide a resolver that returns an explicit Snapshot Scope." + ) + + def add_agent_framework_fastapi_endpoint( app: FastAPI, agent: SupportsAgentRun | AgentFrameworkAgent | Workflow | AgentFrameworkWorkflow, @@ -33,6 +80,8 @@ def add_agent_framework_fastapi_endpoint( default_state: dict[str, Any] | None = None, tags: list[str] | None = None, dependencies: Sequence[Depends] | None = None, + snapshot_store: AGUIThreadSnapshotStore | None = None, + snapshot_scope_resolver: SnapshotScopeResolver | None = None, ) -> None: """Add an AG-UI endpoint to a FastAPI app. @@ -50,6 +99,10 @@ def add_agent_framework_fastapi_endpoint( These dependencies run before the endpoint handler. Use this to add authentication checks, rate limiting, or other middleware-like behavior. Example: `dependencies=[Depends(verify_api_key)]` + snapshot_store: Optional AG-UI Thread Snapshot store. Snapshot persistence is opt-in and requires an + explicit Snapshot Scope resolver. + snapshot_scope_resolver: Optional resolver for the application-defined Snapshot Scope. Required whenever + a snapshot store is configured because an AG-UI Thread id is not an authorization boundary. """ protocol_runner: AgentFrameworkAgent | AgentFrameworkWorkflow if isinstance(agent, AgentFrameworkWorkflow): @@ -63,10 +116,17 @@ def add_agent_framework_fastapi_endpoint( agent=agent, state_schema=state_schema, predict_state_config=predict_state_config, + snapshot_store=snapshot_store, ) else: raise TypeError("agent must be SupportsAgentRun, Workflow, AgentFrameworkAgent, or AgentFrameworkWorkflow.") + _configure_snapshot_persistence( + protocol_runner, + snapshot_store=snapshot_store, + snapshot_scope_resolver=snapshot_scope_resolver, + ) + @app.post(path, tags=tags or ["AG-UI"], dependencies=dependencies, response_model=None) # type: ignore[arg-type] async def agent_endpoint(request_body: AGUIRequest) -> StreamingResponse: """Handle AG-UI agent requests. @@ -76,11 +136,23 @@ async def agent_endpoint(request_body: AGUIRequest) -> StreamingResponse: """ try: input_data = request_body.model_dump(exclude_none=True) + snapshot_persistence_active = False + if snapshot_scope_resolver is not None and _get_snapshot_store(protocol_runner) is not None: + snapshot_scope = snapshot_scope_resolver(request_body) + if isawaitable(snapshot_scope): + snapshot_scope = await snapshot_scope + input_data[_SNAPSHOT_SCOPE_INPUT_KEY] = snapshot_scope + snapshot_persistence_active = True if default_state: - state = input_data.setdefault("state", {}) - for key, value in default_state.items(): - if key not in state: - state[key] = copy.deepcopy(value) + if snapshot_persistence_active: + # Defer default application to the runner so defaults only fill keys + # missing from both the stored snapshot state and the request state. + input_data[_DEFAULT_STATE_INPUT_KEY] = copy.deepcopy(default_state) + else: + state = input_data.setdefault("state", {}) + for key, value in default_state.items(): + if key not in state: + state[key] = copy.deepcopy(value) logger.debug( f"[{path}] Received request - Run ID: {input_data.get('run_id', 'no-run-id')}, " f"Thread ID: {input_data.get('thread_id', 'no-thread-id')}, " diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_run_common.py b/python/packages/ag-ui/agent_framework_ag_ui/_run_common.py index fe51e426183..a679c069cd7 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_run_common.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_run_common.py @@ -4,6 +4,7 @@ from __future__ import annotations +import copy import json import logging from collections.abc import Mapping @@ -33,7 +34,7 @@ from ._orchestration._predictive_state import PredictiveStateHandler from ._state import TOOL_RESULT_DISPLAY_KEY, TOOL_RESULT_STATE_KEY -from ._utils import generate_event_id, make_json_safe +from ._utils import generate_event_id, make_json_safe, normalize_agui_role logger = logging.getLogger(__name__) @@ -733,3 +734,117 @@ def _emit_content( return _emit_text_reasoning(content, flow) logger.debug("Skipping unsupported content type in AG-UI emitter: %s", content_type) return events + + +def _canonical_snapshot_message(message: dict[str, Any]) -> dict[str, Any]: + """Normalize an AG-UI message for identity comparison without generated ids.""" + from ._message_adapters import agui_messages_to_snapshot_format + + normalized_message = agui_messages_to_snapshot_format([copy.deepcopy(message)])[0] + normalized_message.pop("id", None) + return cast(dict[str, Any], make_json_safe(normalized_message)) + + +def _snapshot_messages_match(stored_message: dict[str, Any], incoming_message: dict[str, Any]) -> bool: + """Return whether an incoming message already represents the stored snapshot message.""" + stored_id = stored_message.get("id") + incoming_id = incoming_message.get("id") + if stored_id and incoming_id: + return str(stored_id) == str(incoming_id) + return _canonical_snapshot_message(stored_message) == _canonical_snapshot_message(incoming_message) + + +def _latest_user_message_index(messages: list[dict[str, Any]]) -> int | None: + """Find the newest incoming user message index.""" + for index in range(len(messages) - 1, -1, -1): + if normalize_agui_role(messages[index].get("role", "user")) == "user": + return index + return None + + +def _known_tool_call_ids( + stored_messages: list[dict[str, Any]], + stored_interrupt: list[dict[str, Any]] | None, +) -> set[str]: + """Collect tool call ids the backend previously issued for this thread.""" + known_ids: set[str] = set() + for message in stored_messages: + tool_calls = message.get("tool_calls") or message.get("toolCalls") or [] + if not isinstance(tool_calls, list): + continue + for tool_call in cast(list[Any], tool_calls): + if isinstance(tool_call, dict): + tool_call_id = cast(dict[str, Any], tool_call).get("id") + if tool_call_id: + known_ids.add(str(tool_call_id)) + for interrupt in stored_interrupt or []: + interrupt_id = interrupt.get("id") + if interrupt_id: + known_ids.add(str(interrupt_id)) + return known_ids + + +def _filter_untrusted_suffix( + incoming_suffix: list[dict[str, Any]], + *, + stored_messages: list[dict[str, Any]], + stored_interrupt: list[dict[str, Any]] | None, +) -> list[dict[str, Any]]: + """Drop client-forged non-user messages before promoting them to stored history. + + Only the user's own turns and tool results answering backend-issued tool calls + (including pending interrupts) may extend the authoritative thread history. + """ + known_ids: set[str] | None = None + filtered: list[dict[str, Any]] = [] + for message in incoming_suffix: + raw_role = str(message.get("role", "")).lower() + if raw_role == "user": + filtered.append(message) + continue + if raw_role == "tool": + tool_call_id = message.get("toolCallId") or message.get("tool_call_id") or message.get("actionExecutionId") + if known_ids is None: + known_ids = _known_tool_call_ids(stored_messages, stored_interrupt) + if tool_call_id and str(tool_call_id) in known_ids: + filtered.append(message) + continue + logger.warning( + "Dropping client-supplied %r message from the incoming thread suffix; " + "only user turns and tool results for backend-issued tool calls extend stored history.", + raw_role or "unknown", + ) + return filtered + + +def _reconstruct_messages_from_thread_snapshot( + *, + stored_messages: list[dict[str, Any]], + incoming_messages: list[dict[str, Any]], + stored_interrupt: list[dict[str, Any]] | None = None, +) -> list[dict[str, Any]]: + """Combine backend-owned prior history with the request-owned new user turn.""" + if not stored_messages or not incoming_messages: + return incoming_messages + + incoming_suffix: list[dict[str, Any]] + if len(incoming_messages) >= len(stored_messages) and all( + _snapshot_messages_match(stored_message, incoming_message) + for stored_message, incoming_message in zip(stored_messages, incoming_messages) + ): + incoming_suffix = incoming_messages[len(stored_messages) :] + else: + latest_user_index = _latest_user_message_index(incoming_messages) + if latest_user_index is None: + return incoming_messages + incoming_suffix = incoming_messages[latest_user_index:] + + incoming_suffix = _filter_untrusted_suffix( + incoming_suffix, + stored_messages=stored_messages, + stored_interrupt=stored_interrupt, + ) + + return [copy.deepcopy(message) for message in stored_messages] + [ + copy.deepcopy(message) for message in incoming_suffix + ] diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_snapshots.py b/python/packages/ag-ui/agent_framework_ag_ui/_snapshots.py new file mode 100644 index 00000000000..b619f99c810 --- /dev/null +++ b/python/packages/ag-ui/agent_framework_ag_ui/_snapshots.py @@ -0,0 +1,202 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""AG-UI Thread Snapshot storage primitives.""" + +from __future__ import annotations + +import copy +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, runtime_checkable + +if TYPE_CHECKING: + from ._types import AGUIRequest + +SnapshotScope: TypeAlias = str +"""Application-defined scope for authorizing access to AG-UI Thread Snapshots.""" + +AGUIThreadID: TypeAlias = str +"""AG-UI Thread identifier within a Snapshot Scope.""" + +SnapshotScopeResolver: TypeAlias = Callable[["AGUIRequest"], str | Awaitable[str]] +"""Callable that resolves the Snapshot Scope for an AG-UI endpoint request.""" + +_SnapshotKey: TypeAlias = tuple[SnapshotScope, AGUIThreadID] + +DEFAULT_MAX_THREAD_SNAPSHOTS = 1_000 +_SNAPSHOT_SCOPE_INPUT_KEY = "__ag_ui_snapshot_scope" +_DEFAULT_STATE_INPUT_KEY = "__ag_ui_default_state" + + +@dataclass(slots=True) +class AGUIThreadSnapshot: + """Replayable AG-UI Thread state. + + AG-UI Thread Snapshots intentionally contain only data that can be replayed + to a UI: message snapshots, optional Shared State, and optional interruption + state. They do not include raw events, request metadata, auth claims, + diagnostics, traces, or provider responses. + + Attributes: + messages: Replayable AG-UI message snapshots. + state: Optional AG-UI Shared State snapshot. + interrupt: Optional interruption state from ``RUN_FINISHED.interrupt``. + """ + + messages: list[dict[str, Any]] = field(default_factory=list) + state: dict[str, Any] | None = None + interrupt: list[dict[str, Any]] | None = None + + +@runtime_checkable +class AGUIThreadSnapshotStore(Protocol): + """Async store for latest AG-UI Thread Snapshots keyed by scope and thread id.""" + + async def save( + self, + *, + scope: SnapshotScope, + thread_id: AGUIThreadID, + snapshot: AGUIThreadSnapshot, + ) -> None: + """Save the latest snapshot for an AG-UI Thread within a Snapshot Scope. + + Args: + scope: Application-defined Snapshot Scope. This is part of the + storage key and must represent the app's authorization boundary. + thread_id: AG-UI Thread id within the scope. + snapshot: Snapshot to save. + """ + ... + + async def get( + self, + *, + scope: SnapshotScope, + thread_id: AGUIThreadID, + ) -> AGUIThreadSnapshot | None: + """Get the latest snapshot for an AG-UI Thread within a Snapshot Scope. + + Args: + scope: Application-defined Snapshot Scope. + thread_id: AG-UI Thread id within the scope. + + Returns: + The latest snapshot, or ``None`` when no snapshot exists for the key. + """ + ... + + async def delete( + self, + *, + scope: SnapshotScope, + thread_id: AGUIThreadID, + ) -> bool: + """Delete the latest snapshot for an AG-UI Thread within a Snapshot Scope. + + Args: + scope: Application-defined Snapshot Scope. + thread_id: AG-UI Thread id within the scope. + + Returns: + ``True`` when a snapshot was deleted, otherwise ``False``. + """ + ... + + async def clear(self, *, scope: SnapshotScope | None = None) -> None: + """Clear saved snapshots. + + Args: + scope: Optional Snapshot Scope to clear. When omitted, all in-memory + snapshots are cleared. + """ + ... + + +class InMemoryAGUIThreadSnapshotStore: + """Bounded memory-only latest snapshot store for local development, demos, and tests. + + This store keeps at most one snapshot per ``(scope, thread_id)`` key. It is + process-local and not durable production storage. + """ + + def __init__(self, *, max_snapshots: int = DEFAULT_MAX_THREAD_SNAPSHOTS) -> None: + """Initialize the in-memory snapshot store. + + Keyword Args: + max_snapshots: Maximum number of scoped thread snapshots to retain. + + Raises: + ValueError: If ``max_snapshots`` is less than 1. + """ + if max_snapshots < 1: + raise ValueError("max_snapshots must be greater than 0.") + self._max_snapshots = max_snapshots + self._snapshots: dict[_SnapshotKey, AGUIThreadSnapshot] = {} + + async def save( + self, + *, + scope: SnapshotScope, + thread_id: AGUIThreadID, + snapshot: AGUIThreadSnapshot, + ) -> None: + """Save the latest snapshot for an AG-UI Thread within a Snapshot Scope.""" + key = self._key(scope=scope, thread_id=thread_id) + if key in self._snapshots: + del self._snapshots[key] + self._snapshots[key] = copy.deepcopy(snapshot) + self._evict_oldest() + + async def get( + self, + *, + scope: SnapshotScope, + thread_id: AGUIThreadID, + ) -> AGUIThreadSnapshot | None: + """Get the latest snapshot for an AG-UI Thread within a Snapshot Scope.""" + snapshot = self._snapshots.get(self._key(scope=scope, thread_id=thread_id)) + return copy.deepcopy(snapshot) if snapshot is not None else None + + async def delete( + self, + *, + scope: SnapshotScope, + thread_id: AGUIThreadID, + ) -> bool: + """Delete the latest snapshot for an AG-UI Thread within a Snapshot Scope.""" + key = self._key(scope=scope, thread_id=thread_id) + if key not in self._snapshots: + return False + del self._snapshots[key] + return True + + async def clear(self, *, scope: SnapshotScope | None = None) -> None: + """Clear saved snapshots, optionally limited to one Snapshot Scope.""" + if scope is None: + self._snapshots.clear() + return + + normalized_scope = self._normalize_key_part(scope, "scope") + for key in list(self._snapshots): + if key[0] == normalized_scope: + del self._snapshots[key] + + @classmethod + def _key(cls, *, scope: SnapshotScope, thread_id: AGUIThreadID) -> _SnapshotKey: + return ( + cls._normalize_key_part(scope, "scope"), + cls._normalize_key_part(thread_id, "thread_id"), + ) + + @staticmethod + def _normalize_key_part(value: str, name: str) -> str: + if not isinstance(value, str): + raise TypeError(f"{name} must be a string.") + if not value: + raise ValueError(f"{name} must be a non-empty string.") + return value + + def _evict_oldest(self) -> None: + while len(self._snapshots) > self._max_snapshots: + del self._snapshots[next(iter(self._snapshots))] diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_workflow.py b/python/packages/ag-ui/agent_framework_ag_ui/_workflow.py index 10b1a6b21fc..aa583856a65 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_workflow.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_workflow.py @@ -4,18 +4,203 @@ from __future__ import annotations +import copy +import logging import uuid from collections.abc import AsyncGenerator, Callable -from typing import Any +from typing import Any, cast -from ag_ui.core import BaseEvent +from ag_ui.core import ( + BaseEvent, + MessagesSnapshotEvent, + RunErrorEvent, + RunFinishedEvent, + RunStartedEvent, + StateSnapshotEvent, + TextMessageContentEvent, + TextMessageEndEvent, + TextMessageStartEvent, + ToolCallArgsEvent, + ToolCallResultEvent, + ToolCallStartEvent, +) from agent_framework import Workflow +from ._message_adapters import agui_messages_to_snapshot_format +from ._run_common import ( + _build_run_finished_event, + _extract_resume_payload, + _reconstruct_messages_from_thread_snapshot, +) +from ._snapshots import ( + _DEFAULT_STATE_INPUT_KEY, + _SNAPSHOT_SCOPE_INPUT_KEY, + AGUIThreadSnapshot, + AGUIThreadSnapshotStore, +) +from ._utils import generate_event_id, make_json_safe from ._workflow_run import run_workflow_stream +logger = logging.getLogger(__name__) + WorkflowFactory = Callable[[str], Workflow] +def _event_messages_to_snapshot_dicts(messages: list[Any]) -> list[dict[str, Any]]: + """Convert AG-UI message event models to plain snapshot dictionaries.""" + safe_messages = make_json_safe(messages) + if not isinstance(safe_messages, list): + return [] + return [cast(dict[str, Any], message) for message in safe_messages if isinstance(message, dict)] + + +class _WorkflowSnapshotBuilder: + """Capture replayable workflow protocol output without retaining raw events.""" + + def __init__(self, raw_messages: list[dict[str, Any]]) -> None: + self._synthesized_messages = agui_messages_to_snapshot_format(raw_messages) + self._emitted_messages: list[dict[str, Any]] | None = None + self._open_text_message: dict[str, Any] | None = None + self._tool_call_message: dict[str, Any] | None = None + self._tool_calls_by_id: dict[str, dict[str, Any]] = {} + self.state: dict[str, Any] | None = None + self.interrupt: list[dict[str, Any]] | None = None + + def observe(self, event: BaseEvent) -> None: + """Fold one replayable AG-UI event into the latest snapshot state.""" + if isinstance(event, StateSnapshotEvent): + state = make_json_safe(event.snapshot) + if isinstance(state, dict): + self.state = cast(dict[str, Any], state) + return + + if isinstance(event, MessagesSnapshotEvent): + self._emitted_messages = _event_messages_to_snapshot_dicts(list(event.messages)) + return + + if isinstance(event, RunFinishedEvent): + interrupt = make_json_safe(getattr(event, "interrupt", None)) + if isinstance(interrupt, list): + self.interrupt = [cast(dict[str, Any], item) for item in interrupt if isinstance(item, dict)] + return + + if self._emitted_messages is not None: + return + + if isinstance(event, TextMessageStartEvent): + self._observe_text_start(event) + elif isinstance(event, TextMessageContentEvent): + self._observe_text_content(event) + elif isinstance(event, TextMessageEndEvent): + self._observe_text_end(event) + elif isinstance(event, ToolCallStartEvent): + self._observe_tool_call_start(event) + elif isinstance(event, ToolCallArgsEvent): + self._observe_tool_call_args(event) + elif isinstance(event, ToolCallResultEvent): + self._observe_tool_call_result(event) + + def build(self) -> AGUIThreadSnapshot: + """Return the replayable thread snapshot.""" + self._flush_open_text_message() + messages = self._emitted_messages if self._emitted_messages is not None else self._synthesized_messages + return AGUIThreadSnapshot(messages=messages, state=self.state, interrupt=self.interrupt) + + def _observe_text_start(self, event: TextMessageStartEvent) -> None: + if self._open_text_message is not None and self._open_text_message.get("id") != event.message_id: + self._flush_open_text_message() + self._open_text_message = {"id": event.message_id, "role": event.role, "content": ""} + + def _observe_text_content(self, event: TextMessageContentEvent) -> None: + if self._open_text_message is None or self._open_text_message.get("id") != event.message_id: + self._open_text_message = {"id": event.message_id, "role": "assistant", "content": ""} + self._open_text_message["content"] = f"{self._open_text_message.get('content', '')}{event.delta}" + + def _observe_text_end(self, event: TextMessageEndEvent) -> None: + if self._open_text_message is None or self._open_text_message.get("id") != event.message_id: + return + self._flush_open_text_message() + + def _observe_tool_call_start(self, event: ToolCallStartEvent) -> None: + parent_message_id = event.parent_message_id + if ( + self._open_text_message is not None + and parent_message_id is not None + and self._open_text_message.get("id") == parent_message_id + and self._open_text_message.get("content") + ): + self._open_text_message["id"] = generate_event_id() + self._flush_open_text_message() + if self._tool_call_message is None or ( + parent_message_id is not None and self._tool_call_message.get("id") != parent_message_id + ): + self._tool_call_message = { + "id": parent_message_id or generate_event_id(), + "role": "assistant", + "tool_calls": [], + } + self._synthesized_messages.append(self._tool_call_message) + + tool_call = { + "id": event.tool_call_id, + "type": "function", + "function": {"name": event.tool_call_name, "arguments": ""}, + } + cast(list[dict[str, Any]], self._tool_call_message["tool_calls"]).append(tool_call) + self._tool_calls_by_id[event.tool_call_id] = tool_call + + def _observe_tool_call_args(self, event: ToolCallArgsEvent) -> None: + tool_call = self._tool_calls_by_id.get(event.tool_call_id) + if tool_call is None: + return + function_payload = cast(dict[str, Any], tool_call["function"]) + function_payload["arguments"] = f"{function_payload.get('arguments', '')}{event.delta}" + + def _observe_tool_call_result(self, event: ToolCallResultEvent) -> None: + self._synthesized_messages.append( + { + "id": event.message_id, + "role": "tool", + "toolCallId": event.tool_call_id, + "content": event.content, + } + ) + # A result closes the current tool-call group; later tool calls start a new + # assistant message so replayed transcripts keep results adjacent to their + # tool_calls message, which provider APIs require. + self._tool_call_message = None + + def _flush_open_text_message(self) -> None: + if self._open_text_message is None: + return + if self._open_text_message.get("content"): + self._synthesized_messages.append(self._open_text_message) + # Text between tool calls closes the current tool-call group as well. + self._tool_call_message = None + self._open_text_message = None + + +async def _hydrate_workflow_thread_snapshot( + *, + snapshot_store: AGUIThreadSnapshotStore, + scope: str, + thread_id: str, + run_id: str, +) -> AsyncGenerator[BaseEvent]: + """Replay the latest stored workflow AG-UI Thread Snapshot without invoking the workflow.""" + yield RunStartedEvent(run_id=run_id, thread_id=thread_id) + snapshot = await snapshot_store.get(scope=scope, thread_id=thread_id) + if snapshot is None: + yield _build_run_finished_event(run_id=run_id, thread_id=thread_id) + return + + if snapshot.state is not None: + yield StateSnapshotEvent(snapshot=snapshot.state) + if snapshot.messages: + yield MessagesSnapshotEvent(messages=snapshot.messages) # type: ignore[arg-type] + yield _build_run_finished_event(run_id=run_id, thread_id=thread_id, interrupts=snapshot.interrupt) + + class AgentFrameworkWorkflow: """Base AG-UI workflow wrapper. @@ -29,15 +214,30 @@ def __init__( workflow_factory: WorkflowFactory | None = None, name: str | None = None, description: str | None = None, + snapshot_store: AGUIThreadSnapshotStore | None = None, ) -> None: + """Initialize the AG-UI workflow wrapper. + + Args: + workflow: Optional workflow instance to expose. + workflow_factory: Optional factory for thread-scoped workflow instances. + name: Optional workflow name. + description: Optional workflow description. + snapshot_store: Optional AG-UI Thread Snapshot store. Snapshot persistence remains inactive unless + endpoint setup also provides an explicit Snapshot Scope resolver. + """ if workflow is not None and workflow_factory is not None: raise ValueError("Pass either workflow= or workflow_factory=, not both.") self.workflow = workflow self._workflow_factory = workflow_factory - self._workflow_by_thread: dict[str, Workflow] = {} + # Cache keyed by (snapshot_scope, thread_id): the Snapshot Scope is the + # authorization boundary, so the same thread id under different scopes + # must never share an in-memory workflow instance. + self._workflow_by_thread: dict[tuple[str | None, str], Workflow] = {} self.name = name if name is not None else getattr(workflow, "name", "workflow") self.description = description if description is not None else getattr(workflow, "description", "") + self.snapshot_store = snapshot_store @staticmethod def _thread_id_from_input(input_data: dict[str, Any]) -> str: @@ -47,7 +247,7 @@ def _thread_id_from_input(input_data: dict[str, Any]) -> str: return str(thread_id) return str(uuid.uuid4()) - def _resolve_workflow(self, thread_id: str) -> Workflow: + def _resolve_workflow(self, thread_id: str, snapshot_scope: str | None = None) -> Workflow: """Get the workflow instance for the current run.""" if self.workflow is not None: return self.workflow @@ -55,17 +255,22 @@ def _resolve_workflow(self, thread_id: str) -> Workflow: if self._workflow_factory is None: raise NotImplementedError("No workflow is attached. Override run or pass workflow=/workflow_factory=.") - workflow = self._workflow_by_thread.get(thread_id) + cache_key = (snapshot_scope, thread_id) + workflow = self._workflow_by_thread.get(cache_key) if workflow is None: workflow = self._workflow_factory(thread_id) if not isinstance(workflow, Workflow): raise TypeError("workflow_factory must return a Workflow instance.") - self._workflow_by_thread[thread_id] = workflow + self._workflow_by_thread[cache_key] = workflow return workflow - def clear_thread_workflow(self, thread_id: str) -> None: - """Drop a single cached thread workflow instance.""" - self._workflow_by_thread.pop(thread_id, None) + def clear_thread_workflow(self, thread_id: str, snapshot_scope: str | None = None) -> None: + """Drop cached workflow instances for a thread, optionally limited to one Snapshot Scope.""" + if snapshot_scope is not None: + self._workflow_by_thread.pop((snapshot_scope, thread_id), None) + return + for key in [key for key in self._workflow_by_thread if key[1] == thread_id]: + del self._workflow_by_thread[key] def clear_workflow_cache(self) -> None: """Drop all cached thread workflow instances.""" @@ -77,6 +282,96 @@ async def run(self, input_data: dict[str, Any]) -> AsyncGenerator[BaseEvent]: Subclasses may override this to provide custom AG-UI streams. """ thread_id = self._thread_id_from_input(input_data) - workflow = self._resolve_workflow(thread_id) + run_id = str(input_data.get("run_id") or input_data.get("runId") or uuid.uuid4()) + snapshot_scope = cast(str | None, input_data.get(_SNAPSHOT_SCOPE_INPUT_KEY)) + raw_messages = list(cast(list[dict[str, Any]], input_data.get("messages", []) or [])) + resume_payload = _extract_resume_payload(input_data) + snapshot_store = self.snapshot_store + + if snapshot_store is not None and snapshot_scope is not None and not raw_messages and resume_payload is None: + async for event in _hydrate_workflow_thread_snapshot( + snapshot_store=snapshot_store, + scope=snapshot_scope, + thread_id=thread_id, + run_id=run_id, + ): + yield event + return + + # Load the stored snapshot for follow-up turns so the workflow runs with the + # full persisted thread history instead of just the latest request messages. + stored_snapshot: AGUIThreadSnapshot | None = None + if snapshot_store is not None and snapshot_scope is not None: + stored_snapshot = await snapshot_store.get(scope=snapshot_scope, thread_id=thread_id) + if stored_snapshot is not None and resume_payload is None: + raw_messages = _reconstruct_messages_from_thread_snapshot( + stored_messages=stored_snapshot.messages, + incoming_messages=raw_messages, + stored_interrupt=stored_snapshot.interrupt, + ) + input_data["messages"] = raw_messages + + # Merge stored state with request overrides, then fill endpoint-deferred + # defaults only for keys missing from both. + request_state = input_data.get("state") + deferred_default_state = cast(dict[str, Any] | None, input_data.get(_DEFAULT_STATE_INPUT_KEY)) + effective_state: dict[str, Any] = {} + if stored_snapshot is not None and stored_snapshot.state is not None: + effective_state.update(stored_snapshot.state) + if isinstance(request_state, dict): + effective_state.update(cast(dict[str, Any], request_state)) + if deferred_default_state: + for key, value in deferred_default_state.items(): + if key not in effective_state: + effective_state[key] = copy.deepcopy(value) + if effective_state: + input_data["state"] = effective_state + + workflow = self._resolve_workflow(thread_id, snapshot_scope) + builder_seed_messages = raw_messages + if resume_payload is not None and stored_snapshot is not None: + # Resume requests carry only the synthesized interrupt response, so seed + # the builder with stored history to avoid persisting a truncated thread. + builder_seed_messages = [ + copy.deepcopy(message) for message in stored_snapshot.messages + ] + builder_seed_messages + snapshot_builder = ( + _WorkflowSnapshotBuilder(builder_seed_messages) + if snapshot_store is not None and snapshot_scope is not None + else None + ) + if snapshot_builder is not None and effective_state: + # Seed builder state so a run that emits no StateSnapshotEvent still + # persists the latest known Shared State instead of dropping it. + state_snapshot = make_json_safe(effective_state) + if isinstance(state_snapshot, dict): + snapshot_builder.state = cast(dict[str, Any], state_snapshot) + run_error_emitted = False async for event in run_workflow_stream(input_data, workflow): + if snapshot_builder is not None: + snapshot_builder.observe(event) + if isinstance(event, RunErrorEvent): + run_error_emitted = True yield event + + if ( + snapshot_builder is not None + and not run_error_emitted + and snapshot_store is not None + and snapshot_scope is not None + ): + try: + await snapshot_store.save( + scope=snapshot_scope, + thread_id=thread_id, + snapshot=snapshot_builder.build(), + ) + except Exception: + # RUN_FINISHED has already been yielded; a store failure must not + # surface as a second terminal RUN_ERROR event. The previous + # snapshot stays available for hydration. + logger.exception( + "Failed to save AG-UI Thread Snapshot for scope=%s thread_id=%s; keeping previous snapshot.", + snapshot_scope, + thread_id, + ) diff --git a/python/packages/ag-ui/tests/ag_ui/test_endpoint.py b/python/packages/ag-ui/tests/ag_ui/test_endpoint.py index 51ab468b84c..20a72cd4381 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_endpoint.py +++ b/python/packages/ag-ui/tests/ag_ui/test_endpoint.py @@ -6,7 +6,7 @@ from typing import Any import pytest -from ag_ui.core import RunStartedEvent +from ag_ui.core import MessagesSnapshotEvent, RunStartedEvent, StateSnapshotEvent from agent_framework import ( Agent, ChatResponseUpdate, @@ -20,11 +20,24 @@ from fastapi.params import Depends from fastapi.testclient import TestClient -from agent_framework_ag_ui import add_agent_framework_fastapi_endpoint +from agent_framework_ag_ui import InMemoryAGUIThreadSnapshotStore, add_agent_framework_fastapi_endpoint from agent_framework_ag_ui._agent import AgentFrameworkAgent from agent_framework_ag_ui._workflow import AgentFrameworkWorkflow +def _decode_sse_events(response: Any) -> list[dict[str, Any]]: + content = response.content.decode("utf-8") + return [json.loads(line[6:]) for line in content.splitlines() if line.startswith("data: ")] + + +def _latest_messages_snapshot(response: Any) -> list[dict[str, Any]]: + snapshots = [ + event["messages"] for event in _decode_sse_events(response) if event.get("type") == "MESSAGES_SNAPSHOT" + ] + assert snapshots + return snapshots[-1] + + @pytest.fixture def build_chat_client(streaming_chat_client_stub, stream_from_updates_fixture): """Create a typed chat client stub for endpoint tests.""" @@ -287,10 +300,18 @@ async def test_endpoint_response_headers(build_chat_client): assert response.headers["cache-control"] == "no-cache" -async def test_endpoint_empty_messages(build_chat_client): - """Test endpoint with empty messages list.""" +async def test_endpoint_empty_messages(streaming_chat_client_stub): + """Empty messages keep the existing no-op run behavior when snapshot persistence is not configured.""" app = FastAPI() - agent = Agent(name="test", instructions="Test agent", client=build_chat_client()) + call_count = 0 + + async def stream_fn(messages: Any, options: Any, **kwargs: Any): + nonlocal call_count + del messages, options, kwargs + call_count += 1 + yield ChatResponseUpdate(contents=[Content.from_text(text="Should not run")]) + + agent = Agent(name="test", instructions="Test agent", client=streaming_chat_client_stub(stream_fn)) add_agent_framework_fastapi_endpoint(app, agent, path="/empty") @@ -298,6 +319,8 @@ async def test_endpoint_empty_messages(build_chat_client): response = client.post("/empty", json={"messages": []}) assert response.status_code == 200 + assert call_count == 0 + assert [event.get("type") for event in _decode_sse_events(response)] == ["RUN_STARTED", "RUN_FINISHED"] async def test_endpoint_complex_input(build_chat_client): @@ -560,6 +583,636 @@ async def test_endpoint_invalid_agent_type_raises_typeerror(): add_agent_framework_fastapi_endpoint(app, agent="not_an_agent") # type: ignore[arg-type] +async def test_endpoint_requires_snapshot_scope_resolver_when_store_configured(build_chat_client): + """Snapshot persistence setup must require an explicit Snapshot Scope resolver.""" + app = FastAPI() + agent = Agent(name="test", instructions="Test agent", client=build_chat_client()) + store = InMemoryAGUIThreadSnapshotStore() + + with pytest.raises(ValueError, match="snapshot_scope_resolver is required"): + add_agent_framework_fastapi_endpoint(app, agent, path="/snapshots", snapshot_store=store) + + +async def test_endpoint_requires_snapshot_scope_resolver_when_wrapped_runner_has_store(build_chat_client): + """Pre-wrapped runners with snapshot stores must also provide a Snapshot Scope resolver.""" + app = FastAPI() + agent = Agent(name="test", instructions="Test agent", client=build_chat_client()) + wrapped_agent = AgentFrameworkAgent(agent=agent, snapshot_store=InMemoryAGUIThreadSnapshotStore()) + + with pytest.raises(ValueError, match="snapshot_scope_resolver is required"): + add_agent_framework_fastapi_endpoint(app, wrapped_agent, path="/snapshots") + + +async def test_endpoint_accepts_snapshot_store_with_scope_resolver(build_chat_client): + """Endpoint behavior remains the normal event stream when snapshot persistence is explicitly configured.""" + app = FastAPI() + agent = Agent(name="test", instructions="Test agent", client=build_chat_client()) + store = InMemoryAGUIThreadSnapshotStore() + + add_agent_framework_fastapi_endpoint( + app, + agent, + path="/snapshots", + snapshot_store=store, + snapshot_scope_resolver=lambda _request: "tenant-a", + ) + + client = TestClient(app) + response = client.post( + "/snapshots", + json={"messages": [{"role": "user", "content": "Hello"}], "thread_id": "thread-1"}, + ) + + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + + +async def test_agent_endpoint_hydrates_stored_thread_snapshot_without_invoking_agent(streaming_chat_client_stub): + """A Hydrate Request replays stored agent messages and state without invoking the wrapped agent.""" + app = FastAPI() + call_count = 0 + + async def stream_fn(messages: Any, options: Any, **kwargs: Any): + nonlocal call_count + del messages, options, kwargs + call_count += 1 + yield ChatResponseUpdate(contents=[Content.from_text(text="Stored reply")]) + + agent = Agent(name="test", instructions="Test agent", client=streaming_chat_client_stub(stream_fn)) + store = InMemoryAGUIThreadSnapshotStore() + add_agent_framework_fastapi_endpoint( + app, + agent, + path="/snapshots", + state_schema={"recipe": {"type": "string"}}, + snapshot_store=store, + snapshot_scope_resolver=lambda _request: "tenant-a", + ) + client = TestClient(app) + + first_response = client.post( + "/snapshots", + json={ + "thread_id": "thread-1", + "messages": [{"role": "user", "content": "Hello"}], + "state": {"recipe": "pasta"}, + }, + ) + assert first_response.status_code == 200 + assert call_count == 1 + + hydrate_response = client.post("/snapshots", json={"thread_id": "thread-1", "messages": []}) + + assert hydrate_response.status_code == 200 + assert call_count == 1 + events = _decode_sse_events(hydrate_response) + event_types = [event.get("type") for event in events] + assert event_types == ["RUN_STARTED", "STATE_SNAPSHOT", "MESSAGES_SNAPSHOT", "RUN_FINISHED"] + assert events[1]["snapshot"] == {"recipe": "pasta"} + assert any(message.get("role") == "user" and message.get("content") == "Hello" for message in events[2]["messages"]) + assert any( + message.get("role") == "assistant" and message.get("content") == "Stored reply" + for message in events[2]["messages"] + ) + + +async def test_agent_endpoint_hydrates_snapshots_by_scope_and_thread(streaming_chat_client_stub): + """Hydration uses Snapshot Scope and AG-UI Thread id together when reading stored snapshots.""" + app = FastAPI() + call_count = 0 + + async def stream_fn(messages: Any, options: Any, **kwargs: Any): + nonlocal call_count + del messages, options, kwargs + call_count += 1 + yield ChatResponseUpdate(contents=[Content.from_text(text="Tenant A reply")]) + + agent = Agent(name="test", instructions="Test agent", client=streaming_chat_client_stub(stream_fn)) + store = InMemoryAGUIThreadSnapshotStore() + add_agent_framework_fastapi_endpoint( + app, + agent, + path="/snapshots", + state_schema={"tenant": {"type": "string"}}, + snapshot_store=store, + snapshot_scope_resolver=lambda request: request.forwarded_props["tenant"], + ) + client = TestClient(app) + + first_response = client.post( + "/snapshots", + json={ + "thread_id": "thread-1", + "messages": [{"role": "user", "content": "Hello tenant A"}], + "state": {"tenant": "tenant-a"}, + "forwardedProps": {"tenant": "tenant-a"}, + }, + ) + assert first_response.status_code == 200 + assert call_count == 1 + + tenant_b_response = client.post( + "/snapshots", + json={"thread_id": "thread-1", "messages": [], "forwardedProps": {"tenant": "tenant-b"}}, + ) + assert tenant_b_response.status_code == 200 + assert call_count == 1 + assert [event.get("type") for event in _decode_sse_events(tenant_b_response)] == [ + "RUN_STARTED", + "RUN_FINISHED", + ] + + tenant_a_response = client.post( + "/snapshots", + json={"thread_id": "thread-1", "messages": [], "forwardedProps": {"tenant": "tenant-a"}}, + ) + assert tenant_a_response.status_code == 200 + assert call_count == 1 + tenant_a_events = _decode_sse_events(tenant_a_response) + assert [event.get("type") for event in tenant_a_events] == [ + "RUN_STARTED", + "STATE_SNAPSHOT", + "MESSAGES_SNAPSHOT", + "RUN_FINISHED", + ] + assert tenant_a_events[1]["snapshot"] == {"tenant": "tenant-a"} + assert any(message.get("content") == "Tenant A reply" for message in tenant_a_events[2]["messages"]) + + +async def test_agent_endpoint_prepends_stored_snapshot_for_new_user_turn(streaming_chat_client_stub): + """A normal agent turn with a known thread id prepends stored history and keeps the new user input.""" + app = FastAPI() + captured_messages: list[list[tuple[str, str]]] = [] + + async def stream_fn(messages: Any, options: Any, **kwargs: Any): + del options, kwargs + captured_messages.append([(message.role, message.text) for message in messages]) + yield ChatResponseUpdate(contents=[Content.from_text(text=f"Reply {len(captured_messages)}")]) + + agent = Agent(name="test", instructions="Test agent", client=streaming_chat_client_stub(stream_fn)) + store = InMemoryAGUIThreadSnapshotStore() + add_agent_framework_fastapi_endpoint( + app, + agent, + path="/snapshots", + state_schema={"recipe": {"type": "string"}}, + snapshot_store=store, + snapshot_scope_resolver=lambda _request: "tenant-a", + ) + client = TestClient(app) + + first_response = client.post( + "/snapshots", + json={ + "thread_id": "thread-1", + "messages": [{"id": "user-1", "role": "user", "content": "Plan dinner"}], + "state": {"recipe": "pasta"}, + }, + ) + assert first_response.status_code == 200 + + second_response = client.post( + "/snapshots", + json={ + "thread_id": "thread-1", + "messages": [{"id": "user-2", "role": "user", "content": "Add dessert"}], + }, + ) + + assert second_response.status_code == 200 + assert len(captured_messages) == 2 + assert captured_messages[1] == [ + ("user", "Plan dinner"), + ("assistant", "Reply 1"), + ( + "system", + ( + "Current state of the application:\n" + '{\n "recipe": "pasta"\n}\n\n' + "When modifying state, you MUST include ALL existing data plus your changes.\n" + "For example, if adding one new item to a list, include ALL existing items PLUS the new item.\n" + "Never replace existing data - always preserve and append or merge." + ), + ), + ("user", "Add dessert"), + ] + events = _decode_sse_events(second_response) + state_snapshots = [event for event in events if event.get("type") == "STATE_SNAPSHOT"] + assert state_snapshots[0]["snapshot"] == {"recipe": "pasta"} + + +async def test_agent_endpoint_deduplicates_full_history_and_merges_fresh_state(streaming_chat_client_stub): + """Stored prior history is authoritative while incoming full history and fresh state remain supported.""" + app = FastAPI() + captured_messages: list[list[tuple[str, str]]] = [] + + async def stream_fn(messages: Any, options: Any, **kwargs: Any): + del options, kwargs + captured_messages.append([(message.role, message.text) for message in messages]) + yield ChatResponseUpdate(contents=[Content.from_text(text=f"Reply {len(captured_messages)}")]) + + agent = Agent(name="test", instructions="Test agent", client=streaming_chat_client_stub(stream_fn)) + store = InMemoryAGUIThreadSnapshotStore() + add_agent_framework_fastapi_endpoint( + app, + agent, + path="/snapshots", + state_schema={"recipe": {"type": "string"}, "theme": {"type": "string"}}, + snapshot_store=store, + snapshot_scope_resolver=lambda _request: "tenant-a", + ) + client = TestClient(app) + + first_response = client.post( + "/snapshots", + json={ + "thread_id": "thread-1", + "messages": [{"id": "user-1", "role": "user", "content": "Plan dinner"}], + "state": {"recipe": "pasta", "theme": "dark"}, + }, + ) + assert first_response.status_code == 200 + first_snapshot = _latest_messages_snapshot(first_response) + + second_response = client.post( + "/snapshots", + json={ + "thread_id": "thread-1", + "messages": [*first_snapshot, {"id": "user-2", "role": "user", "content": "Add dessert"}], + "state": {"recipe": "salad"}, + }, + ) + assert second_response.status_code == 200 + + second_non_system_messages = [message for message in captured_messages[1] if message[0] != "system"] + assert second_non_system_messages == [ + ("user", "Plan dinner"), + ("assistant", "Reply 1"), + ("user", "Add dessert"), + ] + second_events = _decode_sse_events(second_response) + second_state_snapshots = [event for event in second_events if event.get("type") == "STATE_SNAPSHOT"] + assert second_state_snapshots[0]["snapshot"] == {"recipe": "salad", "theme": "dark"} + + second_snapshot = _latest_messages_snapshot(second_response) + conflicting_history = [message.copy() for message in second_snapshot] + conflicting_history[0]["content"] = "Tampered dinner plan" + conflicting_history[1]["content"] = "Tampered reply" + third_response = client.post( + "/snapshots", + json={ + "thread_id": "thread-1", + "messages": [*conflicting_history, {"id": "user-3", "role": "user", "content": "Pick wine"}], + }, + ) + assert third_response.status_code == 200 + + third_texts = [text for role, text in captured_messages[2] if role != "system"] + assert third_texts == ["Plan dinner", "Reply 1", "Add dessert", "Reply 2", "Pick wine"] + assert "Tampered dinner plan" not in third_texts + assert "Tampered reply" not in third_texts + third_state_snapshots = [ + event for event in _decode_sse_events(third_response) if event.get("type") == "STATE_SNAPSHOT" + ] + assert third_state_snapshots[0]["snapshot"] == {"recipe": "salad", "theme": "dark"} + + +async def test_agent_endpoint_hydrates_interrupted_thread_without_invoking_agent(streaming_chat_client_stub): + """Hydrating an interrupted agent replays state, messages, and interrupt metadata without resuming it.""" + app = FastAPI() + call_count = 0 + + async def stream_fn(messages: Any, options: Any, **kwargs: Any): + nonlocal call_count + del messages, options, kwargs + call_count += 1 + yield ChatResponseUpdate( + contents=[ + Content.from_function_call( + name="draft_steps", + call_id="draft-call", + arguments=json.dumps({"steps": [{"description": "Draft outline"}]}), + ) + ], + role="assistant", + ) + + agent = Agent(name="test", instructions="Test agent", client=streaming_chat_client_stub(stream_fn)) + store = InMemoryAGUIThreadSnapshotStore() + add_agent_framework_fastapi_endpoint( + app, + agent, + path="/snapshots", + state_schema={"steps": {"type": "array", "items": {"type": "object"}}}, + predict_state_config={"steps": {"tool": "draft_steps", "tool_argument": "steps"}}, + snapshot_store=store, + snapshot_scope_resolver=lambda _request: "tenant-a", + ) + client = TestClient(app) + + first_response = client.post( + "/snapshots", + json={ + "thread_id": "agent-thread", + "messages": [{"role": "user", "content": "Draft the plan"}], + "state": {"steps": []}, + }, + ) + assert first_response.status_code == 200 + assert call_count == 1 + first_events = _decode_sse_events(first_response) + first_finished = [event for event in first_events if event.get("type") == "RUN_FINISHED"] + assert first_finished[-1]["interrupt"][0]["value"]["function_call"]["call_id"] == "draft-call" + + hydrate_response = client.post("/snapshots", json={"thread_id": "agent-thread", "messages": []}) + + assert hydrate_response.status_code == 200 + assert call_count == 1 + events = _decode_sse_events(hydrate_response) + assert [event.get("type") for event in events] == [ + "RUN_STARTED", + "STATE_SNAPSHOT", + "MESSAGES_SNAPSHOT", + "RUN_FINISHED", + ] + assert events[1]["snapshot"] == {"steps": [{"description": "Draft outline"}]} + assert events[-1]["interrupt"][0]["value"]["function_call"]["name"] == "draft_steps" + + +async def test_agent_endpoint_run_error_does_not_overwrite_previous_snapshot(streaming_chat_client_stub): + """A failing agent turn leaves the last good AG-UI Thread Snapshot available for hydration.""" + app = FastAPI() + call_count = 0 + + async def stream_fn(messages: Any, options: Any, **kwargs: Any): + nonlocal call_count + del messages, options, kwargs + call_count += 1 + if call_count == 1: + yield ChatResponseUpdate(contents=[Content.from_text(text="Stable reply")]) + return + raise RuntimeError("agent exploded") + + agent = Agent(name="test", instructions="Test agent", client=streaming_chat_client_stub(stream_fn)) + store = InMemoryAGUIThreadSnapshotStore() + add_agent_framework_fastapi_endpoint( + app, + agent, + path="/snapshots", + snapshot_store=store, + snapshot_scope_resolver=lambda _request: "tenant-a", + ) + client = TestClient(app) + + first_response = client.post( + "/snapshots", + json={"thread_id": "agent-thread", "messages": [{"role": "user", "content": "Start"}]}, + ) + assert first_response.status_code == 200 + assert call_count == 1 + + error_response = client.post( + "/snapshots", + json={"thread_id": "agent-thread", "messages": [{"role": "user", "content": "Break the run"}]}, + ) + assert error_response.status_code == 200 + assert call_count == 2 + assert "RUN_ERROR" in [event.get("type") for event in _decode_sse_events(error_response)] + + hydrate_response = client.post("/snapshots", json={"thread_id": "agent-thread", "messages": []}) + + assert hydrate_response.status_code == 200 + assert call_count == 2 + messages = _latest_messages_snapshot(hydrate_response) + assert any(message.get("role") == "assistant" and message.get("content") == "Stable reply" for message in messages) + assert not any(message.get("content") == "Break the run" for message in messages) + + +async def test_workflow_endpoint_hydrates_emitted_snapshots_without_invoking_workflow(): + """A workflow Hydrate Request replays emitted snapshots without invoking the wrapped workflow.""" + app = FastAPI() + call_count = 0 + + @executor(id="snapshotter") + async def snapshotter(message: Any, ctx: WorkflowContext) -> None: + nonlocal call_count + del message + call_count += 1 + await ctx.yield_output(StateSnapshotEvent(snapshot={"active_agent": "flights"})) + await ctx.yield_output( + MessagesSnapshotEvent( + messages=[{"id": "assistant-snapshot", "role": "assistant", "content": "Stored workflow reply"}] + ) + ) + + workflow = WorkflowBuilder(start_executor=snapshotter).build() + store = InMemoryAGUIThreadSnapshotStore() + add_agent_framework_fastapi_endpoint( + app, + workflow, + path="/workflow-snapshots", + snapshot_store=store, + snapshot_scope_resolver=lambda _request: "tenant-a", + ) + client = TestClient(app) + + first_response = client.post( + "/workflow-snapshots", + json={"thread_id": "workflow-thread", "messages": [{"role": "user", "content": "Start workflow"}]}, + ) + assert first_response.status_code == 200 + assert call_count == 1 + + hydrate_response = client.post("/workflow-snapshots", json={"thread_id": "workflow-thread", "messages": []}) + + assert hydrate_response.status_code == 200 + assert call_count == 1 + events = _decode_sse_events(hydrate_response) + assert [event.get("type") for event in events] == [ + "RUN_STARTED", + "STATE_SNAPSHOT", + "MESSAGES_SNAPSHOT", + "RUN_FINISHED", + ] + assert events[1]["snapshot"] == {"active_agent": "flights"} + assert events[2]["messages"] == [ + {"id": "assistant-snapshot", "role": "assistant", "content": "Stored workflow reply"} + ] + + +async def test_workflow_endpoint_hydrates_synthesized_text_and_tool_snapshot(): + """Workflow text and tool output are synthesized into replayable snapshot messages.""" + app = FastAPI() + call_count = 0 + + @executor(id="responder") + async def responder(message: Any, ctx: WorkflowContext) -> None: + nonlocal call_count + del message + call_count += 1 + await ctx.yield_output("Workflow answer") + await ctx.yield_output( + [ + Content.from_function_call( + name="lookup_weather", + call_id="call-1", + arguments='{"city":"SF"}', + ), + Content.from_function_result(call_id="call-1", result="72F"), + ] + ) + await ctx.yield_output({"diagnostic": "not persisted"}) + + workflow = WorkflowBuilder(start_executor=responder).build() + store = InMemoryAGUIThreadSnapshotStore() + add_agent_framework_fastapi_endpoint( + app, + workflow, + path="/workflow-snapshots", + snapshot_store=store, + snapshot_scope_resolver=lambda _request: "tenant-a", + ) + client = TestClient(app) + + first_response = client.post( + "/workflow-snapshots", + json={ + "thread_id": "workflow-thread", + "messages": [{"id": "user-1", "role": "user", "content": "Start workflow"}], + }, + ) + assert first_response.status_code == 200 + assert call_count == 1 + + hydrate_response = client.post("/workflow-snapshots", json={"thread_id": "workflow-thread", "messages": []}) + + assert hydrate_response.status_code == 200 + assert call_count == 1 + events = _decode_sse_events(hydrate_response) + assert [event.get("type") for event in events] == ["RUN_STARTED", "MESSAGES_SNAPSHOT", "RUN_FINISHED"] + messages = events[1]["messages"] + assert any(message.get("role") == "user" and message.get("content") == "Start workflow" for message in messages) + assert any( + message.get("role") == "assistant" and message.get("content") == "Workflow answer" for message in messages + ) + tool_call_messages = [ + message for message in messages if message.get("role") == "assistant" and message.get("toolCalls") + ] + assert len(tool_call_messages) == 1 + tool_call = tool_call_messages[0]["toolCalls"][0] + assert tool_call["id"] == "call-1" + assert tool_call["function"] == {"name": "lookup_weather", "arguments": '{"city":"SF"}'} + assert any( + message.get("role") == "tool" and message.get("toolCallId") == "call-1" and message.get("content") == "72F" + for message in messages + ) + + +async def test_workflow_endpoint_hydrates_interrupted_thread_without_invoking_workflow(): + """Hydrating an interrupted workflow replays state, messages, and interrupt metadata without resuming it.""" + app = FastAPI() + call_count = 0 + + @executor(id="requester") + async def requester(message: Any, ctx: WorkflowContext) -> None: + nonlocal call_count + del message + call_count += 1 + await ctx.yield_output(StateSnapshotEvent(snapshot={"step": "approval"})) + await ctx.request_info( + {"message": "Approve workflow step", "options": ["Approve", "Reject"]}, + dict, + request_id="workflow-approval", + ) + + workflow = WorkflowBuilder(start_executor=requester).build() + store = InMemoryAGUIThreadSnapshotStore() + add_agent_framework_fastapi_endpoint( + app, + workflow, + path="/workflow-snapshots", + snapshot_store=store, + snapshot_scope_resolver=lambda _request: "tenant-a", + ) + client = TestClient(app) + + first_response = client.post( + "/workflow-snapshots", + json={"thread_id": "workflow-thread", "messages": [{"role": "user", "content": "Start workflow"}]}, + ) + assert first_response.status_code == 200 + assert call_count == 1 + first_finished = [event for event in _decode_sse_events(first_response) if event.get("type") == "RUN_FINISHED"] + assert first_finished[-1]["interrupt"][0]["id"] == "workflow-approval" + + hydrate_response = client.post("/workflow-snapshots", json={"thread_id": "workflow-thread", "messages": []}) + + assert hydrate_response.status_code == 200 + assert call_count == 1 + events = _decode_sse_events(hydrate_response) + assert [event.get("type") for event in events] == [ + "RUN_STARTED", + "STATE_SNAPSHOT", + "MESSAGES_SNAPSHOT", + "RUN_FINISHED", + ] + assert events[1]["snapshot"] == {"step": "approval"} + assert events[-1]["interrupt"][0]["id"] == "workflow-approval" + assert events[-1]["interrupt"][0]["value"]["message"] == "Approve workflow step" + + +async def test_workflow_endpoint_run_error_does_not_overwrite_previous_snapshot(): + """A failing workflow turn leaves the last good AG-UI Thread Snapshot available for hydration.""" + app = FastAPI() + call_count = 0 + + @executor(id="responder") + async def responder(message: Any, ctx: WorkflowContext) -> None: + nonlocal call_count + del message + call_count += 1 + if call_count == 1: + await ctx.yield_output("Stable workflow reply") + return + raise RuntimeError("workflow exploded") + + workflow = WorkflowBuilder(start_executor=responder).build() + store = InMemoryAGUIThreadSnapshotStore() + add_agent_framework_fastapi_endpoint( + app, + workflow, + path="/workflow-snapshots", + snapshot_store=store, + snapshot_scope_resolver=lambda _request: "tenant-a", + ) + client = TestClient(app) + + first_response = client.post( + "/workflow-snapshots", + json={"thread_id": "workflow-thread", "messages": [{"role": "user", "content": "Start workflow"}]}, + ) + assert first_response.status_code == 200 + assert call_count == 1 + + error_response = client.post( + "/workflow-snapshots", + json={"thread_id": "workflow-thread", "messages": [{"role": "user", "content": "Break workflow"}]}, + ) + assert error_response.status_code == 200 + assert call_count == 2 + assert "RUN_ERROR" in [event.get("type") for event in _decode_sse_events(error_response)] + + hydrate_response = client.post("/workflow-snapshots", json={"thread_id": "workflow-thread", "messages": []}) + + assert hydrate_response.status_code == 200 + assert call_count == 2 + messages = _latest_messages_snapshot(hydrate_response) + assert any( + message.get("role") == "assistant" and message.get("content") == "Stable workflow reply" for message in messages + ) + assert not any(message.get("content") == "Break workflow" for message in messages) + + async def test_endpoint_encoding_failure_emits_run_error(): """Event encoding failure emits RUN_ERROR event in the SSE stream.""" from unittest.mock import patch @@ -603,3 +1256,589 @@ async def run(self, input_data: dict[str, Any]): # Should still get 200 (SSE stream), just with no events assert response.status_code == 200 + + +async def test_agent_endpoint_confirm_changes_clears_persisted_interrupt(streaming_chat_client_stub): + """A confirm_changes response persists the completed turn and clears the stored interrupt.""" + app = FastAPI() + call_count = 0 + + async def stream_fn(messages: Any, options: Any, **kwargs: Any): + nonlocal call_count + del messages, options, kwargs + call_count += 1 + yield ChatResponseUpdate( + contents=[ + Content.from_function_call( + name="draft_steps", + call_id="draft-call", + arguments=json.dumps({"steps": [{"description": "Draft outline"}]}), + ) + ], + role="assistant", + ) + + agent = Agent(name="test", instructions="Test agent", client=streaming_chat_client_stub(stream_fn)) + store = InMemoryAGUIThreadSnapshotStore() + add_agent_framework_fastapi_endpoint( + app, + agent, + path="/snapshots", + state_schema={"steps": {"type": "array", "items": {"type": "object"}}}, + predict_state_config={"steps": {"tool": "draft_steps", "tool_argument": "steps"}}, + snapshot_store=store, + snapshot_scope_resolver=lambda _request: "tenant-a", + ) + client = TestClient(app) + + first_response = client.post( + "/snapshots", + json={ + "thread_id": "agent-thread", + "messages": [{"id": "user-1", "role": "user", "content": "Draft the plan"}], + "state": {"steps": []}, + }, + ) + assert first_response.status_code == 200 + assert call_count == 1 + first_events = _decode_sse_events(first_response) + first_finished = [event for event in first_events if event.get("type") == "RUN_FINISHED"] + assert first_finished[-1]["interrupt"] + confirm_call_id = first_finished[-1]["interrupt"][0]["id"] + + confirm_response = client.post( + "/snapshots", + json={ + "thread_id": "agent-thread", + "messages": [], + "resume": {"interrupts": [{"id": confirm_call_id, "value": json.dumps({"accepted": True, "steps": []})}]}, + }, + ) + assert confirm_response.status_code == 200 + assert call_count == 1 + confirm_event_types = [event.get("type") for event in _decode_sse_events(confirm_response)] + assert "TEXT_MESSAGE_CONTENT" in confirm_event_types + + hydrate_response = client.post("/snapshots", json={"thread_id": "agent-thread", "messages": []}) + + assert hydrate_response.status_code == 200 + assert call_count == 1 + events = _decode_sse_events(hydrate_response) + assert not events[-1].get("interrupt") + messages = _latest_messages_snapshot(hydrate_response) + assert any( + message.get("role") == "assistant" and message.get("content") == "Changes confirmed and applied successfully!" + for message in messages + ) + assert any(message.get("role") == "user" and message.get("content") == "Draft the plan" for message in messages) + + +async def test_agent_endpoint_default_state_does_not_reset_persisted_state(streaming_chat_client_stub): + """Endpoint defaults fill missing keys but never override persisted Shared State.""" + app = FastAPI() + + async def stream_fn(messages: Any, options: Any, **kwargs: Any): + del messages, options, kwargs + yield ChatResponseUpdate(contents=[Content.from_text(text="Reply")]) + + agent = Agent(name="test", instructions="Test agent", client=streaming_chat_client_stub(stream_fn)) + store = InMemoryAGUIThreadSnapshotStore() + add_agent_framework_fastapi_endpoint( + app, + agent, + path="/snapshots", + state_schema={"recipe": {"type": "string"}}, + default_state={"recipe": ""}, + snapshot_store=store, + snapshot_scope_resolver=lambda _request: "tenant-a", + ) + client = TestClient(app) + + fresh_response = client.post( + "/snapshots", + json={"thread_id": "thread-fresh", "messages": [{"id": "user-0", "role": "user", "content": "Hi"}]}, + ) + assert fresh_response.status_code == 200 + fresh_state_snapshots = [ + event for event in _decode_sse_events(fresh_response) if event.get("type") == "STATE_SNAPSHOT" + ] + assert fresh_state_snapshots[0]["snapshot"] == {"recipe": ""} + + first_response = client.post( + "/snapshots", + json={ + "thread_id": "thread-1", + "messages": [{"id": "user-1", "role": "user", "content": "Plan dinner"}], + "state": {"recipe": "pasta"}, + }, + ) + assert first_response.status_code == 200 + + second_response = client.post( + "/snapshots", + json={ + "thread_id": "thread-1", + "messages": [{"id": "user-2", "role": "user", "content": "Add dessert"}], + }, + ) + assert second_response.status_code == 200 + second_state_snapshots = [ + event for event in _decode_sse_events(second_response) if event.get("type") == "STATE_SNAPSHOT" + ] + assert second_state_snapshots[0]["snapshot"] == {"recipe": "pasta"} + + hydrate_response = client.post("/snapshots", json={"thread_id": "thread-1", "messages": []}) + assert hydrate_response.status_code == 200 + hydrate_events = _decode_sse_events(hydrate_response) + hydrate_state_snapshots = [event for event in hydrate_events if event.get("type") == "STATE_SNAPSHOT"] + assert hydrate_state_snapshots[0]["snapshot"] == {"recipe": "pasta"} + + +async def test_agent_endpoint_persists_turn_output_when_intermediate_snapshot_suppressed(streaming_chat_client_stub): + """A no-confirmation predictive turn persists tool output even when the outbound snapshot is suppressed.""" + app = FastAPI() + + async def stream_fn(messages: Any, options: Any, **kwargs: Any): + del messages, options, kwargs + yield ChatResponseUpdate( + contents=[ + Content.from_function_call( + name="write_doc", + call_id="doc-call", + arguments=json.dumps({"document": "Draft text"}), + ) + ], + role="assistant", + ) + yield ChatResponseUpdate( + contents=[Content.from_function_result(call_id="doc-call", result="ok")], + role="tool", + ) + yield ChatResponseUpdate(contents=[Content.from_text(text="Done writing")], role="assistant") + + agent = Agent(name="test", instructions="Test agent", client=streaming_chat_client_stub(stream_fn)) + wrapped = AgentFrameworkAgent( + agent=agent, + state_schema={"document": {"type": "string"}}, + predict_state_config={"document": {"tool": "write_doc", "tool_argument": "document"}}, + require_confirmation=False, + ) + store = InMemoryAGUIThreadSnapshotStore() + add_agent_framework_fastapi_endpoint( + app, + wrapped, + path="/snapshots", + snapshot_store=store, + snapshot_scope_resolver=lambda _request: "tenant-a", + ) + client = TestClient(app) + + first_response = client.post( + "/snapshots", + json={ + "thread_id": "doc-thread", + "messages": [{"id": "user-1", "role": "user", "content": "Write the doc"}], + }, + ) + assert first_response.status_code == 200 + first_event_types = [event.get("type") for event in _decode_sse_events(first_response)] + assert "MESSAGES_SNAPSHOT" not in first_event_types + + hydrate_response = client.post("/snapshots", json={"thread_id": "doc-thread", "messages": []}) + + assert hydrate_response.status_code == 200 + messages = _latest_messages_snapshot(hydrate_response) + assert any(message.get("role") == "assistant" and message.get("content") == "Done writing" for message in messages) + assert any(message.get("role") == "tool" and message.get("toolCallId") == "doc-call" for message in messages) + + +async def test_workflow_preserves_history_across_turns(): + """Workflow follow-up turns merge stored history so persisted snapshots keep earlier turns. + + Uses async runner.run() directly instead of HTTP TestClient because the sync + TestClient runs each request in a different event loop, which conflicts with + the workflow's asyncio Queue across turns. + """ + from agent_framework_ag_ui._snapshots import _SNAPSHOT_SCOPE_INPUT_KEY + + call_count = 0 + + @executor(id="responder") + async def responder(message: Any, ctx: WorkflowContext) -> None: + nonlocal call_count + del message + call_count += 1 + await ctx.yield_output(f"Workflow reply {call_count}") + + workflow = WorkflowBuilder(start_executor=responder).build() + store = InMemoryAGUIThreadSnapshotStore() + runner = AgentFrameworkWorkflow(workflow=workflow, snapshot_store=store) + + first_events = [ + event + async for event in runner.run( + { + "thread_id": "workflow-thread", + "run_id": "run-1", + "messages": [{"id": "user-1", "role": "user", "content": "First question"}], + _SNAPSHOT_SCOPE_INPUT_KEY: "tenant-a", + } + ) + ] + assert first_events + assert call_count == 1 + + second_events = [ + event + async for event in runner.run( + { + "thread_id": "workflow-thread", + "run_id": "run-2", + "messages": [{"id": "user-2", "role": "user", "content": "Second question"}], + _SNAPSHOT_SCOPE_INPUT_KEY: "tenant-a", + } + ) + ] + assert second_events + assert call_count == 2 + + snapshot = await store.get(scope="tenant-a", thread_id="workflow-thread") + assert snapshot is not None + contents = [message.get("content") for message in snapshot.messages] + assert "First question" in contents + assert "Workflow reply 1" in contents + assert "Second question" in contents + assert "Workflow reply 2" in contents + + hydrate_events = [ + event + async for event in runner.run( + { + "thread_id": "workflow-thread", + "run_id": "run-3", + "messages": [], + _SNAPSHOT_SCOPE_INPUT_KEY: "tenant-a", + } + ) + ] + assert call_count == 2 + hydrated_snapshots = [event for event in hydrate_events if isinstance(event, MessagesSnapshotEvent)] + assert hydrated_snapshots + + +async def test_agent_endpoint_resume_preserves_persisted_history(streaming_chat_client_stub): + """A generic interrupt resume keeps stored history in the persisted snapshot.""" + app = FastAPI() + call_count = 0 + + async def stream_fn(messages: Any, options: Any, **kwargs: Any): + nonlocal call_count + del messages, options, kwargs + call_count += 1 + if call_count == 1: + yield ChatResponseUpdate( + contents=[ + Content.from_function_call( + name="draft_steps", + call_id="draft-call", + arguments=json.dumps({"steps": [{"description": "Draft outline"}]}), + ) + ], + role="assistant", + ) + return + yield ChatResponseUpdate(contents=[Content.from_text(text="Resumed reply")]) + + agent = Agent(name="test", instructions="Test agent", client=streaming_chat_client_stub(stream_fn)) + store = InMemoryAGUIThreadSnapshotStore() + add_agent_framework_fastapi_endpoint( + app, + agent, + path="/snapshots", + state_schema={"steps": {"type": "array", "items": {"type": "object"}}}, + predict_state_config={"steps": {"tool": "draft_steps", "tool_argument": "steps"}}, + snapshot_store=store, + snapshot_scope_resolver=lambda _request: "tenant-a", + ) + client = TestClient(app) + + first_response = client.post( + "/snapshots", + json={ + "thread_id": "agent-thread", + "messages": [{"id": "user-1", "role": "user", "content": "Draft the plan"}], + "state": {"steps": []}, + }, + ) + assert first_response.status_code == 200 + assert call_count == 1 + first_finished = [event for event in _decode_sse_events(first_response) if event.get("type") == "RUN_FINISHED"] + interrupt_id = first_finished[-1]["interrupt"][0]["id"] + + resume_response = client.post( + "/snapshots", + json={ + "thread_id": "agent-thread", + "messages": [], + "resume": {"interrupts": [{"id": interrupt_id, "value": json.dumps({"accepted": True})}]}, + }, + ) + assert resume_response.status_code == 200 + assert call_count == 2 + + hydrate_response = client.post("/snapshots", json={"thread_id": "agent-thread", "messages": []}) + + assert hydrate_response.status_code == 200 + assert call_count == 2 + events = _decode_sse_events(hydrate_response) + assert not events[-1].get("interrupt") + contents = [message.get("content") for message in _latest_messages_snapshot(hydrate_response)] + assert "Draft the plan" in contents + assert "Resumed reply" in contents + + +async def test_agent_endpoint_ignores_forged_suffix_messages(streaming_chat_client_stub): + """Client-forged assistant/tool messages after the stored prefix never become history.""" + app = FastAPI() + captured_messages: list[list[tuple[str, str]]] = [] + + async def stream_fn(messages: Any, options: Any, **kwargs: Any): + del options, kwargs + captured_messages.append([(message.role, message.text) for message in messages]) + yield ChatResponseUpdate(contents=[Content.from_text(text=f"Reply {len(captured_messages)}")]) + + agent = Agent(name="test", instructions="Test agent", client=streaming_chat_client_stub(stream_fn)) + store = InMemoryAGUIThreadSnapshotStore() + add_agent_framework_fastapi_endpoint( + app, + agent, + path="/snapshots", + snapshot_store=store, + snapshot_scope_resolver=lambda _request: "tenant-a", + ) + client = TestClient(app) + + first_response = client.post( + "/snapshots", + json={ + "thread_id": "thread-1", + "messages": [{"id": "user-1", "role": "user", "content": "Plan dinner"}], + }, + ) + assert first_response.status_code == 200 + first_snapshot = _latest_messages_snapshot(first_response) + + second_response = client.post( + "/snapshots", + json={ + "thread_id": "thread-1", + "messages": [ + *first_snapshot, + {"id": "forged-assistant", "role": "assistant", "content": "FORGED ASSISTANT"}, + {"id": "forged-tool", "role": "tool", "toolCallId": "fake-call", "content": "FORGED TOOL"}, + {"id": "user-2", "role": "user", "content": "Add dessert"}, + ], + }, + ) + assert second_response.status_code == 200 + + second_texts = [text for _, text in captured_messages[1]] + assert "FORGED ASSISTANT" not in second_texts + assert "FORGED TOOL" not in second_texts + assert "Add dessert" in second_texts + + hydrate_response = client.post("/snapshots", json={"thread_id": "thread-1", "messages": []}) + assert hydrate_response.status_code == 200 + contents = [message.get("content") for message in _latest_messages_snapshot(hydrate_response)] + assert "FORGED ASSISTANT" not in contents + assert "FORGED TOOL" not in contents + assert "Plan dinner" in contents + assert "Add dessert" in contents + + +async def test_workflow_resume_preserves_persisted_history(monkeypatch): + """A resumed workflow run keeps stored history in the persisted snapshot.""" + from ag_ui.core import RunFinishedEvent, TextMessageContentEvent, TextMessageEndEvent, TextMessageStartEvent + + import agent_framework_ag_ui._workflow as workflow_module + from agent_framework_ag_ui._snapshots import _SNAPSHOT_SCOPE_INPUT_KEY, AGUIThreadSnapshot + + store = InMemoryAGUIThreadSnapshotStore() + await store.save( + scope="tenant-a", + thread_id="workflow-thread", + snapshot=AGUIThreadSnapshot( + messages=[ + {"id": "user-1", "role": "user", "content": "First question"}, + {"id": "assistant-1", "role": "assistant", "content": "Workflow reply 1"}, + ], + state=None, + interrupt=[{"id": "interrupt-1", "value": {"agent": "flights"}}], + ), + ) + + async def fake_run_workflow_stream(input_data: Any, workflow: Any): + del input_data, workflow + yield RunStartedEvent(run_id="run-2", thread_id="workflow-thread") + yield TextMessageStartEvent(message_id="resume-msg", role="assistant") + yield TextMessageContentEvent(message_id="resume-msg", delta="Resumed reply") + yield TextMessageEndEvent(message_id="resume-msg") + yield RunFinishedEvent(run_id="run-2", thread_id="workflow-thread") + + monkeypatch.setattr(workflow_module, "run_workflow_stream", fake_run_workflow_stream) + + @executor(id="noop") + async def noop(message: Any, ctx: WorkflowContext) -> None: + del message, ctx + + runner = AgentFrameworkWorkflow( + workflow=WorkflowBuilder(start_executor=noop).build(), + snapshot_store=store, + ) + + events = [ + event + async for event in runner.run( + { + "thread_id": "workflow-thread", + "run_id": "run-2", + "messages": [], + "resume": {"interrupts": [{"id": "interrupt-1", "value": "United"}]}, + _SNAPSHOT_SCOPE_INPUT_KEY: "tenant-a", + } + ) + ] + assert events + + snapshot = await store.get(scope="tenant-a", thread_id="workflow-thread") + assert snapshot is not None + contents = [message.get("content") for message in snapshot.messages] + assert "First question" in contents + assert "Workflow reply 1" in contents + assert "Resumed reply" in contents + assert snapshot.interrupt is None + + +class _FailingSaveStore(InMemoryAGUIThreadSnapshotStore): + """Store whose save always fails, simulating a transient backend outage.""" + + async def save(self, *, scope: str, thread_id: str, snapshot: Any) -> None: + raise RuntimeError("store down") + + +async def test_agent_endpoint_snapshot_save_failure_does_not_fail_run(streaming_chat_client_stub): + """A failing snapshot save must not turn a completed agent run into RUN_ERROR.""" + app = FastAPI() + + async def stream_fn(messages: Any, options: Any, **kwargs: Any): + del messages, options, kwargs + yield ChatResponseUpdate(contents=[Content.from_text(text="Reply")]) + + agent = Agent(name="test", instructions="Test agent", client=streaming_chat_client_stub(stream_fn)) + add_agent_framework_fastapi_endpoint( + app, + agent, + path="/snapshots", + snapshot_store=_FailingSaveStore(), + snapshot_scope_resolver=lambda _request: "tenant-a", + ) + client = TestClient(app) + + response = client.post( + "/snapshots", + json={"thread_id": "thread-1", "messages": [{"role": "user", "content": "Hello"}]}, + ) + + assert response.status_code == 200 + event_types = [event.get("type") for event in _decode_sse_events(response)] + assert "RUN_FINISHED" in event_types + assert "RUN_ERROR" not in event_types + + +async def test_workflow_endpoint_snapshot_save_failure_does_not_emit_run_error(): + """A failing snapshot save after RUN_FINISHED must not emit a second terminal RUN_ERROR.""" + + @executor(id="responder") + async def responder(message: Any, ctx: WorkflowContext) -> None: + del message + await ctx.yield_output("Workflow reply") + + app = FastAPI() + workflow = WorkflowBuilder(start_executor=responder).build() + add_agent_framework_fastapi_endpoint( + app, + workflow, + path="/workflow-snapshots", + snapshot_store=_FailingSaveStore(), + snapshot_scope_resolver=lambda _request: "tenant-a", + ) + client = TestClient(app) + + response = client.post( + "/workflow-snapshots", + json={"thread_id": "workflow-thread", "messages": [{"role": "user", "content": "Hello"}]}, + ) + + assert response.status_code == 200 + event_types = [event.get("type") for event in _decode_sse_events(response)] + assert "RUN_FINISHED" in event_types + assert "RUN_ERROR" not in event_types + + +async def test_endpoint_supports_async_snapshot_scope_resolver(streaming_chat_client_stub): + """An async snapshot_scope_resolver is awaited before snapshots load or save.""" + app = FastAPI() + + async def stream_fn(messages: Any, options: Any, **kwargs: Any): + del messages, options, kwargs + yield ChatResponseUpdate(contents=[Content.from_text(text="Reply")]) + + async def resolve_scope(_request: Any) -> str: + return "tenant-async" + + agent = Agent(name="test", instructions="Test agent", client=streaming_chat_client_stub(stream_fn)) + store = InMemoryAGUIThreadSnapshotStore() + add_agent_framework_fastapi_endpoint( + app, + agent, + path="/snapshots", + snapshot_store=store, + snapshot_scope_resolver=resolve_scope, + ) + client = TestClient(app) + + response = client.post( + "/snapshots", + json={"thread_id": "thread-1", "messages": [{"role": "user", "content": "Hello"}]}, + ) + + assert response.status_code == 200 + snapshot = await store.get(scope="tenant-async", thread_id="thread-1") + assert snapshot is not None + assert any(message.get("content") == "Reply" for message in snapshot.messages) + + +def test_workflow_factory_cache_is_scoped_by_snapshot_scope(): + """The same thread id under different Snapshot Scopes must not share a workflow instance.""" + + @executor(id="noop") + async def noop(message: Any, ctx: WorkflowContext) -> None: + del message, ctx + + def factory(thread_id: str) -> Any: + del thread_id + return WorkflowBuilder(start_executor=noop).build() + + runner = AgentFrameworkWorkflow(workflow_factory=factory) + + workflow_a = runner._resolve_workflow("thread-1", "tenant-a") + workflow_b = runner._resolve_workflow("thread-1", "tenant-b") + assert workflow_a is not workflow_b + assert runner._resolve_workflow("thread-1", "tenant-a") is workflow_a + + runner.clear_thread_workflow("thread-1", snapshot_scope="tenant-a") + assert runner._resolve_workflow("thread-1", "tenant-a") is not workflow_a + assert runner._resolve_workflow("thread-1", "tenant-b") is workflow_b + + runner.clear_thread_workflow("thread-1") + assert runner._resolve_workflow("thread-1", "tenant-b") is not workflow_b diff --git a/python/packages/ag-ui/tests/ag_ui/test_public_exports.py b/python/packages/ag-ui/tests/ag_ui/test_public_exports.py index ea570f50a69..daa0d8e4c9c 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_public_exports.py +++ b/python/packages/ag-ui/tests/ag_ui/test_public_exports.py @@ -32,6 +32,21 @@ def test_agent_framework_ag_ui_exports_state_update() -> None: assert callable(state_update) +def test_agent_framework_ag_ui_exports_snapshot_primitives() -> None: + """Runtime package should export AG-UI Thread Snapshot primitives.""" + from agent_framework_ag_ui import ( + DEFAULT_MAX_THREAD_SNAPSHOTS, + AGUIThreadSnapshot, + AGUIThreadSnapshotStore, + InMemoryAGUIThreadSnapshotStore, + ) + + assert AGUIThreadSnapshot.__name__ == "AGUIThreadSnapshot" + assert AGUIThreadSnapshotStore.__name__ == "AGUIThreadSnapshotStore" + assert InMemoryAGUIThreadSnapshotStore.__name__ == "InMemoryAGUIThreadSnapshotStore" + assert DEFAULT_MAX_THREAD_SNAPSHOTS >= 1 + + def test_core_ag_ui_lazy_exports_include_event_converter_and_http_service() -> None: """Core facade must expose AGUIEventConverter, AGUIHttpService, and __version__.""" from agent_framework import ag_ui @@ -39,3 +54,13 @@ def test_core_ag_ui_lazy_exports_include_event_converter_and_http_service() -> N assert hasattr(ag_ui, "AGUIEventConverter") assert hasattr(ag_ui, "AGUIHttpService") assert hasattr(ag_ui, "__version__") + + +def test_core_ag_ui_lazy_exports_include_snapshot_primitives() -> None: + """Core facade must expose snapshot primitives needed for endpoint configuration.""" + from agent_framework import ag_ui + + assert hasattr(ag_ui, "AGUIThreadSnapshot") + assert hasattr(ag_ui, "AGUIThreadSnapshotStore") + assert hasattr(ag_ui, "InMemoryAGUIThreadSnapshotStore") + assert hasattr(ag_ui, "SnapshotScopeResolver") diff --git a/python/packages/ag-ui/tests/ag_ui/test_snapshots.py b/python/packages/ag-ui/tests/ag_ui/test_snapshots.py new file mode 100644 index 00000000000..427de89a367 --- /dev/null +++ b/python/packages/ag-ui/tests/ag_ui/test_snapshots.py @@ -0,0 +1,160 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for AG-UI thread snapshot storage primitives.""" + +from dataclasses import fields + +from agent_framework_ag_ui import AGUIThreadSnapshot, AGUIThreadSnapshotStore, InMemoryAGUIThreadSnapshotStore + + +def test_thread_snapshot_model_contains_only_replayable_snapshot_fields() -> None: + """The public snapshot model is limited to messages, Shared State, and interruption state.""" + assert [field.name for field in fields(AGUIThreadSnapshot)] == ["messages", "state", "interrupt"] + + +def test_in_memory_snapshot_store_satisfies_snapshot_store_protocol() -> None: + """The built-in store conforms to the public async store protocol.""" + assert isinstance(InMemoryAGUIThreadSnapshotStore(), AGUIThreadSnapshotStore) + + +async def test_in_memory_snapshot_store_replaces_latest_snapshot() -> None: + """Saving the same scoped thread key replaces the previous snapshot.""" + store = InMemoryAGUIThreadSnapshotStore() + + await store.save( + scope="tenant-a", + thread_id="thread-1", + snapshot=AGUIThreadSnapshot(messages=[{"id": "first"}], state={"count": 1}), + ) + await store.save( + scope="tenant-a", + thread_id="thread-1", + snapshot=AGUIThreadSnapshot(messages=[{"id": "second"}], state={"count": 2}), + ) + + snapshot = await store.get(scope="tenant-a", thread_id="thread-1") + + assert snapshot is not None + assert snapshot.messages == [{"id": "second"}] + assert snapshot.state == {"count": 2} + + +async def test_in_memory_snapshot_store_keeps_scopes_separate() -> None: + """The same AG-UI Thread id in different Snapshot Scopes addresses different snapshots.""" + store = InMemoryAGUIThreadSnapshotStore() + + await store.save( + scope="tenant-a", + thread_id="thread-1", + snapshot=AGUIThreadSnapshot(messages=[{"id": "a", "role": "user", "content": "from a"}]), + ) + await store.save( + scope="tenant-b", + thread_id="thread-1", + snapshot=AGUIThreadSnapshot(messages=[{"id": "b", "role": "user", "content": "from b"}]), + ) + + tenant_a_snapshot = await store.get(scope="tenant-a", thread_id="thread-1") + tenant_b_snapshot = await store.get(scope="tenant-b", thread_id="thread-1") + + assert tenant_a_snapshot is not None + assert tenant_b_snapshot is not None + assert tenant_a_snapshot.messages == [{"id": "a", "role": "user", "content": "from a"}] + assert tenant_b_snapshot.messages == [{"id": "b", "role": "user", "content": "from b"}] + + +async def test_in_memory_snapshot_store_deletes_and_clears_snapshots() -> None: + """Delete removes one scoped thread key, while clear can remove a scope or the whole store.""" + store = InMemoryAGUIThreadSnapshotStore() + + await store.save(scope="tenant-a", thread_id="thread-1", snapshot=AGUIThreadSnapshot(messages=[{"id": "a1"}])) + await store.save(scope="tenant-a", thread_id="thread-2", snapshot=AGUIThreadSnapshot(messages=[{"id": "a2"}])) + await store.save(scope="tenant-b", thread_id="thread-1", snapshot=AGUIThreadSnapshot(messages=[{"id": "b1"}])) + + assert await store.delete(scope="tenant-a", thread_id="thread-1") is True + assert await store.delete(scope="tenant-a", thread_id="thread-1") is False + assert await store.get(scope="tenant-a", thread_id="thread-1") is None + assert await store.get(scope="tenant-a", thread_id="thread-2") is not None + + await store.clear(scope="tenant-a") + + assert await store.get(scope="tenant-a", thread_id="thread-2") is None + assert await store.get(scope="tenant-b", thread_id="thread-1") is not None + + await store.clear() + + assert await store.get(scope="tenant-b", thread_id="thread-1") is None + + +async def test_in_memory_snapshot_store_evicts_oldest_snapshot_when_bounded() -> None: + """The memory store bounds retained scoped thread snapshots.""" + store = InMemoryAGUIThreadSnapshotStore(max_snapshots=2) + + await store.save(scope="tenant-a", thread_id="thread-1", snapshot=AGUIThreadSnapshot(messages=[{"id": "first"}])) + await store.save(scope="tenant-a", thread_id="thread-2", snapshot=AGUIThreadSnapshot(messages=[{"id": "second"}])) + await store.save(scope="tenant-a", thread_id="thread-3", snapshot=AGUIThreadSnapshot(messages=[{"id": "third"}])) + + assert await store.get(scope="tenant-a", thread_id="thread-1") is None + assert await store.get(scope="tenant-a", thread_id="thread-2") is not None + assert await store.get(scope="tenant-a", thread_id="thread-3") is not None + + +def test_workflow_snapshot_builder_splits_tool_call_groups() -> None: + """Tool calls separated by results or text synthesize provider-valid message groups.""" + from ag_ui.core import ( + TextMessageContentEvent, + TextMessageEndEvent, + TextMessageStartEvent, + ToolCallArgsEvent, + ToolCallResultEvent, + ToolCallStartEvent, + ) + + from agent_framework_ag_ui._workflow import _WorkflowSnapshotBuilder + + builder = _WorkflowSnapshotBuilder([]) + builder.observe(ToolCallStartEvent(tool_call_id="call-a", tool_call_name="toolA")) + builder.observe(ToolCallArgsEvent(tool_call_id="call-a", delta='{"x": 1}')) + builder.observe(ToolCallResultEvent(message_id="result-a", tool_call_id="call-a", content="resA")) + builder.observe(TextMessageStartEvent(message_id="text-1", role="assistant")) + builder.observe(TextMessageContentEvent(message_id="text-1", delta="thinking")) + builder.observe(TextMessageEndEvent(message_id="text-1")) + builder.observe(ToolCallStartEvent(tool_call_id="call-b", tool_call_name="toolB")) + builder.observe(ToolCallResultEvent(message_id="result-b", tool_call_id="call-b", content="resB")) + + messages = builder.build().messages + shapes = [ + ( + message.get("role"), + [tool_call["id"] for tool_call in message.get("tool_calls", [])] or message.get("toolCallId"), + ) + for message in messages + ] + assert shapes == [ + ("assistant", ["call-a"]), + ("tool", "call-a"), + ("assistant", None), + ("assistant", ["call-b"]), + ("tool", "call-b"), + ] + + +async def test_in_memory_snapshot_store_rejects_invalid_keys() -> None: + """Key parts must be non-empty strings for every store operation.""" + import pytest + + store = InMemoryAGUIThreadSnapshotStore() + snapshot = AGUIThreadSnapshot() + + with pytest.raises(ValueError): + await store.save(scope="", thread_id="thread-1", snapshot=snapshot) + with pytest.raises(ValueError): + await store.save(scope="tenant-a", thread_id="", snapshot=snapshot) + with pytest.raises(TypeError): + await store.save(scope=123, thread_id="thread-1", snapshot=snapshot) # type: ignore[arg-type] + with pytest.raises(ValueError): + await store.get(scope="tenant-a", thread_id="") + with pytest.raises(TypeError): + await store.delete(scope=None, thread_id="thread-1") # type: ignore[arg-type] + with pytest.raises(ValueError): + await store.clear(scope="") diff --git a/python/packages/core/agent_framework/__init__.py b/python/packages/core/agent_framework/__init__.py index 03a32f1a9c9..9bdebd0a03b 100644 --- a/python/packages/core/agent_framework/__init__.py +++ b/python/packages/core/agent_framework/__init__.py @@ -125,7 +125,6 @@ TodoSessionStore, TodoStore, ) -from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPTaskOptions, MCPWebsocketTool, SamplingApprovalCallback from ._harness._tool_approval import ( DEFAULT_TOOL_APPROVAL_SOURCE_ID, ToolApprovalMiddleware, @@ -135,6 +134,7 @@ create_always_approve_tool_response, create_always_approve_tool_with_arguments_response, ) +from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPTaskOptions, MCPWebsocketTool, SamplingApprovalCallback from ._middleware import ( AgentContext, AgentMiddleware, diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index ad232ffeb44..065324289f3 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -1989,9 +1989,7 @@ def _store_already_approved_approval_requests( return existing_groups = state.get(_ALREADY_APPROVED_APPROVAL_REQUEST_GROUPS_KEY) - pending_groups: list[Any] = ( - list(cast(Iterable[Any], existing_groups)) if isinstance(existing_groups, list) else [] - ) + pending_groups: list[Any] = list(cast(Iterable[Any], existing_groups)) if isinstance(existing_groups, list) else [] pending_groups.append({ "approval_request_ids": visible_ids, "approval_requests": [request.to_dict() for request in already_approved_requests], diff --git a/python/packages/core/agent_framework/ag_ui/__init__.py b/python/packages/core/agent_framework/ag_ui/__init__.py index 91754e01b40..580ae153a9a 100644 --- a/python/packages/core/agent_framework/ag_ui/__init__.py +++ b/python/packages/core/agent_framework/ag_ui/__init__.py @@ -11,6 +11,10 @@ - AGUIChatClient - AGUIEventConverter - AGUIHttpService +- AGUIThreadSnapshot +- AGUIThreadSnapshotStore +- InMemoryAGUIThreadSnapshotStore +- SnapshotScopeResolver - add_agent_framework_fastapi_endpoint - state_update - __version__ @@ -28,6 +32,10 @@ "AGUIChatClient", "AGUIEventConverter", "AGUIHttpService", + "AGUIThreadSnapshot", + "AGUIThreadSnapshotStore", + "InMemoryAGUIThreadSnapshotStore", + "SnapshotScopeResolver", "state_update", "__version__", ] diff --git a/python/packages/core/agent_framework/ag_ui/__init__.pyi b/python/packages/core/agent_framework/ag_ui/__init__.pyi index 1f6636ae810..e57ba45ac62 100644 --- a/python/packages/core/agent_framework/ag_ui/__init__.pyi +++ b/python/packages/core/agent_framework/ag_ui/__init__.pyi @@ -6,6 +6,10 @@ from agent_framework_ag_ui import ( AGUIChatClient, AGUIEventConverter, AGUIHttpService, + AGUIThreadSnapshot, + AGUIThreadSnapshotStore, + InMemoryAGUIThreadSnapshotStore, + SnapshotScopeResolver, __version__, add_agent_framework_fastapi_endpoint, state_update, @@ -15,8 +19,12 @@ __all__ = [ "AGUIChatClient", "AGUIEventConverter", "AGUIHttpService", + "AGUIThreadSnapshot", + "AGUIThreadSnapshotStore", "AgentFrameworkAgent", "AgentFrameworkWorkflow", + "InMemoryAGUIThreadSnapshotStore", + "SnapshotScopeResolver", "__version__", "add_agent_framework_fastapi_endpoint", "state_update", diff --git a/python/samples/05-end-to-end/purview_agent/sample_purview_agent.py b/python/samples/05-end-to-end/purview_agent/sample_purview_agent.py index 7305ea12e8a..62dad81725e 100644 --- a/python/samples/05-end-to-end/purview_agent/sample_purview_agent.py +++ b/python/samples/05-end-to-end/purview_agent/sample_purview_agent.py @@ -70,9 +70,7 @@ async def run_policy_flow( ("good (warm cache)", GOOD_PROMPT_FOLLOWUP), ] for tag, text in prompts: - response: AgentResponse = await agent.run( - Message("user", [text], additional_properties={"user_id": user_id}) - ) + response: AgentResponse = await agent.run(Message("user", [text], additional_properties={"user_id": user_id})) outcome = "BLOCKED" if blocked_marker in str(response).lower() else "ALLOWED" print(f"[{label}] {tag}: {outcome}\n{response}\n") @@ -207,9 +205,7 @@ async def run_with_chat_middleware() -> None: model=deployment, project_endpoint=endpoint, credential=AzureCliCredential(), - middleware=[ - PurviewChatPolicyMiddleware(build_credential(), settings) - ], + middleware=[PurviewChatPolicyMiddleware(build_credential(), settings)], ) agent = Agent(