diff --git a/python/packages/core/agent_framework/__init__.py b/python/packages/core/agent_framework/__init__.py index 03a32f1a9c9..445f138c885 100644 --- a/python/packages/core/agent_framework/__init__.py +++ b/python/packages/core/agent_framework/__init__.py @@ -125,7 +125,6 @@ TodoSessionStore, TodoStore, ) -from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPTaskOptions, MCPWebsocketTool, SamplingApprovalCallback from ._harness._tool_approval import ( DEFAULT_TOOL_APPROVAL_SOURCE_ID, ToolApprovalMiddleware, @@ -135,6 +134,7 @@ create_always_approve_tool_response, create_always_approve_tool_with_arguments_response, ) +from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPTaskOptions, MCPWebsocketTool, SamplingApprovalCallback from ._middleware import ( AgentContext, AgentMiddleware, @@ -252,6 +252,7 @@ ) from ._workflows._agent_utils import resolve_agent_id from ._workflows._checkpoint import ( + CheckpointID, CheckpointStorage, FileCheckpointStorage, InMemoryCheckpointStorage, @@ -295,7 +296,6 @@ workflow, ) from ._workflows._request_info_mixin import response_handler -from ._workflows._runner import Runner from ._workflows._runner_context import ( InProcRunnerContext, RunnerContext, @@ -390,6 +390,7 @@ "ChatResponse", "ChatResponseUpdate", "CheckResult", + "CheckpointID", "CheckpointStorage", "ClassSkill", "CompactionProvider", @@ -481,7 +482,6 @@ "RoleLiteral", "RubricScore", "RunContext", - "Runner", "RunnerContext", "SamplingApprovalCallback", "SecretString", diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index ad232ffeb44..065324289f3 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -1989,9 +1989,7 @@ def _store_already_approved_approval_requests( return existing_groups = state.get(_ALREADY_APPROVED_APPROVAL_REQUEST_GROUPS_KEY) - pending_groups: list[Any] = ( - list(cast(Iterable[Any], existing_groups)) if isinstance(existing_groups, list) else [] - ) + pending_groups: list[Any] = list(cast(Iterable[Any], existing_groups)) if isinstance(existing_groups, list) else [] pending_groups.append({ "approval_request_ids": visible_ids, "approval_requests": [request.to_dict() for request in already_approved_requests], diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index 51a3312e2ba..db6b61af1f7 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -10,7 +10,6 @@ from ..exceptions import ( WorkflowCheckpointException, WorkflowConvergenceException, - WorkflowRunnerException, ) from ._checkpoint import CheckpointID, CheckpointStorage, WorkflowCheckpoint from ._const import EXECUTOR_STATE_KEY @@ -63,99 +62,105 @@ def __init__( self._iteration = 0 self._max_iterations = max_iterations self._state = state - self._running = False - self._resumed_from_checkpoint = False # Track whether we resumed + + # Checkpointing related attributes + self._resumed_from_checkpoint = False + self.previous_checkpoint_id: CheckpointID | None = None @property def context(self) -> RunnerContext: - """Get the workflow context.""" + """Get the runner context for message, event, and checkpoint handling.""" return self._ctx + @property + def state(self) -> State: + """Get the shared state for the workflow.""" + return self._state + def reset_iteration_count(self) -> None: - """Reset the iteration count to zero.""" + """Reset the iteration count to zero. + + This is useful when the workflow resumes from a new set of messages. + + Note: + When a workflow is resumed from a response (for a request_info_event) + or a checkpoint, the iteration count is normally NOT reset. + """ self._iteration = 0 async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: """Run the workflow until no more messages are sent.""" - if self._running: - raise WorkflowRunnerException("Runner is already running.") - - self._running = True - previous_checkpoint_id: CheckpointID | None = None - try: - # Emit any events already produced prior to entering loop - if await self._ctx.has_events(): - logger.info("Yielding pre-loop events") - for event in await self._ctx.drain_events(): - yield event - - # Create the first checkpoint. Checkpoints are usually considered to be created at the end of an iteration, - # we can think of the first checkpoint as being created at the end of a "superstep 0" which captures the - # states after which the start executor has run. Note that we execute the start executor outside of the - # main iteration loop. - if await self._ctx.has_messages() and not self._resumed_from_checkpoint: - previous_checkpoint_id = await self._create_checkpoint_if_enabled(previous_checkpoint_id) - - while self._iteration < self._max_iterations: - logger.info(f"Starting superstep {self._iteration + 1}") - yield WorkflowEvent.superstep_started(iteration=self._iteration + 1) - - # Run iteration concurrently with live event streaming: we poll - # for new events while the iteration coroutine progresses. - iteration_task = asyncio.create_task(self._run_iteration()) - try: - while not iteration_task.done(): - try: - # Wait briefly for any new event; timeout allows progress checks - event = await asyncio.wait_for(self._ctx.next_event(), timeout=0.05) - yield event - except asyncio.TimeoutError: - # Periodically continue to let iteration advance - continue - except asyncio.CancelledError: - # Propagate cancellation to the iteration task to avoid orphaned work - iteration_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await iteration_task - raise - - # Propagate errors from iteration, but first surface any pending events - try: + # Emit any events already produced prior to entering loop + if await self._ctx.has_events(): + logger.info("Yielding pre-loop events") + for event in await self._ctx.drain_events(): + yield event + + # Create a checkpoint before a run starts. Checkpoints are usually considered to be created at the + # end of an iteration, we can think of this checkpoint as being created at the end of "superstep 0" + # which captures the states after which the start executor has run. Note that we execute the start + # executor outside of the main iteration loop. + if await self._ctx.has_messages() and not self._resumed_from_checkpoint: + await self.create_checkpoint_if_enabled() + + while self._iteration < self._max_iterations: + logger.info(f"Starting superstep {self._iteration + 1}") + yield WorkflowEvent.superstep_started(iteration=self._iteration + 1) + + # Run iteration concurrently with live event streaming: we poll + # for new events while the iteration coroutine progresses. + iteration_task = asyncio.create_task(self._run_iteration()) + try: + while not iteration_task.done(): + try: + # Wait briefly for any new event; timeout allows progress checks + event = await asyncio.wait_for(self._ctx.next_event(), timeout=0.05) + yield event + except asyncio.TimeoutError: + # Periodically continue to let iteration advance + continue + except asyncio.CancelledError: + # Propagate cancellation to the iteration task to avoid orphaned work + iteration_task.cancel() + with contextlib.suppress(asyncio.CancelledError): await iteration_task - except Exception: - # Make sure failure-related events (like ExecutorFailedEvent) are surfaced - if await self._ctx.has_events(): - for event in await self._ctx.drain_events(): - yield event - raise - self._iteration += 1 - - # Drain any straggler events emitted at tail end + raise + + # Propagate errors from iteration, but first surface any pending events + try: + await iteration_task + except Exception: + # Make sure failure-related events (like ExecutorFailedEvent) are surfaced if await self._ctx.has_events(): for event in await self._ctx.drain_events(): yield event + raise + self._iteration += 1 + + # Drain any straggler events emitted at tail end + if await self._ctx.has_events(): + for event in await self._ctx.drain_events(): + yield event - logger.info(f"Completed superstep {self._iteration}") + logger.info(f"Completed superstep {self._iteration}") - # Commit pending state changes at superstep boundary - self._state.commit() + # Commit pending state changes at superstep boundary + self._state.commit() - # Create checkpoint after each superstep iteration - previous_checkpoint_id = await self._create_checkpoint_if_enabled(previous_checkpoint_id) + # Create checkpoint after each superstep iteration + await self.create_checkpoint_if_enabled() - yield WorkflowEvent.superstep_completed(iteration=self._iteration) + yield WorkflowEvent.superstep_completed(iteration=self._iteration) - # Check for convergence: no more messages to process - if not await self._ctx.has_messages(): - break + # Check for convergence: no more messages to process + if not await self._ctx.has_messages(): + break - if self._iteration >= self._max_iterations and await self._ctx.has_messages(): - raise WorkflowConvergenceException(f"Runner did not converge after {self._max_iterations} iterations.") + logger.info(f"Workflow completed after {self._iteration} supersteps") + self._resumed_from_checkpoint = False # Reset resume flag for next run - logger.info(f"Workflow completed after {self._iteration} supersteps") - self._resumed_from_checkpoint = False # Reset resume flag for next run - finally: - self._running = False + if self._iteration >= self._max_iterations and await self._ctx.has_messages(): + raise WorkflowConvergenceException(f"Runner did not converge after {self._max_iterations} iterations.") async def _run_iteration(self) -> None: """Run a single iteration of the workflow. @@ -209,10 +214,10 @@ async def _deliver_messages_for_edge_runner(edge_runner: EdgeRunner) -> None: ] await asyncio.gather(*tasks) - async def _create_checkpoint_if_enabled(self, previous_checkpoint_id: CheckpointID | None) -> CheckpointID | None: + async def create_checkpoint_if_enabled(self) -> None: """Create a checkpoint if checkpointing is enabled and attach a label and metadata.""" if not self._ctx.has_checkpointing(): - return None + return try: # Save executor states into the shared state before creating the checkpoint, @@ -227,22 +232,33 @@ async def _create_checkpoint_if_enabled(self, previous_checkpoint_id: Checkpoint self._workflow_name, self._graph_signature_hash, self._state, - previous_checkpoint_id, + self.previous_checkpoint_id, self._iteration, ) - logger.info(f"Created checkpoint: {checkpoint_id}") - return checkpoint_id + logger.info( + "Created checkpoint: %s with parent checkpoint at iteration %d: %s", + checkpoint_id, + self._iteration, + self.previous_checkpoint_id, + ) + self.previous_checkpoint_id = checkpoint_id except Exception as e: - logger.warning(f"Failed to create checkpoint: {e}") - return None + logger.warning( + "Failed to create checkpoint at iteration %d: %s. " + "Note that this does not fail the workflow run. " + "The next successfully-created checkpoint will be parented to the last successful checkpoint: %s", + self._iteration, + e, + self.previous_checkpoint_id, + ) async def restore_from_checkpoint( self, checkpoint_id: CheckpointID, checkpoint_storage: CheckpointStorage | None = None, ) -> None: - """Restore workflow state from a checkpoint. + """Restore the runner from a checkpoint. Args: checkpoint_id: The ID of the checkpoint to restore from @@ -290,7 +306,7 @@ async def restore_from_checkpoint( # Apply the checkpoint to the context await self._ctx.apply_checkpoint(checkpoint) # Mark the runner as resumed - self._mark_resumed(checkpoint.iteration_count) + self._mark_resumed(checkpoint) logger.info(f"Successfully restored workflow from checkpoint: {checkpoint_id}") except WorkflowCheckpointException: @@ -356,13 +372,14 @@ def _parse_edge_runners(self, edge_runners: list[EdgeRunner]) -> dict[str, list[ return parsed - def _mark_resumed(self, iteration: int) -> None: + def _mark_resumed(self, checkpoint: WorkflowCheckpoint) -> None: """Mark the runner as having resumed from a checkpoint. Optionally set the current iteration and max iterations. """ self._resumed_from_checkpoint = True - self._iteration = iteration + self._iteration = checkpoint.iteration_count + self.previous_checkpoint_id = checkpoint.checkpoint_id async def _set_executor_state(self, executor_id: str, state: dict[str, Any]) -> None: """Store executor state in state under a reserved key. diff --git a/python/packages/core/agent_framework/_workflows/_runner_context.py b/python/packages/core/agent_framework/_workflows/_runner_context.py index 2e4901f4118..30f82955244 100644 --- a/python/packages/core/agent_framework/_workflows/_runner_context.py +++ b/python/packages/core/agent_framework/_workflows/_runner_context.py @@ -403,12 +403,14 @@ async def load_checkpoint(self, checkpoint_id: CheckpointID) -> WorkflowCheckpoi def reset_for_new_run(self) -> None: """Reset the context for a new workflow run. - This clears messages, events, and resets streaming flag. - Runtime checkpoint storage is NOT cleared here as it's managed at the workflow level. + Clears messages, the pending event queue, the pending request_info + correlation map, and the streaming flag. Runtime checkpoint storage is + NOT cleared here as it's managed at the workflow level. """ self._messages.clear() # Clear any pending events (best-effort) by recreating the queue self._event_queue = asyncio.Queue() + self._pending_request_info_events.clear() self._streaming = False # Reset streaming flag async def apply_checkpoint(self, checkpoint: WorkflowCheckpoint) -> None: diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index c4840bb0455..09ff6c28c18 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -11,14 +11,16 @@ import types import uuid import warnings +import weakref from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, Sequence from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Literal, overload from .._sessions import ContextProvider from .._types import ResponseStream +from ..exceptions import WorkflowCheckpointException, WorkflowException from ..observability import OtelAttr, capture_exception, create_workflow_span -from ._checkpoint import CheckpointStorage +from ._checkpoint import CheckpointID, CheckpointStorage from ._const import DEFAULT_MAX_ITERATIONS, GLOBAL_KWARGS_KEY, WORKFLOW_RUN_KWARGS_KEY from ._edge import ( EdgeGroup, @@ -346,25 +348,29 @@ def __init__( # Store non-serializable runtime objects as private attributes self._runner_context = runner_context self._runner_context.set_yield_output_classifier(self._output_designation.classify) - self._state = State() self._runner: Runner = Runner( self.edge_groups, self.executors, - self._state, + State(), runner_context, self.name, self.graph_signature_hash, max_iterations=max_iterations, ) - # Flag to prevent concurrent workflow executions - self._is_running = False - # Current run-level status of this workflow instance. Updated in lockstep with # the status events emitted from `_run_workflow_with_tracing`. Defaults to IDLE # for a freshly built workflow that has not yet been run. self._status: WorkflowRunState = WorkflowRunState.IDLE + # Weak reference to the in-flight run's ``ResponseStream``. Used as the single + # concurrency lock: if the previous stream is still alive, ``run()`` rejects a + # new run synchronously (before any await). When the stream is fully consumed + # ``_run_core``'s finally clears this; if the caller drops the stream without + # ever iterating, the weakref dereferences to ``None`` once Python collects it, + # so a subsequent ``run()`` is allowed. + self._active_run: weakref.ref[ResponseStream[WorkflowEvent, WorkflowRunResult]] | None = None + @property def status(self) -> WorkflowRunState: """Return the current run-level status of this workflow instance. @@ -376,16 +382,6 @@ def status(self) -> WorkflowRunState: """ return self._status - def _ensure_not_running(self) -> None: - """Ensure the workflow is not already running.""" - if self._is_running: - raise RuntimeError("Workflow is already running. Concurrent executions are not allowed.") - self._is_running = True - - def _reset_running_flag(self) -> None: - """Reset the running flag.""" - self._is_running = False - def to_dict(self) -> dict[str, Any]: """Serialize the workflow definition into a JSON-ready dictionary.""" data: dict[str, Any] = { @@ -478,6 +474,50 @@ def get_executors_list(self) -> list[Executor]: """Get the list of executors in the workflow.""" return list(self.executors.values()) + async def create_checkpoint(self, checkpoint_storage: CheckpointStorage | None) -> CheckpointID: + """Create a checkpoint of the current workflow state in the provided storage. + + Args: + checkpoint_storage: The CheckpointStorage instance where the checkpoint will be stored. + If None, will use the workflow's default checkpoint storage if configured, or raise + if checkpointing is not enabled. + + Notes: + - Checkpoints can only be created when the workflow is idle (not actively running). + - Checkpoints are automatically created at the end of each superstep if a checkpoint storage is configured. + Use this method only when necessary, for example to capture the initial state of the workflow prior to the + first run. + - Creating a checkpoint manually will alter the checkpoint lineage. The new checkpoint will become the + parent of the next checkpoint created automatically (if checkpointing is enabled by providing a storage). + """ + if self._is_run_active(): + raise WorkflowException( + "Cannot create checkpoint while a workflow run is active. " + "Checkpointing is only allowed between runs when the workflow is idle." + ) + + if checkpoint_storage is None and not self._runner.context.has_checkpointing(): + raise WorkflowCheckpointException( + "Checkpoint storage must be provided to create a checkpoint when checkpointing is not enabled." + ) + if checkpoint_storage is not None: + self._runner.context.set_runtime_checkpoint_storage(checkpoint_storage) + + # Capture the runner's checkpoint id before attempting to save. The runner + # log-and-swallows storage save errors and only updates + # ``previous_checkpoint_id`` on success, so a failed save would otherwise + # leave the prior id in place and we'd return it as if a fresh checkpoint + # had been created. + previous_id_before = self._runner.previous_checkpoint_id + try: + await self._runner.create_checkpoint_if_enabled() + new_id = self._runner.previous_checkpoint_id + if new_id is None or new_id == previous_id_before: + raise WorkflowCheckpointException("Failed to create checkpoint.") + return new_id + finally: + self._runner.context.clear_runtime_checkpoint_storage() + async def _run_workflow_with_tracing( self, initial_executor_fn: Callable[[], Awaitable[None]] | None = None, @@ -535,13 +575,12 @@ async def _run_workflow_with_tracing( yield in_progress # noqa: RUF070 # Per-run reset for fresh-message runs only. We deliberately - # do NOT clear shared workflow state (`_state.clear()`) or the - # runner context's in-flight messages (`reset_for_new_run()`) - # here - state and pending work persist across `run()` calls - # so that a `WorkflowAgent` can deliver multi-turn input on - # the same instance and have prior turns' context survive. - # Iteration counting and per-run kwargs ARE per-run though, - # so they're reset here. + # do NOT clear shared workflow state or the runner context's + # in-flight messages here - state and pending work persist + # across `run()` calls so that a `WorkflowAgent` can deliver + # multi-turn input on the same instance and have prior turns' + # context survive. Iteration counting and per-run kwargs ARE + # per-run though, so they're reset here. if not is_continuation: self._runner.reset_iteration_count() @@ -564,14 +603,13 @@ async def _run_workflow_with_tracing( combined_kwargs["client_kwargs"] = self._resolve_invocation_kwargs( client_kwargs, "client_kwargs" ) - self._state.set(WORKFLOW_RUN_KWARGS_KEY, combined_kwargs) + self._runner.state.set(WORKFLOW_RUN_KWARGS_KEY, combined_kwargs) elif not is_continuation: - self._state.set(WORKFLOW_RUN_KWARGS_KEY, {}) - self._state.commit() # Commit immediately so kwargs are available + self._runner.state.set(WORKFLOW_RUN_KWARGS_KEY, {}) + self._runner.state.commit() # Commit immediately so kwargs are available - # Set streaming mode (always set explicitly per run since - # reset_for_new_run() no longer runs to clear it). - self._runner_context.set_streaming(streaming) + # Explicitly set streaming mode per run + self._runner.context.set_streaming(streaming) # Execute initial setup if provided if initial_executor_fn: @@ -665,7 +703,7 @@ async def _execute_with_message_or_checkpoint( await executor.execute( message, [self.__class__.__name__], - self._state, + self._runner.state, self._runner.context, trace_contexts=None, source_span_ids=None, @@ -745,9 +783,22 @@ def run( Raises: ValueError: If parameter combination is invalid. """ - # Validate parameters and set running flag eagerly (before any async work) + # Validate parameters first so misuse fails before we touch any run state. self._validate_run_params(message, responses, checkpoint_id) - self._ensure_not_running() + + # Concurrency check: reject a second run synchronously - before constructing + # the ResponseStream or yielding control to the event loop - so a concurrent + # ``run`` call can't slip past the guard while the first call is suspended + # inside its async generator. The ``ResponseStream`` returned below is the + # lock: as long as the caller holds a reference to it, ``self._active_run()`` + # resolves to a live object and a new ``run`` is rejected. When the stream is + # fully consumed, ``_run_core``'s finally clears the attribute. When the + # caller drops the stream without iterating, garbage collection invalidates + # the weakref, so a subsequent ``run`` is permitted. + if self._is_run_active(): + raise WorkflowException( + "Workflow is already running; concurrent runs are not allowed on the same instance." + ) response_stream = ResponseStream[WorkflowEvent, WorkflowRunResult]( self._run_core( @@ -760,10 +811,8 @@ def run( client_kwargs=client_kwargs, ), finalizer=functools.partial(self._finalize_events, include_status_events=include_status_events), - cleanup_hooks=[ - functools.partial(self._run_cleanup, checkpoint_storage), - ], ) + self._active_run = weakref.ref(response_stream) if stream: return response_stream @@ -789,51 +838,57 @@ async def _run_core( if checkpoint_storage is not None: self._runner.context.set_runtime_checkpoint_storage(checkpoint_storage) - # Async validation: a fresh-message run is only allowed when the - # runner context has fully drained from any prior run. If it still - # has in-flight executor messages, the prior run didn't complete - - # the caller must either resume from a checkpoint or wait for the - # prior run to drain. (Pending request_info events are intentionally - # NOT blocked here: a follow-up run with message=... is the normal - # way to deliver a response to those pending requests, e.g. via - # WorkflowAgent._process_pending_requests.) - # NOTE: _validate_run_params already enforces that ``message`` is - # mutually exclusive with both ``checkpoint_id`` and ``responses``, - # so we don't need to re-check those here. - if message is not None and await self._runner.context.has_messages(): - raise RuntimeError( - "Cannot start a new run with 'message' while in-flight executor " - "messages remain from a prior run. Resume from a checkpoint " - "(checkpoint_id=...) or wait for the prior run to complete. " - "Workflows that need to recover from a mid-run failure must use " - "checkpointing; there is no in-process recovery path." - ) + try: + # Async validation: a fresh-message run is only allowed when the + # runner context has fully drained from any prior run. If it still + # has in-flight executor messages, the prior run didn't complete - + # the caller must either resume from a checkpoint or wait for the + # prior run to drain. (Pending request_info events are intentionally + # NOT blocked here: a follow-up run with message=... is the normal + # way to deliver a response to those pending requests, e.g. via + # WorkflowAgent._process_pending_requests.) + # NOTE: _validate_run_params already enforces that ``message`` is + # mutually exclusive with both ``checkpoint_id`` and ``responses``, + # so we don't need to re-check those here. + if message is not None and await self._runner.context.has_messages(): + raise RuntimeError( + "Cannot start a new run with 'message' while in-flight executor " + "messages remain from a prior run. Resume from a checkpoint " + "(checkpoint_id=...) or wait for the prior run to complete. " + "Workflows that need to recover from a mid-run failure must use " + "checkpointing; there is no in-process recovery path." + ) - initial_executor_fn = self._resolve_execution_mode(message, responses, checkpoint_id, checkpoint_storage) - - async for event in self._run_workflow_with_tracing( - initial_executor_fn=initial_executor_fn, - is_continuation=(message is None), - streaming=streaming, - function_invocation_kwargs=function_invocation_kwargs, - client_kwargs=client_kwargs, - ): - if event.type == "request_info" and event.request_id in (responses or {}): - # Don't yield request_info events for which we have responses to send - - # these are considered "handled". This prevents the caller from seeing - # events for requests they are already responding to. - # This usually happens when responses are provided with a checkpoint - # (restore then send), because the request_info events are stored in the - # checkpoint and would be emitted on restoration by the runner regardless - # of if a response is provided or not. - continue - yield event + initial_executor_fn = self._resolve_execution_mode(message, responses, checkpoint_id, checkpoint_storage) - async def _run_cleanup(self, checkpoint_storage: CheckpointStorage | None) -> None: - """Cleanup hook called after stream consumption.""" - if checkpoint_storage is not None: - self._runner.context.clear_runtime_checkpoint_storage() - self._reset_running_flag() + async for event in self._run_workflow_with_tracing( + initial_executor_fn=initial_executor_fn, + is_continuation=(message is None), + streaming=streaming, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, + ): + if event.type == "request_info" and event.request_id in (responses or {}): + # Don't yield request_info events for which we have responses to send - + # these are considered "handled". This prevents the caller from seeing + # events for requests they are already responding to. + # This usually happens when responses are provided with a checkpoint + # (restore then send), because the request_info events are stored in the + # checkpoint and would be emitted on restoration by the runner regardless + # of if a response is provided or not. + continue + yield event + finally: + # Clear the active-run weakref so a subsequent ``run()`` is allowed. + # ``run()`` set this synchronously after constructing the ResponseStream; + # we clear it here once the run has finished (success, error, early + # close, or partial iteration). This is in-band, so by the time the + # caller's stream is later garbage collected, ``_active_run`` is already + # ``None`` (or has been replaced by a newer run's weakref) - no GC-time + # finalizer is needed. + self._active_run = None + if checkpoint_storage is not None: + self._runner.context.clear_runtime_checkpoint_storage() @staticmethod def _finalize_events( @@ -935,7 +990,7 @@ async def _restore_and_send_responses( async def _send_responses_internal(self, responses: Mapping[str, Any]) -> None: """Internal method to validate and send responses to the executors.""" - pending_requests = await self._runner_context.get_pending_request_info_events() + pending_requests = await self._runner.context.get_pending_request_info_events() if not pending_requests: raise RuntimeError("No pending requests found in workflow context.") @@ -955,7 +1010,7 @@ async def _send_responses_internal(self, responses: Mapping[str, Any]) -> None: coerced_responses[request_id] = response await asyncio.gather(*[ - self._runner_context.send_request_info_response(request_id, response) + self._runner.context.send_request_info_response(request_id, response) for request_id, response in coerced_responses.items() ]) @@ -1151,3 +1206,12 @@ def as_agent( context_providers=context_providers, **kwargs, ) + + def _is_run_active(self) -> bool: + """Check if a workflow run is currently active. + + Returns: + True if a run is active, False otherwise. + """ + existing_stream = self._active_run() if self._active_run is not None else None + return existing_stream is not None diff --git a/python/packages/core/tests/workflow/test_checkpoint.py b/python/packages/core/tests/workflow/test_checkpoint.py index e395655afaf..fe1df65822d 100644 --- a/python/packages/core/tests/workflow/test_checkpoint.py +++ b/python/packages/core/tests/workflow/test_checkpoint.py @@ -336,6 +336,97 @@ async def finish(self, message: str, ctx: WorkflowContext[Never, str]) -> None: ) +async def test_workflow_checkpoint_ancestry_preserved_after_resume(): + """Resuming from a checkpoint must preserve ancestry: future checkpoints chain back to the resumed one.""" + from typing_extensions import Never + + from agent_framework import WorkflowBuilder, WorkflowContext, handler + from agent_framework._workflows._executor import Executor + + class StartExecutor(Executor): + @handler + async def run(self, message: str, ctx: WorkflowContext[str]) -> None: + await ctx.send_message(message, target_id="middle") + + class MiddleExecutor(Executor): + @handler + async def process(self, message: str, ctx: WorkflowContext[str]) -> None: + await ctx.send_message(message + "-processed", target_id="finish") + + class FinishExecutor(Executor): + @handler + async def finish(self, message: str, ctx: WorkflowContext[Never, str]) -> None: + await ctx.yield_output(message + "-done") + + storage = InMemoryCheckpointStorage() + + def _build_workflow() -> Any: + start = StartExecutor(id="start") + middle = MiddleExecutor(id="middle") + finish = FinishExecutor(id="finish") + return ( + WorkflowBuilder( + name="resume-ancestry-test", + max_iterations=10, + start_executor=start, + checkpoint_storage=storage, + ) + .add_edge(start, middle) + .add_edge(middle, finish) + .build() + ) + + # First run: produce an initial chain of checkpoints + workflow = _build_workflow() + workflow_name = workflow.name + _ = [event async for event in workflow.run("hello", stream=True)] + + initial_checkpoints = sorted(await storage.list_checkpoints(workflow_name=workflow_name), key=lambda c: c.timestamp) + assert len(initial_checkpoints) >= 3, ( + f"Need at least 3 initial checkpoints to pick a middle one, got {len(initial_checkpoints)}" + ) + initial_ids = {cp.checkpoint_id for cp in initial_checkpoints} + + # Pick an intermediate checkpoint to resume from (not the first, not the last) + resume_from = initial_checkpoints[len(initial_checkpoints) // 2] + + # Resume on a fresh workflow instance (same graph signature) and run to completion + resumed_workflow = _build_workflow() + assert resumed_workflow.name == workflow_name + _ = [event async for event in resumed_workflow.run(checkpoint_id=resume_from.checkpoint_id, stream=True)] + + # Inspect new checkpoints created after resuming + all_checkpoints = sorted(await storage.list_checkpoints(workflow_name=workflow_name), key=lambda c: c.timestamp) + new_checkpoints = [cp for cp in all_checkpoints if cp.checkpoint_id not in initial_ids] + assert new_checkpoints, "Resuming from an intermediate checkpoint should produce new checkpoints" + + # The very first checkpoint created after resuming must chain back to the resumed checkpoint + assert new_checkpoints[0].previous_checkpoint_id == resume_from.checkpoint_id, ( + "First post-resume checkpoint must chain to the checkpoint that was resumed from; " + f"got previous_checkpoint_id={new_checkpoints[0].previous_checkpoint_id!r}, " + f"expected {resume_from.checkpoint_id!r}" + ) + + # Subsequent post-resume checkpoints must continue chaining + for i in range(1, len(new_checkpoints)): + assert new_checkpoints[i].previous_checkpoint_id == new_checkpoints[i - 1].checkpoint_id, ( + f"Post-resume checkpoint {i} should chain to checkpoint {i - 1}" + ) + + # Walking the chain backwards from the most recent checkpoint must reach the original root + # without breaks (i.e. the full ancestry across the resume boundary is intact). + checkpoints_by_id = {cp.checkpoint_id: cp for cp in all_checkpoints} + chain: list[str] = [] + cursor: str | None = new_checkpoints[-1].checkpoint_id + while cursor is not None: + chain.append(cursor) + cursor = checkpoints_by_id[cursor].previous_checkpoint_id + # Chain must include the resumed-from checkpoint and terminate at the original root + assert resume_from.checkpoint_id in chain + assert chain[-1] == initial_checkpoints[0].checkpoint_id + assert checkpoints_by_id[chain[-1]].previous_checkpoint_id is None + + async def test_memory_checkpoint_storage_roundtrip_json_native_types(): """Test that JSON-native types (str, int, float, bool, None) roundtrip correctly.""" storage = InMemoryCheckpointStorage() diff --git a/python/packages/core/tests/workflow/test_runner.py b/python/packages/core/tests/workflow/test_runner.py index 4fef26bd2d2..5c458be7c45 100644 --- a/python/packages/core/tests/workflow/test_runner.py +++ b/python/packages/core/tests/workflow/test_runner.py @@ -17,7 +17,6 @@ WorkflowContext, WorkflowConvergenceException, WorkflowEvent, - WorkflowRunnerException, WorkflowRunState, handler, ) @@ -305,40 +304,62 @@ async def handle(self, message: MockMessage, ctx: WorkflowContext[MockMessage, i assert probe_target.call_count == 1 -async def test_runner_already_running(): - """Test that running the runner while it is already running raises an error.""" +async def test_runner_run_until_convergence_runs_sequentially(): + """run_until_convergence can be invoked back-to-back on the same Runner. + + The Runner itself does not enforce concurrency; that responsibility lives on + :class:`Workflow`. This test simply confirms the Runner is reusable across + sequential runs. + """ + runner = _make_runner() + async for _ in runner.run_until_convergence(): + pass + async for _ in runner.run_until_convergence(): + pass + + +def _make_runner() -> Runner: + """Build a minimal runner for runner-level tests.""" + return Runner( + [], + {}, + State(), + InProcRunnerContext(), + "test_name", + graph_signature_hash="test_hash", + ) + + +async def test_runner_accepts_new_run_after_previous_failure(): + """A failed run must not leave the Runner unable to start a new run. + + After the first run raises, ``run_until_convergence()`` must be callable + again and not surface any lifecycle-related rejection. + """ executor_a = MockExecutor(id="executor_a") executor_b = MockExecutor(id="executor_b") - - # Create a loop edges = [ SingleEdgeGroup(executor_a.id, executor_b.id), SingleEdgeGroup(executor_b.id, executor_a.id), ] - - executors: dict[str, Executor] = { - executor_a.id: executor_a, - executor_b.id: executor_b, - } + executors: dict[str, Executor] = {executor_a.id: executor_a, executor_b.id: executor_b} state = State() ctx = InProcRunnerContext() + runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash", max_iterations=2) - runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash") - - await executor_a.execute( - MockMessage(data=0), - ["START"], # source_executor_ids - state, # state - ctx, # runner_context - ) + await executor_a.execute(MockMessage(data=0), ["START"], state, ctx) - with pytest.raises(WorkflowRunnerException, match="Runner is already running."): - - async def _run(): - async for _ in runner.run_until_convergence(): - pass + with pytest.raises(WorkflowConvergenceException): + async for _ in runner.run_until_convergence(): + pass - await asyncio.gather(_run(), _run()) + # A second run on the same Runner must not be blocked by stale lifecycle + # state from the failed run. + try: + async for _ in runner.run_until_convergence(): + pass + except Exception as exc: + assert "Runner is already running" not in str(exc), "Runner stayed locked after a failed run" async def test_runner_emits_runner_completion_for_agent_response_without_targets(): @@ -862,7 +883,13 @@ async def test_runner_checkpoint_with_resumed_flag(): state = State() runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash") - runner._mark_resumed(5) # pyright: ignore[reportPrivateUsage] + resumed_checkpoint = WorkflowCheckpoint( + checkpoint_id="resumed-cp", + workflow_name="test_name", + graph_signature_hash="test_hash", + iteration_count=5, + ) + runner._mark_resumed(resumed_checkpoint) # pyright: ignore[reportPrivateUsage] # Add a message to trigger the checkpoint creation path await ctx.send_message(WorkflowMessage(data=MockMessage(data=8), source_id="START")) @@ -882,6 +909,86 @@ async def test_runner_checkpoint_with_resumed_flag(): assert runner._resumed_from_checkpoint is False # pyright: ignore[reportPrivateUsage] +async def test_runner_mark_resumed_sets_previous_checkpoint_id(): + """_mark_resumed must populate _previous_checkpoint_id so future checkpoints chain back to the resume point.""" + runner = Runner( + [], + {}, + State(), + InProcRunnerContext(), + "test_name", + graph_signature_hash="test_hash", + ) + + # Pre-condition: nothing to chain back to + assert runner.previous_checkpoint_id is None + + resumed_checkpoint = WorkflowCheckpoint( + checkpoint_id="resumed-cp-id", + workflow_name="test_name", + graph_signature_hash="test_hash", + iteration_count=3, + ) + runner._mark_resumed(resumed_checkpoint) # pyright: ignore[reportPrivateUsage] + + assert runner._resumed_from_checkpoint is True # pyright: ignore[reportPrivateUsage] + assert runner._iteration == 3 # pyright: ignore[reportPrivateUsage] + assert runner.previous_checkpoint_id == "resumed-cp-id" + + +async def test_runner_post_resume_checkpoint_chains_to_resumed_checkpoint(): + """After resuming, the next checkpoint created must reference the resumed checkpoint as its parent.""" + storage = InMemoryCheckpointStorage() + ctx = CheckpointingContext(storage) + executor_a = MockExecutor(id="executor_a") + executor_b = MockExecutor(id="executor_b") + + edges = [ + SingleEdgeGroup(executor_a.id, executor_b.id), + SingleEdgeGroup(executor_b.id, executor_a.id), + ] + + executors: dict[str, Executor] = { + executor_a.id: executor_a, + executor_b.id: executor_b, + } + state = State() + + runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash") + + # Simulate having resumed from a prior checkpoint + resumed_checkpoint = WorkflowCheckpoint( + checkpoint_id="parent-checkpoint-id", + workflow_name="test_name", + graph_signature_hash="test_hash", + iteration_count=1, + ) + runner._mark_resumed(resumed_checkpoint) # pyright: ignore[reportPrivateUsage] + + # Seed a message so the runner has work to do (and creates checkpoints at superstep boundaries) + await ctx.send_message(WorkflowMessage(data=MockMessage(data=8), source_id=executor_a.id)) + + async for _ in runner.run_until_convergence(): + pass + + # Find the first checkpoint created after the resume point (across all workflows tracked by storage) + new_checkpoints = sorted( + await storage.list_checkpoints(workflow_name="test_name"), + key=lambda c: c.timestamp, + ) + assert new_checkpoints, "Resuming and running should produce at least one new checkpoint" + + # The first new checkpoint must chain to the resumed-from checkpoint, not to None + assert new_checkpoints[0].previous_checkpoint_id == "parent-checkpoint-id", ( + "First post-resume checkpoint must chain to the resumed checkpoint id; " + f"got {new_checkpoints[0].previous_checkpoint_id!r}" + ) + + # Subsequent post-resume checkpoints continue the chain + for i in range(1, len(new_checkpoints)): + assert new_checkpoints[i].previous_checkpoint_id == new_checkpoints[i - 1].checkpoint_id + + class ExecutorThatFailsWithEvents(Executor): """An executor that emits events and then raises an exception after receiving messages.""" diff --git a/python/packages/core/tests/workflow/test_runner_context.py b/python/packages/core/tests/workflow/test_runner_context.py new file mode 100644 index 00000000000..2a4662db574 --- /dev/null +++ b/python/packages/core/tests/workflow/test_runner_context.py @@ -0,0 +1,72 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for `InProcRunnerContext`.""" + +import pytest + +from agent_framework import ( + InProcRunnerContext, + WorkflowEvent, + WorkflowMessage, +) + + +def _make_request_info_event(request_id: str, source_executor_id: str = "executor") -> WorkflowEvent[str]: + return WorkflowEvent.request_info( + request_id=request_id, + source_executor_id=source_executor_id, + request_data="please respond", + response_type=str, + ) + + +class TestInProcRunnerContextResetForNewRun: + """Verify `reset_for_new_run` clears per-run state, including pending request_info events.""" + + async def test_reset_clears_pending_request_info_events(self) -> None: + ctx = InProcRunnerContext() + + await ctx.add_request_info_event(_make_request_info_event("req-1")) + await ctx.add_request_info_event(_make_request_info_event("req-2")) + + assert set((await ctx.get_pending_request_info_events()).keys()) == {"req-1", "req-2"} + + ctx.reset_for_new_run() + + assert await ctx.get_pending_request_info_events() == {} + + async def test_reset_clears_pending_request_info_events_when_already_empty(self) -> None: + ctx = InProcRunnerContext() + + assert await ctx.get_pending_request_info_events() == {} + + ctx.reset_for_new_run() + + assert await ctx.get_pending_request_info_events() == {} + + async def test_reset_after_pending_event_blocks_response_correlation(self) -> None: + """After `reset_for_new_run`, prior request ids must no longer correlate to a response.""" + ctx = InProcRunnerContext() + await ctx.add_request_info_event(_make_request_info_event("req-1")) + + ctx.reset_for_new_run() + + with pytest.raises(ValueError, match="No pending request found for request_id: req-1"): + await ctx.send_request_info_response("req-1", "answer") + + async def test_reset_clears_messages_events_and_streaming_flag(self) -> None: + """Sanity-check the other state `reset_for_new_run` is documented to clear.""" + ctx = InProcRunnerContext() + await ctx.send_message(WorkflowMessage(data="hello", source_id="executor")) + await ctx.add_event(WorkflowEvent("status", data="running")) + ctx.set_streaming(True) + + assert await ctx.has_messages() is True + assert await ctx.has_events() is True + assert ctx.is_streaming() is True + + ctx.reset_for_new_run() + + assert await ctx.has_messages() is False + assert await ctx.has_events() is False + assert ctx.is_streaming() is False diff --git a/python/packages/core/tests/workflow/test_workflow.py b/python/packages/core/tests/workflow/test_workflow.py index 27f24d26f9c..9884088cdf9 100644 --- a/python/packages/core/tests/workflow/test_workflow.py +++ b/python/packages/core/tests/workflow/test_workflow.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio +import gc import tempfile from collections.abc import AsyncIterable, Awaitable, Sequence from dataclasses import dataclass, field @@ -19,6 +20,7 @@ Content, Executor, FileCheckpointStorage, + InMemoryCheckpointStorage, Message, ResponseStream, WorkflowBuilder, @@ -26,6 +28,7 @@ WorkflowContext, WorkflowConvergenceException, WorkflowEvent, + WorkflowException, WorkflowMessage, WorkflowRunState, handler, @@ -759,8 +762,7 @@ async def run_workflow(): # Try to start a second concurrent execution - this should fail with pytest.raises( - RuntimeError, - match="Workflow is already running. Concurrent executions are not allowed.", + WorkflowException, match="Workflow is already running; concurrent runs are not allowed on the same instance." ): await workflow.run(NumberMessage(data=0)) @@ -795,8 +797,7 @@ async def consume_stream_slowly(): # Try to start a second concurrent execution - this should fail with pytest.raises( - RuntimeError, - match="Workflow is already running. Concurrent executions are not allowed.", + WorkflowException, match="Workflow is already running; concurrent runs are not allowed on the same instance." ): await workflow.run(NumberMessage(data=0)) @@ -828,14 +829,12 @@ async def consume_stream(): # Try different execution methods - all should fail with pytest.raises( - RuntimeError, - match="Workflow is already running. Concurrent executions are not allowed.", + WorkflowException, match="Workflow is already running; concurrent runs are not allowed on the same instance." ): await workflow.run(NumberMessage(data=0)) with pytest.raises( - RuntimeError, - match="Workflow is already running. Concurrent executions are not allowed.", + WorkflowException, match="Workflow is already running; concurrent runs are not allowed on the same instance." ): async for _ in workflow.run(NumberMessage(data=0), stream=True): break @@ -848,6 +847,92 @@ async def consume_stream(): assert result.get_final_state() == WorkflowRunState.IDLE +async def test_workflow_sequential_runs_after_completion() -> None: + """A completed run must release the runner so the next ``run`` succeeds. + + This is the happy-path counterpart to the concurrent-run guard tests: + those tests verify that a *concurrent* run is rejected, but they do not + verify that the lock is actually released afterwards. This test + exercises that release path explicitly across the three call shapes + (non-streaming, streaming-iterated, streaming-via-get_final_response) + and across multiple consecutive turns to catch lock leaks. + """ + executor = IncrementExecutor(id="seq_executor", limit=3, increment=1) + workflow = WorkflowBuilder(start_executor=executor).build() + + # Non-streaming -> non-streaming + r1 = await workflow.run(NumberMessage(data=0)) + assert r1.get_final_state() == WorkflowRunState.IDLE + + r2 = await workflow.run(NumberMessage(data=0)) + assert r2.get_final_state() == WorkflowRunState.IDLE + + # Non-streaming -> streaming-iterated + stream_events: list[WorkflowEvent] = [] + async for event in workflow.run(NumberMessage(data=0), stream=True): + stream_events.append(event) + assert any(e.type == "status" and e.state == WorkflowRunState.IDLE for e in stream_events) + + # Streaming -> streaming via get_final_response (no manual iteration) + r3 = await workflow.run(NumberMessage(data=0), stream=True).get_final_response() + assert r3.get_final_state() == WorkflowRunState.IDLE + + # Streaming -> non-streaming (back to the start) + r4 = await workflow.run(NumberMessage(data=0)) + assert r4.get_final_state() == WorkflowRunState.IDLE + + +async def test_workflow_unconsumed_stream_releases_run_lock() -> None: + """An unconsumed stream must not leak the run lock. + + ``Workflow.run`` reserves the runner *synchronously* so that concurrent + callers are rejected immediately. The reservation is normally released + by ``_run_core``'s ``finally`` once the stream is iterated. If the + caller never iterates the stream, a GC-time finalizer must release the + reservation instead - otherwise every subsequent ``Workflow.run`` call + on this instance would fail with the concurrent-run error. + """ + executor = IncrementExecutor(id="unconsumed_stream_exec", limit=3, increment=1) + workflow = WorkflowBuilder(start_executor=executor).build() + + # Build a stream and immediately drop it without iterating. + stream = workflow.run(NumberMessage(data=0), stream=True) + assert stream is not None # silence unused-variable warnings; stream is GC'd below + del stream + gc.collect() + # Yield to the event loop so any scheduled finalizer work can run. + await asyncio.sleep(0) + + # The runner should be back to IDLE; a fresh run must succeed. + result = await workflow.run(NumberMessage(data=0)) + assert result.get_final_state() == WorkflowRunState.IDLE + + +async def test_workflow_unawaited_run_coroutine_releases_run_lock() -> None: + """An un-awaited non-streaming ``run()`` coroutine must also not leak the lock. + + ``Workflow.run`` (non-streaming) returns a coroutine produced by + ``ResponseStream.get_final_response``. The underlying ResponseStream is + held alive by that coroutine, so dropping the coroutine without + awaiting it must still release the reservation via the same GC-time + fallback used for unconsumed streams. + """ + executor = IncrementExecutor(id="unawaited_run_exec", limit=3, increment=1) + workflow = WorkflowBuilder(start_executor=executor).build() + + coro = workflow.run(NumberMessage(data=0)) + # Closing suppresses the "coroutine was never awaited" warning. We cast to + # ``Any`` because the typed return is ``Awaitable[...]``; in practice it is + # a coroutine that exposes ``close``. + cast(Any, coro).close() + del coro + gc.collect() + await asyncio.sleep(0) + + result = await workflow.run(NumberMessage(data=0)) + assert result.get_final_state() == WorkflowRunState.IDLE + + class _StreamingTestAgent(BaseAgent): """Test agent that supports both streaming and non-streaming modes.""" @@ -1269,3 +1354,143 @@ async def test_output_executors_filtering_with_run_responses_streaming() -> None # endregion + + +# region Workflow.create_checkpoint + + +class TestWorkflowCreateCheckpoint: + """Tests for :meth:`Workflow.create_checkpoint`.""" + + async def test_returns_checkpoint_id_with_runtime_storage(self, simple_executor: Executor) -> None: + """Calling `create_checkpoint` with a runtime storage persists a checkpoint and returns its id.""" + storage = InMemoryCheckpointStorage() + workflow = WorkflowBuilder(start_executor=simple_executor).add_edge(simple_executor, simple_executor).build() + + checkpoint_id = await workflow.create_checkpoint(storage) + + assert checkpoint_id + loaded = await storage.load(checkpoint_id) + assert loaded is not None + assert loaded.checkpoint_id == checkpoint_id + assert loaded.workflow_name == workflow.name + assert loaded.graph_signature_hash == workflow.graph_signature_hash + + async def test_uses_buildtime_storage_when_none_provided(self, simple_executor: Executor) -> None: + """When called with `None`, the build-time storage is used.""" + storage = InMemoryCheckpointStorage() + workflow = ( + WorkflowBuilder(start_executor=simple_executor, checkpoint_storage=storage) + .add_edge(simple_executor, simple_executor) + .build() + ) + + checkpoint_id = await workflow.create_checkpoint(None) + + loaded = await storage.load(checkpoint_id) + assert loaded is not None + assert loaded.checkpoint_id == checkpoint_id + + async def test_raises_when_no_storage_available(self, simple_executor: Executor) -> None: + """Without build-time or runtime storage, `create_checkpoint(None)` raises.""" + workflow = WorkflowBuilder(start_executor=simple_executor).add_edge(simple_executor, simple_executor).build() + + with pytest.raises(WorkflowCheckpointException, match="Checkpoint storage must be provided"): + await workflow.create_checkpoint(None) + + async def test_raises_while_run_active(self, simple_executor: Executor) -> None: + """`create_checkpoint` must reject while a workflow run is still active.""" + storage = InMemoryCheckpointStorage() + workflow = WorkflowBuilder(start_executor=simple_executor).add_edge(simple_executor, simple_executor).build() + + # Hold a live reference to a streaming run without iterating it so that + # ``_is_run_active`` remains True (the active-run weakref still resolves). + active_stream = workflow.run(WorkflowMessage(data="hi", source_id="test"), stream=True) + try: + with pytest.raises(WorkflowException, match="Cannot create checkpoint while a workflow run is active"): + await workflow.create_checkpoint(storage) + finally: + # Drain the stream so the run completes cleanly and the active-run + # weakref is cleared; otherwise pytest's asyncio teardown can leak + # the unconsumed generator. + async for _ in active_stream: + pass + + async def test_clears_runtime_storage_after_call(self, simple_executor: Executor) -> None: + """The runtime storage override must not leak past the call.""" + storage = InMemoryCheckpointStorage() + workflow = WorkflowBuilder(start_executor=simple_executor).add_edge(simple_executor, simple_executor).build() + + await workflow.create_checkpoint(storage) + + assert workflow._runner.context.has_checkpointing() is False + assert workflow._runner.context._runtime_checkpoint_storage is None # type: ignore[attr-defined] + + async def test_clears_runtime_storage_after_failure(self, simple_executor: Executor) -> None: + """The runtime storage override must be cleared even if checkpoint creation fails.""" + from unittest.mock import AsyncMock + + storage = InMemoryCheckpointStorage() + workflow = WorkflowBuilder(start_executor=simple_executor).add_edge(simple_executor, simple_executor).build() + + # The runner logs-and-swallows storage save errors, so a failed save + # surfaces as the "Failed to create checkpoint." path when + # ``previous_checkpoint_id`` remains ``None``. Either way, the + # ``finally`` cleanup must still clear the runtime override. + storage.save = AsyncMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign] + + with pytest.raises(WorkflowCheckpointException, match="Failed to create checkpoint"): + await workflow.create_checkpoint(storage) + + assert workflow._runner.context._runtime_checkpoint_storage is None # type: ignore[attr-defined] + + async def test_alters_lineage_for_next_checkpoint(self, simple_executor: Executor) -> None: + """A manually created checkpoint becomes the parent of the next checkpoint.""" + storage = InMemoryCheckpointStorage() + workflow = ( + WorkflowBuilder(start_executor=simple_executor, checkpoint_storage=storage) + .add_edge(simple_executor, simple_executor) + .build() + ) + + first_id = await workflow.create_checkpoint(None) + second_id = await workflow.create_checkpoint(None) + + assert first_id != second_id + second = await storage.load(second_id) + assert second is not None + assert second.previous_checkpoint_id == first_id + + async def test_raises_when_save_fails_after_prior_success(self, simple_executor: Executor) -> None: + """A failed save after an earlier successful checkpoint must not return the stale id. + + The runner log-and-swallows storage save errors and only updates + ``previous_checkpoint_id`` on success. Without an explicit transition check, + ``create_checkpoint`` would silently return the previously stored id as if a + new checkpoint had been created. + """ + from unittest.mock import AsyncMock + + storage = InMemoryCheckpointStorage() + workflow = WorkflowBuilder(start_executor=simple_executor).add_edge(simple_executor, simple_executor).build() + + # First call succeeds and seeds ``previous_checkpoint_id``. + first_id = await workflow.create_checkpoint(storage) + assert first_id + + # Second call fails to save, so the runner leaves ``previous_checkpoint_id`` + # pointing at ``first_id``. The method must detect that the id did not + # transition and raise instead of returning the stale value. + original_save = storage.save + storage.save = AsyncMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign] + try: + with pytest.raises(WorkflowCheckpointException, match="Failed to create checkpoint"): + await workflow.create_checkpoint(storage) + finally: + storage.save = original_save # type: ignore[method-assign] + + # The runner's bookkeeping is unchanged after the failed call. + assert workflow._runner.previous_checkpoint_id == first_id # type: ignore[attr-defined] + + +# endregion diff --git a/python/packages/declarative/tests/test_http_request_executor.py b/python/packages/declarative/tests/test_http_request_executor.py index 4030cf42946..f07f19d27a3 100644 --- a/python/packages/declarative/tests/test_http_request_executor.py +++ b/python/packages/declarative/tests/test_http_request_executor.py @@ -90,7 +90,7 @@ async def _run(yaml_def: dict[str, Any], handler: HttpRequestHandler) -> Any: def _state(workflow: Any, events: Any) -> dict[str, Any]: """Read declarative state out of the workflow after run completes.""" - return workflow._state.get(DECLARATIVE_STATE_KEY) or {} + return workflow._runner.state.get(DECLARATIVE_STATE_KEY) or {} # Helper used by parametrised path tests @@ -151,7 +151,7 @@ async def test_get_parses_json_object(self) -> None: workflow = factory.create_workflow_from_definition(_yaml(_action(method="GET", response="Local.Result"))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert decl["Local"]["Result"] == {"key": "value", "number": 42} assert handler.last_info is not None assert handler.last_info.method == "GET" @@ -164,7 +164,7 @@ async def test_get_parses_plain_string(self) -> None: workflow = factory.create_workflow_from_definition(_yaml(_action(response="Local.Result"))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert decl["Local"]["Result"] == "not-json content" @pytest.mark.asyncio @@ -174,7 +174,7 @@ async def test_get_empty_body_yields_none(self) -> None: workflow = factory.create_workflow_from_definition(_yaml(_action(response="Local.Result"))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert decl["Local"]["Result"] is None @pytest.mark.asyncio @@ -184,7 +184,7 @@ async def test_response_object_form_path(self) -> None: workflow = factory.create_workflow_from_definition(_yaml(_action(response={"path": "Local.Result"}))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert decl["Local"]["Result"] == {"x": 1} @pytest.mark.asyncio @@ -517,7 +517,7 @@ async def test_response_headers_folded_with_commas(self) -> None: factory = WorkflowFactory(http_request_handler=handler) workflow = factory.create_workflow_from_definition(_yaml(_action(response_headers="Local.H"))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) h = decl["Local"]["H"] assert h["Content-Type"] == "application/json" assert h["Set-Cookie"] == "a=1,b=2" @@ -528,7 +528,7 @@ async def test_response_headers_empty_assigned_none(self) -> None: factory = WorkflowFactory(http_request_handler=handler) workflow = factory.create_workflow_from_definition(_yaml(_action(response_headers="Local.H"))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert decl["Local"]["H"] is None @pytest.mark.asyncio @@ -538,7 +538,7 @@ async def test_non_2xx_still_publishes_headers(self) -> None: workflow = factory.create_workflow_from_definition(_yaml(_action(response_headers="Local.H"))) with pytest.raises(DeclarativeActionError): await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert decl["Local"]["H"] == {"X-Trace": "abc"} @@ -559,7 +559,7 @@ async def test_conversation_id_appends_message(self) -> None: ) ) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) conv = decl["System"]["conversations"].get("conv-test-1") assert conv is not None assert len(conv["messages"]) == 1 @@ -570,7 +570,7 @@ async def test_empty_conversation_id_does_not_append(self) -> None: factory = WorkflowFactory(http_request_handler=handler) workflow = factory.create_workflow_from_definition(_yaml(_action(response="Local.Result", conversation_id=""))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) # Auto-init creates an entry for the System.ConversationId conversation, # but it should NOT have HTTP-appended messages from us. for _cid, conv in decl["System"]["conversations"].items(): @@ -582,7 +582,7 @@ async def test_empty_body_skips_conversation_append(self) -> None: factory = WorkflowFactory(http_request_handler=handler) workflow = factory.create_workflow_from_definition(_yaml(_action(conversation_id="conv-test-1"))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) # No conversation entry should have been created either. assert "conv-test-1" not in decl["System"]["conversations"] diff --git a/python/packages/declarative/tests/test_http_request_yaml_integration.py b/python/packages/declarative/tests/test_http_request_yaml_integration.py index 49cd0d15e83..64454e2b841 100644 --- a/python/packages/declarative/tests/test_http_request_yaml_integration.py +++ b/python/packages/declarative/tests/test_http_request_yaml_integration.py @@ -73,7 +73,7 @@ async def test_http_request_yaml_roundtrip() -> None: workflow = factory.create_workflow_from_yaml_path(FIXTURE_PATH) await workflow.run({}) - decl: dict[str, Any] = workflow._state.get(DECLARATIVE_STATE_KEY) or {} + decl: dict[str, Any] = workflow._runner.state.get(DECLARATIVE_STATE_KEY) or {} local = decl.get("Local") or {} assert local.get("RepoOwner") == "dotnet" diff --git a/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py b/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py index fdee1f7df1d..cfff22b84ea 100644 --- a/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py +++ b/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py @@ -244,7 +244,7 @@ async def test_output_result_parses_json_text(self) -> None: factory = WorkflowFactory(mcp_tool_handler=handler) workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"}))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert decl["Local"]["Result"] == [{"k": "v", "n": 1}] @pytest.mark.asyncio @@ -253,7 +253,7 @@ async def test_output_result_falls_back_to_raw_text(self) -> None: factory = WorkflowFactory(mcp_tool_handler=handler) workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"}))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert decl["Local"]["Result"] == ["plain text not json"] @pytest.mark.asyncio @@ -262,7 +262,7 @@ async def test_output_messages_writes_single_tool_role_message(self) -> None: factory = WorkflowFactory(mcp_tool_handler=handler) workflow = factory.create_workflow_from_definition(_yaml(_action(output={"messages": "Local.Messages"}))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) msg = decl["Local"]["Messages"] # Single Tool-role message containing both contents (parity with .NET). assert isinstance(msg, Message) @@ -276,7 +276,7 @@ async def test_uri_content_serialised_as_uri_string(self) -> None: factory = WorkflowFactory(mcp_tool_handler=handler) workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"}))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert decl["Local"]["Result"] == ["https://example.com/file.txt"] @pytest.mark.asyncio @@ -285,7 +285,7 @@ async def test_output_path_object_form(self) -> None: factory = WorkflowFactory(mcp_tool_handler=handler) workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": {"path": "Local.Result"}}))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert decl["Local"]["Result"] == ["ok"] @@ -306,7 +306,7 @@ async def test_conversation_id_appends_assistant_message(self) -> None: ) ) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) conv = decl["System"]["conversations"]["conv-42"] msgs = conv["messages"] if isinstance(conv, dict) else conv.messages assert len(msgs) == 1 @@ -328,7 +328,7 @@ async def test_empty_conversation_id_does_not_append(self) -> None: ) ) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) # Empty conversation id must not produce a `""` entry under System.conversations. conversations = decl.get("System", {}).get("conversations", {}) assert "" not in conversations @@ -562,7 +562,7 @@ async def test_handler_returns_error_result_assigns_error_string(self) -> None: factory = WorkflowFactory(mcp_tool_handler=handler) workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"}))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert decl["Local"]["Result"] == "Error: server down" @pytest.mark.asyncio @@ -571,7 +571,7 @@ async def test_tool_execution_exception_becomes_error_result(self) -> None: factory = WorkflowFactory(mcp_tool_handler=handler) workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"}))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert decl["Local"]["Result"] == "Error: invalid arguments" @pytest.mark.asyncio @@ -580,7 +580,7 @@ async def test_httpx_error_becomes_error_result(self) -> None: factory = WorkflowFactory(mcp_tool_handler=handler) workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"}))) await workflow.run({}) - decl = workflow._state.get(DECLARATIVE_STATE_KEY) + decl = workflow._runner.state.get(DECLARATIVE_STATE_KEY) result = decl["Local"]["Result"] assert isinstance(result, str) assert result.startswith("Error:") diff --git a/python/packages/declarative/tests/test_workflow_factory.py b/python/packages/declarative/tests/test_workflow_factory.py index f163cd18f0b..b6d0aaa6dad 100644 --- a/python/packages/declarative/tests/test_workflow_factory.py +++ b/python/packages/declarative/tests/test_workflow_factory.py @@ -289,11 +289,11 @@ async def test_as_agent_continuation_preserves_prior_state(self): # Stamp a marker into the declarative state between turns. The # continuation branch must preserve it; a state-clearing run would # wipe ``DECLARATIVE_STATE_KEY`` and force re-initialization. - state_data = workflow._state.get(DECLARATIVE_STATE_KEY) + state_data = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert isinstance(state_data, dict), "Expected declarative state to be initialized after turn 1" state_data["Local"] = {"persisted_marker": "kept-from-turn-1"} - workflow._state.set(DECLARATIVE_STATE_KEY, state_data) - workflow._state.commit() + workflow._runner.state.set(DECLARATIVE_STATE_KEY, state_data) + workflow._runner.state.commit() second = await agent.run("turn-2-msg") assert second.text == "turn-2-msg", ( @@ -303,7 +303,7 @@ async def test_as_agent_continuation_preserves_prior_state(self): # The continuation branch in ``_ensure_state_initialized`` must: # 1. preserve the cross-turn marker we stamped above # 2. refresh Inputs.input and System.LastMessage* to the new turn - post_state = workflow._state.get(DECLARATIVE_STATE_KEY) + post_state = workflow._runner.state.get(DECLARATIVE_STATE_KEY) assert isinstance(post_state, dict), "declarative state vanished between turns" local = post_state.get("Local", {}) assert local.get("persisted_marker") == "kept-from-turn-1", ( diff --git a/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py b/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py index ec4a9d85336..1f64d24eff4 100644 --- a/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py +++ b/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py @@ -17,6 +17,7 @@ from agent_framework import ( ChatOptions, + CheckpointID, Content, ContextProvider, FileCheckpointStorage, @@ -343,6 +344,7 @@ class ResponsesHostServer(ResponsesAgentServerHost): # TODO(@taochen): Allow a different checkpoint storage that stores checkpoints externally CHECKPOINT_STORAGE_PATH = "/.checkpoints" + INITIAL_CHECKPOINT_STORAGE_NAME = "initial" FUNCTION_APPROVAL_STORAGE_PATH = "/.function_approvals/approval_requests.json" def __init__( @@ -386,7 +388,6 @@ def __init__( ) self._is_workflow_agent = False - self._checkpoint_storage_path = None if isinstance(agent, WorkflowAgent): if agent.workflow._runner_context.has_checkpointing(): # pyright: ignore[reportPrivateUsage] raise RuntimeError( @@ -399,6 +400,12 @@ def __init__( else os.path.join(os.getcwd(), self.CHECKPOINT_STORAGE_PATH.lstrip("/")) ) self._is_workflow_agent = True + # The initial checkpoint storage that stores the workflow's initial state. We will use this checkpoint + # to restore the workflow when no conversation_id or previous_response_id is supplied in a request. + self._initial_checkpoint_storage = _checkpoint_storage_for_context( + self._checkpoint_storage_path, self.INITIAL_CHECKPOINT_STORAGE_NAME + ) + self._initial_checkpoint_id: CheckpointID | None = None self._agent = agent self._approval_storage = ( @@ -580,8 +587,6 @@ async def _handle_inner_workflow( # The following should never happen due to the checks above. # This is for type safety and defensive programming. - if self._checkpoint_storage_path is None: - raise RuntimeError("Checkpoint storage path is not configured for workflow agent.") if not isinstance(self._agent, WorkflowAgent): raise RuntimeError("Agent is not a workflow agent.") @@ -590,6 +595,12 @@ async def _handle_inner_workflow( # any future async resources owned by the workflow are entered here. await self._ensure_agent_ready() + # Create a checkpoint to store the initial state of the workflow, if it doesn't already exist. + # This allows us to restore to a clean slate when no conversation_id or previous_response_id + # is supplied in a request. + if self._initial_checkpoint_id is None: + self._initial_checkpoint_id = await self._agent.workflow.create_checkpoint(self._initial_checkpoint_storage) + # Determine the latest checkpoint (if any) so we can resume the # workflow's prior state for this turn. The directory is keyed by # the inbound context id (conversation_id when set, otherwise @@ -599,14 +610,40 @@ async def _handle_inner_workflow( # the only place that state lives is the workflow checkpoint, so # on every turn we restore the latest checkpoint and feed the new # input back into the start executor as a continuation rather than - # a fresh run. - latest_checkpoint_id: str | None = None - restore_storage: FileCheckpointStorage | None = None + # a fresh run. If no conversation_id or previous_response_id is + # supplied, the workflow will be restored to the initial checkpoint + # to avoid context bleed between requests. + latest_checkpoint_id: str = self._initial_checkpoint_id + restore_storage: FileCheckpointStorage = self._initial_checkpoint_storage if context_id is not None: - restore_storage = _checkpoint_storage_for_context(self._checkpoint_storage_path, context_id) - latest_checkpoint = await restore_storage.get_latest(workflow_name=self._agent.workflow.name) + context_storage = _checkpoint_storage_for_context(self._checkpoint_storage_path, context_id) + latest_checkpoint = await context_storage.get_latest(workflow_name=self._agent.workflow.name) if latest_checkpoint is not None: + # Only switch the restore storage when a checkpoint was actually + # found under the per-context directory. Otherwise the initial + # checkpoint id would not resolve in `context_storage` and the + # restore call below would fail. latest_checkpoint_id = latest_checkpoint.checkpoint_id + restore_storage = context_storage + + # Restore the workflow to the latest checkpoint and run it with the + # new input. Events (including request info events) will not be emitted + # during restoration (in streaming) or after restoration (in non-streaming) + # since we assume the client had already seen those events and we don't want + # to emit duplicates. + if is_streaming_request: + async for _ in self._agent.run( + stream=True, + checkpoint_id=latest_checkpoint_id, + checkpoint_storage=restore_storage, + ): + pass + else: + await self._agent.run( + stream=False, + checkpoint_id=latest_checkpoint_id, + checkpoint_storage=restore_storage, + ) # Storage that will receive checkpoints written during this turn. # When the caller chains with previous_response_id, the next turn @@ -619,37 +656,6 @@ async def _handle_inner_workflow( write_context_id = context.conversation_id or context.response_id write_storage = _checkpoint_storage_for_context(self._checkpoint_storage_path, write_context_id) - # Multi-turn pattern: when we have a prior checkpoint, restore it - # first (drive the workflow back to idle with prior state intact), - # then make a separate call that delivers the new user input. This - # depends on Workflow.run preserving shared state across calls. The - # restore-only call may yield events from any pending in-flight - # work in the checkpoint; we consume those internally here so they - # don't surface to the response stream as duplicates. - # - # If the restored checkpoint had pending request_info events, the - # restore-only call replays them through - # ``WorkflowAgent._convert_workflow_event_to_agent_response_updates`` - # and populates ``self._agent.pending_requests``. That is the correct - # state: those requests are genuinely outstanding, and the next - # ``run(input_messages, ...)`` call may contain ``function_call_output`` - # items (carried as FunctionResult/FunctionApprovalResponse content) - # that fulfill them via :meth:`WorkflowAgent._process_pending_requests`. - if latest_checkpoint_id is not None: - if is_streaming_request: - async for _ in self._agent.run( - stream=True, - checkpoint_id=latest_checkpoint_id, - checkpoint_storage=restore_storage, - ): - pass - else: - await self._agent.run( - stream=False, - checkpoint_id=latest_checkpoint_id, - checkpoint_storage=restore_storage, - ) - # Now run the agent with the latest input response_event_stream = ResponseEventStream(response_id=context.response_id, model=request.model) diff --git a/python/packages/foundry_hosting/tests/test_responses.py b/python/packages/foundry_hosting/tests/test_responses.py index 8bc09c5d13d..13923cc9d2d 100644 --- a/python/packages/foundry_hosting/tests/test_responses.py +++ b/python/packages/foundry_hosting/tests/test_responses.py @@ -3032,6 +3032,7 @@ async def test_handle_inner_workflow_restores_message_role_checkpoint_from_previ agent.workflow = MagicMock() agent.workflow.name = "wf" agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False) + agent.workflow.create_checkpoint = AsyncMock(return_value="cp_initial") agent.run = AsyncMock( side_effect=[ AgentResponse(messages=[]), @@ -3062,6 +3063,181 @@ async def test_handle_inner_workflow_restores_message_role_checkpoint_from_previ assert new_turn_messages[0].text == "next turn" assert new_turn_call.kwargs["checkpoint_storage"].storage_path == (root / response_id).resolve() + async def test_handle_inner_workflow_restores_initial_checkpoint_when_no_context_id(self, tmp_path: Any) -> None: + """When neither previous_response_id nor conversation_id is supplied, the workflow + must be restored from the initial checkpoint to avoid context bleed between requests. + """ + from agent_framework import WorkflowAgent + from azure.ai.agentserver.responses import ResponseContext + from azure.ai.agentserver.responses.models import CreateResponse, ItemMessage + + response_id = "resp_current" + root = tmp_path / "root" + root.mkdir() + + agent = MagicMock(spec=WorkflowAgent) + agent.id = "wf-agent" + agent.name = "wf" + agent.description = "" + agent.context_providers = [] + agent.workflow = MagicMock() + agent.workflow.name = "wf" + agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False) + agent.workflow.create_checkpoint = AsyncMock(return_value="cp_initial") + agent.run = AsyncMock( + side_effect=[ + AgentResponse(messages=[]), + AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("ok")])]), + ] + ) + server = ResponsesHostServer(agent, store=InMemoryResponseProvider()) + server._checkpoint_storage_path = str(root) # pyright: ignore[reportPrivateUsage] + + # No previous_response_id and no conversation_id. + request = CreateResponse(model="m", input="hi") + context = ResponseContext(response_id=response_id, mode_flags=MagicMock()) + input_item = ItemMessage({"type": "message", "role": "user", "content": "fresh turn"}) + + with patch.object(ResponseContext, "get_input_items", new=AsyncMock(return_value=[input_item])): + async for _ in server._handle_inner_workflow(request, context): # pyright: ignore[reportPrivateUsage] + pass + + # The initial checkpoint must have been created exactly once, against the + # initial checkpoint storage owned by the server. + assert agent.workflow.create_checkpoint.await_count == 1 + (initial_storage_arg,) = agent.workflow.create_checkpoint.await_args.args + assert initial_storage_arg is server._initial_checkpoint_storage # pyright: ignore[reportPrivateUsage] + + # First run() call is the restoration: no positional input, restored from + # the initial checkpoint id, using the initial checkpoint storage (NOT a + # per-context directory). + assert agent.run.call_count == 2 + restore_call = agent.run.call_args_list[0] + assert restore_call.args == () + assert restore_call.kwargs["checkpoint_id"] == "cp_initial" + assert restore_call.kwargs["checkpoint_storage"] is server._initial_checkpoint_storage # pyright: ignore[reportPrivateUsage] + + # Second run() call delivers the new input; checkpoints land under response_id + # (the write-sink directory keyed by the current response id). + new_turn_call = agent.run.call_args_list[1] + new_turn_messages = new_turn_call.args[0] + assert len(new_turn_messages) == 1 + assert new_turn_messages[0].text == "fresh turn" + assert new_turn_call.kwargs["checkpoint_storage"].storage_path == (root / response_id).resolve() + + async def test_handle_inner_workflow_creates_initial_checkpoint_once_across_requests(self, tmp_path: Any) -> None: + """The initial checkpoint must be created exactly once and reused across + subsequent requests, regardless of whether the requests carry a context id. + """ + from agent_framework import WorkflowAgent + from azure.ai.agentserver.responses import ResponseContext + from azure.ai.agentserver.responses.models import CreateResponse, ItemMessage + + root = tmp_path / "root" + root.mkdir() + + agent = MagicMock(spec=WorkflowAgent) + agent.id = "wf-agent" + agent.name = "wf" + agent.description = "" + agent.context_providers = [] + agent.workflow = MagicMock() + agent.workflow.name = "wf" + agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False) + agent.workflow.create_checkpoint = AsyncMock(return_value="cp_initial") + # Four run() calls total: restore + new turn for each of the two requests. + agent.run = AsyncMock(return_value=AgentResponse(messages=[])) + + server = ResponsesHostServer(agent, store=InMemoryResponseProvider()) + server._checkpoint_storage_path = str(root) # pyright: ignore[reportPrivateUsage] + + request1 = CreateResponse(model="m", input="hi") + context1 = ResponseContext(response_id="resp_first", mode_flags=MagicMock()) + request2 = CreateResponse(model="m", input="hi again") + context2 = ResponseContext(response_id="resp_second", mode_flags=MagicMock()) + input_item = ItemMessage({"type": "message", "role": "user", "content": "turn"}) + + with patch.object(ResponseContext, "get_input_items", new=AsyncMock(return_value=[input_item])): + async for _ in server._handle_inner_workflow(request1, context1): # pyright: ignore[reportPrivateUsage] + pass + async for _ in server._handle_inner_workflow(request2, context2): # pyright: ignore[reportPrivateUsage] + pass + + # Initial checkpoint creation must not be repeated on the second request. + assert agent.workflow.create_checkpoint.await_count == 1 + + # Both requests' restoration calls must use the same initial checkpoint id + # and the same initial checkpoint storage instance. + restore_call_1 = agent.run.call_args_list[0] + restore_call_2 = agent.run.call_args_list[2] + assert restore_call_1.kwargs["checkpoint_id"] == "cp_initial" + assert restore_call_2.kwargs["checkpoint_id"] == "cp_initial" + assert ( + restore_call_1.kwargs["checkpoint_storage"] + is restore_call_2.kwargs["checkpoint_storage"] + is server._initial_checkpoint_storage # pyright: ignore[reportPrivateUsage] + ) + + async def test_handle_inner_workflow_falls_back_to_initial_storage_when_context_dir_is_empty( + self, tmp_path: Any + ) -> None: + """When ``previous_response_id`` is supplied but its checkpoint directory has no + checkpoints, the restoration must fall back to BOTH the initial checkpoint id + and the initial checkpoint storage. Otherwise the initial id would be looked up + inside the per-context storage where it does not exist, and the restore would + fail. + """ + from agent_framework import WorkflowAgent + from azure.ai.agentserver.responses import ResponseContext + from azure.ai.agentserver.responses.models import CreateResponse, ItemMessage + + previous_response_id = "resp_previous" + response_id = "resp_current" + root = tmp_path / "root" + root.mkdir() + # The per-context storage exists but contains no checkpoints. + (root / previous_response_id).mkdir() + + agent = MagicMock(spec=WorkflowAgent) + agent.id = "wf-agent" + agent.name = "wf" + agent.description = "" + agent.context_providers = [] + agent.workflow = MagicMock() + agent.workflow.name = "wf" + agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False) + agent.workflow.create_checkpoint = AsyncMock(return_value="cp_initial") + agent.run = AsyncMock( + side_effect=[ + AgentResponse(messages=[]), + AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("ok")])]), + ] + ) + server = ResponsesHostServer(agent, store=InMemoryResponseProvider()) + server._checkpoint_storage_path = str(root) # pyright: ignore[reportPrivateUsage] + + request = CreateResponse(model="m", input="hi", previous_response_id=previous_response_id) + context = ResponseContext( + response_id=response_id, previous_response_id=previous_response_id, mode_flags=MagicMock() + ) + input_item = ItemMessage({"type": "message", "role": "user", "content": "next turn"}) + + with patch.object(ResponseContext, "get_input_items", new=AsyncMock(return_value=[input_item])): + async for _ in server._handle_inner_workflow(request, context): # pyright: ignore[reportPrivateUsage] + pass + + # The restoration call must use the initial id AND the initial storage, + # not the empty per-context storage. Mismatching the two would attempt + # to load ``cp_initial`` from a directory that doesn't contain it. + assert agent.run.call_count == 2 + restore_call = agent.run.call_args_list[0] + assert restore_call.kwargs["checkpoint_id"] == "cp_initial" + assert restore_call.kwargs["checkpoint_storage"] is server._initial_checkpoint_storage # pyright: ignore[reportPrivateUsage] + + # The new turn still writes checkpoints under the current response id. + new_turn_call = agent.run.call_args_list[1] + assert new_turn_call.kwargs["checkpoint_storage"].storage_path == (root / response_id).resolve() + @pytest.mark.parametrize( "bad_id", [ @@ -3155,6 +3331,8 @@ async def test_handle_inner_workflow_rejects_malicious_context_id( agent.workflow = MagicMock() agent.workflow.name = "wf" agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False) + agent.workflow.create_checkpoint = AsyncMock(return_value="cp_initial") + agent.run = AsyncMock(return_value=AgentResponse(messages=[])) # Constructor inspects WorkflowAgent.workflow internals; bypass setup # by feeding a configured mock through a normal init.