Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/packages/core/agent_framework/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@
TodoSessionStore,
TodoStore,
)
from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPTaskOptions, MCPWebsocketTool, SamplingApprovalCallback
from ._harness._tool_approval import (
DEFAULT_TOOL_APPROVAL_SOURCE_ID,
ToolApprovalMiddleware,
Expand All @@ -135,6 +134,7 @@
create_always_approve_tool_response,
create_always_approve_tool_with_arguments_response,
)
from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPTaskOptions, MCPWebsocketTool, SamplingApprovalCallback
from ._middleware import (
AgentContext,
AgentMiddleware,
Expand Down Expand Up @@ -252,6 +252,7 @@
)
from ._workflows._agent_utils import resolve_agent_id
from ._workflows._checkpoint import (
CheckpointID,
CheckpointStorage,
FileCheckpointStorage,
InMemoryCheckpointStorage,
Expand Down Expand Up @@ -295,7 +296,6 @@
workflow,
)
from ._workflows._request_info_mixin import response_handler
from ._workflows._runner import Runner
from ._workflows._runner_context import (
InProcRunnerContext,
RunnerContext,
Expand Down Expand Up @@ -390,6 +390,7 @@
"ChatResponse",
"ChatResponseUpdate",
"CheckResult",
"CheckpointID",
"CheckpointStorage",
"ClassSkill",
"CompactionProvider",
Expand Down Expand Up @@ -481,7 +482,6 @@
"RoleLiteral",
"RubricScore",
"RunContext",
"Runner",
Comment thread
TaoChenOSU marked this conversation as resolved.
"RunnerContext",
"SamplingApprovalCallback",
"SecretString",
Expand Down
4 changes: 1 addition & 3 deletions python/packages/core/agent_framework/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1989,9 +1989,7 @@ def _store_already_approved_approval_requests(
return

existing_groups = state.get(_ALREADY_APPROVED_APPROVAL_REQUEST_GROUPS_KEY)
pending_groups: list[Any] = (
list(cast(Iterable[Any], existing_groups)) if isinstance(existing_groups, list) else []
)
pending_groups: list[Any] = list(cast(Iterable[Any], existing_groups)) if isinstance(existing_groups, list) else []
pending_groups.append({
"approval_request_ids": visible_ids,
"approval_requests": [request.to_dict() for request in already_approved_requests],
Expand Down
185 changes: 101 additions & 84 deletions python/packages/core/agent_framework/_workflows/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from ..exceptions import (
WorkflowCheckpointException,
WorkflowConvergenceException,
WorkflowRunnerException,
)
from ._checkpoint import CheckpointID, CheckpointStorage, WorkflowCheckpoint
from ._const import EXECUTOR_STATE_KEY
Expand Down Expand Up @@ -63,99 +62,105 @@ def __init__(
self._iteration = 0
self._max_iterations = max_iterations
self._state = state
self._running = False
self._resumed_from_checkpoint = False # Track whether we resumed

# Checkpointing related attributes
self._resumed_from_checkpoint = False
self.previous_checkpoint_id: CheckpointID | None = None

@property
def context(self) -> RunnerContext:
"""Get the workflow context."""
"""Get the runner context for message, event, and checkpoint handling."""
return self._ctx

@property
def state(self) -> State:
"""Get the shared state for the workflow."""
return self._state

def reset_iteration_count(self) -> None:
"""Reset the iteration count to zero."""
"""Reset the iteration count to zero.

This is useful when the workflow resumes from a new set of messages.

Note:
When a workflow is resumed from a response (for a request_info_event)
or a checkpoint, the iteration count is normally NOT reset.
"""
self._iteration = 0

async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]:
"""Run the workflow until no more messages are sent."""
if self._running:
raise WorkflowRunnerException("Runner is already running.")

self._running = True
previous_checkpoint_id: CheckpointID | None = None
try:
# Emit any events already produced prior to entering loop
if await self._ctx.has_events():
logger.info("Yielding pre-loop events")
for event in await self._ctx.drain_events():
yield event

# Create the first checkpoint. Checkpoints are usually considered to be created at the end of an iteration,
# we can think of the first checkpoint as being created at the end of a "superstep 0" which captures the
# states after which the start executor has run. Note that we execute the start executor outside of the
# main iteration loop.
if await self._ctx.has_messages() and not self._resumed_from_checkpoint:
previous_checkpoint_id = await self._create_checkpoint_if_enabled(previous_checkpoint_id)

while self._iteration < self._max_iterations:
logger.info(f"Starting superstep {self._iteration + 1}")
yield WorkflowEvent.superstep_started(iteration=self._iteration + 1)

# Run iteration concurrently with live event streaming: we poll
# for new events while the iteration coroutine progresses.
iteration_task = asyncio.create_task(self._run_iteration())
try:
while not iteration_task.done():
try:
# Wait briefly for any new event; timeout allows progress checks
event = await asyncio.wait_for(self._ctx.next_event(), timeout=0.05)
yield event
except asyncio.TimeoutError:
# Periodically continue to let iteration advance
continue
except asyncio.CancelledError:
# Propagate cancellation to the iteration task to avoid orphaned work
iteration_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await iteration_task
raise

# Propagate errors from iteration, but first surface any pending events
try:
# Emit any events already produced prior to entering loop
if await self._ctx.has_events():
logger.info("Yielding pre-loop events")
for event in await self._ctx.drain_events():
yield event

# Create a checkpoint before a run starts. Checkpoints are usually considered to be created at the
# end of an iteration, we can think of this checkpoint as being created at the end of "superstep 0"
# which captures the states after which the start executor has run. Note that we execute the start
# executor outside of the main iteration loop.
if await self._ctx.has_messages() and not self._resumed_from_checkpoint:
await self.create_checkpoint_if_enabled()

while self._iteration < self._max_iterations:
logger.info(f"Starting superstep {self._iteration + 1}")
yield WorkflowEvent.superstep_started(iteration=self._iteration + 1)

# Run iteration concurrently with live event streaming: we poll
# for new events while the iteration coroutine progresses.
iteration_task = asyncio.create_task(self._run_iteration())
try:
while not iteration_task.done():
try:
# Wait briefly for any new event; timeout allows progress checks
event = await asyncio.wait_for(self._ctx.next_event(), timeout=0.05)
yield event
except asyncio.TimeoutError:
# Periodically continue to let iteration advance
continue
except asyncio.CancelledError:
# Propagate cancellation to the iteration task to avoid orphaned work
iteration_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await iteration_task
except Exception:
# Make sure failure-related events (like ExecutorFailedEvent) are surfaced
if await self._ctx.has_events():
for event in await self._ctx.drain_events():
yield event
raise
self._iteration += 1

# Drain any straggler events emitted at tail end
raise

# Propagate errors from iteration, but first surface any pending events
try:
await iteration_task
except Exception:
# Make sure failure-related events (like ExecutorFailedEvent) are surfaced
if await self._ctx.has_events():
for event in await self._ctx.drain_events():
yield event
raise
self._iteration += 1

# Drain any straggler events emitted at tail end
if await self._ctx.has_events():
for event in await self._ctx.drain_events():
yield event

logger.info(f"Completed superstep {self._iteration}")
logger.info(f"Completed superstep {self._iteration}")

# Commit pending state changes at superstep boundary
self._state.commit()
# Commit pending state changes at superstep boundary
self._state.commit()

# Create checkpoint after each superstep iteration
previous_checkpoint_id = await self._create_checkpoint_if_enabled(previous_checkpoint_id)
# Create checkpoint after each superstep iteration
await self.create_checkpoint_if_enabled()

yield WorkflowEvent.superstep_completed(iteration=self._iteration)
yield WorkflowEvent.superstep_completed(iteration=self._iteration)

# Check for convergence: no more messages to process
if not await self._ctx.has_messages():
break
# Check for convergence: no more messages to process
if not await self._ctx.has_messages():
break

if self._iteration >= self._max_iterations and await self._ctx.has_messages():
raise WorkflowConvergenceException(f"Runner did not converge after {self._max_iterations} iterations.")
logger.info(f"Workflow completed after {self._iteration} supersteps")
self._resumed_from_checkpoint = False # Reset resume flag for next run
Comment thread
TaoChenOSU marked this conversation as resolved.

logger.info(f"Workflow completed after {self._iteration} supersteps")
self._resumed_from_checkpoint = False # Reset resume flag for next run
finally:
self._running = False
if self._iteration >= self._max_iterations and await self._ctx.has_messages():
raise WorkflowConvergenceException(f"Runner did not converge after {self._max_iterations} iterations.")

async def _run_iteration(self) -> None:
"""Run a single iteration of the workflow.
Expand Down Expand Up @@ -209,10 +214,10 @@ async def _deliver_messages_for_edge_runner(edge_runner: EdgeRunner) -> None:
]
await asyncio.gather(*tasks)

async def _create_checkpoint_if_enabled(self, previous_checkpoint_id: CheckpointID | None) -> CheckpointID | None:
async def create_checkpoint_if_enabled(self) -> None:
"""Create a checkpoint if checkpointing is enabled and attach a label and metadata."""
if not self._ctx.has_checkpointing():
return None
return

try:
# Save executor states into the shared state before creating the checkpoint,
Expand All @@ -227,22 +232,33 @@ async def _create_checkpoint_if_enabled(self, previous_checkpoint_id: Checkpoint
self._workflow_name,
self._graph_signature_hash,
self._state,
previous_checkpoint_id,
self.previous_checkpoint_id,
self._iteration,
)

logger.info(f"Created checkpoint: {checkpoint_id}")
return checkpoint_id
logger.info(
"Created checkpoint: %s with parent checkpoint at iteration %d: %s",
checkpoint_id,
self._iteration,
self.previous_checkpoint_id,
)
self.previous_checkpoint_id = checkpoint_id
except Exception as e:
logger.warning(f"Failed to create checkpoint: {e}")
return None
logger.warning(
"Failed to create checkpoint at iteration %d: %s. "
"Note that this does not fail the workflow run. "
"The next successfully-created checkpoint will be parented to the last successful checkpoint: %s",
self._iteration,
e,
self.previous_checkpoint_id,
)

async def restore_from_checkpoint(
self,
checkpoint_id: CheckpointID,
checkpoint_storage: CheckpointStorage | None = None,
) -> None:
"""Restore workflow state from a checkpoint.
"""Restore the runner from a checkpoint.

Args:
checkpoint_id: The ID of the checkpoint to restore from
Expand Down Expand Up @@ -290,7 +306,7 @@ async def restore_from_checkpoint(
# Apply the checkpoint to the context
await self._ctx.apply_checkpoint(checkpoint)
# Mark the runner as resumed
self._mark_resumed(checkpoint.iteration_count)
self._mark_resumed(checkpoint)

logger.info(f"Successfully restored workflow from checkpoint: {checkpoint_id}")
except WorkflowCheckpointException:
Expand Down Expand Up @@ -356,13 +372,14 @@ def _parse_edge_runners(self, edge_runners: list[EdgeRunner]) -> dict[str, list[

return parsed

def _mark_resumed(self, iteration: int) -> None:
def _mark_resumed(self, checkpoint: WorkflowCheckpoint) -> None:
"""Mark the runner as having resumed from a checkpoint.

Optionally set the current iteration and max iterations.
"""
self._resumed_from_checkpoint = True
self._iteration = iteration
self._iteration = checkpoint.iteration_count
self.previous_checkpoint_id = checkpoint.checkpoint_id

async def _set_executor_state(self, executor_id: str, state: dict[str, Any]) -> None:
"""Store executor state in state under a reserved key.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,12 +403,14 @@ async def load_checkpoint(self, checkpoint_id: CheckpointID) -> WorkflowCheckpoi
def reset_for_new_run(self) -> None:
"""Reset the context for a new workflow run.

This clears messages, events, and resets streaming flag.
Runtime checkpoint storage is NOT cleared here as it's managed at the workflow level.
Clears messages, the pending event queue, the pending request_info
correlation map, and the streaming flag. Runtime checkpoint storage is
NOT cleared here as it's managed at the workflow level.
"""
self._messages.clear()
# Clear any pending events (best-effort) by recreating the queue
self._event_queue = asyncio.Queue()
self._pending_request_info_events.clear()
Comment thread
TaoChenOSU marked this conversation as resolved.
self._streaming = False # Reset streaming flag

async def apply_checkpoint(self, checkpoint: WorkflowCheckpoint) -> None:
Expand Down
Loading
Loading