Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/strands/multiagent/a2a/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,12 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater
self._current_artifact_id = str(uuid.uuid4())
self._is_first_chunk = True

# Pass the A2A RequestContext through invocation state so downstream
# tools and hooks can access request metadata, task info, configuration, etc.
invocation_state: dict[str, Any] = {"a2a_request_context": context}

try:
async for event in self.agent.stream_async(content_blocks):
async for event in self.agent.stream_async(content_blocks, invocation_state=invocation_state):
await self._handle_streaming_event(event, updater)
except Exception:
logger.exception("Error in streaming execution")
Expand Down
136 changes: 135 additions & 1 deletion tests/strands/multiagent/a2a/test_executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for the StrandsA2AExecutor class."""

import base64
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
Expand Down Expand Up @@ -1196,4 +1197,137 @@ async def test_a2a_compliant_handle_result_not_first_chunk(mock_strands_agent):
assert mock_updater.add_artifact.call_args[1]["artifact_id"] == "artifact-abc"
assert mock_updater.add_artifact.call_args[1]["append"] is True
assert mock_updater.add_artifact.call_args[1]["last_chunk"] is True
mock_updater.complete.assert_called_once()


# Tests for invocation state propagation from A2A request context


def _setup_streaming_context(
mock_strands_agent: MagicMock,
mock_request_context: MagicMock,
) -> None:
"""Set up common mocks for invocation state streaming tests.
Args:
mock_strands_agent: The mock Strands Agent.
mock_request_context: The mock RequestContext.
"""

async def mock_stream(content_blocks: list, **kwargs: Any) -> Any:
yield {"result": MagicMock(spec=SAAgentResult)}

mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream)

# Set up message with a text part
mock_text_part = MagicMock(spec=TextPart)
mock_text_part.text = "test input"
mock_part = MagicMock()
mock_part.root = mock_text_part
mock_message = MagicMock()
mock_message.parts = [mock_part]
mock_request_context.message = mock_message


@pytest.mark.asyncio
async def test_invocation_state_contains_request_context(mock_strands_agent, mock_request_context, mock_event_queue):
"""Test that the full RequestContext is passed as a2a_request_context in invocation state."""
mock_task = MagicMock()
mock_task.id = "task-42"
mock_task.context_id = "ctx-99"
mock_request_context.current_task = mock_task
mock_request_context.metadata = {"caller": "test-client"}

_setup_streaming_context(mock_strands_agent, mock_request_context)

executor = StrandsA2AExecutor(mock_strands_agent)
await executor.execute(mock_request_context, mock_event_queue)

mock_strands_agent.stream_async.assert_called_once()
call_kwargs = mock_strands_agent.stream_async.call_args[1]
invocation_state = call_kwargs["invocation_state"]

assert invocation_state is not None
assert invocation_state["a2a_request_context"] is mock_request_context


@pytest.mark.asyncio
async def test_invocation_state_context_exposes_metadata(mock_strands_agent, mock_request_context, mock_event_queue):
"""Test that metadata is accessible through the RequestContext in invocation state."""
test_metadata = {"caller": "test-client", "session": "abc-123"}
mock_request_context.metadata = test_metadata
mock_task = MagicMock()
mock_task.id = "task-1"
mock_task.context_id = "ctx-1"
mock_request_context.current_task = mock_task

_setup_streaming_context(mock_strands_agent, mock_request_context)

executor = StrandsA2AExecutor(mock_strands_agent)
await executor.execute(mock_request_context, mock_event_queue)

call_kwargs = mock_strands_agent.stream_async.call_args[1]
context = call_kwargs["invocation_state"]["a2a_request_context"]

assert context.metadata == test_metadata


@pytest.mark.asyncio
async def test_invocation_state_context_exposes_task_info(mock_strands_agent, mock_request_context, mock_event_queue):
"""Test that task info is accessible through the RequestContext in invocation state."""
mock_task = MagicMock()
mock_task.id = "task-100"
mock_task.context_id = "ctx-200"
mock_request_context.current_task = mock_task

_setup_streaming_context(mock_strands_agent, mock_request_context)

executor = StrandsA2AExecutor(mock_strands_agent)
await executor.execute(mock_request_context, mock_event_queue)

call_kwargs = mock_strands_agent.stream_async.call_args[1]
context = call_kwargs["invocation_state"]["a2a_request_context"]

assert context.current_task.id == "task-100"
assert context.current_task.context_id == "ctx-200"


@pytest.mark.asyncio
async def test_invocation_state_context_when_no_task(mock_strands_agent, mock_request_context, mock_event_queue):
"""Test that RequestContext is passed even when there is no current task."""
mock_request_context.current_task = None
mock_request_context.metadata = {}

_setup_streaming_context(mock_strands_agent, mock_request_context)

executor = StrandsA2AExecutor(mock_strands_agent)

with patch("strands.multiagent.a2a.executor.new_task") as mock_new_task:
mock_new_task.return_value = MagicMock(id="generated-id", context_id="generated-ctx")
await executor.execute(mock_request_context, mock_event_queue)

call_kwargs = mock_strands_agent.stream_async.call_args[1]
invocation_state = call_kwargs["invocation_state"]

assert invocation_state["a2a_request_context"] is mock_request_context


@pytest.mark.asyncio
async def test_invocation_state_with_a2a_compliant_streaming(
mock_strands_agent, mock_request_context, mock_event_queue
):
"""Test that invocation state is passed correctly in A2A-compliant streaming mode."""
mock_task = MagicMock()
mock_task.id = "task-compliant"
mock_task.context_id = "ctx-compliant"
mock_request_context.current_task = mock_task

_setup_streaming_context(mock_strands_agent, mock_request_context)

executor = StrandsA2AExecutor(mock_strands_agent, enable_a2a_compliant_streaming=True)
await executor.execute(mock_request_context, mock_event_queue)

call_kwargs = mock_strands_agent.stream_async.call_args[1]
invocation_state = call_kwargs["invocation_state"]

assert invocation_state is not None
assert invocation_state["a2a_request_context"] is mock_request_context
Loading