diff --git a/src/strands/agent/a2a_agent.py b/src/strands/agent/a2a_agent.py index e18da2f4a..da8d07f0d 100644 --- a/src/strands/agent/a2a_agent.py +++ b/src/strands/agent/a2a_agent.py @@ -7,12 +7,12 @@ """ import logging +import warnings from collections.abc import AsyncIterator -from contextlib import asynccontextmanager from typing import Any import httpx -from a2a.client import A2ACardResolver, ClientConfig, ClientFactory +from a2a.client import A2ACardResolver, Client, ClientConfig, ClientFactory from a2a.types import AgentCard, Message, TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent from .._async import run_async @@ -38,6 +38,7 @@ def __init__( name: str | None = None, description: str | None = None, timeout: int = _DEFAULT_TIMEOUT, + client_config: ClientConfig | None = None, a2a_client_factory: ClientFactory | None = None, ): """Initialize A2A agent. @@ -47,17 +48,33 @@ def __init__( name: Agent name. If not provided, will be populated from agent card. description: Agent description. If not provided, will be populated from agent card. timeout: Timeout for HTTP operations in seconds (defaults to 300). - a2a_client_factory: Optional pre-configured A2A ClientFactory. If provided, - it will be used to create the A2A client after discovering the agent card. - Note: When providing a custom factory, you are responsible for managing - the lifecycle of any httpx client it uses. + client_config: Optional A2A ClientConfig for authentication and transport settings. + When provided, the config (including any authenticated httpx client) is used + for both agent card discovery and A2A message sending. This is the recommended + way to configure authentication (e.g. SigV4, OAuth bearer tokens). + a2a_client_factory: Deprecated. Use ``client_config`` instead. Optional pre-configured + A2A ClientFactory. When provided, ``factory.create()`` is used for client creation, + preserving any configured interceptors, consumers, and custom transports. For card + resolution, ``client_config`` is preferred if provided; otherwise the factory's + internal config is used as a fallback. The caller is responsible for managing the + lifecycle of any httpx client configured in the factory. """ + if a2a_client_factory is not None: + warnings.warn( + "a2a_client_factory is deprecated, use client_config instead. " + "a2a_client_factory will be removed in a future version.", + DeprecationWarning, + stacklevel=2, + ) + self.endpoint = endpoint self.name = name self.description = description self.timeout = timeout - self._agent_card: AgentCard | None = None + self._client_config = client_config self._a2a_client_factory: ClientFactory | None = a2a_client_factory + self._agent_card: AgentCard | None = None + self._a2a_client: Client | None = None def __call__( self, @@ -164,15 +181,23 @@ async def get_agent_card(self) -> AgentCard: populating name and description if not already set. The card is cached after the first fetch. + When ``client_config`` is provided, its httpx client is used for card resolution + so that any pre-configured authentication (SigV4, OAuth bearer tokens, etc.) is applied. + Returns: The remote agent's AgentCard containing name, description, capabilities, skills, etc. """ if self._agent_card is not None: return self._agent_card - async with httpx.AsyncClient(timeout=self.timeout) as client: - resolver = A2ACardResolver(httpx_client=client, base_url=self.endpoint) + config = self._resolve_client_config() + if config is not None and config.httpx_client is not None: + resolver = A2ACardResolver(httpx_client=config.httpx_client, base_url=self.endpoint) self._agent_card = await resolver.get_agent_card() + else: + async with httpx.AsyncClient(timeout=self.timeout) as client: + resolver = A2ACardResolver(httpx_client=client, base_url=self.endpoint) + self._agent_card = await resolver.get_agent_card() # Populate name from card if not set if self.name is None and self._agent_card.name: @@ -185,25 +210,66 @@ async def get_agent_card(self) -> AgentCard: logger.debug("agent=<%s>, endpoint=<%s> | discovered agent card", self.name, self.endpoint) return self._agent_card - @asynccontextmanager - async def _get_a2a_client(self) -> AsyncIterator[Any]: - """Get A2A client for sending messages. + def _resolve_client_config(self) -> ClientConfig | None: + """Resolve the effective client config for card resolution and client creation. - If a custom factory was provided, uses that (caller manages httpx lifecycle). - Otherwise creates a per-call httpx client with proper cleanup. + Precedence: + 1. Explicit ``client_config`` parameter (always preferred) + 2. Factory's internal config (fallback for deprecated factory path) + 3. None (use defaults) - Yields: - Configured A2A client instance. + Returns: + Resolved ClientConfig, or None if no config is available. """ - agent_card = await self.get_agent_card() + if self._client_config is not None: + return self._client_config + + if self._a2a_client_factory is not None: + config = getattr(self._a2a_client_factory, "_config", None) + if config is None: + logger.warning( + "endpoint=<%s> | could not access factory client config, " + "falling back to default config for card resolution", + self.endpoint, + ) + return config + + return None + async def _get_or_create_client(self) -> Client: + """Get or create an A2A client for communicating with the remote agent. + + When a deprecated factory is provided, ``factory.create()`` is used for client creation + to preserve interceptors, consumers, and custom transports. The resulting client is cached. + + When ``client_config`` is provided without a factory, ``ClientFactory.connect()`` is used + with the config for both card resolution and client creation. The client is cached. + + When neither is provided, a transient client is created per call to avoid long-lived + httpx connections which can cause connection breakdown and deadlocks on Windows. + + Returns: + Configured A2A Client instance. + """ + # Deprecated factory path: use factory.create() to preserve interceptors/consumers/transports if self._a2a_client_factory is not None: - yield self._a2a_client_factory.create(agent_card) - return + if self._a2a_client is not None: + return self._a2a_client + + agent_card = await self.get_agent_card() + self._a2a_client = self._a2a_client_factory.create(agent_card) + return self._a2a_client + + # client_config path: cache the client + if self._client_config is not None: + if self._a2a_client is not None: + return self._a2a_client + + self._a2a_client = await ClientFactory.connect(self.endpoint, client_config=self._client_config) + return self._a2a_client - async with httpx.AsyncClient(timeout=self.timeout) as httpx_client: - config = ClientConfig(httpx_client=httpx_client, streaming=True) - yield ClientFactory(config).create(agent_card) + # No factory or config: create transient client per call + return await ClientFactory.connect(self.endpoint) async def _send_message(self, prompt: AgentInput) -> AsyncIterator[A2AResponse]: """Send message to A2A agent. @@ -223,9 +289,9 @@ async def _send_message(self, prompt: AgentInput) -> AsyncIterator[A2AResponse]: message = convert_input_to_message(prompt) logger.debug("agent=<%s>, endpoint=<%s> | sending message", self.name, self.endpoint) - async with self._get_a2a_client() as client: - async for event in client.send_message(message): - yield event + client = await self._get_or_create_client() + async for event in client.send_message(message): + yield event def _is_complete_event(self, event: A2AResponse) -> bool: """Check if an A2A event represents a complete response. diff --git a/tests/strands/agent/test_a2a_agent.py b/tests/strands/agent/test_a2a_agent.py index 26a34476d..544fd10c5 100644 --- a/tests/strands/agent/test_a2a_agent.py +++ b/tests/strands/agent/test_a2a_agent.py @@ -1,10 +1,11 @@ """Tests for A2AAgent class.""" -from contextlib import asynccontextmanager +import warnings from unittest.mock import AsyncMock, MagicMock, patch from uuid import uuid4 import pytest +from a2a.client import ClientConfig from a2a.types import AgentCard, Message, Part, Role, TextPart from strands.agent.a2a_agent import A2AAgent @@ -41,21 +42,7 @@ def mock_httpx_client(): return mock_client -@asynccontextmanager -async def mock_a2a_client_context(send_message_func): - """Helper to create mock A2A client setup for _send_message tests.""" - mock_client = MagicMock() - mock_client.send_message = send_message_func - with patch("strands.agent.a2a_agent.httpx.AsyncClient") as mock_httpx_class: - mock_httpx = AsyncMock() - mock_httpx.__aenter__.return_value = mock_httpx - mock_httpx.__aexit__.return_value = None - mock_httpx_class.return_value = mock_httpx - with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: - mock_factory = MagicMock() - mock_factory.create.return_value = mock_client - mock_factory_class.return_value = mock_factory - yield mock_httpx_class, mock_factory_class +# --- Initialization tests --- def test_init_with_defaults(): @@ -64,6 +51,9 @@ def test_init_with_defaults(): assert agent.endpoint == "http://localhost:8000" assert agent.timeout == 300 assert agent._agent_card is None + assert agent._a2a_client is None + assert agent._client_config is None + assert agent._a2a_client_factory is None assert agent.name is None assert agent.description is None @@ -81,16 +71,32 @@ def test_init_with_custom_timeout(): assert agent.timeout == 600 -def test_init_with_external_a2a_client_factory(): - """Test initialization with external A2A client factory.""" - external_factory = MagicMock() - agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=external_factory) - assert agent._a2a_client_factory is external_factory +def test_init_with_client_config(): + """Test initialization with client_config.""" + config = ClientConfig() + agent = A2AAgent(endpoint="http://localhost:8000", client_config=config) + assert agent._client_config is config + assert agent._a2a_client_factory is None + + +def test_init_with_factory_emits_deprecation_warning(): + """Test that passing a2a_client_factory emits a DeprecationWarning.""" + factory = MagicMock() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=factory) + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "a2a_client_factory is deprecated" in str(w[0].message) + assert agent._a2a_client_factory is factory + + +# --- Card resolution tests --- @pytest.mark.asyncio -async def test_get_agent_card(a2a_agent, mock_agent_card, mock_httpx_client): - """Test agent card discovery.""" +async def test_get_agent_card_no_config(a2a_agent, mock_agent_card, mock_httpx_client): + """Test agent card discovery without config uses transient httpx client.""" with patch("strands.agent.a2a_agent.httpx.AsyncClient", return_value=mock_httpx_client): with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: mock_resolver = AsyncMock() @@ -103,6 +109,107 @@ async def test_get_agent_card(a2a_agent, mock_agent_card, mock_httpx_client): assert a2a_agent._agent_card == mock_agent_card +@pytest.mark.asyncio +async def test_get_agent_card_with_client_config(): + """Test agent card discovery with client_config uses its httpx client.""" + mock_httpx = MagicMock() + config = ClientConfig(httpx_client=mock_httpx) + agent = A2AAgent(endpoint="http://localhost:8000", client_config=config) + + mock_card = MagicMock(spec=AgentCard) + mock_card.name = "test" + mock_card.description = "desc" + + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_card) + mock_resolver_class.return_value = mock_resolver + + card = await agent.get_agent_card() + + # Should use the config's httpx client, not create a new one + mock_resolver_class.assert_called_once_with(httpx_client=mock_httpx, base_url="http://localhost:8000") + assert card == mock_card + + +@pytest.mark.asyncio +async def test_get_agent_card_with_factory_uses_factory_config(mock_agent_card): + """Test agent card discovery with deprecated factory extracts its config for auth.""" + mock_httpx = MagicMock() + mock_config = MagicMock(spec=ClientConfig) + mock_config.httpx_client = mock_httpx + + mock_factory = MagicMock() + mock_factory._config = mock_config + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=mock_factory) + + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_agent_card) + mock_resolver_class.return_value = mock_resolver + + card = await agent.get_agent_card() + + # Should use the factory's httpx client for card resolution + mock_resolver_class.assert_called_once_with(httpx_client=mock_httpx, base_url="http://localhost:8000") + assert card == mock_agent_card + + +@pytest.mark.asyncio +async def test_get_agent_card_client_config_takes_precedence_over_factory(mock_agent_card): + """Test that client_config is preferred over factory config for card resolution.""" + explicit_httpx = MagicMock() + explicit_config = ClientConfig(httpx_client=explicit_httpx) + + factory_httpx = MagicMock() + factory_config = MagicMock(spec=ClientConfig) + factory_config.httpx_client = factory_httpx + mock_factory = MagicMock() + mock_factory._config = factory_config + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + agent = A2AAgent( + endpoint="http://localhost:8000", + client_config=explicit_config, + a2a_client_factory=mock_factory, + ) + + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_agent_card) + mock_resolver_class.return_value = mock_resolver + + await agent.get_agent_card() + + # Should use explicit client_config's httpx, not factory's + mock_resolver_class.assert_called_once_with(httpx_client=explicit_httpx, base_url="http://localhost:8000") + + +@pytest.mark.asyncio +async def test_get_agent_card_factory_without_config_attr(mock_agent_card, mock_httpx_client): + """Test fallback when factory has no _config attribute.""" + mock_factory = MagicMock(spec=[]) # No _config attribute + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=mock_factory) + + with patch("strands.agent.a2a_agent.httpx.AsyncClient", return_value=mock_httpx_client): + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_agent_card) + mock_resolver_class.return_value = mock_resolver + + card = await agent.get_agent_card() + + # Should fall back to transient httpx client + assert card == mock_agent_card + + @pytest.mark.asyncio async def test_get_agent_card_cached(a2a_agent, mock_agent_card): """Test that agent card is cached after first discovery.""" @@ -147,6 +254,102 @@ async def test_get_agent_card_preserves_custom_name_and_description(mock_agent_c assert agent.description == "Custom description" +# --- Client creation tests --- + + +@pytest.mark.asyncio +async def test_get_or_create_client_with_client_config(mock_agent_card): + """Test _get_or_create_client with client_config uses ClientFactory.connect() and caches.""" + config = ClientConfig() + agent = A2AAgent(endpoint="http://localhost:8000", client_config=config) + + mock_client = AsyncMock() + + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_factory_class.connect = AsyncMock(return_value=mock_client) + + client1 = await agent._get_or_create_client() + client2 = await agent._get_or_create_client() + + # Should connect once and cache + mock_factory_class.connect.assert_called_once_with("http://localhost:8000", client_config=config) + assert client1 is client2 + assert client1 is mock_client + + +@pytest.mark.asyncio +async def test_get_or_create_client_with_factory_uses_factory_create(mock_agent_card): + """Test _get_or_create_client with deprecated factory uses factory.create().""" + mock_factory = MagicMock() + mock_factory._config = MagicMock(spec=ClientConfig) + mock_factory._config.httpx_client = None + mock_created_client = MagicMock() + mock_factory.create.return_value = mock_created_client + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=mock_factory) + + # Pre-set the agent card to avoid card resolution complexity + agent._agent_card = mock_agent_card + + client = await agent._get_or_create_client() + + # Should use factory.create() with the agent card + mock_factory.create.assert_called_once_with(mock_agent_card) + assert client is mock_created_client + + +@pytest.mark.asyncio +async def test_get_or_create_client_factory_caches(): + """Test _get_or_create_client caches the client when factory is provided.""" + mock_factory = MagicMock() + mock_factory._config = MagicMock(spec=ClientConfig) + mock_factory._config.httpx_client = None + mock_created_client = MagicMock() + mock_factory.create.return_value = mock_created_client + + mock_card = MagicMock(spec=AgentCard) + mock_card.name = "test" + mock_card.description = "desc" + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=mock_factory) + + agent._agent_card = mock_card + + client1 = await agent._get_or_create_client() + client2 = await agent._get_or_create_client() + + # factory.create() should only be called once + mock_factory.create.assert_called_once() + assert client1 is client2 + + +@pytest.mark.asyncio +async def test_get_or_create_client_transient_without_config(): + """Test _get_or_create_client creates transient clients when no config or factory.""" + agent = A2AAgent(endpoint="http://localhost:8000") + + mock_client_1 = AsyncMock() + mock_client_2 = AsyncMock() + + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_factory_class.connect = AsyncMock(side_effect=[mock_client_1, mock_client_2]) + + client1 = await agent._get_or_create_client() + client2 = await agent._get_or_create_client() + + # Should connect each time (transient) + assert mock_factory_class.connect.call_count == 2 + assert client1 is mock_client_1 + assert client2 is mock_client_2 + + +# --- Invocation tests --- + + @pytest.mark.asyncio async def test_invoke_async_success(a2a_agent, mock_agent_card): """Test successful async invocation.""" @@ -156,15 +359,18 @@ async def test_invoke_async_success(a2a_agent, mock_agent_card): parts=[Part(TextPart(kind="text", text="Response"))], ) + mock_client = AsyncMock() + async def mock_send_message(*args, **kwargs): yield mock_response - with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): - async with mock_a2a_client_context(mock_send_message): - result = await a2a_agent.invoke_async("Hello") + mock_client.send_message = mock_send_message + + with patch.object(a2a_agent, "_get_or_create_client", return_value=mock_client): + result = await a2a_agent.invoke_async("Hello") - assert isinstance(result, AgentResult) - assert result.message["content"][0]["text"] == "Response" + assert isinstance(result, AgentResult) + assert result.message["content"][0]["text"] == "Response" @pytest.mark.asyncio @@ -177,15 +383,17 @@ async def test_invoke_async_no_prompt(a2a_agent): @pytest.mark.asyncio async def test_invoke_async_no_response(a2a_agent, mock_agent_card): """Test that invoke_async raises RuntimeError when no response received.""" + mock_client = AsyncMock() async def mock_send_message(*args, **kwargs): return yield # Make it an async generator - with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): - async with mock_a2a_client_context(mock_send_message): - with pytest.raises(RuntimeError, match="No response received"): - await a2a_agent.invoke_async("Hello") + mock_client.send_message = mock_send_message + + with patch.object(a2a_agent, "_get_or_create_client", return_value=mock_client): + with pytest.raises(RuntimeError, match="No response received"): + await a2a_agent.invoke_async("Hello") def test_call_sync(a2a_agent): @@ -206,6 +414,9 @@ def test_call_sync(a2a_agent): mock_run_async.assert_called_once() +# --- Streaming tests --- + + @pytest.mark.asyncio async def test_stream_async_success(a2a_agent, mock_agent_card): """Test successful async streaming.""" @@ -215,23 +426,24 @@ async def test_stream_async_success(a2a_agent, mock_agent_card): parts=[Part(TextPart(kind="text", text="Response"))], ) + mock_client = AsyncMock() + async def mock_send_message(*args, **kwargs): yield mock_response - with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): - async with mock_a2a_client_context(mock_send_message): - events = [] - async for event in a2a_agent.stream_async("Hello"): - events.append(event) + mock_client.send_message = mock_send_message + + with patch.object(a2a_agent, "_get_or_create_client", return_value=mock_client): + events = [] + async for event in a2a_agent.stream_async("Hello"): + events.append(event) - assert len(events) == 2 - # First event is A2A stream event - assert events[0]["type"] == "a2a_stream" - assert events[0]["event"] == mock_response - # Final event is AgentResult - assert "result" in events[1] - assert isinstance(events[1]["result"], AgentResult) - assert events[1]["result"].message["content"][0]["text"] == "Response" + assert len(events) == 2 + assert events[0]["type"] == "a2a_stream" + assert events[0]["event"] == mock_response + assert "result" in events[1] + assert isinstance(events[1]["result"], AgentResult) + assert events[1]["result"].message["content"][0]["text"] == "Response" @pytest.mark.asyncio @@ -242,48 +454,7 @@ async def test_stream_async_no_prompt(a2a_agent): pass -@pytest.mark.asyncio -async def test_send_message_uses_provided_factory(mock_agent_card): - """Test _send_message uses provided factory instead of creating per-call client.""" - external_factory = MagicMock() - mock_a2a_client = MagicMock() - - async def mock_send_message(*args, **kwargs): - yield MagicMock() - - mock_a2a_client.send_message = mock_send_message - external_factory.create.return_value = mock_a2a_client - - agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=external_factory) - - with patch.object(agent, "get_agent_card", return_value=mock_agent_card): - # Consume the async iterator - async for _ in agent._send_message("Hello"): - pass - - external_factory.create.assert_called_once_with(mock_agent_card) - - -@pytest.mark.asyncio -async def test_send_message_creates_per_call_client(a2a_agent, mock_agent_card): - """Test _send_message creates a fresh httpx client for each call when no factory provided.""" - mock_response = Message( - message_id=uuid4().hex, - role=Role.agent, - parts=[Part(TextPart(kind="text", text="Response"))], - ) - - async def mock_send_message(*args, **kwargs): - yield mock_response - - with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): - async with mock_a2a_client_context(mock_send_message) as (mock_httpx_class, _): - # Consume the async iterator - async for _ in a2a_agent._send_message("Hello"): - pass - - # Verify httpx client was created with timeout - mock_httpx_class.assert_called_once_with(timeout=300) +# --- _is_complete_event tests --- def test_is_complete_event_message(a2a_agent): @@ -351,6 +522,9 @@ def test_is_complete_event_unknown_type(a2a_agent): assert a2a_agent._is_complete_event("unknown") is False +# --- Complete event tracking tests --- + + @pytest.mark.asyncio async def test_stream_async_tracks_complete_events(a2a_agent, mock_agent_card): """Test stream_async uses last complete event for final result.""" @@ -359,32 +533,32 @@ async def test_stream_async_tracks_complete_events(a2a_agent, mock_agent_card): mock_task = MagicMock() mock_task.artifacts = None - # First event: incomplete incomplete_event = MagicMock(spec=TaskStatusUpdateEvent) incomplete_event.status = MagicMock() incomplete_event.status.state = TaskState.working incomplete_event.status.message = None - # Second event: complete complete_event = MagicMock(spec=TaskStatusUpdateEvent) complete_event.status = MagicMock() complete_event.status.state = TaskState.completed complete_event.status.message = MagicMock() complete_event.status.message.parts = [] + mock_client = AsyncMock() + async def mock_send_message(*args, **kwargs): yield (mock_task, incomplete_event) yield (mock_task, complete_event) - with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): - async with mock_a2a_client_context(mock_send_message): - events = [] - async for event in a2a_agent.stream_async("Hello"): - events.append(event) + mock_client.send_message = mock_send_message + + with patch.object(a2a_agent, "_get_or_create_client", return_value=mock_client): + events = [] + async for event in a2a_agent.stream_async("Hello"): + events.append(event) - # Should have 2 stream events + 1 result event - assert len(events) == 3 - assert "result" in events[2] + assert len(events) == 3 + assert "result" in events[2] @pytest.mark.asyncio @@ -400,15 +574,17 @@ async def test_stream_async_falls_back_to_last_event(a2a_agent, mock_agent_card) incomplete_event.status.state = TaskState.working incomplete_event.status.message = None + mock_client = AsyncMock() + async def mock_send_message(*args, **kwargs): yield (mock_task, incomplete_event) - with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): - async with mock_a2a_client_context(mock_send_message): - events = [] - async for event in a2a_agent.stream_async("Hello"): - events.append(event) + mock_client.send_message = mock_send_message + + with patch.object(a2a_agent, "_get_or_create_client", return_value=mock_client): + events = [] + async for event in a2a_agent.stream_async("Hello"): + events.append(event) - # Should have 1 stream event + 1 result event (falls back to last) - assert len(events) == 2 - assert "result" in events[1] + assert len(events) == 2 + assert "result" in events[1]