diff --git a/python/packages/core/agent_framework/_workflows/_edge.py b/python/packages/core/agent_framework/_workflows/_edge.py index 02544ad3df..b9dbd266ec 100644 --- a/python/packages/core/agent_framework/_workflows/_edge.py +++ b/python/packages/core/agent_framework/_workflows/_edge.py @@ -9,6 +9,7 @@ from dataclasses import dataclass, field from typing import Any, ClassVar, TypeAlias, TypeVar +from .._agents import SupportsAgentRun from ._const import INTERNAL_SOURCE_ID from ._executor import Executor from ._model_utils import DictConvertible, encode_value @@ -264,7 +265,7 @@ def __init__(self) -> None: """ condition: Callable[[Any], bool] - target: Executor | str + target: Executor | SupportsAgentRun @dataclass @@ -287,7 +288,7 @@ def __init__(self) -> None: assert fallback.target.id == "dead_letter" """ - target: Executor | str + target: Executor | SupportsAgentRun @dataclass(init=False) diff --git a/python/packages/core/tests/workflow/test_workflow_builder.py b/python/packages/core/tests/workflow/test_workflow_builder.py index 073a24e5a3..f1abdd619a 100644 --- a/python/packages/core/tests/workflow/test_workflow_builder.py +++ b/python/packages/core/tests/workflow/test_workflow_builder.py @@ -11,6 +11,8 @@ AgentResponseUpdate, AgentSession, BaseAgent, + Case, + Default, Executor, Message, WorkflowBuilder, @@ -193,6 +195,29 @@ def condition_func(msg: MockMessage) -> bool: assert "Target" in workflow.executors +def test_switch_case_with_agents(): + """Test add_switch_case_edge_group with Case and Default edges using agents.""" + router = DummyAgent(id="router_agent", name="router") + handler = DummyAgent(id="handler", name="handler") + fallback = DummyAgent(id="fallback_agent", name="fallback") + + workflow = ( + WorkflowBuilder(start_executor=router) + .add_switch_case_edge_group( + router, + [ + Case(condition=lambda _: True, target=handler), + Default(target=fallback), + ], + ) + .build() + ) + + # All three agents should be AgentExecutor wrappers + agent_executors = [e for e in workflow.executors.values() if isinstance(e, AgentExecutor)] + assert len(agent_executors) == 3 + + # region with_output_from tests