diff --git a/src/agents/extensions/models/any_llm_model.py b/src/agents/extensions/models/any_llm_model.py index 5302e49779..15296086da 100644 --- a/src/agents/extensions/models/any_llm_model.py +++ b/src/agents/extensions/models/any_llm_model.py @@ -42,8 +42,8 @@ ) from ...retry import ModelRetryAdvice, ModelRetryAdviceRequest from ...tool import Tool -from ...tracing import generation_span, response_span -from ...tracing.span_data import GenerationSpanData +from ...tracing import SpanError, generation_span, response_span as create_response_span +from ...tracing.span_data import GenerationSpanData, ResponseSpanData from ...tracing.spans import Span from ...usage import Usage from ...util._json import _to_dump_compatible @@ -260,6 +260,7 @@ async def get_response( previous_response_id: str | None = None, conversation_id: str | None = None, prompt: ResponsePromptParam | None = None, + response_span: Span[ResponseSpanData] | None = None, ) -> ModelResponse: if self._selected_api() == "responses": return await self._get_response_via_responses( @@ -273,6 +274,7 @@ async def get_response( previous_response_id=previous_response_id, conversation_id=conversation_id, prompt=prompt, + response_span=response_span, ) return await self._get_response_via_chat( @@ -298,6 +300,7 @@ async def stream_response( previous_response_id: str | None = None, conversation_id: str | None = None, prompt: ResponsePromptParam | None = None, + response_span: Span[ResponseSpanData] | None = None, ) -> AsyncIterator[TResponseStreamEvent]: if self._selected_api() == "responses": async for chunk in self._stream_response_via_responses( @@ -311,6 +314,7 @@ async def stream_response( previous_response_id=previous_response_id, conversation_id=conversation_id, prompt=prompt, + response_span=response_span, ): yield chunk return @@ -340,56 +344,76 @@ async def _get_response_via_responses( previous_response_id: str | None, conversation_id: str | None, prompt: ResponsePromptParam | None, + response_span: Span[ResponseSpanData] | None, ) -> ModelResponse: - with response_span(disabled=tracing.is_disabled()) as span_response: - response = await self._fetch_responses_response( - system_instructions=system_instructions, - input=input, - model_settings=model_settings, - tools=tools, - output_schema=output_schema, - handoffs=handoffs, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - stream=False, - prompt=prompt, - ) - - if _debug.DONT_LOG_MODEL_DATA: - logger.debug("LLM responded") - else: - logger.debug( - "LLM resp:\n%s\n", - json.dumps( - [item.model_dump() for item in response.output], - indent=2, - ensure_ascii=False, - ), + span_response = response_span or create_response_span(disabled=tracing.is_disabled()) + owns_response_span = response_span is None + if owns_response_span: + span_response.start(mark_as_current=True) + try: + try: + response = await self._fetch_responses_response( + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=tools, + output_schema=output_schema, + handoffs=handoffs, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + stream=False, + prompt=prompt, ) - usage = ( - Usage( - requests=1, - input_tokens=response.usage.input_tokens, - output_tokens=response.usage.output_tokens, - total_tokens=response.usage.total_tokens, - input_tokens_details=response.usage.input_tokens_details, - output_tokens_details=response.usage.output_tokens_details, + if _debug.DONT_LOG_MODEL_DATA: + logger.debug("LLM responded") + else: + logger.debug( + "LLM resp:\n%s\n", + json.dumps( + [item.model_dump() for item in response.output], + indent=2, + ensure_ascii=False, + ), + ) + + usage = ( + Usage( + requests=1, + input_tokens=response.usage.input_tokens, + output_tokens=response.usage.output_tokens, + total_tokens=response.usage.total_tokens, + input_tokens_details=response.usage.input_tokens_details, + output_tokens_details=response.usage.output_tokens_details, + ) + if response.usage + else Usage() ) - if response.usage - else Usage() - ) - if tracing.include_data(): - span_response.span_data.response = response - span_response.span_data.input = input - - return ModelResponse( - output=response.output, - usage=usage, - response_id=response.id, - request_id=getattr(response, "_request_id", None), - ) + if tracing.include_data(): + span_response.span_data.response = response + span_response.span_data.input = input + + return ModelResponse( + output=response.output, + usage=usage, + response_id=response.id, + request_id=getattr(response, "_request_id", None), + ) + except Exception as e: + if owns_response_span: + span_response.set_error( + SpanError( + message="Error getting response", + data={ + "error": str(e) if tracing.include_data() else e.__class__.__name__, + }, + ) + ) + raise + finally: + if owns_response_span: + span_response.finish(reset_current=True) async def _stream_response_via_responses( self, @@ -404,37 +428,60 @@ async def _stream_response_via_responses( previous_response_id: str | None, conversation_id: str | None, prompt: ResponsePromptParam | None, + response_span: Span[ResponseSpanData] | None, ) -> AsyncIterator[ResponseStreamEvent]: - with response_span(disabled=tracing.is_disabled()) as span_response: - stream = await self._fetch_responses_response( - system_instructions=system_instructions, - input=input, - model_settings=model_settings, - tools=tools, - output_schema=output_schema, - handoffs=handoffs, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - stream=True, - prompt=prompt, - ) - - final_response: Response | None = None + span_response = response_span or create_response_span(disabled=tracing.is_disabled()) + owns_response_span = response_span is None + if owns_response_span: + span_response.start(mark_as_current=True) + try: try: - async for chunk in stream: - if isinstance(chunk, ResponseCompletedEvent): - final_response = chunk.response - elif getattr(chunk, "type", None) in {"response.failed", "response.incomplete"}: - terminal_response = getattr(chunk, "response", None) - if isinstance(terminal_response, Response): - final_response = terminal_response - yield chunk - finally: - await self._maybe_aclose(stream) + stream = await self._fetch_responses_response( + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=tools, + output_schema=output_schema, + handoffs=handoffs, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + stream=True, + prompt=prompt, + ) - if tracing.include_data() and final_response: - span_response.span_data.response = final_response - span_response.span_data.input = input + final_response: Response | None = None + try: + async for chunk in stream: + if isinstance(chunk, ResponseCompletedEvent): + final_response = chunk.response + elif getattr(chunk, "type", None) in { + "response.failed", + "response.incomplete", + }: + terminal_response = getattr(chunk, "response", None) + if isinstance(terminal_response, Response): + final_response = terminal_response + yield chunk + finally: + await self._maybe_aclose(stream) + + if tracing.include_data() and final_response: + span_response.span_data.response = final_response + span_response.span_data.input = input + except Exception as e: + if owns_response_span: + span_response.set_error( + SpanError( + message="Error streaming response", + data={ + "error": str(e) if tracing.include_data() else e.__class__.__name__, + }, + ) + ) + raise + finally: + if owns_response_span: + span_response.finish(reset_current=True) async def _get_response_via_chat( self, diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index 906130b87f..2e91a7262c 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -53,7 +53,7 @@ from ...retry import ModelRetryAdvice, ModelRetryAdviceRequest from ...tool import Tool from ...tracing import generation_span -from ...tracing.span_data import GenerationSpanData +from ...tracing.span_data import GenerationSpanData, ResponseSpanData from ...tracing.spans import Span from ...usage import Usage from ...util._json import _to_dump_compatible @@ -171,7 +171,9 @@ async def get_response( previous_response_id: str | None = None, # unused conversation_id: str | None = None, # unused prompt: Any | None = None, + response_span: Span[ResponseSpanData] | None = None, ) -> ModelResponse: + del response_span with generation_span( model=str(self.model), model_config=model_settings.to_json_dict() @@ -285,7 +287,9 @@ async def stream_response( previous_response_id: str | None = None, # unused conversation_id: str | None = None, # unused prompt: Any | None = None, + response_span: Span[ResponseSpanData] | None = None, ) -> AsyncIterator[TResponseStreamEvent]: + del response_span with generation_span( model=str(self.model), model_config=model_settings.to_json_dict() diff --git a/src/agents/models/interface.py b/src/agents/models/interface.py index 8d18a9a363..a361a572c7 100644 --- a/src/agents/models/interface.py +++ b/src/agents/models/interface.py @@ -15,6 +15,8 @@ if TYPE_CHECKING: from ..model_settings import ModelSettings from ..retry import ModelRetryAdvice, ModelRetryAdviceRequest + from ..tracing import Span + from ..tracing.span_data import ResponseSpanData class ModelTracing(enum.Enum): @@ -67,6 +69,7 @@ async def get_response( previous_response_id: str | None, conversation_id: str | None, prompt: ResponsePromptParam | None, + response_span: Span[ResponseSpanData] | None = None, ) -> ModelResponse: """Get a response from the model. @@ -102,6 +105,7 @@ def stream_response( previous_response_id: str | None, conversation_id: str | None, prompt: ResponsePromptParam | None, + response_span: Span[ResponseSpanData] | None = None, ) -> AsyncIterator[TResponseStreamEvent]: """Stream a response from the model. diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 5b48cb1707..bbbdc9381c 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -27,7 +27,7 @@ from ..retry import ModelRetryAdvice, ModelRetryAdviceRequest from ..tool import Tool from ..tracing import generation_span -from ..tracing.span_data import GenerationSpanData +from ..tracing.span_data import GenerationSpanData, ResponseSpanData from ..tracing.spans import Span from ..usage import Usage from ..util._json import _to_dump_compatible @@ -108,7 +108,9 @@ async def get_response( previous_response_id: str | None = None, # unused conversation_id: str | None = None, # unused prompt: ResponsePromptParam | None = None, + response_span: Span[ResponseSpanData] | None = None, ) -> ModelResponse: + del response_span with generation_span( model=str(self.model), model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)}, @@ -221,10 +223,12 @@ async def stream_response( previous_response_id: str | None = None, # unused conversation_id: str | None = None, # unused prompt: ResponsePromptParam | None = None, + response_span: Span[ResponseSpanData] | None = None, ) -> AsyncIterator[TResponseStreamEvent]: """ Yields a partial message as it is generated, as well as the usage information. """ + del response_span with generation_span( model=str(self.model), model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)}, diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index f98da12344..8d862b48de 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -60,7 +60,7 @@ has_required_tool_search_surface, validate_responses_tool_search_configuration, ) -from ..tracing import SpanError, response_span +from ..tracing import SpanError, response_span as create_response_span from ..usage import Usage from ..util._json import _to_dump_compatible from ..version import __version__ @@ -74,6 +74,8 @@ if TYPE_CHECKING: from ..model_settings import ModelSettings + from ..tracing import Span + from ..tracing.span_data import ResponseSpanData _USER_AGENT = f"Agents/Python {__version__}" @@ -430,8 +432,13 @@ async def get_response( previous_response_id: str | None = None, conversation_id: str | None = None, prompt: ResponsePromptParam | None = None, + response_span: Span[ResponseSpanData] | None = None, ) -> ModelResponse: - with response_span(disabled=tracing.is_disabled()) as span_response: + span_response = response_span or create_response_span(disabled=tracing.is_disabled()) + owns_response_span = response_span is None + if owns_response_span: + span_response.start(mark_as_current=True) + try: try: response = await self._fetch_response( system_instructions, @@ -477,24 +484,27 @@ async def get_response( span_response.span_data.response = response span_response.span_data.input = input except Exception as e: - span_response.set_error( - SpanError( - message="Error getting response", - data={ - "error": str(e) if tracing.include_data() else e.__class__.__name__, - }, + if owns_response_span: + span_response.set_error( + SpanError( + message="Error getting response", + data={ + "error": str(e) if tracing.include_data() else e.__class__.__name__, + }, + ) ) - ) request_id = getattr(e, "request_id", None) logger.error(f"Error getting response: {e}. (request_id: {request_id})") raise - - return ModelResponse( - output=response.output, - usage=usage, - response_id=response.id, - request_id=getattr(response, "_request_id", None), - ) + return ModelResponse( + output=response.output, + usage=usage, + response_id=response.id, + request_id=getattr(response, "_request_id", None), + ) + finally: + if owns_response_span: + span_response.finish(reset_current=True) async def stream_response( self, @@ -508,11 +518,16 @@ async def stream_response( previous_response_id: str | None = None, conversation_id: str | None = None, prompt: ResponsePromptParam | None = None, + response_span: Span[ResponseSpanData] | None = None, ) -> AsyncIterator[ResponseStreamEvent]: """ Yields a partial message as it is generated, as well as the usage information. """ - with response_span(disabled=tracing.is_disabled()) as span_response: + span_response = response_span or create_response_span(disabled=tracing.is_disabled()) + owns_response_span = response_span is None + if owns_response_span: + span_response.start(mark_as_current=True) + try: try: stream = await self._fetch_response( system_instructions, @@ -571,16 +586,20 @@ async def stream_response( span_response.span_data.input = input except Exception as e: - span_response.set_error( - SpanError( - message="Error streaming response", - data={ - "error": str(e) if tracing.include_data() else e.__class__.__name__, - }, + if owns_response_span: + span_response.set_error( + SpanError( + message="Error streaming response", + data={ + "error": str(e) if tracing.include_data() else e.__class__.__name__, + }, + ) ) - ) logger.error(f"Error streaming response: {e}") raise + finally: + if owns_response_span: + span_response.finish(reset_current=True) @overload async def _fetch_response( diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 36e34b4f56..910e3f4c99 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -7,6 +7,7 @@ import asyncio import dataclasses as _dc +import inspect import json from collections.abc import Awaitable, Callable, Mapping from typing import Any, TypeVar, cast @@ -64,7 +65,13 @@ RunItemStreamEvent, ) from ..tool import FunctionTool, Tool, dispose_resolved_computers -from ..tracing import Span, SpanError, agent_span, get_current_trace +from ..tracing import ( + Span, + SpanError, + agent_span, + get_current_trace, + response_span as create_response_span, +) from ..tracing.model_tracing import get_model_tracing_impl from ..tracing.span_data import AgentSpanData from ..usage import Usage @@ -231,6 +238,23 @@ ] +def _maybe_attach_response_span_kwarg( + method: Callable[..., Any], + request_kwargs: dict[str, Any], + response_span: Span[Any], +) -> None: + try: + parameters = inspect.signature(method).parameters.values() + except (TypeError, ValueError): + return + + if any( + parameter.kind is inspect.Parameter.VAR_KEYWORD or parameter.name == "response_span" + for parameter in parameters + ): + request_kwargs["response_span"] = response_span + + def _should_attach_generic_agent_error(exc: Exception) -> bool: return not isinstance( exc, @@ -1230,212 +1254,238 @@ def _tool_search_fingerprint(raw_item: Any) -> str: if not filtered.input and server_conversation_tracker is None: raise RuntimeError("Prepared model input is empty") - await asyncio.gather( - hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input), - ( - agent.hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input) - if agent.hooks - else _coro.noop_coroutine() - ), - ) - - if ( - not streamed_result._stream_input_persisted - and session is not None - and server_conversation_tracker is None - and streamed_result._original_input_for_persistence - and len(streamed_result._original_input_for_persistence) > 0 - ): - streamed_result._stream_input_persisted = True - input_items_to_save = [ - ensure_input_item_format(item) - for item in ItemHelpers.input_to_new_input_list( - streamed_result._original_input_for_persistence - ) - ] - if input_items_to_save: - await save_result_to_session(session, input_items_to_save, [], streamed_result._state) - - previous_response_id = ( - server_conversation_tracker.previous_response_id - if server_conversation_tracker - and server_conversation_tracker.previous_response_id is not None - else None + model_tracing = get_model_tracing_impl( + run_config.tracing_disabled, run_config.trace_include_sensitive_data ) - conversation_id = ( - server_conversation_tracker.conversation_id if server_conversation_tracker else None - ) - if conversation_id: - logger.debug("Using conversation_id=%s", conversation_id) - else: - logger.debug("No conversation_id available for request") + active_response_span = create_response_span(disabled=model_tracing.is_disabled()) + active_response_span.start(mark_as_current=True) - async def rewind_model_request() -> None: - items_to_rewind = session_items_to_rewind if session_items_to_rewind is not None else [] - await rewind_session_items(session, items_to_rewind, server_conversation_tracker) - if server_conversation_tracker is not None: - server_conversation_tracker.rewind_input(filtered.input) - - stream_failed_retry_attempts: list[int] = [0] - retry_stream = stream_response_with_retry( - get_stream=lambda: model.stream_response( - filtered.instructions, - filtered.input, - model_settings, - all_tools, - output_schema, - handoffs, - get_model_tracing_impl( - run_config.tracing_disabled, run_config.trace_include_sensitive_data + try: + await asyncio.gather( + hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input), + ( + agent.hooks.on_llm_start( + context_wrapper, agent, filtered.instructions, filtered.input + ) + if agent.hooks + else _coro.noop_coroutine() ), - previous_response_id=previous_response_id, - conversation_id=conversation_id, - prompt=prompt_config, - ), - rewind=rewind_model_request, - retry_settings=model_settings.retry, - get_retry_advice=model.get_retry_advice, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - failed_retry_attempts_out=stream_failed_retry_attempts, - ) + ) - async for event in retry_stream: - streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event)) - - terminal_response: Response | None = None - if isinstance(event, ResponseCompletedEvent): - terminal_response = event.response - elif getattr(event, "type", None) in {"response.incomplete", "response.failed"}: - maybe_response = getattr(event, "response", None) - if isinstance(maybe_response, Response): - terminal_response = maybe_response - - if terminal_response is not None: - usage = ( - apply_retry_attempt_usage( - Usage( - requests=1, - input_tokens=terminal_response.usage.input_tokens, - output_tokens=terminal_response.usage.output_tokens, - total_tokens=terminal_response.usage.total_tokens, - input_tokens_details=terminal_response.usage.input_tokens_details, - output_tokens_details=terminal_response.usage.output_tokens_details, - ), - stream_failed_retry_attempts[0], + if ( + not streamed_result._stream_input_persisted + and session is not None + and server_conversation_tracker is None + and streamed_result._original_input_for_persistence + and len(streamed_result._original_input_for_persistence) > 0 + ): + streamed_result._stream_input_persisted = True + input_items_to_save = [ + ensure_input_item_format(item) + for item in ItemHelpers.input_to_new_input_list( + streamed_result._original_input_for_persistence + ) + ] + if input_items_to_save: + await save_result_to_session( + session, input_items_to_save, [], streamed_result._state ) - if terminal_response.usage - else Usage() - ) - final_response = ModelResponse( - output=terminal_response.output, - usage=usage, - response_id=terminal_response.id, - request_id=getattr(terminal_response, "_request_id", None), - ) - if isinstance(event, ResponseOutputItemDoneEvent): - output_item = event.item - output_item_type = getattr(output_item, "type", None) + previous_response_id = ( + server_conversation_tracker.previous_response_id + if server_conversation_tracker + and server_conversation_tracker.previous_response_id is not None + else None + ) + conversation_id = ( + server_conversation_tracker.conversation_id if server_conversation_tracker else None + ) + if conversation_id: + logger.debug("Using conversation_id=%s", conversation_id) + else: + logger.debug("No conversation_id available for request") + + async def rewind_model_request() -> None: + items_to_rewind = session_items_to_rewind if session_items_to_rewind is not None else [] + await rewind_session_items(session, items_to_rewind, server_conversation_tracker) + if server_conversation_tracker is not None: + server_conversation_tracker.rewind_input(filtered.input) + + stream_request_kwargs: dict[str, Any] = { + "system_instructions": filtered.instructions, + "input": filtered.input, + "model_settings": model_settings, + "tools": all_tools, + "output_schema": output_schema, + "handoffs": handoffs, + "tracing": model_tracing, + "previous_response_id": previous_response_id, + "conversation_id": conversation_id, + "prompt": prompt_config, + } + _maybe_attach_response_span_kwarg( + model.stream_response, stream_request_kwargs, active_response_span + ) - if output_item_type == "tool_search_call": - emitted_tool_search_fingerprints.add(_tool_search_fingerprint(output_item)) - streamed_result._event_queue.put_nowait( - RunItemStreamEvent( - item=ToolSearchCallItem( - raw_item=coerce_tool_search_call_raw_item(output_item), - agent=agent, - ), - name="tool_search_called", - ) - ) + stream_failed_retry_attempts: list[int] = [0] + retry_stream = stream_response_with_retry( + get_stream=lambda: model.stream_response(**stream_request_kwargs), + rewind=rewind_model_request, + retry_settings=model_settings.retry, + get_retry_advice=model.get_retry_advice, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + failed_retry_attempts_out=stream_failed_retry_attempts, + ) - elif output_item_type == "tool_search_output": - emitted_tool_search_fingerprints.add(_tool_search_fingerprint(output_item)) - streamed_result._event_queue.put_nowait( - RunItemStreamEvent( - item=ToolSearchOutputItem( - raw_item=coerce_tool_search_output_raw_item(output_item), - agent=agent, + async for event in retry_stream: + streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event)) + + terminal_response: Response | None = None + if isinstance(event, ResponseCompletedEvent): + terminal_response = event.response + elif getattr(event, "type", None) in {"response.incomplete", "response.failed"}: + maybe_response = getattr(event, "response", None) + if isinstance(maybe_response, Response): + terminal_response = maybe_response + + if terminal_response is not None: + usage = ( + apply_retry_attempt_usage( + Usage( + requests=1, + input_tokens=terminal_response.usage.input_tokens, + output_tokens=terminal_response.usage.output_tokens, + total_tokens=terminal_response.usage.total_tokens, + input_tokens_details=terminal_response.usage.input_tokens_details, + output_tokens_details=terminal_response.usage.output_tokens_details, ), - name="tool_search_output_created", + stream_failed_retry_attempts[0], ) + if terminal_response.usage + else Usage() ) - - elif isinstance(output_item, McpListTools): - hosted_mcp_tool_metadata.update(collect_mcp_list_tools_metadata([output_item])) - - elif isinstance(output_item, TOOL_CALL_TYPES): - output_call_id: str | None = getattr( - output_item, "call_id", getattr(output_item, "id", None) + final_response = ModelResponse( + output=terminal_response.output, + usage=usage, + response_id=terminal_response.id, + request_id=getattr(terminal_response, "_request_id", None), ) - if ( - output_call_id - and isinstance(output_call_id, str) - and output_call_id not in emitted_tool_call_ids - ): - emitted_tool_call_ids.add(output_call_id) + if isinstance(event, ResponseOutputItemDoneEvent): + output_item = event.item + output_item_type = getattr(output_item, "type", None) - # Look up tool description from precomputed map ("last wins" matches - # execution behavior in process_model_response). - tool_lookup_key = get_function_tool_lookup_key_for_call(output_item) - matched_tool = ( - tool_map.get(tool_lookup_key) if tool_lookup_key is not None else None - ) - tool_description: str | None = None - tool_title: str | None = None - if isinstance(output_item, McpCall): - metadata = hosted_mcp_tool_metadata.get( - (output_item.server_label, output_item.name) + if output_item_type == "tool_search_call": + emitted_tool_search_fingerprints.add(_tool_search_fingerprint(output_item)) + streamed_result._event_queue.put_nowait( + RunItemStreamEvent( + item=ToolSearchCallItem( + raw_item=coerce_tool_search_call_raw_item(output_item), + agent=agent, + ), + name="tool_search_called", ) - if metadata is not None: - tool_description = metadata.description - tool_title = metadata.title - elif matched_tool is not None: - tool_description = getattr(matched_tool, "description", None) - tool_title = getattr(matched_tool, "_mcp_title", None) - - tool_item = ToolCallItem( - raw_item=cast(ToolCallItemTypes, output_item), - agent=agent, - description=tool_description, - title=tool_title, ) + + elif output_item_type == "tool_search_output": + emitted_tool_search_fingerprints.add(_tool_search_fingerprint(output_item)) streamed_result._event_queue.put_nowait( - RunItemStreamEvent(item=tool_item, name="tool_called") + RunItemStreamEvent( + item=ToolSearchOutputItem( + raw_item=coerce_tool_search_output_raw_item(output_item), + agent=agent, + ), + name="tool_search_output_created", + ) ) - elif isinstance(output_item, ResponseReasoningItem): - reasoning_id: str | None = getattr(output_item, "id", None) - - if reasoning_id and reasoning_id not in emitted_reasoning_item_ids: - emitted_reasoning_item_ids.add(reasoning_id) + elif isinstance(output_item, McpListTools): + hosted_mcp_tool_metadata.update(collect_mcp_list_tools_metadata([output_item])) - reasoning_item = ReasoningItem(raw_item=output_item, agent=agent) - streamed_result._event_queue.put_nowait( - RunItemStreamEvent(item=reasoning_item, name="reasoning_item_created") + elif isinstance(output_item, TOOL_CALL_TYPES): + output_call_id: str | None = getattr( + output_item, "call_id", getattr(output_item, "id", None) ) - if final_response is not None: - context_wrapper.usage.add(final_response.usage) - await asyncio.gather( - ( - agent.hooks.on_llm_end(context_wrapper, agent, final_response) - if agent.hooks - else _coro.noop_coroutine() - ), - hooks.on_llm_end(context_wrapper, agent, final_response), - ) + if ( + output_call_id + and isinstance(output_call_id, str) + and output_call_id not in emitted_tool_call_ids + ): + emitted_tool_call_ids.add(output_call_id) + + # Look up tool description from precomputed map ("last wins" matches + # execution behavior in process_model_response). + tool_lookup_key = get_function_tool_lookup_key_for_call(output_item) + matched_tool = ( + tool_map.get(tool_lookup_key) if tool_lookup_key is not None else None + ) + tool_description: str | None = None + tool_title: str | None = None + if isinstance(output_item, McpCall): + metadata = hosted_mcp_tool_metadata.get( + (output_item.server_label, output_item.name) + ) + if metadata is not None: + tool_description = metadata.description + tool_title = metadata.title + elif matched_tool is not None: + tool_description = getattr(matched_tool, "description", None) + tool_title = getattr(matched_tool, "_mcp_title", None) + + tool_item = ToolCallItem( + raw_item=cast(ToolCallItemTypes, output_item), + agent=agent, + description=tool_description, + title=tool_title, + ) + streamed_result._event_queue.put_nowait( + RunItemStreamEvent(item=tool_item, name="tool_called") + ) - if not final_response: - raise ModelBehaviorError("Model did not produce a final response!") + elif isinstance(output_item, ResponseReasoningItem): + reasoning_id: str | None = getattr(output_item, "id", None) - if server_conversation_tracker is not None: - # Streaming uses the same rewind helper, so a successful retry must restore delivered - # input tracking before the next turn computes server-managed deltas. - server_conversation_tracker.mark_input_as_sent(filtered.input) - server_conversation_tracker.track_server_items(final_response) + if reasoning_id and reasoning_id not in emitted_reasoning_item_ids: + emitted_reasoning_item_ids.add(reasoning_id) + + reasoning_item = ReasoningItem(raw_item=output_item, agent=agent) + streamed_result._event_queue.put_nowait( + RunItemStreamEvent(item=reasoning_item, name="reasoning_item_created") + ) + + if final_response is not None: + context_wrapper.usage.add(final_response.usage) + await asyncio.gather( + ( + agent.hooks.on_llm_end(context_wrapper, agent, final_response) + if agent.hooks + else _coro.noop_coroutine() + ), + hooks.on_llm_end(context_wrapper, agent, final_response), + ) + + if not final_response: + raise ModelBehaviorError("Model did not produce a final response!") + + if server_conversation_tracker is not None: + # Streaming uses the same rewind helper, so a successful retry must restore delivered + # input tracking before the next turn computes server-managed deltas. + server_conversation_tracker.mark_input_as_sent(filtered.input) + server_conversation_tracker.track_server_items(final_response) + except Exception as e: + active_response_span.set_error( + SpanError( + message="Error during streamed LLM execution", + data={"error": str(e) if model_tracing.include_data() else e.__class__.__name__}, + ) + ) + raise + finally: + active_response_span.finish(reset_current=True) + + assert final_response is not None single_step_result = await get_single_step_result_from_response( agent=agent, @@ -1612,76 +1662,96 @@ async def get_new_response( if server_conversation_tracker is not None: server_conversation_tracker.mark_input_as_sent(filtered.input) - await asyncio.gather( - hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input), - ( - agent.hooks.on_llm_start( - context_wrapper, - agent, - filtered.instructions, - filtered.input, - ) - if agent.hooks - else _coro.noop_coroutine() - ), - ) - - previous_response_id = ( - server_conversation_tracker.previous_response_id - if server_conversation_tracker - and server_conversation_tracker.previous_response_id is not None - else None - ) - conversation_id = ( - server_conversation_tracker.conversation_id if server_conversation_tracker else None + model_tracing = get_model_tracing_impl( + run_config.tracing_disabled, run_config.trace_include_sensitive_data ) - if conversation_id: - logger.debug("Using conversation_id=%s", conversation_id) - else: - logger.debug("No conversation_id available for request") + active_response_span = create_response_span(disabled=model_tracing.is_disabled()) + active_response_span.start(mark_as_current=True) - async def rewind_model_request() -> None: - items_to_rewind = session_items_to_rewind if session_items_to_rewind is not None else [] - await rewind_session_items(session, items_to_rewind, server_conversation_tracker) - if server_conversation_tracker is not None: - server_conversation_tracker.rewind_input(filtered.input) - - new_response = await get_response_with_retry( - get_response=lambda: model.get_response( - system_instructions=filtered.instructions, - input=filtered.input, - model_settings=model_settings, - tools=all_tools, - output_schema=output_schema, - handoffs=handoffs, - tracing=get_model_tracing_impl( - run_config.tracing_disabled, run_config.trace_include_sensitive_data + try: + await asyncio.gather( + hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input), + ( + agent.hooks.on_llm_start( + context_wrapper, + agent, + filtered.instructions, + filtered.input, + ) + if agent.hooks + else _coro.noop_coroutine() ), + ) + + previous_response_id = ( + server_conversation_tracker.previous_response_id + if server_conversation_tracker + and server_conversation_tracker.previous_response_id is not None + else None + ) + conversation_id = ( + server_conversation_tracker.conversation_id if server_conversation_tracker else None + ) + if conversation_id: + logger.debug("Using conversation_id=%s", conversation_id) + else: + logger.debug("No conversation_id available for request") + + async def rewind_model_request() -> None: + items_to_rewind = session_items_to_rewind if session_items_to_rewind is not None else [] + await rewind_session_items(session, items_to_rewind, server_conversation_tracker) + if server_conversation_tracker is not None: + server_conversation_tracker.rewind_input(filtered.input) + + response_request_kwargs: dict[str, Any] = { + "system_instructions": filtered.instructions, + "input": filtered.input, + "model_settings": model_settings, + "tools": all_tools, + "output_schema": output_schema, + "handoffs": handoffs, + "tracing": model_tracing, + "previous_response_id": previous_response_id, + "conversation_id": conversation_id, + "prompt": prompt_config, + } + _maybe_attach_response_span_kwarg( + model.get_response, response_request_kwargs, active_response_span + ) + + new_response = await get_response_with_retry( + get_response=lambda: model.get_response(**response_request_kwargs), + rewind=rewind_model_request, + retry_settings=model_settings.retry, + get_retry_advice=model.get_retry_advice, previous_response_id=previous_response_id, conversation_id=conversation_id, - prompt=prompt_config, - ), - rewind=rewind_model_request, - retry_settings=model_settings.retry, - get_retry_advice=model.get_retry_advice, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - ) - if server_conversation_tracker is not None: - # Retry helpers rewind sent-input tracking before replaying a failed request. Mark the - # filtered input as delivered again once a retry succeeds so subsequent turns only send - # new deltas. - server_conversation_tracker.mark_input_as_sent(filtered.input) + ) + if server_conversation_tracker is not None: + # Retry helpers rewind sent-input tracking before replaying a failed request. Mark the + # filtered input as delivered again once a retry succeeds so subsequent turns only send + # new deltas. + server_conversation_tracker.mark_input_as_sent(filtered.input) - context_wrapper.usage.add(new_response.usage) + context_wrapper.usage.add(new_response.usage) - await asyncio.gather( - ( - agent.hooks.on_llm_end(context_wrapper, agent, new_response) - if agent.hooks - else _coro.noop_coroutine() - ), - hooks.on_llm_end(context_wrapper, agent, new_response), - ) + await asyncio.gather( + ( + agent.hooks.on_llm_end(context_wrapper, agent, new_response) + if agent.hooks + else _coro.noop_coroutine() + ), + hooks.on_llm_end(context_wrapper, agent, new_response), + ) - return new_response + return new_response + except Exception as e: + active_response_span.set_error( + SpanError( + message="Error during LLM execution", + data={"error": str(e) if model_tracing.include_data() else e.__class__.__name__}, + ) + ) + raise + finally: + active_response_span.finish(reset_current=True) diff --git a/src/agents/tracing/processors.py b/src/agents/tracing/processors.py index 7132faf1c8..1599836f8a 100644 --- a/src/agents/tracing/processors.py +++ b/src/agents/tracing/processors.py @@ -203,6 +203,12 @@ def _sanitize_for_openai_tracing_api(self, payload_item: dict[str, Any]) -> dict did_mutate = True sanitized_span_data[field_name] = sanitized_field + if span_data.get("type") == "response" and "metadata" in span_data: + if not did_mutate: + sanitized_span_data = dict(span_data) + did_mutate = True + sanitized_span_data.pop("metadata", None) + if span_data.get("type") != "generation": if not did_mutate: return payload_item diff --git a/src/agents/tracing/span_data.py b/src/agents/tracing/span_data.py index cb3e8491d3..b72fe33580 100644 --- a/src/agents/tracing/span_data.py +++ b/src/agents/tracing/span_data.py @@ -139,20 +139,22 @@ def export(self) -> dict[str, Any]: class ResponseSpanData(SpanData): """ Represents a Response Span in the trace. - Includes response and input. + Includes response, input, and user-defined metadata. """ - __slots__ = ("response", "input") + __slots__ = ("response", "input", "metadata") def __init__( self, response: Response | None = None, input: str | list[ResponseInputItemParam] | None = None, + metadata: dict[str, Any] | None = None, ) -> None: self.response = response # This is not used by the OpenAI trace processors, but is useful for other tracing # processor implementations self.input = input + self.metadata = metadata or {} @property def type(self) -> str: @@ -162,6 +164,7 @@ def export(self) -> dict[str, Any]: return { "type": self.type, "response_id": self.response.id if self.response else None, + "metadata": self.metadata or None, } diff --git a/tests/fake_model.py b/tests/fake_model.py index ed44d72d04..a1b118678f 100644 --- a/tests/fake_model.py +++ b/tests/fake_model.py @@ -44,7 +44,8 @@ from agents.model_settings import ModelSettings from agents.models.interface import Model, ModelTracing from agents.tool import Tool -from agents.tracing import SpanError, generation_span +from agents.tracing import Span, SpanError, generation_span +from agents.tracing.span_data import ResponseSpanData from agents.usage import Usage @@ -91,6 +92,7 @@ async def get_response( previous_response_id: str | None, conversation_id: str | None, prompt: Any | None, + response_span: Span[ResponseSpanData] | None = None, ) -> ModelResponse: turn_args = { "system_instructions": system_instructions, @@ -143,11 +145,19 @@ async def get_response( else: converted_output.append(item) - return ModelResponse( + model_response = ModelResponse( output=converted_output, usage=self.hardcoded_usage or Usage(), response_id="resp-789", ) + if response_span is not None: + response_span.span_data.response = get_response_obj( + converted_output, + response_id=model_response.response_id, + usage=model_response.usage, + ) + response_span.span_data.input = input + return model_response async def stream_response( self, @@ -162,6 +172,7 @@ async def stream_response( previous_response_id: str | None = None, conversation_id: str | None = None, prompt: Any | None = None, + response_span: Span[ResponseSpanData] | None = None, ) -> AsyncIterator[TResponseStreamEvent]: turn_args = { "system_instructions": system_instructions, @@ -192,6 +203,9 @@ async def stream_response( raise output response = get_response_obj(output, usage=self.hardcoded_usage) + if response_span is not None: + response_span.span_data.response = response + response_span.span_data.input = input sequence_number = 0 yield ResponseCreatedEvent( diff --git a/tests/mcp/test_mcp_tracing.py b/tests/mcp/test_mcp_tracing.py index 9cb3454b1b..ac4ff7be2a 100644 --- a/tests/mcp/test_mcp_tracing.py +++ b/tests/mcp/test_mcp_tracing.py @@ -57,6 +57,10 @@ async def test_mcp_tracing(): "output_type": "str", }, "children": [ + { + "type": "response", + "data": {"response_id": "resp-789"}, + }, { "type": "function", "data": { @@ -70,6 +74,10 @@ async def test_mcp_tracing(): "type": "mcp_tools", "data": {"server": "fake_mcp_server", "result": ["test_tool_1"]}, }, + { + "type": "response", + "data": {"response_id": "resp-789"}, + }, ], }, ], @@ -120,6 +128,10 @@ async def test_mcp_tracing(): "output_type": "str", }, "children": [ + { + "type": "response", + "data": {"response_id": "resp-789"}, + }, { "type": "function", "data": { @@ -144,6 +156,10 @@ async def test_mcp_tracing(): "result": ["test_tool_1", "test_tool_2"], }, }, + { + "type": "response", + "data": {"response_id": "resp-789"}, + }, ], }, ], @@ -192,6 +208,10 @@ async def test_mcp_tracing(): "output_type": "str", }, "children": [ + { + "type": "response", + "data": {"response_id": "resp-789"}, + }, { "type": "function", "data": { @@ -208,6 +228,10 @@ async def test_mcp_tracing(): "result": ["test_tool_1", "test_tool_2", "test_tool_3"], }, }, + { + "type": "response", + "data": {"response_id": "resp-789"}, + }, ], }, ], diff --git a/tests/test_agent_prompt.py b/tests/test_agent_prompt.py index e3ed40fbe1..b9165851b4 100644 --- a/tests/test_agent_prompt.py +++ b/tests/test_agent_prompt.py @@ -31,6 +31,7 @@ async def get_response( previous_response_id, conversation_id, prompt, + response_span=None, ): # Record the prompt that the agent resolved and passed in. self.last_prompt = prompt @@ -45,6 +46,7 @@ async def get_response( previous_response_id=previous_response_id, conversation_id=conversation_id, prompt=prompt, + response_span=response_span, ) diff --git a/tests/test_agent_runner_streamed.py b/tests/test_agent_runner_streamed.py index 1c28fafbc2..e6e7b6f2d0 100644 --- a/tests/test_agent_runner_streamed.py +++ b/tests/test_agent_runner_streamed.py @@ -161,6 +161,7 @@ async def stream_response( previous_response_id=None, conversation_id=None, prompt=None, + response_span=None, ): self.last_turn_args = { "system_instructions": system_instructions, @@ -177,6 +178,9 @@ async def stream_response( response = get_response_obj( [get_text_message("partial final")], response_id="resp-partial" ) + if response_span is not None: + response_span.span_data.response = response + response_span.span_data.input = input yield terminal_event_cls( type=terminal_event_type, response=response, @@ -211,11 +215,15 @@ async def stream_response( previous_response_id=None, conversation_id=None, prompt=None, + response_span=None, ): response = get_response_obj( [get_text_message("partial final")], response_id="resp-partial" ) response._request_id = "req_streamed_result_123" + if response_span is not None: + response_span.span_data.response = response + response_span.span_data.input = input yield ResponseCompletedEvent( type="response.completed", response=response, diff --git a/tests/test_agent_tracing.py b/tests/test_agent_tracing.py index 14ab62b2b2..2ca2dca144 100644 --- a/tests/test_agent_tracing.py +++ b/tests/test_agent_tracing.py @@ -51,6 +51,7 @@ async def test_single_run_is_single_trace(): "tools": [], "output_type": "str", }, + "children": [{"type": "response", "data": {"response_id": "resp-789"}}], } ], } @@ -88,6 +89,7 @@ async def test_multiple_runs_are_multiple_traces(): "tools": [], "output_type": "str", }, + "children": [{"type": "response", "data": {"response_id": "resp-789"}}], } ], }, @@ -102,6 +104,7 @@ async def test_multiple_runs_are_multiple_traces(): "tools": [], "output_type": "str", }, + "children": [{"type": "response", "data": {"response_id": "resp-789"}}], } ], }, @@ -278,6 +281,7 @@ async def test_wrapped_trace_is_single_trace(): "tools": [], "output_type": "str", }, + "children": [{"type": "response", "data": {"response_id": "resp-789"}}], }, { "type": "agent", @@ -287,6 +291,7 @@ async def test_wrapped_trace_is_single_trace(): "tools": [], "output_type": "str", }, + "children": [{"type": "response", "data": {"response_id": "resp-789"}}], }, { "type": "agent", @@ -296,6 +301,7 @@ async def test_wrapped_trace_is_single_trace(): "tools": [], "output_type": "str", }, + "children": [{"type": "response", "data": {"response_id": "resp-789"}}], }, ], } @@ -362,6 +368,7 @@ async def test_trace_config_works(): "tools": [], "output_type": "str", }, + "children": [{"type": "response", "data": {"response_id": "resp-789"}}], } ], } @@ -399,6 +406,7 @@ async def test_not_starting_streaming_creates_trace(): "tools": [], "output_type": "str", }, + "children": [{"type": "response", "data": {"response_id": "resp-789"}}], } ], } @@ -436,6 +444,7 @@ async def test_streaming_single_run_is_single_trace(): "tools": [], "output_type": "str", }, + "children": [{"type": "response", "data": {"response_id": "resp-789"}}], } ], } @@ -478,6 +487,7 @@ async def test_multiple_streamed_runs_are_multiple_traces(): "tools": [], "output_type": "str", }, + "children": [{"type": "response", "data": {"response_id": "resp-789"}}], } ], }, @@ -492,6 +502,7 @@ async def test_multiple_streamed_runs_are_multiple_traces(): "tools": [], "output_type": "str", }, + "children": [{"type": "response", "data": {"response_id": "resp-789"}}], } ], }, @@ -571,6 +582,7 @@ async def test_wrapped_streaming_trace_is_single_trace(): "tools": [], "output_type": "str", }, + "children": [{"type": "response", "data": {"response_id": "resp-789"}}], }, { "type": "agent", @@ -580,6 +592,7 @@ async def test_wrapped_streaming_trace_is_single_trace(): "tools": [], "output_type": "str", }, + "children": [{"type": "response", "data": {"response_id": "resp-789"}}], }, { "type": "agent", @@ -589,6 +602,7 @@ async def test_wrapped_streaming_trace_is_single_trace(): "tools": [], "output_type": "str", }, + "children": [{"type": "response", "data": {"response_id": "resp-789"}}], }, ], } @@ -635,6 +649,7 @@ async def test_wrapped_mixed_trace_is_single_trace(): "tools": [], "output_type": "str", }, + "children": [{"type": "response", "data": {"response_id": "resp-789"}}], }, { "type": "agent", @@ -644,6 +659,7 @@ async def test_wrapped_mixed_trace_is_single_trace(): "tools": [], "output_type": "str", }, + "children": [{"type": "response", "data": {"response_id": "resp-789"}}], }, { "type": "agent", @@ -653,6 +669,7 @@ async def test_wrapped_mixed_trace_is_single_trace(): "tools": [], "output_type": "str", }, + "children": [{"type": "response", "data": {"response_id": "resp-789"}}], }, ], } diff --git a/tests/test_cancel_streaming.py b/tests/test_cancel_streaming.py index fc697728e1..07dddac71d 100644 --- a/tests/test_cancel_streaming.py +++ b/tests/test_cancel_streaming.py @@ -196,6 +196,7 @@ async def stream_response( previous_response_id=None, conversation_id=None, prompt=None, + response_span=None, ): await block_event.wait() async for event in super().stream_response( @@ -209,6 +210,7 @@ async def stream_response( previous_response_id=previous_response_id, conversation_id=conversation_id, prompt=prompt, + response_span=response_span, ): yield event diff --git a/tests/test_hitl_session_scenario.py b/tests/test_hitl_session_scenario.py index c7b3ab579d..c001e9a58f 100644 --- a/tests/test_hitl_session_scenario.py +++ b/tests/test_hitl_session_scenario.py @@ -86,7 +86,9 @@ async def get_response( previous_response_id: str | None, conversation_id: str | None, prompt: Any | None, + response_span: Any | None = None, ) -> ModelResponse: + del response_span if input_has_rejection(input): return ModelResponse( output=[get_text_message(HITL_REJECTION_MSG)], @@ -119,7 +121,9 @@ async def stream_response( previous_response_id: str | None, conversation_id: str | None, prompt: Any | None, + response_span: Any | None = None, ) -> AsyncIterator[TResponseStreamEvent]: + del response_span if False: yield cast(TResponseStreamEvent, {}) raise RuntimeError("Streaming is not supported in this scenario.") diff --git a/tests/test_llm_hook_tracing.py b/tests/test_llm_hook_tracing.py new file mode 100644 index 0000000000..ecfb2a0d03 --- /dev/null +++ b/tests/test_llm_hook_tracing.py @@ -0,0 +1,298 @@ +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import Any, cast + +import pytest + +from agents import Agent, AgentHooks, RunHooks, Runner, get_current_span +from agents.agent_output import AgentOutputSchemaBase +from agents.handoffs import Handoff +from agents.items import ModelResponse, TResponseInputItem, TResponseStreamEvent +from agents.model_settings import ModelSettings +from agents.models.interface import ModelTracing +from agents.run_context import RunContextWrapper, TContext +from agents.tool import Tool +from agents.tracing.span_data import ResponseSpanData + +from .fake_model import FakeModel +from .test_responses import get_function_tool, get_function_tool_call, get_text_message +from .testing_processor import fetch_normalized_spans, fetch_ordered_spans + + +class SpanAwareRunHooks(RunHooks): + def __init__(self) -> None: + self.start_span_types: list[str | None] = [] + self.end_span_types: list[str | None] = [] + + async def on_llm_start( + self, + context: RunContextWrapper[TContext], + agent: Agent[TContext], + system_prompt: str | None, + input_items: list[TResponseInputItem], + ) -> None: + current_span = get_current_span() + self.start_span_types.append(current_span.span_data.type if current_span else None) + if current_span is not None and isinstance(current_span.span_data, ResponseSpanData): + current_span.span_data.metadata["run_hook_start_agent"] = agent.name + + async def on_llm_end( + self, + context: RunContextWrapper[TContext], + agent: Agent[TContext], + response: ModelResponse, + ) -> None: + current_span = get_current_span() + self.end_span_types.append(current_span.span_data.type if current_span else None) + if current_span is not None and isinstance(current_span.span_data, ResponseSpanData): + current_span.span_data.metadata["run_hook_end_response_id"] = response.response_id + + +class SpanAwareAgentHooks(AgentHooks): + def __init__(self) -> None: + self.start_span_types: list[str | None] = [] + self.end_span_types: list[str | None] = [] + + async def on_llm_start( + self, + context: RunContextWrapper[TContext], + agent: Agent[TContext], + system_prompt: str | None, + input_items: list[TResponseInputItem], + ) -> None: + current_span = get_current_span() + self.start_span_types.append(current_span.span_data.type if current_span else None) + if current_span is not None and isinstance(current_span.span_data, ResponseSpanData): + current_span.span_data.metadata["agent_hook_start_agent"] = agent.name + + async def on_llm_end( + self, + context: RunContextWrapper[TContext], + agent: Agent[TContext], + response: ModelResponse, + ) -> None: + current_span = get_current_span() + self.end_span_types.append(current_span.span_data.type if current_span else None) + if current_span is not None and isinstance(current_span.span_data, ResponseSpanData): + current_span.span_data.metadata["agent_hook_end_response_id"] = response.response_id + + +def _find_response_spans() -> list[ResponseSpanData]: + return [ + span.span_data + for span in fetch_ordered_spans() + if isinstance(span.span_data, ResponseSpanData) + ] + + +def _find_exported_response_spans() -> list[dict[str, Any]]: + exported_spans: list[dict[str, Any]] = [] + + def _walk(node: dict[str, Any]) -> None: + if node.get("type") == "response": + exported_spans.append(node) + for child in node.get("children", []): + _walk(child) + + for trace in fetch_normalized_spans(): + _walk(trace) + + return exported_spans + + +def _make_legacy_signature_model() -> FakeModel: + model = FakeModel() + original_get_response = model.get_response + original_stream_response = model.stream_response + + async def legacy_get_response( + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: str | None, + conversation_id: str | None, + prompt: Any | None, + ) -> ModelResponse: + return await original_get_response( + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=tools, + output_schema=output_schema, + handoffs=handoffs, + tracing=tracing, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + prompt=prompt, + ) + + async def legacy_stream_response( + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: str | None = None, + conversation_id: str | None = None, + prompt: Any | None = None, + ) -> AsyncIterator[TResponseStreamEvent]: + async for event in original_stream_response( + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=tools, + output_schema=output_schema, + handoffs=handoffs, + tracing=tracing, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + prompt=prompt, + ): + yield event + + legacy_model = cast(Any, model) + legacy_model.get_response = legacy_get_response + legacy_model.stream_response = legacy_stream_response + return model + + +@pytest.mark.asyncio +async def test_non_streamed_llm_hooks_can_mutate_response_span_metadata() -> None: + run_hooks = SpanAwareRunHooks() + agent_hooks = SpanAwareAgentHooks() + model = FakeModel() + model.set_next_output([get_text_message("hello")]) + agent = Agent(name="hooked-agent", model=model, hooks=agent_hooks) + + await Runner.run(agent, input="hello", hooks=run_hooks) + + assert run_hooks.start_span_types == ["response"] + assert run_hooks.end_span_types == ["response"] + assert agent_hooks.start_span_types == ["response"] + assert agent_hooks.end_span_types == ["response"] + + response_spans = _find_response_spans() + assert len(response_spans) == 1 + assert response_spans[0].metadata == { + "run_hook_start_agent": "hooked-agent", + "run_hook_end_response_id": "resp-789", + "agent_hook_start_agent": "hooked-agent", + "agent_hook_end_response_id": "resp-789", + } + assert _find_exported_response_spans() == [ + { + "type": "response", + "data": { + "response_id": "resp-789", + "metadata": { + "run_hook_start_agent": "hooked-agent", + "run_hook_end_response_id": "resp-789", + "agent_hook_start_agent": "hooked-agent", + "agent_hook_end_response_id": "resp-789", + }, + }, + } + ] + + +@pytest.mark.asyncio +async def test_streamed_llm_hooks_can_mutate_response_span_metadata() -> None: + run_hooks = SpanAwareRunHooks() + agent_hooks = SpanAwareAgentHooks() + model = FakeModel() + model.set_next_output([get_text_message("hello")]) + agent = Agent(name="streamed-agent", model=model, hooks=agent_hooks) + + result = Runner.run_streamed(agent, input="hello", hooks=run_hooks) + async for _ in result.stream_events(): + pass + + assert run_hooks.start_span_types == ["response"] + assert run_hooks.end_span_types == ["response"] + assert agent_hooks.start_span_types == ["response"] + assert agent_hooks.end_span_types == ["response"] + + response_spans = _find_response_spans() + assert len(response_spans) == 1 + assert response_spans[0].metadata == { + "run_hook_start_agent": "streamed-agent", + "run_hook_end_response_id": "resp-789", + "agent_hook_start_agent": "streamed-agent", + "agent_hook_end_response_id": "resp-789", + } + assert _find_exported_response_spans() == [ + { + "type": "response", + "data": { + "response_id": "resp-789", + "metadata": { + "run_hook_start_agent": "streamed-agent", + "run_hook_end_response_id": "resp-789", + "agent_hook_start_agent": "streamed-agent", + "agent_hook_end_response_id": "resp-789", + }, + }, + } + ] + + +@pytest.mark.asyncio +async def test_runner_accepts_legacy_models_without_response_span_kwarg() -> None: + run_hooks = SpanAwareRunHooks() + model = _make_legacy_signature_model() + model.set_next_output([get_text_message("legacy-ok")]) + agent = Agent(name="legacy-agent", model=model) + + result = await Runner.run(agent, input="hello", hooks=run_hooks) + + assert result.final_output == "legacy-ok" + assert run_hooks.start_span_types == ["response"] + assert run_hooks.end_span_types == ["response"] + + +@pytest.mark.asyncio +async def test_runner_streamed_accepts_legacy_models_without_response_span_kwarg() -> None: + run_hooks = SpanAwareRunHooks() + model = _make_legacy_signature_model() + model.set_next_output([get_text_message("legacy-stream-ok")]) + agent = Agent(name="legacy-stream-agent", model=model) + + result = Runner.run_streamed(agent, input="hello", hooks=run_hooks) + async for _ in result.stream_events(): + pass + + assert result.final_output == "legacy-stream-ok" + assert run_hooks.start_span_types == ["response"] + assert run_hooks.end_span_types == ["response"] + + +@pytest.mark.asyncio +async def test_streamed_tool_spans_are_not_nested_under_response_spans() -> None: + model = FakeModel(tracing_enabled=True) + model.add_multiple_turn_outputs( + [ + [get_text_message("a_message"), get_function_tool_call("foo", '{"a": "b"}')], + [get_text_message("done")], + ] + ) + agent = Agent( + name="streamed-tool-agent", + model=model, + tools=[get_function_tool("foo", "tool_result")], + ) + + result = Runner.run_streamed(agent, input="hello") + async for _ in result.stream_events(): + pass + + agent_children = fetch_normalized_spans()[0]["children"][0]["children"] + assert [child["type"] for child in agent_children] == ["response", "function", "response"] diff --git a/tests/test_run_state.py b/tests/test_run_state.py index 56cd61fab2..7a9bf8bf8c 100644 --- a/tests/test_run_state.py +++ b/tests/test_run_state.py @@ -2046,6 +2046,7 @@ async def get_response( previous_response_id: str | None, conversation_id: str | None, prompt: Any | None, + response_span: Any | None = None, ) -> ModelResponse: del ( system_instructions, @@ -2057,6 +2058,7 @@ async def get_response( previous_response_id, conversation_id, prompt, + response_span, ) if _has_function_call_output(input): return ModelResponse( @@ -2092,6 +2094,7 @@ async def stream_response( previous_response_id: str | None, conversation_id: str | None, prompt: Any | None, + response_span: Any | None = None, ) -> AsyncIterator[TResponseStreamEvent]: del ( system_instructions, @@ -2104,6 +2107,7 @@ async def stream_response( previous_response_id, conversation_id, prompt, + response_span, ) if False: yield cast(TResponseStreamEvent, {}) diff --git a/tests/test_streaming_tool_call_arguments.py b/tests/test_streaming_tool_call_arguments.py index ce476e59b1..37959ff46f 100644 --- a/tests/test_streaming_tool_call_arguments.py +++ b/tests/test_streaming_tool_call_arguments.py @@ -59,7 +59,9 @@ async def get_response( previous_response_id: Optional[str], conversation_id: Optional[str], prompt: Optional[Any], + response_span: Optional[Any] = None, ): + del response_span raise NotImplementedError("Use stream_response instead") async def stream_response( @@ -75,6 +77,7 @@ async def stream_response( previous_response_id: Optional[str] = None, conversation_id: Optional[str] = None, prompt: Optional[Any] = None, + response_span: Optional[Any] = None, ) -> AsyncIterator[TResponseStreamEvent]: """Stream events that simulate real OpenAI streaming behavior for tool calls.""" self.last_turn_args = { @@ -123,9 +126,13 @@ async def stream_response( sequence_number += 1 # Finally: emit completion + response = get_response_obj(output) + if response_span is not None: + response_span.span_data.response = response + response_span.span_data.input = input yield ResponseCompletedEvent( type="response.completed", - response=get_response_obj(output), + response=response, sequence_number=sequence_number, ) diff --git a/tests/test_trace_processor.py b/tests/test_trace_processor.py index ad061d7995..f740566d1a 100644 --- a/tests/test_trace_processor.py +++ b/tests/test_trace_processor.py @@ -488,6 +488,26 @@ def test_sanitize_for_openai_tracing_api_keeps_allowed_generation_usage(): exporter.close() +def test_sanitize_for_openai_tracing_api_drops_response_metadata(): + exporter = BackendSpanExporter(api_key="test_key") + payload: dict[str, Any] = { + "object": "trace.span", + "span_data": { + "type": "response", + "response_id": "resp_123", + "metadata": {"hook": "value"}, + }, + } + + sanitized = exporter._sanitize_for_openai_tracing_api(payload) + assert sanitized["span_data"] == { + "type": "response", + "response_id": "resp_123", + } + assert payload["span_data"]["metadata"] == {"hook": "value"} + exporter.close() + + @patch("httpx.Client") def test_backend_span_exporter_keeps_large_input_for_custom_endpoint(mock_client): class DummyItem: diff --git a/tests/test_tracing_errors.py b/tests/test_tracing_errors.py index 6149afc79f..69f2e7fd77 100644 --- a/tests/test_tracing_errors.py +++ b/tests/test_tracing_errors.py @@ -56,11 +56,20 @@ async def test_single_turn_model_error(): }, "children": [ { - "type": "generation", + "type": "response", "error": { - "message": "Error", - "data": {"name": "ValueError", "message": "test error"}, + "message": "Error during LLM execution", + "data": {"error": "test error"}, }, + "children": [ + { + "type": "generation", + "error": { + "message": "Error", + "data": {"name": "ValueError", "message": "test error"}, + }, + } + ], } ], } @@ -108,7 +117,11 @@ async def test_multi_turn_no_handoffs(): "output_type": "str", }, "children": [ - {"type": "generation"}, + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + }, { "type": "function", "data": { @@ -118,11 +131,20 @@ async def test_multi_turn_no_handoffs(): }, }, { - "type": "generation", + "type": "response", "error": { - "message": "Error", - "data": {"name": "ValueError", "message": "test error"}, + "message": "Error during LLM execution", + "data": {"error": "test error"}, }, + "children": [ + { + "type": "generation", + "error": { + "message": "Error", + "data": {"name": "ValueError", "message": "test error"}, + }, + } + ], }, ], } @@ -170,7 +192,11 @@ async def test_tool_call_error(): "output_type": "str", }, "children": [ - {"type": "generation"}, + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + }, { "type": "function", "error": { @@ -190,7 +216,11 @@ async def test_tool_call_error(): ), }, }, - {"type": "generation"}, + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + }, ], } ], @@ -249,7 +279,11 @@ async def test_multiple_handoff_doesnt_error(): "output_type": "str", }, "children": [ - {"type": "generation"}, + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + }, { "type": "function", "data": { @@ -258,7 +292,11 @@ async def test_multiple_handoff_doesnt_error(): "output": "result", }, }, - {"type": "generation"}, + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + }, { "type": "handoff", "data": {"from_agent": "test", "to_agent": "test"}, @@ -277,7 +315,13 @@ async def test_multiple_handoff_doesnt_error(): { "type": "agent", "data": {"name": "test", "handoffs": [], "tools": [], "output_type": "str"}, - "children": [{"type": "generation"}], + "children": [ + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + } + ], }, ], } @@ -317,7 +361,13 @@ async def test_multiple_final_output_doesnt_error(): { "type": "agent", "data": {"name": "test", "handoffs": [], "tools": [], "output_type": "Foo"}, - "children": [{"type": "generation"}], + "children": [ + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + } + ], } ], } @@ -387,7 +437,11 @@ async def test_handoffs_lead_to_correct_agent_spans(): "output_type": "str", }, "children": [ - {"type": "generation"}, + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + }, { "type": "function", "data": { @@ -396,7 +450,11 @@ async def test_handoffs_lead_to_correct_agent_spans(): "output": "result", }, }, - {"type": "generation"}, + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + }, { "type": "handoff", "data": {"from_agent": "test_agent_3", "to_agent": "test_agent_1"}, @@ -421,7 +479,11 @@ async def test_handoffs_lead_to_correct_agent_spans(): "output_type": "str", }, "children": [ - {"type": "generation"}, + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + }, { "type": "function", "data": { @@ -430,7 +492,11 @@ async def test_handoffs_lead_to_correct_agent_spans(): "output": "result", }, }, - {"type": "generation"}, + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + }, { "type": "handoff", "data": {"from_agent": "test_agent_1", "to_agent": "test_agent_3"}, @@ -445,7 +511,13 @@ async def test_handoffs_lead_to_correct_agent_spans(): "tools": ["some_function"], "output_type": "str", }, - "children": [{"type": "generation"}], + "children": [ + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + } + ], }, ], } @@ -492,12 +564,20 @@ async def test_max_turns_exceeded(): "output_type": "Foo", }, "children": [ - {"type": "generation"}, + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + }, { "type": "function", "data": {"name": "foo", "input": "", "output": "result"}, }, - {"type": "generation"}, + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + }, { "type": "function", "data": {"name": "foo", "input": "", "output": "result"}, diff --git a/tests/test_tracing_errors_streamed.py b/tests/test_tracing_errors_streamed.py index 35055d2ad0..9f0210b44f 100644 --- a/tests/test_tracing_errors_streamed.py +++ b/tests/test_tracing_errors_streamed.py @@ -80,11 +80,20 @@ async def test_single_turn_model_error(): }, "children": [ { - "type": "generation", + "type": "response", "error": { - "message": "Error", - "data": {"name": "ValueError", "message": "test error"}, + "message": "Error during streamed LLM execution", + "data": {"error": "test error"}, }, + "children": [ + { + "type": "generation", + "error": { + "message": "Error", + "data": {"name": "ValueError", "message": "test error"}, + }, + } + ], } ], } @@ -135,7 +144,11 @@ async def test_multi_turn_no_handoffs(): "output_type": "str", }, "children": [ - {"type": "generation"}, + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + }, { "type": "function", "data": { @@ -145,11 +158,20 @@ async def test_multi_turn_no_handoffs(): }, }, { - "type": "generation", + "type": "response", "error": { - "message": "Error", - "data": {"name": "ValueError", "message": "test error"}, + "message": "Error during streamed LLM execution", + "data": {"error": "test error"}, }, + "children": [ + { + "type": "generation", + "error": { + "message": "Error", + "data": {"name": "ValueError", "message": "test error"}, + }, + } + ], }, ], } @@ -199,7 +221,11 @@ async def test_tool_call_error(): "output_type": "str", }, "children": [ - {"type": "generation"}, + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + }, { "type": "function", "error": { @@ -214,12 +240,16 @@ async def test_tool_call_error(): "input": "bad_json", "output": ( "An error occurred while parsing tool arguments. " - "Please try again with valid JSON. Error: Expecting " - "value: line 1 column 1 (char 0)" + "Please try again with valid JSON. Error: " + "Expecting value: line 1 column 1 (char 0)" ), }, }, - {"type": "generation"}, + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + }, ], } ], @@ -281,7 +311,11 @@ async def test_multiple_handoff_doesnt_error(): "output_type": "str", }, "children": [ - {"type": "generation"}, + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + }, { "type": "function", "data": { @@ -290,21 +324,31 @@ async def test_multiple_handoff_doesnt_error(): "output": "result", }, }, - {"type": "generation"}, + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + }, { "type": "handoff", - "data": {"from_agent": "test", "to_agent": "test"}, "error": { - "data": {"requested_agents": ["test", "test"]}, "message": "Multiple handoffs requested", + "data": {"requested_agents": ["test", "test"]}, }, + "data": {"from_agent": "test", "to_agent": "test"}, }, ], }, { "type": "agent", "data": {"name": "test", "handoffs": [], "tools": [], "output_type": "str"}, - "children": [{"type": "generation"}], + "children": [ + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + } + ], }, ], } @@ -348,7 +392,13 @@ async def test_multiple_final_output_no_error(): { "type": "agent", "data": {"name": "test", "handoffs": [], "tools": [], "output_type": "Foo"}, - "children": [{"type": "generation"}], + "children": [ + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + } + ], } ], } @@ -420,7 +470,11 @@ async def test_handoffs_lead_to_correct_agent_spans(): "output_type": "str", }, "children": [ - {"type": "generation"}, + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + }, { "type": "function", "data": { @@ -429,7 +483,11 @@ async def test_handoffs_lead_to_correct_agent_spans(): "output": "result", }, }, - {"type": "generation"}, + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + }, { "type": "handoff", "error": { @@ -449,7 +507,11 @@ async def test_handoffs_lead_to_correct_agent_spans(): "output_type": "str", }, "children": [ - {"type": "generation"}, + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + }, { "type": "function", "data": { @@ -458,7 +520,11 @@ async def test_handoffs_lead_to_correct_agent_spans(): "output": "result", }, }, - {"type": "generation"}, + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + }, { "type": "handoff", "data": {"from_agent": "test_agent_1", "to_agent": "test_agent_3"}, @@ -473,7 +539,13 @@ async def test_handoffs_lead_to_correct_agent_spans(): "tools": ["some_function"], "output_type": "str", }, - "children": [{"type": "generation"}], + "children": [ + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + } + ], }, ], } @@ -522,12 +594,20 @@ async def test_max_turns_exceeded(): "output_type": "Foo", }, "children": [ - {"type": "generation"}, + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + }, { "type": "function", "data": {"name": "foo", "input": "", "output": "result"}, }, - {"type": "generation"}, + { + "type": "response", + "data": {"response_id": "resp-789"}, + "children": [{"type": "generation"}], + }, { "type": "function", "data": {"name": "foo", "input": "", "output": "result"}, @@ -584,7 +664,8 @@ async def test_input_guardrail_error(): { "type": "guardrail", "data": {"name": "input_guardrail_function", "triggered": True}, - } + }, + {"type": "response", "data": {"response_id": "resp-789"}}, ], } ], @@ -631,10 +712,11 @@ async def test_output_guardrail_error(): }, "data": {"name": "test", "handoffs": [], "tools": [], "output_type": "str"}, "children": [ + {"type": "response", "data": {"response_id": "resp-789"}}, { "type": "guardrail", "data": {"name": "output_guardrail_function", "triggered": True}, - } + }, ], } ], diff --git a/tests/voice/test_workflow.py b/tests/voice/test_workflow.py index 402c521280..06ecb139d7 100644 --- a/tests/voice/test_workflow.py +++ b/tests/voice/test_workflow.py @@ -57,7 +57,9 @@ async def get_response( previous_response_id: str | None, conversation_id: str | None, prompt: Any | None, + response_span: Any | None = None, ) -> ModelResponse: + del response_span raise NotImplementedError("Not implemented") async def stream_response( @@ -73,8 +75,13 @@ async def stream_response( previous_response_id: str | None, conversation_id: str | None, prompt: Any | None, + response_span: Any | None = None, ) -> AsyncIterator[TResponseStreamEvent]: output = self.get_next_output() + response = get_response_obj(output) + if response_span is not None: + response_span.span_data.response = response + response_span.span_data.input = input for item in output: if ( item.type == "message" @@ -93,7 +100,7 @@ async def stream_response( yield ResponseCompletedEvent( type="response.completed", - response=get_response_obj(output), + response=response, sequence_number=1, ) diff --git a/uv.lock b/uv.lock index f4e4b0e36f..fe9b1a28b4 100644 --- a/uv.lock +++ b/uv.lock @@ -1992,7 +1992,7 @@ requires-dist = [ { name = "graphviz", marker = "extra == 'viz'", specifier = ">=0.17" }, { name = "griffe", specifier = ">=1.5.6,<2" }, { name = "grpcio", marker = "extra == 'dapr'", specifier = ">=1.60.0" }, - { name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.81.0,<2" }, + { name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.81.0,<=1.82.6" }, { name = "mcp", marker = "python_full_version >= '3.10'", specifier = ">=1.19.0,<2" }, { name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" }, { name = "openai", specifier = ">=2.26.0,<3" },