Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/strands/multiagent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
standardized communication between agents.
"""

from .base import MultiAgentBase, MultiAgentResult, Status
from .base import EdgeExecutionMode, MultiAgentBase, MultiAgentResult, Status
from .graph import GraphBuilder, GraphResult
from .swarm import Swarm, SwarmResult

__all__ = [
"EdgeExecutionMode",
"GraphBuilder",
"GraphResult",
"MultiAgentBase",
Expand Down
18 changes: 18 additions & 0 deletions src/strands/multiagent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,24 @@ class Status(Enum):
INTERRUPTED = "interrupted"


class EdgeExecutionMode(Enum):
"""Edge execution mode for graph traversal.

Controls how the graph determines when a node is ready to execute
based on its incoming edges.

Attributes:
OR: Node executes when ANY incoming edge's source node completes (default).
This is the current behavior - a node becomes ready as soon as any
of its predecessors finish.
AND: Node executes when ALL incoming edges' source nodes complete.
The node waits for every predecessor to finish before executing.
"""

OR = "or"
AND = "and"


@dataclass
class NodeResult:
"""Unified result from node execution - handles both Agent and nested MultiAgentBase results."""
Expand Down
106 changes: 93 additions & 13 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from ..types.event_loop import Metrics, Usage
from ..types.multiagent import MultiAgentInput
from ..types.traces import AttributeValue
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
from .base import EdgeExecutionMode, MultiAgentBase, MultiAgentResult, NodeResult, Status

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -241,6 +241,7 @@ def __init__(self) -> None:
self._execution_timeout: float | None = None
self._node_timeout: float | None = None
self._reset_on_revisit: bool = False
self._edge_execution_mode: EdgeExecutionMode = EdgeExecutionMode.OR
self._id: str = _DEFAULT_GRAPH_ID
self._session_manager: SessionManager | None = None
self._hooks: list[HookProvider] | None = None
Expand Down Expand Up @@ -307,6 +308,25 @@ def reset_on_revisit(self, enabled: bool = True) -> "GraphBuilder":
self._reset_on_revisit = enabled
return self

def set_edge_execution_mode(self, mode: EdgeExecutionMode) -> "GraphBuilder":
"""Set the edge execution mode for determining node readiness.

Controls how the graph determines when a node is ready to execute
based on its incoming edges.

Args:
mode: EdgeExecutionMode.OR (default) - node executes when ANY predecessor completes.
EdgeExecutionMode.AND - node executes when ALL predecessors complete.

Example:
For a graph where nodes A, B, C all connect to node Z:

- OR mode (default): Z executes as soon as A, B, or C completes
- AND mode: Z waits until A AND B AND C have all completed
"""
self._edge_execution_mode = mode
return self

def set_max_node_executions(self, max_executions: int) -> "GraphBuilder":
"""Set maximum number of node executions allowed.

Expand Down Expand Up @@ -389,6 +409,7 @@ def build(self) -> "Graph":
session_manager=self._session_manager,
hooks=self._hooks,
id=self._id,
edge_execution_mode=self._edge_execution_mode,
)

def _validate_graph(self) -> None:
Expand Down Expand Up @@ -420,6 +441,7 @@ def __init__(
hooks: list[HookProvider] | None = None,
id: str = _DEFAULT_GRAPH_ID,
trace_attributes: Mapping[str, AttributeValue] | None = None,
edge_execution_mode: EdgeExecutionMode = EdgeExecutionMode.OR,
) -> None:
"""Initialize Graph with execution limits and reset behavior.

Expand All @@ -435,6 +457,9 @@ def __init__(
hooks: List of hook providers for monitoring and extending graph execution behavior (default: None)
id: Unique graph id (default: None)
trace_attributes: Custom trace attributes to apply to the agent's trace span (default: None)
edge_execution_mode: Controls when nodes execute based on incoming edges (default: OR).
OR - node executes when ANY predecessor completes.
AND - node executes when ALL predecessors complete.
"""
super().__init__()

Expand All @@ -448,6 +473,7 @@ def __init__(
self.execution_timeout = execution_timeout
self.node_timeout = node_timeout
self.reset_on_revisit = reset_on_revisit
self.edge_execution_mode = edge_execution_mode
self.state = GraphState()
self._interrupt_state = _InterruptState()
self.tracer = get_tracer()
Expand Down Expand Up @@ -824,23 +850,65 @@ def _find_newly_ready_nodes(self, completed_batch: list["GraphNode"]) -> list["G
return newly_ready

def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list["GraphNode"]) -> bool:
"""Check if a node is ready considering conditional edges."""
"""Check if a node is ready considering conditional edges and execution mode.

For OR mode (default): Node is ready when ANY incoming edge's source has completed.
For AND mode: Node is ready when ALL incoming edges' sources have completed.
"""
# Get incoming edges to this node
incoming_edges = [edge for edge in self.edges if edge.to_node == node]

# Check if at least one incoming edge condition is satisfied
for edge in incoming_edges:
if edge.from_node in completed_batch:
if edge.should_traverse(self.state):
if not incoming_edges:
return False

if self.edge_execution_mode == EdgeExecutionMode.AND:
# AND mode: ALL incoming edges must have their source nodes completed
# and all conditions must be satisfied
all_completed_nodes = self.state.completed_nodes | set(completed_batch)
has_new_completion = False

for edge in incoming_edges:
# Check if source node has completed (either previously or in current batch)
if edge.from_node not in all_completed_nodes:
logger.debug(
"from=<%s>, to=<%s> | edge ready via satisfied condition", edge.from_node.node_id, node.node_id
"from=<%s>, to=<%s> | AND mode: source not completed", edge.from_node.node_id, node.node_id
)
return True
else:
return False
# Check if condition is satisfied
if not edge.should_traverse(self.state):
logger.debug(
"from=<%s>, to=<%s> | edge condition not satisfied", edge.from_node.node_id, node.node_id
"from=<%s>, to=<%s> | AND mode: edge condition not satisfied",
edge.from_node.node_id,
node.node_id,
)
return False
return False
# Track if at least one edge is from the current batch
if edge.from_node in completed_batch:
has_new_completion = True

# Only ready if at least one edge is newly completed (prevents re-triggering)
if has_new_completion:
logger.debug("node=<%s> | AND mode: all dependencies satisfied", node.node_id)
return True
return False
else:
# OR mode (default): ANY incoming edge condition being satisfied is enough
for edge in incoming_edges:
if edge.from_node in completed_batch:
if edge.should_traverse(self.state):
logger.debug(
"from=<%s>, to=<%s> | OR mode: edge ready via satisfied condition",
edge.from_node.node_id,
node.node_id,
)
return True
else:
logger.debug(
"from=<%s>, to=<%s> | OR mode: edge condition not satisfied",
edge.from_node.node_id,
node.node_id,
)
return False

async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
"""Execute a single node and yield TypedEvent objects."""
Expand Down Expand Up @@ -1172,6 +1240,12 @@ def deserialize_state(self, payload: dict[str, Any]) -> None:
self._resume_from_session = True

def _compute_ready_nodes_for_resume(self) -> list[GraphNode]:
"""Compute which nodes should be ready to execute when resuming.

Respects the edge_execution_mode setting:
- OR mode: Node is ready if ANY incoming edge's source has completed
- AND mode: Node is ready if ALL incoming edges' sources have completed
"""
if self.state.status == Status.PENDING:
return []
ready_nodes: list[GraphNode] = []
Expand All @@ -1182,8 +1256,14 @@ def _compute_ready_nodes_for_resume(self) -> list[GraphNode]:
incoming = [e for e in self.edges if e.to_node is node]
if not incoming:
ready_nodes.append(node)
elif all(e.from_node in completed_nodes and e.should_traverse(self.state) for e in incoming):
ready_nodes.append(node)
elif self.edge_execution_mode == EdgeExecutionMode.AND:
# AND mode: ALL incoming edges must have completed sources with satisfied conditions
if all(e.from_node in completed_nodes and e.should_traverse(self.state) for e in incoming):
ready_nodes.append(node)
else:
# OR mode: ANY incoming edge with completed source and satisfied condition
if any(e.from_node in completed_nodes and e.should_traverse(self.state) for e in incoming):
ready_nodes.append(node)

return ready_nodes

Expand Down
Loading
Loading