diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index e2ac3aa71..05c3af191 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -474,7 +474,7 @@ async def structured_output_async(self, output_model: type[T], prompt: AgentInpu category=DeprecationWarning, stacklevel=2, ) - await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self)) + await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self, invocation_state={})) with self.tracer.tracer.start_as_current_span( "execute_structured_output", kind=trace_api.SpanKind.CLIENT ) as structured_output_span: @@ -515,7 +515,7 @@ async def structured_output_async(self, output_model: type[T], prompt: AgentInpu return event["output"] finally: - await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self)) + await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self, invocation_state={})) def cleanup(self) -> None: """Clean up resources used by the agent. @@ -657,7 +657,7 @@ async def _run_loop( Events from the event loop cycle. """ before_invocation_event, _interrupts = await self.hooks.invoke_callbacks_async( - BeforeInvocationEvent(agent=self, messages=messages) + BeforeInvocationEvent(agent=self, invocation_state=invocation_state, messages=messages) ) messages = before_invocation_event.messages if before_invocation_event.messages is not None else messages @@ -695,7 +695,9 @@ async def _run_loop( finally: self.conversation_manager.apply_management(self) - await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self, result=agent_result)) + await self.hooks.invoke_callbacks_async( + AfterInvocationEvent(agent=self, invocation_state=invocation_state, result=agent_result) + ) async def _execute_event_loop_cycle( self, invocation_state: dict[str, Any], structured_output_context: StructuredOutputContext | None = None diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 41122efc5..9fe645f80 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -318,6 +318,7 @@ async def _handle_model_execution( await agent.hooks.invoke_callbacks_async( BeforeModelCallEvent( agent=agent, + invocation_state=invocation_state, ) ) @@ -343,6 +344,7 @@ async def _handle_model_execution( after_model_call_event = AfterModelCallEvent( agent=agent, + invocation_state=invocation_state, stop_response=AfterModelCallEvent.ModelStopResponse( stop_reason=stop_reason, message=message, @@ -370,6 +372,7 @@ async def _handle_model_execution( # Exception is automatically recorded by use_span with end_on_exit=True after_model_call_event = AfterModelCallEvent( agent=agent, + invocation_state=invocation_state, exception=e, ) await agent.hooks.invoke_callbacks_async(after_model_call_event) diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index ad40dfd7f..8d3e5d280 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -4,7 +4,7 @@ """ import uuid -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any from typing_extensions import override @@ -48,10 +48,14 @@ class BeforeInvocationEvent(HookEvent): - Agent.structured_output Attributes: + invocation_state: State and configuration passed through the agent invocation. + This can include shared context for multi-agent coordination, request tracking, + and dynamic configuration. messages: The input messages for this invocation. Can be modified by hooks to redact or transform content before processing. """ + invocation_state: dict[str, Any] = field(default_factory=dict) messages: Messages | None = None def _can_write(self, name: str) -> bool: @@ -75,11 +79,15 @@ class AfterInvocationEvent(HookEvent): - Agent.structured_output Attributes: + invocation_state: State and configuration passed through the agent invocation. + This can include shared context for multi-agent coordination, request tracking, + and dynamic configuration. result: The result of the agent invocation, if available. This will be None when invoked from structured_output methods, as those return typed output directly rather than AgentResult. """ + invocation_state: dict[str, Any] = field(default_factory=dict) result: "AgentResult | None" = None @property @@ -208,9 +216,14 @@ class BeforeModelCallEvent(HookEvent): that will be sent to the model. Note: This event is not fired for invocations to structured_output. + + Attributes: + invocation_state: State and configuration passed through the agent invocation. + This can include shared context for multi-agent coordination, request tracking, + and dynamic configuration. """ - pass + invocation_state: dict[str, Any] = field(default_factory=dict) @dataclass @@ -239,6 +252,9 @@ class AfterModelCallEvent(HookEvent): conversation history Attributes: + invocation_state: State and configuration passed through the agent invocation. + This can include shared context for multi-agent coordination, request tracking, + and dynamic configuration. stop_response: The model response data if invocation was successful, None if failed. exception: Exception if the model invocation failed, None if successful. retry: Whether to retry the model invocation. Can be set by hook callbacks @@ -258,6 +274,7 @@ class ModelStopResponse: message: Message stop_reason: StopReason + invocation_state: dict[str, Any] = field(default_factory=dict) stop_response: ModelStopResponse | None = None exception: Exception | None = None retry: bool = False diff --git a/tests/strands/agent/hooks/test_events.py b/tests/strands/agent/hooks/test_events.py index 83cb1af24..762b77452 100644 --- a/tests/strands/agent/hooks/test_events.py +++ b/tests/strands/agent/hooks/test_events.py @@ -5,9 +5,11 @@ from strands.agent.agent_result import AgentResult from strands.hooks import ( AfterInvocationEvent, + AfterModelCallEvent, AfterToolCallEvent, AgentInitializedEvent, BeforeInvocationEvent, + BeforeModelCallEvent, BeforeToolCallEvent, MessageAddedEvent, ) @@ -170,6 +172,41 @@ def test_after_invocation_event_properties_not_writable(agent): with pytest.raises(AttributeError, match="Property agent is not writable"): event.agent = Mock() + with pytest.raises(AttributeError, match="Property invocation_state is not writable"): + event.invocation_state = {} + + +def test_invocation_state_is_available_in_invocation_events(agent): + """Test that invocation_state is accessible in BeforeInvocationEvent and AfterInvocationEvent.""" + invocation_state = {"session_id": "test-123", "request_id": "req-456"} + + before_event = BeforeInvocationEvent(agent=agent, invocation_state=invocation_state) + assert before_event.invocation_state == invocation_state + assert before_event.invocation_state["session_id"] == "test-123" + assert before_event.invocation_state["request_id"] == "req-456" + + after_event = AfterInvocationEvent(agent=agent, invocation_state=invocation_state, result=None) + assert after_event.invocation_state == invocation_state + assert after_event.invocation_state["session_id"] == "test-123" + assert after_event.invocation_state["request_id"] == "req-456" + + +def test_invocation_state_is_available_in_model_call_events(agent): + """Test that invocation_state is accessible in BeforeModelCallEvent and AfterModelCallEvent.""" + invocation_state = {"session_id": "test-123", "request_id": "req-456"} + + before_event = BeforeModelCallEvent(agent=agent, invocation_state=invocation_state) + assert before_event.invocation_state == invocation_state + assert before_event.invocation_state["session_id"] == "test-123" + assert before_event.invocation_state["request_id"] == "req-456" + + after_event = AfterModelCallEvent(agent=agent, invocation_state=invocation_state) + assert after_event.invocation_state == invocation_state + assert after_event.invocation_state["session_id"] == "test-123" + assert after_event.invocation_state["request_id"] == "req-456" + + + def test_before_invocation_event_messages_default_none(agent): """Test that BeforeInvocationEvent.messages defaults to None for backward compatibility.""" diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index e8b7e5077..8ff81295a 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -160,14 +160,15 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u assert length == 12 - assert next(events) == BeforeInvocationEvent(agent=agent, messages=agent.messages[0:1]) + assert next(events) == BeforeInvocationEvent(agent=agent, invocation_state=ANY, messages=agent.messages[0:1]) assert next(events) == MessageAddedEvent( agent=agent, message=agent.messages[0], ) - assert next(events) == BeforeModelCallEvent(agent=agent) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY) assert next(events) == AfterModelCallEvent( agent=agent, + invocation_state=ANY, stop_response=AfterModelCallEvent.ModelStopResponse( message={ "content": [{"toolUse": tool_use}], @@ -193,9 +194,10 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"}, ) assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2]) - assert next(events) == BeforeModelCallEvent(agent=agent) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY) assert next(events) == AfterModelCallEvent( agent=agent, + invocation_state=ANY, stop_response=AfterModelCallEvent.ModelStopResponse( message=mock_model.agent_responses[1], stop_reason="end_turn", @@ -204,7 +206,7 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u ) assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3]) - assert next(events) == AfterInvocationEvent(agent=agent, result=result) + assert next(events) == AfterInvocationEvent(agent=agent, invocation_state=ANY, result=result) assert len(agent.messages) == 4 @@ -215,8 +217,9 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m iterator = agent.stream_async("test message") await anext(iterator) - # Verify first event is BeforeInvocationEvent with messages + # Verify first event is BeforeInvocationEvent with invocation_state and messages assert len(hook_provider.events_received) == 1 + assert hook_provider.events_received[0].invocation_state is not None assert hook_provider.events_received[0].messages is not None assert hook_provider.events_received[0].messages[0]["role"] == "user" @@ -230,14 +233,15 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m assert length == 12 - assert next(events) == BeforeInvocationEvent(agent=agent, messages=agent.messages[0:1]) + assert next(events) == BeforeInvocationEvent(agent=agent, invocation_state=ANY, messages=agent.messages[0:1]) assert next(events) == MessageAddedEvent( agent=agent, message=agent.messages[0], ) - assert next(events) == BeforeModelCallEvent(agent=agent) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY) assert next(events) == AfterModelCallEvent( agent=agent, + invocation_state=ANY, stop_response=AfterModelCallEvent.ModelStopResponse( message={ "content": [{"toolUse": tool_use}], @@ -263,9 +267,10 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"}, ) assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2]) - assert next(events) == BeforeModelCallEvent(agent=agent) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY) assert next(events) == AfterModelCallEvent( agent=agent, + invocation_state=ANY, stop_response=AfterModelCallEvent.ModelStopResponse( message=mock_model.agent_responses[1], stop_reason="end_turn", @@ -274,7 +279,7 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m ) assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3]) - assert next(events) == AfterInvocationEvent(agent=agent, result=result) + assert next(events) == AfterInvocationEvent(agent=agent, invocation_state=ANY, result=result) assert len(agent.messages) == 4 @@ -289,8 +294,8 @@ def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator): assert length == 2 - assert next(events) == BeforeInvocationEvent(agent=agent) - assert next(events) == AfterInvocationEvent(agent=agent) + assert next(events) == BeforeInvocationEvent(agent=agent, invocation_state=ANY) + assert next(events) == AfterInvocationEvent(agent=agent, invocation_state=ANY) assert len(agent.messages) == 0 # no new messages added @@ -306,8 +311,8 @@ async def test_agent_structured_async_output_hooks(agent, hook_provider, user, a assert length == 2 - assert next(events) == BeforeInvocationEvent(agent=agent) - assert next(events) == AfterInvocationEvent(agent=agent) + assert next(events) == BeforeInvocationEvent(agent=agent, invocation_state=ANY) + assert next(events) == AfterInvocationEvent(agent=agent, invocation_state=ANY) assert len(agent.messages) == 0 # no new messages added diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index ae18a9131..46876d8e5 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -362,7 +362,7 @@ def test_per_turn_dynamic_change(): mock_agent = MagicMock() mock_agent.messages = [] - event = BeforeModelCallEvent(agent=mock_agent) + event = BeforeModelCallEvent(agent=mock_agent, invocation_state={}) # Initially disabled with patch.object(manager, "apply_management") as mock_apply: diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index a76a5b6b5..8c6155e20 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -855,27 +855,28 @@ async def test_event_loop_cycle_exception_model_hooks(mock_sleep, agent, model, assert count == 9 # 1st call - throttled - assert next(events) == BeforeModelCallEvent(agent=agent) - expected_after = AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY) + expected_after = AfterModelCallEvent(agent=agent, invocation_state=ANY, stop_response=None, exception=exception) expected_after.retry = True assert next(events) == expected_after # 2nd call - throttled - assert next(events) == BeforeModelCallEvent(agent=agent) - expected_after = AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY) + expected_after = AfterModelCallEvent(agent=agent, invocation_state=ANY, stop_response=None, exception=exception) expected_after.retry = True assert next(events) == expected_after # 3rd call - throttled - assert next(events) == BeforeModelCallEvent(agent=agent) - expected_after = AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY) + expected_after = AfterModelCallEvent(agent=agent, invocation_state=ANY, stop_response=None, exception=exception) expected_after.retry = True assert next(events) == expected_after # 4th call - successful - assert next(events) == BeforeModelCallEvent(agent=agent) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY) assert next(events) == AfterModelCallEvent( agent=agent, + invocation_state=ANY, stop_response=AfterModelCallEvent.ModelStopResponse( message={"content": [{"text": "test text"}], "role": "assistant"}, stop_reason="end_turn" ), diff --git a/tests/strands/experimental/hooks/test_hook_aliases.py b/tests/strands/experimental/hooks/test_hook_aliases.py index 2da8a6f90..b229c1c2d 100644 --- a/tests/strands/experimental/hooks/test_hook_aliases.py +++ b/tests/strands/experimental/hooks/test_hook_aliases.py @@ -68,7 +68,7 @@ def test_after_tool_call_event_type_equality(): def test_before_model_call_event_type_equality(): """Verify that BeforeModelInvocationEvent alias has the same type identity.""" - before_model_event = BeforeModelCallEvent(agent=Mock()) + before_model_event = BeforeModelCallEvent(agent=Mock(), invocation_state={}) assert isinstance(before_model_event, BeforeModelInvocationEvent) assert isinstance(before_model_event, BeforeModelCallEvent) @@ -76,7 +76,7 @@ def test_before_model_call_event_type_equality(): def test_after_model_call_event_type_equality(): """Verify that AfterModelInvocationEvent alias has the same type identity.""" - after_model_event = AfterModelCallEvent(agent=Mock()) + after_model_event = AfterModelCallEvent(agent=Mock(), invocation_state={}) assert isinstance(after_model_event, AfterModelInvocationEvent) assert isinstance(after_model_event, AfterModelCallEvent)