diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 1faa8a917..ad40dfd7f 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -158,6 +158,18 @@ class AfterToolCallEvent(HookEvent): Note: This event uses reverse callback ordering, meaning callbacks registered later will be invoked first during cleanup. + Tool Retrying: + When ``retry`` is set to True by a hook callback, the tool executor will + discard the current tool result and invoke the tool again. This has important + implications for streaming consumers: + + - ToolStreamEvents (intermediate streaming events) from the discarded tool execution + will have already been emitted to callers before the retry occurs. Agent invokers + consuming streamed events should be prepared to handle this scenario, potentially + by tracking retry state or implementing idempotent event processing + - ToolResultEvent is NOT emitted for discarded attempts - only the final attempt's + result is emitted and added to the conversation history + Attributes: selected_tool: The tool that was invoked. It may be None if tool lookup failed. tool_use: The tool parameters that were passed to the tool invoked. @@ -165,6 +177,9 @@ class AfterToolCallEvent(HookEvent): result: The result of the tool invocation. Either a ToolResult on success or an Exception if the tool execution failed. cancel_message: The cancellation message if the user cancelled the tool call. + retry: Whether to retry the tool invocation. Can be set by hook callbacks + to trigger a retry. When True, the current result is discarded and the + tool is called again. Defaults to False. """ selected_tool: AgentTool | None @@ -173,9 +188,10 @@ class AfterToolCallEvent(HookEvent): result: ToolResult exception: Exception | None = None cancel_message: str | None = None + retry: bool = False def _can_write(self, name: str) -> bool: - return name == "result" + return name in ["result", "retry"] @property def should_reverse_callbacks(self) -> bool: diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 6d58c5c75..ef000fbd6 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -148,109 +148,127 @@ async def _stream( } ) - before_event, interrupts = await ToolExecutor._invoke_before_tool_call_hook( - agent, tool_func, tool_use, invocation_state - ) - - if interrupts: - yield ToolInterruptEvent(tool_use, interrupts) - return - - if before_event.cancel_tool: - cancel_message = ( - before_event.cancel_tool if isinstance(before_event.cancel_tool, str) else "tool cancelled by user" + # Retry loop for tool execution - hooks can set after_event.retry = True to retry + while True: + before_event, interrupts = await ToolExecutor._invoke_before_tool_call_hook( + agent, tool_func, tool_use, invocation_state ) - yield ToolCancelEvent(tool_use, cancel_message) - cancel_result: ToolResult = { - "toolUseId": str(tool_use.get("toolUseId")), - "status": "error", - "content": [{"text": cancel_message}], - } + if interrupts: + yield ToolInterruptEvent(tool_use, interrupts) + return - after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( - agent, None, tool_use, invocation_state, cancel_result, cancel_message=cancel_message - ) - yield ToolResultEvent(after_event.result) - tool_results.append(after_event.result) - return - - try: - selected_tool = before_event.selected_tool - tool_use = before_event.tool_use - invocation_state = before_event.invocation_state - - if not selected_tool: - if tool_func == selected_tool: - logger.error( - "tool_name=<%s>, available_tools=<%s> | tool not found in registry", - tool_name, - list(agent.tool_registry.registry.keys()), - ) - else: - logger.debug( - "tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call", - tool_name, - str(tool_use.get("toolUseId")), - ) + if before_event.cancel_tool: + cancel_message = ( + before_event.cancel_tool if isinstance(before_event.cancel_tool, str) else "tool cancelled by user" + ) + yield ToolCancelEvent(tool_use, cancel_message) - result: ToolResult = { + cancel_result: ToolResult = { "toolUseId": str(tool_use.get("toolUseId")), "status": "error", - "content": [{"text": f"Unknown tool: {tool_name}"}], + "content": [{"text": cancel_message}], } after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( - agent, selected_tool, tool_use, invocation_state, result + agent, None, tool_use, invocation_state, cancel_result, cancel_message=cancel_message ) yield ToolResultEvent(after_event.result) tool_results.append(after_event.result) return - if structured_output_context.is_enabled: - kwargs["structured_output_context"] = structured_output_context - async for event in selected_tool.stream(tool_use, invocation_state, **kwargs): - # Internal optimization; for built-in AgentTools, we yield TypedEvents out of .stream() - # so that we don't needlessly yield ToolStreamEvents for non-generator callbacks. - # In which case, as soon as we get a ToolResultEvent we're done and for ToolStreamEvent - # we yield it directly; all other cases (non-sdk AgentTools), we wrap events in - # ToolStreamEvent and the last event is just the result. - - if isinstance(event, ToolInterruptEvent): - yield event - return - - if isinstance(event, ToolResultEvent): - # below the last "event" must point to the tool_result - event = event.tool_result - break - if isinstance(event, ToolStreamEvent): - yield event - else: - yield ToolStreamEvent(tool_use, event) + try: + selected_tool = before_event.selected_tool + tool_use = before_event.tool_use + invocation_state = before_event.invocation_state + + if not selected_tool: + if tool_func == selected_tool: + logger.error( + "tool_name=<%s>, available_tools=<%s> | tool not found in registry", + tool_name, + list(agent.tool_registry.registry.keys()), + ) + else: + logger.debug( + "tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call", + tool_name, + str(tool_use.get("toolUseId")), + ) + + result: ToolResult = { + "toolUseId": str(tool_use.get("toolUseId")), + "status": "error", + "content": [{"text": f"Unknown tool: {tool_name}"}], + } + + after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( + agent, selected_tool, tool_use, invocation_state, result + ) + # Check if retry requested for unknown tool error + # Use getattr because BidiAfterToolCallEvent doesn't have retry attribute + if getattr(after_event, "retry", False): + logger.debug("tool_name=<%s> | retry requested, retrying tool call", tool_name) + continue + yield ToolResultEvent(after_event.result) + tool_results.append(after_event.result) + return + if structured_output_context.is_enabled: + kwargs["structured_output_context"] = structured_output_context + async for event in selected_tool.stream(tool_use, invocation_state, **kwargs): + # Internal optimization; for built-in AgentTools, we yield TypedEvents out of .stream() + # so that we don't needlessly yield ToolStreamEvents for non-generator callbacks. + # In which case, as soon as we get a ToolResultEvent we're done and for ToolStreamEvent + # we yield it directly; all other cases (non-sdk AgentTools), we wrap events in + # ToolStreamEvent and the last event is just the result. + + if isinstance(event, ToolInterruptEvent): + yield event + return + + if isinstance(event, ToolResultEvent): + # below the last "event" must point to the tool_result + event = event.tool_result + break + + if isinstance(event, ToolStreamEvent): + yield event + else: + yield ToolStreamEvent(tool_use, event) + + result = cast(ToolResult, event) - result = cast(ToolResult, event) + after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( + agent, selected_tool, tool_use, invocation_state, result + ) - after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( - agent, selected_tool, tool_use, invocation_state, result - ) + # Check if retry requested (getattr for BidiAfterToolCallEvent compatibility) + if getattr(after_event, "retry", False): + logger.debug("tool_name=<%s> | retry requested, retrying tool call", tool_name) + continue - yield ToolResultEvent(after_event.result) - tool_results.append(after_event.result) + yield ToolResultEvent(after_event.result) + tool_results.append(after_event.result) + return - except Exception as e: - logger.exception("tool_name=<%s> | failed to process tool", tool_name) - error_result: ToolResult = { - "toolUseId": str(tool_use.get("toolUseId")), - "status": "error", - "content": [{"text": f"Error: {str(e)}"}], - } + except Exception as e: + logger.exception("tool_name=<%s> | failed to process tool", tool_name) + error_result: ToolResult = { + "toolUseId": str(tool_use.get("toolUseId")), + "status": "error", + "content": [{"text": f"Error: {str(e)}"}], + } - after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( - agent, selected_tool, tool_use, invocation_state, error_result, exception=e - ) - yield ToolResultEvent(after_event.result) - tool_results.append(after_event.result) + after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( + agent, selected_tool, tool_use, invocation_state, error_result, exception=e + ) + # Check if retry requested (getattr for BidiAfterToolCallEvent compatibility) + if getattr(after_event, "retry", False): + logger.debug("tool_name=<%s> | retry requested after exception, retrying tool call", tool_name) + continue + yield ToolResultEvent(after_event.result) + tool_results.append(after_event.result) + return @staticmethod async def _stream_with_trace( diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index 8139fbf66..78e35c2aa 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -4,6 +4,7 @@ import pytest import strands +from strands.experimental.hooks.events import BidiAfterToolCallEvent from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent from strands.interrupt import Interrupt from strands.telemetry.metrics import Trace @@ -479,3 +480,281 @@ async def test_executor_stream_updates_invocation_state_with_agent( # Verify that the invocation_state was updated with the agent assert "agent" in empty_invocation_state assert empty_invocation_state["agent"] is agent + + +@pytest.mark.asyncio +async def test_executor_stream_no_retry_set(executor, agent, tool_results, invocation_state, alist): + """Test default behavior when retry is not set - tool executes once.""" + call_count = {"count": 0} + + @strands.tool(name="counting_tool") + def counting_tool(): + call_count["count"] += 1 + return f"attempt_{call_count['count']}" + + agent.tool_registry.register_tool(counting_tool) + + tool_use: ToolUse = {"name": "counting_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + + # Tool should be called exactly once + assert call_count["count"] == 1 + + # Single result event with first attempt's content + assert len(tru_events) == 1 + assert tru_events[0].tool_result == {"toolUseId": "1", "status": "success", "content": [{"text": "attempt_1"}]} + + # tool_results should contain the result + assert len(tool_results) == 1 + assert tool_results[0] == {"toolUseId": "1", "status": "success", "content": [{"text": "attempt_1"}]} + + +@pytest.mark.asyncio +async def test_executor_stream_retry_true(executor, agent, tool_results, invocation_state, alist): + """Test that retry=True causes tool re-execution.""" + call_count = {"count": 0} + + @strands.tool(name="counting_tool") + def counting_tool(): + call_count["count"] += 1 + return f"attempt_{call_count['count']}" + + agent.tool_registry.register_tool(counting_tool) + + # Set retry=True on first call only + def retry_once(event): + if isinstance(event, AfterToolCallEvent) and call_count["count"] == 1: + event.retry = True + return event + + agent.hooks.add_callback(AfterToolCallEvent, retry_once) + + tool_use: ToolUse = {"name": "counting_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + + # Tool should be called twice due to retry + assert call_count["count"] == 2 + + # Only final result is yielded (first attempt's result was discarded) + assert len(tru_events) == 1 + assert tru_events[0].tool_result == {"toolUseId": "1", "status": "success", "content": [{"text": "attempt_2"}]} + + # tool_results only contains the final result + assert len(tool_results) == 1 + assert tool_results[0] == {"toolUseId": "1", "status": "success", "content": [{"text": "attempt_2"}]} + + +@pytest.mark.asyncio +async def test_executor_stream_retry_true_emits_events_from_both_attempts( + executor, agent, tool_results, invocation_state, alist +): + """Test that ToolStreamEvents from discarded attempt ARE emitted, but ToolResultEvent is NOT. + + This validates the documented behavior: 'Streaming events from the discarded + tool execution will have already been emitted to callers before the retry occurs.' + + Key distinction: + - ToolStreamEvent (intermediate): Yielded immediately, visible from BOTH attempts + - ToolResultEvent (final): Only yielded for the final attempt, discarded on retry + """ + call_count = {"count": 0} + + @strands.tool(name="streaming_tool") + def streaming_tool(): + return "unused" + + # Provide streaming implementation (same pattern as exception_tool fixture) + async def tool_stream(_tool_use, _invocation_state, **kwargs): + call_count["count"] += 1 + yield f"streaming_from_attempt_{call_count['count']}" + yield ToolResultEvent( + {"toolUseId": "1", "status": "success", "content": [{"text": f"result_{call_count['count']}"}]} + ) + + streaming_tool.stream = tool_stream + agent.tool_registry.register_tool(streaming_tool) + + # Set retry=True on first call + def retry_once(event): + if isinstance(event, AfterToolCallEvent) and call_count["count"] == 1: + event.retry = True + return event + + agent.hooks.add_callback(AfterToolCallEvent, retry_once) + + tool_use: ToolUse = {"name": "streaming_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + + # Tool called twice + assert call_count["count"] == 2 + + # Streaming events from BOTH attempts are emitted (documented behavior) + stream_events = [e for e in tru_events if isinstance(e, ToolStreamEvent)] + assert len(stream_events) == 2 + assert stream_events[0] == ToolStreamEvent(tool_use, "streaming_from_attempt_1") + assert stream_events[1] == ToolStreamEvent(tool_use, "streaming_from_attempt_2") + + # Only final ToolResultEvent is emitted + result_events = [e for e in tru_events if isinstance(e, ToolResultEvent)] + assert len(result_events) == 1 + assert result_events[0].tool_result["content"][0]["text"] == "result_2" + + +@pytest.mark.asyncio +async def test_executor_stream_retry_false(executor, agent, tool_results, invocation_state, alist): + """Test that explicitly setting retry=False does not retry.""" + call_count = {"count": 0} + + @strands.tool(name="counting_tool") + def counting_tool(): + call_count["count"] += 1 + return f"attempt_{call_count['count']}" + + agent.tool_registry.register_tool(counting_tool) + + # Explicitly set retry=False + def no_retry(event): + if isinstance(event, AfterToolCallEvent): + event.retry = False + return event + + agent.hooks.add_callback(AfterToolCallEvent, no_retry) + + tool_use: ToolUse = {"name": "counting_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + + # Tool should be called exactly once + assert call_count["count"] == 1 + + # Single result event + assert len(tru_events) == 1 + assert tru_events[0].tool_result == {"toolUseId": "1", "status": "success", "content": [{"text": "attempt_1"}]} + + # tool_results should contain the result + assert len(tool_results) == 1 + assert tool_results[0] == {"toolUseId": "1", "status": "success", "content": [{"text": "attempt_1"}]} + + +@pytest.mark.asyncio +async def test_executor_stream_bidi_event_no_retry_attribute(executor, agent, tool_results, invocation_state, alist): + """Test that BidiAfterToolCallEvent (which lacks retry attribute) doesn't cause retry. + + This tests the getattr(after_event, "retry", False) fallback for events without retry. + """ + call_count = {"count": 0} + + @strands.tool(name="counting_tool") + def counting_tool(): + call_count["count"] += 1 + return f"attempt_{call_count['count']}" + + agent.tool_registry.register_tool(counting_tool) + + tool_use: ToolUse = {"name": "counting_tool", "toolUseId": "1", "input": {}} + result: strands.types.tools.ToolResult = { + "toolUseId": "1", + "status": "success", + "content": [{"text": "attempt_1"}], + } + + # Create a BidiAfterToolCallEvent (which has no retry attribute) + bidi_event = BidiAfterToolCallEvent( + agent=agent, + selected_tool=counting_tool, + tool_use=tool_use, + invocation_state=invocation_state, + result=result, + ) + + # Patch _invoke_after_tool_call_hook to return BidiAfterToolCallEvent + async def mock_after_hook(*args, **kwargs): + return bidi_event, [] + + with unittest.mock.patch.object(ToolExecutor, "_invoke_after_tool_call_hook", mock_after_hook): + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + tru_events = await alist(stream) + + # Tool should be called once - no retry since BidiAfterToolCallEvent has no retry attr + assert call_count["count"] == 1 + + # Result should be returned + assert len(tru_events) == 1 + + +@pytest.mark.asyncio +async def test_executor_stream_retry_after_exception(executor, agent, tool_results, invocation_state, alist): + """Test that retry=True works when tool raises an exception. + + Covers the exception path retry check. + """ + call_count = {"count": 0} + + @strands.tool(name="flaky_tool") + def flaky_tool(): + call_count["count"] += 1 + if call_count["count"] == 1: + raise RuntimeError("First call fails") + return "success" + + agent.tool_registry.register_tool(flaky_tool) + + # Retry once on error (check result status, not exception attribute) + def retry_on_error(event): + if isinstance(event, AfterToolCallEvent) and event.result.get("status") == "error" and call_count["count"] == 1: + event.retry = True + return event + + agent.hooks.add_callback(AfterToolCallEvent, retry_on_error) + + tool_use: ToolUse = {"name": "flaky_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + tru_events = await alist(stream) + + # Tool called twice (1 exception + 1 success) + assert call_count["count"] == 2 + + # Final result is success + assert len(tru_events) == 1 + assert tru_events[0].tool_result["status"] == "success" + + +@pytest.mark.asyncio +async def test_executor_stream_retry_after_unknown_tool(executor, agent, tool_results, invocation_state, alist): + """Test that retry=True triggers retry loop for unknown tool. + + Covers the unknown tool path retry check. Tool lookup happens before retry loop, + so even after retry the tool remains unknown - this test verifies the retry + mechanism is triggered, not that it resolves the unknown tool. + """ + hook_call_count = {"count": 0} + + # Retry once on first unknown tool error + def retry_once_on_unknown(event): + if isinstance(event, AfterToolCallEvent): + hook_call_count["count"] += 1 + # Retry only on first call + if hook_call_count["count"] == 1: + event.retry = True + return event + + agent.hooks.add_callback(AfterToolCallEvent, retry_once_on_unknown) + + tool_use: ToolUse = {"name": "nonexistent_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + tru_events = await alist(stream) + + # Hook called twice (retry was triggered) + assert hook_call_count["count"] == 2 + + # Final result is still error (tool remains unknown after retry) + assert len(tru_events) == 1 + assert tru_events[0].tool_result["status"] == "error" + assert "Unknown tool" in tru_events[0].tool_result["content"][0]["text"] diff --git a/tests_integ/test_tool_retry_hook.py b/tests_integ/test_tool_retry_hook.py new file mode 100644 index 000000000..3e35ff5e6 --- /dev/null +++ b/tests_integ/test_tool_retry_hook.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +"""Integration tests for tool retry hook mechanism. + +Tests that setting AfterToolCallEvent.retry=True causes tool re-execution. +Uses direct tool invocation to test the executor-level retry, not model behavior. +""" + +from strands import Agent, tool +from strands.hooks import AfterToolCallEvent + + +def test_tool_retry_hook_causes_reexecution(): + """Test that setting retry=True on AfterToolCallEvent causes tool re-execution. + + Verifies: + 1. Tool is called again when retry=True + 2. Hook receives AfterToolCallEvent for BOTH attempts + 3. Same tool_use_id is used (proves executor retry, not model re-calling) + """ + state = {"call_count": 0} + + @tool(name="flaky_tool") + def flaky_tool(message: str) -> str: + """A tool that fails once then succeeds. + + Args: + message: A message to include in the response. + """ + state["call_count"] += 1 + if state["call_count"] == 1: + raise RuntimeError("First call fails") + return f"Success on attempt {state['call_count']}" + + hook_calls: list[dict] = [] + + def retry_on_first_error(event: AfterToolCallEvent) -> None: + tool_use_id = str(event.tool_use.get("toolUseId", "")) + hook_calls.append( + { + "tool_use_id": tool_use_id, + "status": event.result.get("status"), + "attempt": state["call_count"], + } + ) + + # Retry once on error + if event.result.get("status") == "error" and state["call_count"] == 1: + event.retry = True + + agent = Agent(tools=[flaky_tool]) + agent.hooks.add_callback(AfterToolCallEvent, retry_on_first_error) + + # Direct tool invocation bypasses model - tests executor retry mechanism + result = agent.tool.flaky_tool(message="test") + + # Tool was called twice (1 failure + 1 success) + assert state["call_count"] == 2 + + # Hook received AfterToolCallEvent for BOTH attempts + assert len(hook_calls) == 2 + assert hook_calls[0]["status"] == "error" + assert hook_calls[0]["attempt"] == 1 + assert hook_calls[1]["status"] == "success" + assert hook_calls[1]["attempt"] == 2 + + # Both calls used the same tool_use_id (executor retry, not new model call) + assert hook_calls[0]["tool_use_id"] == hook_calls[1]["tool_use_id"] + + assert result["status"] == "success"