From 20cf08299b39a48953a775340e241700ab1aa838 Mon Sep 17 00:00:00 2001 From: Containerized Agent Date: Wed, 11 Mar 2026 00:02:32 +0000 Subject: [PATCH] feat: pass A2A request context metadata as invocation state --- src/strands/multiagent/a2a/executor.py | 6 +- tests/strands/multiagent/a2a/test_executor.py | 136 +++++++++++++++++- 2 files changed, 140 insertions(+), 2 deletions(-) diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index 2f8de99f7..c8c00600b 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -128,8 +128,12 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater self._current_artifact_id = str(uuid.uuid4()) self._is_first_chunk = True + # Pass the A2A RequestContext through invocation state so downstream + # tools and hooks can access request metadata, task info, configuration, etc. + invocation_state: dict[str, Any] = {"a2a_request_context": context} + try: - async for event in self.agent.stream_async(content_blocks): + async for event in self.agent.stream_async(content_blocks, invocation_state=invocation_state): await self._handle_streaming_event(event, updater) except Exception: logger.exception("Error in streaming execution") diff --git a/tests/strands/multiagent/a2a/test_executor.py b/tests/strands/multiagent/a2a/test_executor.py index 932f26247..dc90fbdd6 100644 --- a/tests/strands/multiagent/a2a/test_executor.py +++ b/tests/strands/multiagent/a2a/test_executor.py @@ -1,6 +1,7 @@ """Tests for the StrandsA2AExecutor class.""" import base64 +from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -1196,4 +1197,137 @@ async def test_a2a_compliant_handle_result_not_first_chunk(mock_strands_agent): assert mock_updater.add_artifact.call_args[1]["artifact_id"] == "artifact-abc" assert mock_updater.add_artifact.call_args[1]["append"] is True assert mock_updater.add_artifact.call_args[1]["last_chunk"] is True - mock_updater.complete.assert_called_once() + + +# Tests for invocation state propagation from A2A request context + + +def _setup_streaming_context( + mock_strands_agent: MagicMock, + mock_request_context: MagicMock, +) -> None: + """Set up common mocks for invocation state streaming tests. + + Args: + mock_strands_agent: The mock Strands Agent. + mock_request_context: The mock RequestContext. + """ + + async def mock_stream(content_blocks: list, **kwargs: Any) -> Any: + yield {"result": MagicMock(spec=SAAgentResult)} + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + # Set up message with a text part + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test input" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + +@pytest.mark.asyncio +async def test_invocation_state_contains_request_context(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that the full RequestContext is passed as a2a_request_context in invocation state.""" + mock_task = MagicMock() + mock_task.id = "task-42" + mock_task.context_id = "ctx-99" + mock_request_context.current_task = mock_task + mock_request_context.metadata = {"caller": "test-client"} + + _setup_streaming_context(mock_strands_agent, mock_request_context) + + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + mock_strands_agent.stream_async.assert_called_once() + call_kwargs = mock_strands_agent.stream_async.call_args[1] + invocation_state = call_kwargs["invocation_state"] + + assert invocation_state is not None + assert invocation_state["a2a_request_context"] is mock_request_context + + +@pytest.mark.asyncio +async def test_invocation_state_context_exposes_metadata(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that metadata is accessible through the RequestContext in invocation state.""" + test_metadata = {"caller": "test-client", "session": "abc-123"} + mock_request_context.metadata = test_metadata + mock_task = MagicMock() + mock_task.id = "task-1" + mock_task.context_id = "ctx-1" + mock_request_context.current_task = mock_task + + _setup_streaming_context(mock_strands_agent, mock_request_context) + + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + call_kwargs = mock_strands_agent.stream_async.call_args[1] + context = call_kwargs["invocation_state"]["a2a_request_context"] + + assert context.metadata == test_metadata + + +@pytest.mark.asyncio +async def test_invocation_state_context_exposes_task_info(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that task info is accessible through the RequestContext in invocation state.""" + mock_task = MagicMock() + mock_task.id = "task-100" + mock_task.context_id = "ctx-200" + mock_request_context.current_task = mock_task + + _setup_streaming_context(mock_strands_agent, mock_request_context) + + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + call_kwargs = mock_strands_agent.stream_async.call_args[1] + context = call_kwargs["invocation_state"]["a2a_request_context"] + + assert context.current_task.id == "task-100" + assert context.current_task.context_id == "ctx-200" + + +@pytest.mark.asyncio +async def test_invocation_state_context_when_no_task(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that RequestContext is passed even when there is no current task.""" + mock_request_context.current_task = None + mock_request_context.metadata = {} + + _setup_streaming_context(mock_strands_agent, mock_request_context) + + executor = StrandsA2AExecutor(mock_strands_agent) + + with patch("strands.multiagent.a2a.executor.new_task") as mock_new_task: + mock_new_task.return_value = MagicMock(id="generated-id", context_id="generated-ctx") + await executor.execute(mock_request_context, mock_event_queue) + + call_kwargs = mock_strands_agent.stream_async.call_args[1] + invocation_state = call_kwargs["invocation_state"] + + assert invocation_state["a2a_request_context"] is mock_request_context + + +@pytest.mark.asyncio +async def test_invocation_state_with_a2a_compliant_streaming( + mock_strands_agent, mock_request_context, mock_event_queue +): + """Test that invocation state is passed correctly in A2A-compliant streaming mode.""" + mock_task = MagicMock() + mock_task.id = "task-compliant" + mock_task.context_id = "ctx-compliant" + mock_request_context.current_task = mock_task + + _setup_streaming_context(mock_strands_agent, mock_request_context) + + executor = StrandsA2AExecutor(mock_strands_agent, enable_a2a_compliant_streaming=True) + await executor.execute(mock_request_context, mock_event_queue) + + call_kwargs = mock_strands_agent.stream_async.call_args[1] + invocation_state = call_kwargs["invocation_state"] + + assert invocation_state is not None + assert invocation_state["a2a_request_context"] is mock_request_context