From a202256485d98e629ab66f9ba90a40a0536132fd Mon Sep 17 00:00:00 2001 From: minorun365 Date: Tue, 20 Jan 2026 23:00:34 +0900 Subject: [PATCH] feat(a2a): Add A2AClient with synchronous and streaming API --- src/strands/multiagent/a2a/__init__.py | 14 +- src/strands/multiagent/a2a/client.py | 346 ++++++++++++++++++++ tests/strands/multiagent/a2a/test_client.py | 223 +++++++++++++ 3 files changed, 581 insertions(+), 2 deletions(-) create mode 100644 src/strands/multiagent/a2a/client.py create mode 100644 tests/strands/multiagent/a2a/test_client.py diff --git a/src/strands/multiagent/a2a/__init__.py b/src/strands/multiagent/a2a/__init__.py index 75f8b1b19..7dc70b9ef 100644 --- a/src/strands/multiagent/a2a/__init__.py +++ b/src/strands/multiagent/a2a/__init__.py @@ -6,10 +6,20 @@ Docs: https://google-a2a.github.io/A2A/latest/ Classes: - A2AAgent: A wrapper that adapts a Strands Agent to be A2A-compatible. + A2AServer: A wrapper that adapts a Strands Agent to be an A2A server. + A2AClient: A client for communicating with remote A2A agents. + StrandsA2AExecutor: The executor that handles A2A requests for Strands Agents. """ +from .client import A2AClient, A2AError, build_agentcore_url, extract_region_from_arn from .executor import StrandsA2AExecutor from .server import A2AServer -__all__ = ["A2AServer", "StrandsA2AExecutor"] +__all__ = [ + "A2AServer", + "A2AClient", + "A2AError", + "StrandsA2AExecutor", + "build_agentcore_url", + "extract_region_from_arn", +] diff --git a/src/strands/multiagent/a2a/client.py b/src/strands/multiagent/a2a/client.py new file mode 100644 index 000000000..543e3dbd7 --- /dev/null +++ b/src/strands/multiagent/a2a/client.py @@ -0,0 +1,346 @@ +"""A2A client for communicating with remote A2A agents.""" + +import json +import logging +import uuid +from typing import Iterator, Optional +from urllib.parse import quote + +import httpx + +logger = logging.getLogger(__name__) + + +def build_agentcore_url(agent_arn: str) -> str: + """Build the invocation URL from an AgentCore Runtime ARN. + + Args: + agent_arn: The ARN of the AgentCore Runtime agent. + + Returns: + The full invocation URL for the agent. + + Raises: + ValueError: If the ARN format is invalid. + """ + if not agent_arn.startswith("arn:aws:bedrock-agentcore:"): + raise ValueError( + f"Invalid AgentCore ARN format. Expected 'arn:aws:bedrock-agentcore:...' but got '{agent_arn}'" + ) + + parts = agent_arn.split(":") + if len(parts) < 6: + raise ValueError(f"Invalid ARN format: {agent_arn}") + + region = parts[3] + encoded_arn = quote(agent_arn, safe="") + + return f"https://bedrock-agentcore.{region}.amazonaws.com/runtimes/{encoded_arn}/invocations" + + +def extract_region_from_arn(agent_arn: str) -> str: + """Extract the AWS region from an AgentCore Runtime ARN. + + Args: + agent_arn: The ARN of the AgentCore Runtime agent. + + Returns: + The AWS region (e.g., "us-east-1"). + + Raises: + ValueError: If the ARN format is invalid. + """ + parts = agent_arn.split(":") + if len(parts) < 4: + raise ValueError(f"Invalid ARN format: {agent_arn}") + return parts[3] + + +class A2AError(Exception): + """Exception raised for A2A protocol errors.""" + + def __init__(self, code: int, message: str): + """Initialize an A2AError.""" + self.code = code + self.message = message + super().__init__(f"A2A Error {code}: {message}") + + +class A2AClient: + """Client for communicating with remote A2A agents. + + This client implements the A2A protocol for sending tasks to remote agents. + It supports synchronous APIs with optional streaming responses. + """ + + def __init__( + self, + url: str, + auth: Optional[httpx.Auth] = None, + timeout: float = 300.0, + headers: Optional[dict[str, str]] = None, + ): + """Initialize an A2A client. + + Args: + url: The base URL of the A2A agent. + auth: Optional authentication object (e.g., SigV4 auth). + timeout: Request timeout in seconds. Defaults to 300. + headers: Optional additional HTTP headers. + """ + self._url = url.rstrip("/") + self._auth = auth + self._timeout = timeout + self._headers = headers or {} + self._agent_card: Optional[dict] = None + + @classmethod + def from_agentcore_arn( + cls, + agent_arn: str, + region: Optional[str] = None, + timeout: float = 300.0, + ) -> "A2AClient": + """Create a client from an AgentCore Runtime ARN with IAM authentication. + + Args: + agent_arn: The ARN of the AgentCore Runtime agent. + region: AWS region for authentication. If None, extracted from ARN. + timeout: Request timeout in seconds. Defaults to 300. + + Returns: + An A2AClient configured for the specified AgentCore agent. + + Raises: + ImportError: If mcp-proxy-for-aws is not installed. + """ + try: + from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session + except ImportError as e: + raise ImportError( + "mcp-proxy-for-aws is required for IAM authentication. " + "Please install it with: pip install mcp-proxy-for-aws" + ) from e + + url = build_agentcore_url(agent_arn) + + if region is None: + region = extract_region_from_arn(agent_arn) + + session = create_aws_session() + credentials = session.get_credentials() + auth = SigV4HTTPXAuth(credentials, "bedrock-agentcore", region) + + logger.debug("Created A2AClient with SigV4 auth for region=%s", region) + + return cls(url=url, auth=auth, timeout=timeout) + + @classmethod + def from_url( + cls, + url: str, + timeout: float = 300.0, + ) -> "A2AClient": + """Create a client from a URL without authentication. + + Args: + url: The URL of the A2A agent. + timeout: Request timeout in seconds. Defaults to 300. + + Returns: + An A2AClient configured for the specified URL. + """ + return cls(url=url, auth=None, timeout=timeout) + + @property + def url(self) -> str: + """Get the base URL of the A2A agent.""" + return self._url + + def get_agent_card(self, force_refresh: bool = False) -> dict: + """Get the agent card (metadata). + + Args: + force_refresh: If True, bypass the cache and fetch fresh data. + + Returns: + The agent card as a dictionary. + """ + if self._agent_card is not None and not force_refresh: + return self._agent_card + + with httpx.Client(auth=self._auth, timeout=self._timeout) as http_client: + response = http_client.get( + f"{self._url}/.well-known/agent.json", + headers=self._headers, + ) + response.raise_for_status() + self._agent_card = response.json() + return self._agent_card + + def send_task( + self, + message: str, + session_id: Optional[str] = None, + ) -> str: + """Send a task and wait for the result. + + Args: + message: The message to send to the agent. + session_id: Optional session ID for conversation continuity. + + Returns: + The agent's response as a string. + + Raises: + A2AError: If the agent returns an error response. + """ + if session_id is None: + session_id = str(uuid.uuid4()) + + task_id = str(uuid.uuid4()) + request_body = self._build_task_request(task_id, session_id, message) + + with httpx.Client(auth=self._auth, timeout=self._timeout) as http_client: + response = http_client.post( + self._url, + json=request_body, + headers={ + "Content-Type": "application/json", + **self._headers, + }, + ) + response.raise_for_status() + result = response.json() + + return self._extract_text_from_response(result) + + def send_task_streaming( + self, + message: str, + session_id: Optional[str] = None, + ) -> Iterator[str]: + """Send a task and stream the results. + + Args: + message: The message to send to the agent. + session_id: Optional session ID for conversation continuity. + + Yields: + Text chunks as they are received from the agent. + """ + if session_id is None: + session_id = str(uuid.uuid4()) + + task_id = str(uuid.uuid4()) + request_body = self._build_task_subscribe_request(task_id, session_id, message) + + with httpx.Client(auth=self._auth, timeout=self._timeout) as http_client: + with http_client.stream( + "POST", + self._url, + json=request_body, + headers={ + "Content-Type": "application/json", + "Accept": "text/event-stream", + **self._headers, + }, + ) as response: + response.raise_for_status() + for line in response.iter_lines(): + text = self._parse_sse_line(line) + if text: + yield text + + def _build_task_request(self, task_id: str, session_id: str, message: str) -> dict: + """Build a JSON-RPC request for tasks/send.""" + return { + "jsonrpc": "2.0", + "method": "tasks/send", + "id": task_id, + "params": { + "id": task_id, + "sessionId": session_id, + "message": { + "role": "user", + "parts": [{"kind": "text", "text": message}], + }, + }, + } + + def _build_task_subscribe_request(self, task_id: str, session_id: str, message: str) -> dict: + """Build a JSON-RPC request for tasks/sendSubscribe (streaming).""" + return { + "jsonrpc": "2.0", + "method": "tasks/sendSubscribe", + "id": task_id, + "params": { + "id": task_id, + "sessionId": session_id, + "message": { + "role": "user", + "parts": [{"kind": "text", "text": message}], + }, + }, + } + + def _extract_text_from_response(self, response: dict) -> str: + """Extract text content from a JSON-RPC response.""" + if "error" in response: + error = response["error"] + raise A2AError( + code=error.get("code", -1), + message=error.get("message", "Unknown error"), + ) + + result = response.get("result", {}) + artifacts = result.get("artifacts", []) + texts = [] + + for artifact in artifacts: + parts = artifact.get("parts", []) + for part in parts: + if part.get("kind") == "text": + texts.append(part.get("text", "")) + + return "".join(texts) + + def _parse_sse_line(self, line: str) -> Optional[str]: + """Parse a Server-Sent Events line and extract text content.""" + if not line.startswith("data: "): + return None + + data = line[6:] # Remove "data: " prefix + if not data.strip(): + return None + + try: + event = json.loads(data) + return self._extract_text_from_event(event) + except json.JSONDecodeError: + logger.debug("Failed to parse SSE data as JSON: %s", data) + return None + + def _extract_text_from_event(self, event: dict) -> Optional[str]: + """Extract text content from a streaming event.""" + result = event.get("result", {}) + status = result.get("status", {}) + message = status.get("message", {}) + parts = message.get("parts", []) + + texts = [] + for part in parts: + if part.get("kind") == "text": + texts.append(part.get("text", "")) + + if texts: + return "".join(texts) + + # Check for final artifacts + artifacts = result.get("artifacts", []) + for artifact in artifacts: + for part in artifact.get("parts", []): + if part.get("kind") == "text": + texts.append(part.get("text", "")) + + return "".join(texts) if texts else None diff --git a/tests/strands/multiagent/a2a/test_client.py b/tests/strands/multiagent/a2a/test_client.py new file mode 100644 index 000000000..754ec93f6 --- /dev/null +++ b/tests/strands/multiagent/a2a/test_client.py @@ -0,0 +1,223 @@ +"""Tests for the A2A client module.""" + +from unittest.mock import MagicMock, patch + +import httpx +import pytest +from strands.multiagent.a2a.client import ( + A2AClient, + A2AError, + build_agentcore_url, + extract_region_from_arn, +) + + +class TestBuildAgentcoreUrl: + """Tests for build_agentcore_url function.""" + + def test_valid_arn(self): + """Test URL building with a valid ARN.""" + arn = "arn:aws:bedrock-agentcore:us-east-1:123456789012:runtime/my-agent" + url = build_agentcore_url(arn) + + assert url.startswith("https://bedrock-agentcore.us-east-1.amazonaws.com/runtimes/") + assert url.endswith("/invocations") + assert "arn%3Aaws%3Abedrock-agentcore" in url + + def test_invalid_arn_prefix(self): + """Test that invalid ARN prefix raises ValueError.""" + with pytest.raises(ValueError, match="Invalid AgentCore ARN format"): + build_agentcore_url("arn:aws:lambda:us-east-1:123456789012:function/my-function") + + def test_invalid_arn_format(self): + """Test that malformed ARN raises ValueError.""" + with pytest.raises(ValueError, match="Invalid"): + build_agentcore_url("arn:aws:bedrock-agentcore") + + +class TestExtractRegionFromArn: + """Tests for extract_region_from_arn function.""" + + def test_valid_arn(self): + """Test region extraction from a valid ARN.""" + arn = "arn:aws:bedrock-agentcore:us-east-1:123456789012:runtime/my-agent" + assert extract_region_from_arn(arn) == "us-east-1" + + def test_invalid_arn_format(self): + """Test that malformed ARN raises ValueError.""" + with pytest.raises(ValueError, match="Invalid ARN format"): + extract_region_from_arn("invalid-arn") + + +class TestA2AClientInitialization: + """Tests for A2AClient initialization.""" + + def test_init_with_url(self): + """Test basic initialization with URL.""" + client = A2AClient(url="http://localhost:9000") + + assert client.url == "http://localhost:9000" + assert client._auth is None + assert client._timeout == 300.0 + + def test_init_strips_trailing_slash(self): + """Test that trailing slashes are stripped from URL.""" + client = A2AClient(url="http://localhost:9000/") + assert client.url == "http://localhost:9000" + + def test_init_with_auth_and_headers(self): + """Test initialization with authentication and headers.""" + mock_auth = MagicMock(spec=httpx.Auth) + headers = {"X-Custom-Header": "value"} + client = A2AClient(url="http://localhost:9000", auth=mock_auth, timeout=600.0, headers=headers) + + assert client._auth == mock_auth + assert client._timeout == 600.0 + assert client._headers == headers + + +class TestA2AClientFromUrl: + """Tests for A2AClient.from_url factory method.""" + + def test_from_url_basic(self): + """Test creating client from URL.""" + client = A2AClient.from_url("http://localhost:9000") + + assert client.url == "http://localhost:9000" + assert client._auth is None + + +class TestA2AClientRequestBuilding: + """Tests for A2A request building methods.""" + + def test_build_task_request(self): + """Test building a task/send request.""" + client = A2AClient(url="http://localhost:9000") + request = client._build_task_request("task-123", "session-456", "Hello") + + assert request["jsonrpc"] == "2.0" + assert request["method"] == "tasks/send" + assert request["id"] == "task-123" + assert request["params"]["id"] == "task-123" + assert request["params"]["sessionId"] == "session-456" + assert request["params"]["message"]["role"] == "user" + assert request["params"]["message"]["parts"][0]["text"] == "Hello" + + def test_build_task_subscribe_request(self): + """Test building a task/sendSubscribe request.""" + client = A2AClient(url="http://localhost:9000") + request = client._build_task_subscribe_request("task-123", "session-456", "Hello") + + assert request["method"] == "tasks/sendSubscribe" + + +class TestA2AClientResponseParsing: + """Tests for A2A response parsing methods.""" + + def test_extract_text_from_response_success(self): + """Test extracting text from a successful response.""" + client = A2AClient(url="http://localhost:9000") + response = { + "jsonrpc": "2.0", + "id": "task-123", + "result": {"artifacts": [{"parts": [{"kind": "text", "text": "Hello World"}]}]}, + } + assert client._extract_text_from_response(response) == "Hello World" + + def test_extract_text_from_response_error(self): + """Test that error response raises A2AError.""" + client = A2AClient(url="http://localhost:9000") + response = { + "jsonrpc": "2.0", + "id": "task-123", + "error": {"code": -32600, "message": "Invalid Request"}, + } + + with pytest.raises(A2AError) as exc_info: + client._extract_text_from_response(response) + + assert exc_info.value.code == -32600 + + def test_extract_text_from_response_empty(self): + """Test extracting text from response with no artifacts.""" + client = A2AClient(url="http://localhost:9000") + response = {"jsonrpc": "2.0", "id": "task-123", "result": {}} + assert client._extract_text_from_response(response) == "" + + def test_parse_sse_line_valid(self): + """Test parsing valid SSE line.""" + client = A2AClient(url="http://localhost:9000") + line = 'data: {"result": {"status": {"message": {"parts": [{"kind": "text", "text": "Hello"}]}}}}' + assert client._parse_sse_line(line) == "Hello" + + def test_parse_sse_line_non_data(self): + """Test parsing non-data SSE line returns None.""" + client = A2AClient(url="http://localhost:9000") + assert client._parse_sse_line("event: message") is None + + def test_parse_sse_line_empty_data(self): + """Test parsing empty data SSE line returns None.""" + client = A2AClient(url="http://localhost:9000") + assert client._parse_sse_line("data: ") is None + + +class TestA2AClientMethods: + """Tests for A2AClient methods.""" + + def test_get_agent_card(self): + """Test getting agent card.""" + client = A2AClient(url="http://localhost:9000") + expected_card = {"name": "Test Agent"} + + mock_response = MagicMock() + mock_response.json.return_value = expected_card + mock_response.raise_for_status = MagicMock() + + with patch("httpx.Client") as mock_client_class: + mock_http_client = MagicMock() + mock_http_client.get.return_value = mock_response + mock_http_client.__enter__ = MagicMock(return_value=mock_http_client) + mock_http_client.__exit__ = MagicMock(return_value=None) + mock_client_class.return_value = mock_http_client + + assert client.get_agent_card() == expected_card + + def test_get_agent_card_cached(self): + """Test that agent card is cached.""" + client = A2AClient(url="http://localhost:9000") + client._agent_card = {"name": "Cached Agent"} + assert client.get_agent_card() == {"name": "Cached Agent"} + + def test_send_task(self): + """Test sending a task.""" + client = A2AClient(url="http://localhost:9000") + response_data = { + "jsonrpc": "2.0", + "id": "task-123", + "result": {"artifacts": [{"parts": [{"kind": "text", "text": "8"}]}]}, + } + + mock_response = MagicMock() + mock_response.json.return_value = response_data + mock_response.raise_for_status = MagicMock() + + with patch("httpx.Client") as mock_client_class: + mock_http_client = MagicMock() + mock_http_client.post.return_value = mock_response + mock_http_client.__enter__ = MagicMock(return_value=mock_http_client) + mock_http_client.__exit__ = MagicMock(return_value=None) + mock_client_class.return_value = mock_http_client + + assert client.send_task("Calculate 3 + 5") == "8" + + +class TestA2AError: + """Tests for A2AError exception.""" + + def test_a2a_error(self): + """Test A2AError has correct attributes and string representation.""" + error = A2AError(code=-32600, message="Invalid Request") + + assert error.code == -32600 + assert error.message == "Invalid Request" + assert str(error) == "A2A Error -32600: Invalid Request"