Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 122 additions & 75 deletions src/agents/extensions/models/any_llm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
)
from ...retry import ModelRetryAdvice, ModelRetryAdviceRequest
from ...tool import Tool
from ...tracing import generation_span, response_span
from ...tracing.span_data import GenerationSpanData
from ...tracing import SpanError, generation_span, response_span as create_response_span
from ...tracing.span_data import GenerationSpanData, ResponseSpanData
from ...tracing.spans import Span
from ...usage import Usage
from ...util._json import _to_dump_compatible
Expand Down Expand Up @@ -260,6 +260,7 @@ async def get_response(
previous_response_id: str | None = None,
conversation_id: str | None = None,
prompt: ResponsePromptParam | None = None,
response_span: Span[ResponseSpanData] | None = None,
) -> ModelResponse:
if self._selected_api() == "responses":
return await self._get_response_via_responses(
Expand All @@ -273,6 +274,7 @@ async def get_response(
previous_response_id=previous_response_id,
conversation_id=conversation_id,
prompt=prompt,
response_span=response_span,
)

return await self._get_response_via_chat(
Expand All @@ -298,6 +300,7 @@ async def stream_response(
previous_response_id: str | None = None,
conversation_id: str | None = None,
prompt: ResponsePromptParam | None = None,
response_span: Span[ResponseSpanData] | None = None,
) -> AsyncIterator[TResponseStreamEvent]:
if self._selected_api() == "responses":
async for chunk in self._stream_response_via_responses(
Expand All @@ -311,6 +314,7 @@ async def stream_response(
previous_response_id=previous_response_id,
conversation_id=conversation_id,
prompt=prompt,
response_span=response_span,
):
yield chunk
return
Expand Down Expand Up @@ -340,56 +344,76 @@ async def _get_response_via_responses(
previous_response_id: str | None,
conversation_id: str | None,
prompt: ResponsePromptParam | None,
response_span: Span[ResponseSpanData] | None,
) -> ModelResponse:
with response_span(disabled=tracing.is_disabled()) as span_response:
response = await self._fetch_responses_response(
system_instructions=system_instructions,
input=input,
model_settings=model_settings,
tools=tools,
output_schema=output_schema,
handoffs=handoffs,
previous_response_id=previous_response_id,
conversation_id=conversation_id,
stream=False,
prompt=prompt,
)

if _debug.DONT_LOG_MODEL_DATA:
logger.debug("LLM responded")
else:
logger.debug(
"LLM resp:\n%s\n",
json.dumps(
[item.model_dump() for item in response.output],
indent=2,
ensure_ascii=False,
),
span_response = response_span or create_response_span(disabled=tracing.is_disabled())
owns_response_span = response_span is None
if owns_response_span:
span_response.start(mark_as_current=True)
try:
try:
response = await self._fetch_responses_response(
system_instructions=system_instructions,
input=input,
model_settings=model_settings,
tools=tools,
output_schema=output_schema,
handoffs=handoffs,
previous_response_id=previous_response_id,
conversation_id=conversation_id,
stream=False,
prompt=prompt,
)

usage = (
Usage(
requests=1,
input_tokens=response.usage.input_tokens,
output_tokens=response.usage.output_tokens,
total_tokens=response.usage.total_tokens,
input_tokens_details=response.usage.input_tokens_details,
output_tokens_details=response.usage.output_tokens_details,
if _debug.DONT_LOG_MODEL_DATA:
logger.debug("LLM responded")
else:
logger.debug(
"LLM resp:\n%s\n",
json.dumps(
[item.model_dump() for item in response.output],
indent=2,
ensure_ascii=False,
),
)

usage = (
Usage(
requests=1,
input_tokens=response.usage.input_tokens,
output_tokens=response.usage.output_tokens,
total_tokens=response.usage.total_tokens,
input_tokens_details=response.usage.input_tokens_details,
output_tokens_details=response.usage.output_tokens_details,
)
if response.usage
else Usage()
)
if response.usage
else Usage()
)

if tracing.include_data():
span_response.span_data.response = response
span_response.span_data.input = input

return ModelResponse(
output=response.output,
usage=usage,
response_id=response.id,
request_id=getattr(response, "_request_id", None),
)
if tracing.include_data():
span_response.span_data.response = response
span_response.span_data.input = input

return ModelResponse(
output=response.output,
usage=usage,
response_id=response.id,
request_id=getattr(response, "_request_id", None),
)
except Exception as e:
if owns_response_span:
span_response.set_error(
SpanError(
message="Error getting response",
data={
"error": str(e) if tracing.include_data() else e.__class__.__name__,
},
)
)
raise
finally:
if owns_response_span:
span_response.finish(reset_current=True)

async def _stream_response_via_responses(
self,
Expand All @@ -404,37 +428,60 @@ async def _stream_response_via_responses(
previous_response_id: str | None,
conversation_id: str | None,
prompt: ResponsePromptParam | None,
response_span: Span[ResponseSpanData] | None,
) -> AsyncIterator[ResponseStreamEvent]:
with response_span(disabled=tracing.is_disabled()) as span_response:
stream = await self._fetch_responses_response(
system_instructions=system_instructions,
input=input,
model_settings=model_settings,
tools=tools,
output_schema=output_schema,
handoffs=handoffs,
previous_response_id=previous_response_id,
conversation_id=conversation_id,
stream=True,
prompt=prompt,
)

final_response: Response | None = None
span_response = response_span or create_response_span(disabled=tracing.is_disabled())
owns_response_span = response_span is None
if owns_response_span:
span_response.start(mark_as_current=True)
try:
try:
async for chunk in stream:
if isinstance(chunk, ResponseCompletedEvent):
final_response = chunk.response
elif getattr(chunk, "type", None) in {"response.failed", "response.incomplete"}:
terminal_response = getattr(chunk, "response", None)
if isinstance(terminal_response, Response):
final_response = terminal_response
yield chunk
finally:
await self._maybe_aclose(stream)
stream = await self._fetch_responses_response(
system_instructions=system_instructions,
input=input,
model_settings=model_settings,
tools=tools,
output_schema=output_schema,
handoffs=handoffs,
previous_response_id=previous_response_id,
conversation_id=conversation_id,
stream=True,
prompt=prompt,
)

if tracing.include_data() and final_response:
span_response.span_data.response = final_response
span_response.span_data.input = input
final_response: Response | None = None
try:
async for chunk in stream:
if isinstance(chunk, ResponseCompletedEvent):
final_response = chunk.response
elif getattr(chunk, "type", None) in {
"response.failed",
"response.incomplete",
}:
terminal_response = getattr(chunk, "response", None)
if isinstance(terminal_response, Response):
final_response = terminal_response
yield chunk
finally:
await self._maybe_aclose(stream)

if tracing.include_data() and final_response:
span_response.span_data.response = final_response
span_response.span_data.input = input
except Exception as e:
if owns_response_span:
span_response.set_error(
SpanError(
message="Error streaming response",
data={
"error": str(e) if tracing.include_data() else e.__class__.__name__,
},
)
)
raise
finally:
if owns_response_span:
span_response.finish(reset_current=True)

async def _get_response_via_chat(
self,
Expand Down
6 changes: 5 additions & 1 deletion src/agents/extensions/models/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from ...retry import ModelRetryAdvice, ModelRetryAdviceRequest
from ...tool import Tool
from ...tracing import generation_span
from ...tracing.span_data import GenerationSpanData
from ...tracing.span_data import GenerationSpanData, ResponseSpanData
from ...tracing.spans import Span
from ...usage import Usage
from ...util._json import _to_dump_compatible
Expand Down Expand Up @@ -171,7 +171,9 @@ async def get_response(
previous_response_id: str | None = None, # unused
conversation_id: str | None = None, # unused
prompt: Any | None = None,
response_span: Span[ResponseSpanData] | None = None,
) -> ModelResponse:
del response_span
with generation_span(
model=str(self.model),
model_config=model_settings.to_json_dict()
Expand Down Expand Up @@ -285,7 +287,9 @@ async def stream_response(
previous_response_id: str | None = None, # unused
conversation_id: str | None = None, # unused
prompt: Any | None = None,
response_span: Span[ResponseSpanData] | None = None,
) -> AsyncIterator[TResponseStreamEvent]:
del response_span
with generation_span(
model=str(self.model),
model_config=model_settings.to_json_dict()
Expand Down
4 changes: 4 additions & 0 deletions src/agents/models/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
if TYPE_CHECKING:
from ..model_settings import ModelSettings
from ..retry import ModelRetryAdvice, ModelRetryAdviceRequest
from ..tracing import Span
from ..tracing.span_data import ResponseSpanData


class ModelTracing(enum.Enum):
Expand Down Expand Up @@ -67,6 +69,7 @@ async def get_response(
previous_response_id: str | None,
conversation_id: str | None,
prompt: ResponsePromptParam | None,
response_span: Span[ResponseSpanData] | None = None,
) -> ModelResponse:
"""Get a response from the model.

Expand Down Expand Up @@ -102,6 +105,7 @@ def stream_response(
previous_response_id: str | None,
conversation_id: str | None,
prompt: ResponsePromptParam | None,
response_span: Span[ResponseSpanData] | None = None,
) -> AsyncIterator[TResponseStreamEvent]:
"""Stream a response from the model.

Expand Down
6 changes: 5 additions & 1 deletion src/agents/models/openai_chatcompletions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ..retry import ModelRetryAdvice, ModelRetryAdviceRequest
from ..tool import Tool
from ..tracing import generation_span
from ..tracing.span_data import GenerationSpanData
from ..tracing.span_data import GenerationSpanData, ResponseSpanData
from ..tracing.spans import Span
from ..usage import Usage
from ..util._json import _to_dump_compatible
Expand Down Expand Up @@ -108,7 +108,9 @@ async def get_response(
previous_response_id: str | None = None, # unused
conversation_id: str | None = None, # unused
prompt: ResponsePromptParam | None = None,
response_span: Span[ResponseSpanData] | None = None,
) -> ModelResponse:
del response_span
with generation_span(
model=str(self.model),
model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)},
Expand Down Expand Up @@ -221,10 +223,12 @@ async def stream_response(
previous_response_id: str | None = None, # unused
conversation_id: str | None = None, # unused
prompt: ResponsePromptParam | None = None,
response_span: Span[ResponseSpanData] | None = None,
) -> AsyncIterator[TResponseStreamEvent]:
"""
Yields a partial message as it is generated, as well as the usage information.
"""
del response_span
with generation_span(
model=str(self.model),
model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)},
Expand Down
Loading
Loading