Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ async def structured_output_async(self, output_model: type[T], prompt: AgentInpu
category=DeprecationWarning,
stacklevel=2,
)
await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self))
await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self, invocation_state={}))
with self.tracer.tracer.start_as_current_span(
"execute_structured_output", kind=trace_api.SpanKind.CLIENT
) as structured_output_span:
Expand Down Expand Up @@ -515,7 +515,7 @@ async def structured_output_async(self, output_model: type[T], prompt: AgentInpu
return event["output"]

finally:
await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self))
await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self, invocation_state={}))

def cleanup(self) -> None:
"""Clean up resources used by the agent.
Expand Down Expand Up @@ -657,7 +657,7 @@ async def _run_loop(
Events from the event loop cycle.
"""
before_invocation_event, _interrupts = await self.hooks.invoke_callbacks_async(
BeforeInvocationEvent(agent=self, messages=messages)
BeforeInvocationEvent(agent=self, invocation_state=invocation_state, messages=messages)
)
messages = before_invocation_event.messages if before_invocation_event.messages is not None else messages

Expand Down Expand Up @@ -695,7 +695,9 @@ async def _run_loop(

finally:
self.conversation_manager.apply_management(self)
await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self, result=agent_result))
await self.hooks.invoke_callbacks_async(
AfterInvocationEvent(agent=self, invocation_state=invocation_state, result=agent_result)
)

async def _execute_event_loop_cycle(
self, invocation_state: dict[str, Any], structured_output_context: StructuredOutputContext | None = None
Expand Down
3 changes: 3 additions & 0 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ async def _handle_model_execution(
await agent.hooks.invoke_callbacks_async(
BeforeModelCallEvent(
agent=agent,
invocation_state=invocation_state,
)
)

Expand All @@ -343,6 +344,7 @@ async def _handle_model_execution(

after_model_call_event = AfterModelCallEvent(
agent=agent,
invocation_state=invocation_state,
stop_response=AfterModelCallEvent.ModelStopResponse(
stop_reason=stop_reason,
message=message,
Expand Down Expand Up @@ -370,6 +372,7 @@ async def _handle_model_execution(
# Exception is automatically recorded by use_span with end_on_exit=True
after_model_call_event = AfterModelCallEvent(
agent=agent,
invocation_state=invocation_state,
exception=e,
)
await agent.hooks.invoke_callbacks_async(after_model_call_event)
Expand Down
21 changes: 19 additions & 2 deletions src/strands/hooks/events.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Hook events emitted as part of invoking Agents.

This module defines the events that are emitted as Agents run through the lifecycle of a request.
"""

import uuid
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

from typing_extensions import override
Expand Down Expand Up @@ -48,10 +48,14 @@
- Agent.structured_output

Attributes:
invocation_state: State and configuration passed through the agent invocation.
This can include shared context for multi-agent coordination, request tracking,
and dynamic configuration.
messages: The input messages for this invocation. Can be modified by hooks
to redact or transform content before processing.
"""

invocation_state: dict[str, Any] = field(default_factory=dict)
messages: Messages | None = None

def _can_write(self, name: str) -> bool:
Expand All @@ -75,11 +79,15 @@
- Agent.structured_output

Attributes:
invocation_state: State and configuration passed through the agent invocation.
This can include shared context for multi-agent coordination, request tracking,
and dynamic configuration.
result: The result of the agent invocation, if available.
This will be None when invoked from structured_output methods, as those return typed output directly rather
than AgentResult.
"""

invocation_state: dict[str, Any] = field(default_factory=dict)
result: "AgentResult | None" = None

@property
Expand Down Expand Up @@ -208,9 +216,14 @@
that will be sent to the model.

Note: This event is not fired for invocations to structured_output.

Attributes:
invocation_state: State and configuration passed through the agent invocation.
This can include shared context for multi-agent coordination, request tracking,
and dynamic configuration.
"""

pass
invocation_state: dict[str, Any] = field(default_factory=dict)


@dataclass
Expand Down Expand Up @@ -239,6 +252,9 @@
conversation history

Attributes:
invocation_state: State and configuration passed through the agent invocation.
This can include shared context for multi-agent coordination, request tracking,
and dynamic configuration.
stop_response: The model response data if invocation was successful, None if failed.
exception: Exception if the model invocation failed, None if successful.
retry: Whether to retry the model invocation. Can be set by hook callbacks
Expand All @@ -258,6 +274,7 @@
message: Message
stop_reason: StopReason

invocation_state: dict[str, Any] = field(default_factory=dict)
stop_response: ModelStopResponse | None = None
exception: Exception | None = None
retry: bool = False
Expand Down
37 changes: 37 additions & 0 deletions tests/strands/agent/hooks/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from strands.agent.agent_result import AgentResult
from strands.hooks import (
AfterInvocationEvent,
AfterModelCallEvent,
AfterToolCallEvent,
AgentInitializedEvent,
BeforeInvocationEvent,
BeforeModelCallEvent,
BeforeToolCallEvent,
MessageAddedEvent,
)
Expand Down Expand Up @@ -170,6 +172,41 @@ def test_after_invocation_event_properties_not_writable(agent):
with pytest.raises(AttributeError, match="Property agent is not writable"):
event.agent = Mock()

with pytest.raises(AttributeError, match="Property invocation_state is not writable"):
event.invocation_state = {}


def test_invocation_state_is_available_in_invocation_events(agent):
"""Test that invocation_state is accessible in BeforeInvocationEvent and AfterInvocationEvent."""
invocation_state = {"session_id": "test-123", "request_id": "req-456"}

before_event = BeforeInvocationEvent(agent=agent, invocation_state=invocation_state)
assert before_event.invocation_state == invocation_state
assert before_event.invocation_state["session_id"] == "test-123"
assert before_event.invocation_state["request_id"] == "req-456"

after_event = AfterInvocationEvent(agent=agent, invocation_state=invocation_state, result=None)
assert after_event.invocation_state == invocation_state
assert after_event.invocation_state["session_id"] == "test-123"
assert after_event.invocation_state["request_id"] == "req-456"


def test_invocation_state_is_available_in_model_call_events(agent):
"""Test that invocation_state is accessible in BeforeModelCallEvent and AfterModelCallEvent."""
invocation_state = {"session_id": "test-123", "request_id": "req-456"}

before_event = BeforeModelCallEvent(agent=agent, invocation_state=invocation_state)
assert before_event.invocation_state == invocation_state
assert before_event.invocation_state["session_id"] == "test-123"
assert before_event.invocation_state["request_id"] == "req-456"

after_event = AfterModelCallEvent(agent=agent, invocation_state=invocation_state)
assert after_event.invocation_state == invocation_state
assert after_event.invocation_state["session_id"] == "test-123"
assert after_event.invocation_state["request_id"] == "req-456"




def test_before_invocation_event_messages_default_none(agent):
"""Test that BeforeInvocationEvent.messages defaults to None for backward compatibility."""
Expand Down
31 changes: 18 additions & 13 deletions tests/strands/agent/test_agent_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,15 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u

assert length == 12

assert next(events) == BeforeInvocationEvent(agent=agent, messages=agent.messages[0:1])
assert next(events) == BeforeInvocationEvent(agent=agent, invocation_state=ANY, messages=agent.messages[0:1])
assert next(events) == MessageAddedEvent(
agent=agent,
message=agent.messages[0],
)
assert next(events) == BeforeModelCallEvent(agent=agent)
assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY)
assert next(events) == AfterModelCallEvent(
agent=agent,
invocation_state=ANY,
stop_response=AfterModelCallEvent.ModelStopResponse(
message={
"content": [{"toolUse": tool_use}],
Expand All @@ -193,9 +194,10 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u
result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"},
)
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2])
assert next(events) == BeforeModelCallEvent(agent=agent)
assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY)
assert next(events) == AfterModelCallEvent(
agent=agent,
invocation_state=ANY,
stop_response=AfterModelCallEvent.ModelStopResponse(
message=mock_model.agent_responses[1],
stop_reason="end_turn",
Expand All @@ -204,7 +206,7 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u
)
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3])

assert next(events) == AfterInvocationEvent(agent=agent, result=result)
assert next(events) == AfterInvocationEvent(agent=agent, invocation_state=ANY, result=result)

assert len(agent.messages) == 4

Expand All @@ -215,8 +217,9 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m
iterator = agent.stream_async("test message")
await anext(iterator)

# Verify first event is BeforeInvocationEvent with messages
# Verify first event is BeforeInvocationEvent with invocation_state and messages
assert len(hook_provider.events_received) == 1
assert hook_provider.events_received[0].invocation_state is not None
assert hook_provider.events_received[0].messages is not None
assert hook_provider.events_received[0].messages[0]["role"] == "user"

Expand All @@ -230,14 +233,15 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m

assert length == 12

assert next(events) == BeforeInvocationEvent(agent=agent, messages=agent.messages[0:1])
assert next(events) == BeforeInvocationEvent(agent=agent, invocation_state=ANY, messages=agent.messages[0:1])
assert next(events) == MessageAddedEvent(
agent=agent,
message=agent.messages[0],
)
assert next(events) == BeforeModelCallEvent(agent=agent)
assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY)
assert next(events) == AfterModelCallEvent(
agent=agent,
invocation_state=ANY,
stop_response=AfterModelCallEvent.ModelStopResponse(
message={
"content": [{"toolUse": tool_use}],
Expand All @@ -263,9 +267,10 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m
result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"},
)
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2])
assert next(events) == BeforeModelCallEvent(agent=agent)
assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY)
assert next(events) == AfterModelCallEvent(
agent=agent,
invocation_state=ANY,
stop_response=AfterModelCallEvent.ModelStopResponse(
message=mock_model.agent_responses[1],
stop_reason="end_turn",
Expand All @@ -274,7 +279,7 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m
)
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3])

assert next(events) == AfterInvocationEvent(agent=agent, result=result)
assert next(events) == AfterInvocationEvent(agent=agent, invocation_state=ANY, result=result)

assert len(agent.messages) == 4

Expand All @@ -289,8 +294,8 @@ def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator):

assert length == 2

assert next(events) == BeforeInvocationEvent(agent=agent)
assert next(events) == AfterInvocationEvent(agent=agent)
assert next(events) == BeforeInvocationEvent(agent=agent, invocation_state=ANY)
assert next(events) == AfterInvocationEvent(agent=agent, invocation_state=ANY)

assert len(agent.messages) == 0 # no new messages added

Expand All @@ -306,8 +311,8 @@ async def test_agent_structured_async_output_hooks(agent, hook_provider, user, a

assert length == 2

assert next(events) == BeforeInvocationEvent(agent=agent)
assert next(events) == AfterInvocationEvent(agent=agent)
assert next(events) == BeforeInvocationEvent(agent=agent, invocation_state=ANY)
assert next(events) == AfterInvocationEvent(agent=agent, invocation_state=ANY)

assert len(agent.messages) == 0 # no new messages added

Expand Down
2 changes: 1 addition & 1 deletion tests/strands/agent/test_conversation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def test_per_turn_dynamic_change():

mock_agent = MagicMock()
mock_agent.messages = []
event = BeforeModelCallEvent(agent=mock_agent)
event = BeforeModelCallEvent(agent=mock_agent, invocation_state={})

# Initially disabled
with patch.object(manager, "apply_management") as mock_apply:
Expand Down
15 changes: 8 additions & 7 deletions tests/strands/event_loop/test_event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,27 +855,28 @@ async def test_event_loop_cycle_exception_model_hooks(mock_sleep, agent, model,
assert count == 9

# 1st call - throttled
assert next(events) == BeforeModelCallEvent(agent=agent)
expected_after = AfterModelCallEvent(agent=agent, stop_response=None, exception=exception)
assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY)
expected_after = AfterModelCallEvent(agent=agent, invocation_state=ANY, stop_response=None, exception=exception)
expected_after.retry = True
assert next(events) == expected_after

# 2nd call - throttled
assert next(events) == BeforeModelCallEvent(agent=agent)
expected_after = AfterModelCallEvent(agent=agent, stop_response=None, exception=exception)
assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY)
expected_after = AfterModelCallEvent(agent=agent, invocation_state=ANY, stop_response=None, exception=exception)
expected_after.retry = True
assert next(events) == expected_after

# 3rd call - throttled
assert next(events) == BeforeModelCallEvent(agent=agent)
expected_after = AfterModelCallEvent(agent=agent, stop_response=None, exception=exception)
assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY)
expected_after = AfterModelCallEvent(agent=agent, invocation_state=ANY, stop_response=None, exception=exception)
expected_after.retry = True
assert next(events) == expected_after

# 4th call - successful
assert next(events) == BeforeModelCallEvent(agent=agent)
assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY)
assert next(events) == AfterModelCallEvent(
agent=agent,
invocation_state=ANY,
stop_response=AfterModelCallEvent.ModelStopResponse(
message={"content": [{"text": "test text"}], "role": "assistant"}, stop_reason="end_turn"
),
Expand Down
4 changes: 2 additions & 2 deletions tests/strands/experimental/hooks/test_hook_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ def test_after_tool_call_event_type_equality():

def test_before_model_call_event_type_equality():
"""Verify that BeforeModelInvocationEvent alias has the same type identity."""
before_model_event = BeforeModelCallEvent(agent=Mock())
before_model_event = BeforeModelCallEvent(agent=Mock(), invocation_state={})

assert isinstance(before_model_event, BeforeModelInvocationEvent)
assert isinstance(before_model_event, BeforeModelCallEvent)


def test_after_model_call_event_type_equality():
"""Verify that AfterModelInvocationEvent alias has the same type identity."""
after_model_event = AfterModelCallEvent(agent=Mock())
after_model_event = AfterModelCallEvent(agent=Mock(), invocation_state={})

assert isinstance(after_model_event, AfterModelInvocationEvent)
assert isinstance(after_model_event, AfterModelCallEvent)
Expand Down
Loading