diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 7b9e9c914..3ac678e09 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -33,8 +33,10 @@ from ..experimental.tools import ToolProvider from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from ..hooks import ( + AfterContextReductionEvent, AfterInvocationEvent, AgentInitializedEvent, + BeforeContextReductionEvent, BeforeInvocationEvent, HookProvider, HookRegistry, @@ -710,8 +712,35 @@ async def _execute_event_loop_cycle( yield event except ContextWindowOverflowException as e: + # Emit before context reduction event + original_message_count = len(self.messages) + await self.hooks.invoke_callbacks_async( + BeforeContextReductionEvent( + agent=self, + exception=e, + message_count=original_message_count, + ) + ) + # Try reducing the context size and retrying - self.conversation_manager.reduce_context(self, e=e) + reduction_exception: Exception | None = None + try: + self.conversation_manager.reduce_context(self, e=e) + except Exception as reduction_error: + reduction_exception = reduction_error + raise + finally: + # Emit after context reduction event + new_message_count = len(self.messages) + await self.hooks.invoke_callbacks_async( + AfterContextReductionEvent( + agent=self, + original_message_count=original_message_count, + new_message_count=new_message_count, + removed_count=original_message_count - new_message_count, + exception=reduction_exception, + ) + ) # Sync agent after reduce_context to keep conversation_manager_state up to date in the session if self._session_manager: diff --git a/src/strands/hooks/__init__.py b/src/strands/hooks/__init__.py index 30163f207..24121f203 100644 --- a/src/strands/hooks/__init__.py +++ b/src/strands/hooks/__init__.py @@ -30,10 +30,12 @@ def log_end(self, event: AfterInvocationEvent) -> None: """ from .events import ( + AfterContextReductionEvent, AfterInvocationEvent, AfterModelCallEvent, AfterToolCallEvent, AgentInitializedEvent, + BeforeContextReductionEvent, BeforeInvocationEvent, BeforeModelCallEvent, BeforeToolCallEvent, @@ -48,6 +50,8 @@ def log_end(self, event: AfterInvocationEvent) -> None: "AfterToolCallEvent", "BeforeModelCallEvent", "AfterModelCallEvent", + "BeforeContextReductionEvent", + "AfterContextReductionEvent", "AfterInvocationEvent", "MessageAddedEvent", "HookEvent", diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 8aa8a68d6..2fdf51eac 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -250,3 +250,55 @@ def _can_write(self, name: str) -> bool: def should_reverse_callbacks(self) -> bool: """True to invoke callbacks in reverse order.""" return True + + +@dataclass +class BeforeContextReductionEvent(HookEvent): + """Event triggered before context window overflow handling begins. + + This event is fired when the agent catches a ContextWindowOverflowException + and is about to reduce the context by calling the conversation manager's + reduce_context method. Hook providers can use this event for: + - Displaying "compacting conversation..." UI feedback to users + - Logging context reduction events for analytics + - Debugging context window management issues + + Attributes: + exception: The ContextWindowOverflowException that triggered the reduction. + message_count: The number of messages before context reduction begins. + """ + + exception: Exception + message_count: int + + +@dataclass +class AfterContextReductionEvent(HookEvent): + """Event triggered after context window overflow handling completes. + + This event is fired after the conversation manager's reduce_context method + has completed, regardless of whether it succeeded or failed. Hook providers + can use this event for: + - Displaying "compaction complete" UI feedback with statistics + - Tracking context reduction frequency and effectiveness + - Post-reduction cleanup or state updates + + Note: This event uses reverse callback ordering, meaning callbacks registered + later will be invoked first during cleanup. + + Attributes: + original_message_count: Number of messages before context reduction. + new_message_count: Number of messages after context reduction. + removed_count: Number of messages that were removed during reduction. + exception: Exception if context reduction failed, None if successful. + """ + + original_message_count: int + new_message_count: int + removed_count: int + exception: Exception | None = None + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True diff --git a/tests/strands/agent/hooks/test_events.py b/tests/strands/agent/hooks/test_events.py index 83cb1af24..0a8598003 100644 --- a/tests/strands/agent/hooks/test_events.py +++ b/tests/strands/agent/hooks/test_events.py @@ -195,3 +195,71 @@ def test_before_invocation_event_agent_not_writable(start_request_event_with_mes """Test that BeforeInvocationEvent.agent is not writable.""" with pytest.raises(AttributeError, match="Property agent is not writable"): start_request_event_with_messages.agent = Mock() + + +# Tests for BeforeContextReductionEvent and AfterContextReductionEvent + + +@pytest.fixture +def context_overflow_exception(): + from strands.types.exceptions import ContextWindowOverflowException + + return ContextWindowOverflowException("Context window exceeded") + + +@pytest.fixture +def before_context_reduction_event(agent, context_overflow_exception): + from strands.hooks import BeforeContextReductionEvent + + return BeforeContextReductionEvent( + agent=agent, + exception=context_overflow_exception, + message_count=50, + ) + + +@pytest.fixture +def after_context_reduction_event(agent): + from strands.hooks import AfterContextReductionEvent + + return AfterContextReductionEvent( + agent=agent, + original_message_count=50, + new_message_count=25, + removed_count=25, + ) + + +@pytest.fixture +def after_context_reduction_event_with_exception(agent): + from strands.hooks import AfterContextReductionEvent + + return AfterContextReductionEvent( + agent=agent, + original_message_count=50, + new_message_count=50, + removed_count=0, + exception=RuntimeError("Reduction failed"), + ) + + +def test_before_context_reduction_event_properties(before_context_reduction_event, context_overflow_exception): + assert before_context_reduction_event.exception == context_overflow_exception + assert before_context_reduction_event.message_count == 50 + assert before_context_reduction_event.should_reverse_callbacks is False + + +def test_after_context_reduction_event_properties(after_context_reduction_event): + assert after_context_reduction_event.original_message_count == 50 + assert after_context_reduction_event.new_message_count == 25 + assert after_context_reduction_event.removed_count == 25 + assert after_context_reduction_event.exception is None + assert after_context_reduction_event.should_reverse_callbacks is True + + +def test_after_context_reduction_event_with_exception(after_context_reduction_event_with_exception): + assert after_context_reduction_event_with_exception.original_message_count == 50 + assert after_context_reduction_event_with_exception.new_message_count == 50 + assert after_context_reduction_event_with_exception.removed_count == 0 + assert after_context_reduction_event_with_exception.exception is not None + assert isinstance(after_context_reduction_event_with_exception.exception, RuntimeError)