From e37a6a6f9f3ec36d07da7db4ded49086782ffe2b Mon Sep 17 00:00:00 2001 From: HuxleyHu98 Date: Tue, 17 Mar 2026 06:20:20 +0800 Subject: [PATCH] docs: clarify ToolContext in function tool lifecycle hooks --- docs/agents.md | 1 + docs/context.md | 2 ++ src/agents/lifecycle.py | 32 ++++++++++++++++++++++++++++---- tests/test_agent_hooks.py | 18 ++++++++++++++++++ tests/test_global_hooks.py | 18 ++++++++++++++++++ tests/test_run_hooks.py | 5 +++++ 6 files changed, 72 insertions(+), 4 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 8637005f2c..01d96f6841 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -257,6 +257,7 @@ Typical hook timing: - `on_agent_start` / `on_agent_end`: when a specific agent begins or finishes producing a final output. - `on_llm_start` / `on_llm_end`: immediately around each model call. - `on_tool_start` / `on_tool_end`: around each local tool invocation. + For function tools, the hook `context` is typically a `ToolContext`, so you can inspect tool-call metadata such as `tool_call_id`. - `on_handoff`: when control moves from one agent to another. Use `RunHooks` when you want a single observer for the whole workflow, and `AgentHooks` when one agent needs custom side effects. diff --git a/docs/context.md b/docs/context.md index 1c7f19bef0..47ba2bddb8 100644 --- a/docs/context.md +++ b/docs/context.md @@ -13,6 +13,8 @@ This is represented via the [`RunContextWrapper`][agents.run_context.RunContextW 2. You pass that object to the various run methods (e.g. `Runner.run(..., context=whatever)`). 3. All your tool calls, lifecycle hooks etc will be passed a wrapper object, `RunContextWrapper[T]`, where `T` represents your context object type which you can access via `wrapper.context`. +For some runtime-specific callbacks, the SDK may pass a more specialized subclass of `RunContextWrapper[T]`. For example, function-tool lifecycle hooks typically receive `ToolContext`, which also exposes tool-call metadata like `tool_call_id`, `tool_name`, and `tool_arguments`. + The **most important** thing to be aware of: every agent, tool function, lifecycle etc for a given agent run must use the same _type_ of context. You can use the context for things like: diff --git a/src/agents/lifecycle.py b/src/agents/lifecycle.py index 38744471fb..8c6e62d049 100644 --- a/src/agents/lifecycle.py +++ b/src/agents/lifecycle.py @@ -73,7 +73,13 @@ async def on_tool_start( agent: TAgent, tool: Tool, ) -> None: - """Called immediately before a local tool is invoked.""" + """Called immediately before a local tool is invoked. + + For function-tool invocations, ``context`` is typically a ``ToolContext`` instance, + which exposes tool-call-specific metadata such as ``tool_call_id``, ``tool_name``, + and ``tool_arguments``. Other local tool families may provide a plain + ``RunContextWrapper`` instead. + """ pass async def on_tool_end( @@ -83,7 +89,13 @@ async def on_tool_end( tool: Tool, result: str, ) -> None: - """Called immediately after a local tool is invoked.""" + """Called immediately after a local tool is invoked. + + For function-tool invocations, ``context`` is typically a ``ToolContext`` instance, + which exposes tool-call-specific metadata such as ``tool_call_id``, ``tool_name``, + and ``tool_arguments``. Other local tool families may provide a plain + ``RunContextWrapper`` instead. + """ pass @@ -135,7 +147,13 @@ async def on_tool_start( agent: TAgent, tool: Tool, ) -> None: - """Called immediately before a local tool is invoked.""" + """Called immediately before a local tool is invoked. + + For function-tool invocations, ``context`` is typically a ``ToolContext`` instance, + which exposes tool-call-specific metadata such as ``tool_call_id``, ``tool_name``, + and ``tool_arguments``. Other local tool families may provide a plain + ``RunContextWrapper`` instead. + """ pass async def on_tool_end( @@ -145,7 +163,13 @@ async def on_tool_end( tool: Tool, result: str, ) -> None: - """Called immediately after a local tool is invoked.""" + """Called immediately after a local tool is invoked. + + For function-tool invocations, ``context`` is typically a ``ToolContext`` instance, + which exposes tool-call-specific metadata such as ``tool_call_id``, ``tool_name``, + and ``tool_arguments``. Other local tool families may provide a plain + ``RunContextWrapper`` instead. + """ pass async def on_llm_start( diff --git a/tests/test_agent_hooks.py b/tests/test_agent_hooks.py index 855ad57a1c..b97f2763e7 100644 --- a/tests/test_agent_hooks.py +++ b/tests/test_agent_hooks.py @@ -12,6 +12,7 @@ from agents.run import Runner from agents.run_context import AgentHookContext, RunContextWrapper, TContext from agents.tool import Tool +from agents.tool_context import ToolContext from .fake_model import FakeModel from .test_responses import ( @@ -26,9 +27,11 @@ class AgentHooksForTests(AgentHooks): def __init__(self): self.events: dict[str, int] = defaultdict(int) + self.tool_context_ids: list[str] = [] def reset(self): self.events.clear() + self.tool_context_ids.clear() async def on_start(self, context: AgentHookContext[TContext], agent: Agent[TContext]) -> None: self.events["on_start"] += 1 @@ -56,6 +59,8 @@ async def on_tool_start( tool: Tool, ) -> None: self.events["on_tool_start"] += 1 + if isinstance(context, ToolContext): + self.tool_context_ids.append(context.tool_call_id) async def on_tool_end( self, @@ -65,6 +70,8 @@ async def on_tool_end( result: str, ) -> None: self.events["on_tool_end"] += 1 + if isinstance(context, ToolContext): + self.tool_context_ids.append(context.tool_call_id) @pytest.mark.asyncio @@ -94,6 +101,17 @@ async def test_non_streamed_agent_hooks(): assert hooks.events == {"on_start": 1, "on_end": 1}, f"{output}" hooks.reset() + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("some_function", json.dumps({"a": "b"}))], + [get_text_message("done")], + ] + ) + await Runner.run(agent_3, input="user_message") + assert len(hooks.tool_context_ids) == 2 + assert len(set(hooks.tool_context_ids)) == 1 + hooks.reset() + model.add_multiple_turn_outputs( [ # First turn: a tool call diff --git a/tests/test_global_hooks.py b/tests/test_global_hooks.py index 45854410df..d6780d6217 100644 --- a/tests/test_global_hooks.py +++ b/tests/test_global_hooks.py @@ -8,6 +8,7 @@ from typing_extensions import TypedDict from agents import Agent, RunContextWrapper, RunHooks, Runner, TContext, Tool +from agents.tool_context import ToolContext from .fake_model import FakeModel from .test_responses import ( @@ -22,9 +23,11 @@ class RunHooksForTests(RunHooks): def __init__(self): self.events: dict[str, int] = defaultdict(int) + self.tool_context_ids: list[str] = [] def reset(self): self.events.clear() + self.tool_context_ids.clear() async def on_agent_start( self, context: RunContextWrapper[TContext], agent: Agent[TContext] @@ -54,6 +57,8 @@ async def on_tool_start( tool: Tool, ) -> None: self.events["on_tool_start"] += 1 + if isinstance(context, ToolContext): + self.tool_context_ids.append(context.tool_call_id) async def on_tool_end( self, @@ -63,6 +68,8 @@ async def on_tool_end( result: str, ) -> None: self.events["on_tool_end"] += 1 + if isinstance(context, ToolContext): + self.tool_context_ids.append(context.tool_call_id) @pytest.mark.asyncio @@ -85,6 +92,17 @@ async def test_non_streamed_agent_hooks(): assert hooks.events == {"on_agent_start": 1, "on_agent_end": 1}, f"{output}" hooks.reset() + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("some_function", json.dumps({"a": "b"}))], + [get_text_message("done")], + ] + ) + await Runner.run(agent_3, input="user_message", hooks=hooks) + assert len(hooks.tool_context_ids) == 2 + assert len(set(hooks.tool_context_ids)) == 1 + hooks.reset() + model.add_multiple_turn_outputs( [ # First turn: a tool call diff --git a/tests/test_run_hooks.py b/tests/test_run_hooks.py index d729905408..92fc7e699a 100644 --- a/tests/test_run_hooks.py +++ b/tests/test_run_hooks.py @@ -10,6 +10,7 @@ from agents.run import Runner from agents.run_context import AgentHookContext, RunContextWrapper, TContext from agents.tool import Tool +from agents.tool_context import ToolContext from tests.test_agent_llm_hooks import AgentHooksForTests from .fake_model import FakeModel @@ -22,9 +23,11 @@ class RunHooksForTests(RunHooks): def __init__(self): self.events: dict[str, int] = defaultdict(int) + self.tool_context_ids: list[str] = [] def reset(self): self.events.clear() + self.tool_context_ids.clear() async def on_agent_start( self, context: AgentHookContext[TContext], agent: Agent[TContext] @@ -57,6 +60,8 @@ async def on_tool_end( result: str, ) -> None: self.events["on_tool_end"] += 1 + if isinstance(context, ToolContext): + self.tool_context_ids.append(context.tool_call_id) async def on_llm_start( self,