Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@ async def event_loop_cycle(
agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context
)
async for model_event in model_events:
if isinstance(model_event, EventLoopStopEvent):
agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
yield model_event
await model_events.aclose() # clean-up async for-loop to avoid CancelledError
return
if not isinstance(model_event, ModelStopReason):
yield model_event

Expand Down Expand Up @@ -368,6 +373,18 @@ async def _handle_model_execution(
stop_reason,
)
continue # Retry the model call
elif after_model_call_event.terminate:
logger.debug(
"stop_reason=<%s>, termination_requested=<True> | hook requested agent termination",
stop_reason,
)
invocation_state["request_state"]["stop_event_loop"] = True
yield EventLoopStopEvent(
stop_reason,
message,
agent.event_loop_metrics,
invocation_state["request_state"],
)

if stop_reason == "max_tokens":
message = recover_message_on_max_tokens_reached(message)
Expand Down
6 changes: 5 additions & 1 deletion src/strands/hooks/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,13 @@ class ModelStopResponse:
stop_response: ModelStopResponse | None = None
exception: Exception | None = None
retry: bool = False
terminate: bool = False

def _can_write(self, name: str) -> bool:
return name == "retry"
return name in (
"retry",
"terminate",
)

@property
def should_reverse_callbacks(self) -> bool:
Expand Down
101 changes: 101 additions & 0 deletions tests/strands/agent/test_agent_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,3 +686,104 @@ async def capture_messages_hook(event: BeforeInvocationEvent):

# structured_output_async uses deprecated path that doesn't pass messages
assert received_messages is None


@pytest.mark.asyncio
async def test_hook_terminate_on_successful_call():
"""Test that hooks can terminate even on successful model calls based on response content."""

mock_provider = MockedModelProvider(
[
{
"role": "assistant",
"content": [{"text": "First conversation successful"}],
},
{
"role": "assistant",
"content": [{"text": "Unnecessary follow-up conversation"}],
},
]
)

# Hook that terminate if response is favorable
class SuccessfulTerminateHook:
def __init__(self, end_marker="success"):
self.end_marker = end_marker
self.call_count = 0

def register_hooks(self, registry):
registry.add_callback(strands.hooks.AfterModelCallEvent, self.handle_after_model_call)

async def handle_after_model_call(self, event):
self.call_count += 1

# Check successful responses for favorable markers
if event.stop_response:
message = event.stop_response.message
text_content = "".join(block.get("text", "") for block in message.get("content", []))

if self.end_marker in text_content:
event.terminate = True

terminate_hook = SuccessfulTerminateHook(end_marker="success")
agent = Agent(model=mock_provider, hooks=[terminate_hook])

result = agent("Generate a response")

# Verify hook was called only once (For first favorable response)
assert terminate_hook.call_count == 1

# Verify final result is the favorable response
assert result.message["content"][0]["text"] == "First conversation successful"


@pytest.mark.asyncio
async def test_hook_terminate_gracefully_on_limits(agent_tool, tool_use):
"""Test that hooks can terminate agent gracefully after maximum counts reached."""

mock_provider = MockedModelProvider(
[
{
"role": "assistant",
"content": [{"text": "First tool-use"}, {"toolUse": tool_use}],
},
{
"role": "assistant",
"content": [{"text": "Second tool-use"}, {"toolUse": tool_use}],
},
{
"role": "assistant",
"content": [{"text": "Third tool-use"}, {"toolUse": tool_use}],
},
]
)

# Hook that counts number of calls
class GracefulTerminateHook:
def __init__(self, max_counts):
self.max_counts = max_counts
self.call_count = 0

def register_hooks(self, registry):
registry.add_callback(strands.hooks.AfterModelCallEvent, self.handle_after_model_call)

async def handle_after_model_call(self, event):
self.call_count += 1

if self.call_count > self.max_counts - 1:
event.terminate = True

terminate_hook = GracefulTerminateHook(max_counts=2)
agent = Agent(
model=mock_provider,
tools=[agent_tool],
hooks=[terminate_hook],
)

result = agent("Generate a response")

# Verify hook was called two times
assert terminate_hook.call_count == 2

# Verify final result is the second tool-use
assert result.message["content"][0]["text"] == "Second tool-use"