Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 224 additions & 2 deletions src/agents/models/chatcmpl_stream_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@

from collections.abc import AsyncIterator, Iterator
from dataclasses import dataclass, field
from typing import Any
from typing import Any, cast

from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import (
Choice,
ChoiceDelta,
ChoiceDeltaToolCall,
ChoiceDeltaToolCallFunction,
)
from openai.types.completion_usage import CompletionUsage
from openai.types.responses import (
Response,
Expand Down Expand Up @@ -42,7 +48,7 @@
)
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails

from ..exceptions import UserError
from ..exceptions import ModelBehaviorError, UserError
from ..items import TResponseStreamEvent
from ..logger import logger
from .chatcmpl_helpers import ChatCmplHelpers
Expand Down Expand Up @@ -76,6 +82,42 @@ class StreamingState:
has_warned_unsupported_choice: bool = False


@dataclass
class _BufferedToolCall:
"""Accumulates a streamed Chat Completions function tool call."""

index: int
call_id: str | None = None
name: str | None = None
arguments: str = ""
provider_specific_fields: dict[str, Any] | None = None
extra_content: dict[str, Any] | None = None


def _merge_buffered_metadata(
current: dict[str, Any] | None,
incoming: dict[str, Any],
) -> dict[str, Any] | None:
"""Merge provider metadata without letting empty chunks erase earlier fields."""
if not incoming:
return current

if current is None:
return incoming.copy()

merged = current.copy()
for key, value in incoming.items():
current_value = merged.get(key)
if isinstance(current_value, dict) and isinstance(value, dict):
merged[key] = _merge_buffered_metadata(current_value, value) or {}
elif isinstance(value, dict) and not value and key in merged:
continue
else:
merged[key] = value

return merged


class SequenceNumber:
def __init__(self):
self._sequence_number = 0
Expand Down Expand Up @@ -163,6 +205,186 @@ def function_calls_after_message(


class ChatCmplStreamHandler:
@staticmethod
def _choice_finished_tool_calls(choice: Choice) -> bool:
return choice.finish_reason == "tool_calls"

@staticmethod
def _should_buffer_tool_call_delta(tool_call_delta: ChoiceDeltaToolCall) -> bool:
tool_call_type = getattr(tool_call_delta, "type", None)
return tool_call_type in (None, "function")

@staticmethod
def _delta_has_passthrough_output(delta: ChoiceDelta | None) -> bool:
if delta is None:
return False

if delta.content is not None or delta.tool_calls:
return True

if hasattr(delta, "refusal") and delta.refusal:
return True

if hasattr(delta, "reasoning_content") and delta.reasoning_content:
return True

if hasattr(delta, "reasoning") and delta.reasoning:
return True

if hasattr(delta, "thinking_blocks") and delta.thinking_blocks:
return True

return False

@staticmethod
def _accumulate_tool_call_delta(
buffered_calls: dict[int, _BufferedToolCall],
tool_call_delta: ChoiceDeltaToolCall,
) -> None:
buffered_call = buffered_calls.setdefault(
tool_call_delta.index,
_BufferedToolCall(index=tool_call_delta.index),
)

if tool_call_delta.id:
buffered_call.call_id = tool_call_delta.id

if tool_call_delta.function:
if tool_call_delta.function.name:
buffered_call.name = tool_call_delta.function.name
if tool_call_delta.function.arguments:
buffered_call.arguments += tool_call_delta.function.arguments

provider_specific_fields = getattr(tool_call_delta, "provider_specific_fields", None)
if isinstance(provider_specific_fields, dict):
buffered_call.provider_specific_fields = _merge_buffered_metadata(
buffered_call.provider_specific_fields,
provider_specific_fields,
)

extra_content = getattr(tool_call_delta, "extra_content", None)
if isinstance(extra_content, dict):
buffered_call.extra_content = _merge_buffered_metadata(
buffered_call.extra_content,
extra_content,
)

@staticmethod
def _buffered_tool_call_delta(
buffered_call: _BufferedToolCall,
) -> ChoiceDeltaToolCall:
if not buffered_call.call_id:
raise ModelBehaviorError(
"Buffered Chat Completions tool call stream ended without a tool call id."
)

if not buffered_call.name:
raise ModelBehaviorError(
"Buffered Chat Completions tool call stream ended without a function name."
)

tool_call_delta = ChoiceDeltaToolCall(
index=buffered_call.index,
id=buffered_call.call_id,
function=ChoiceDeltaToolCallFunction(
name=buffered_call.name,
arguments=buffered_call.arguments,
),
type="function",
)

tool_call_delta_any = cast(Any, tool_call_delta)
if buffered_call.provider_specific_fields is not None:
tool_call_delta_any.provider_specific_fields = buffered_call.provider_specific_fields
if buffered_call.extra_content is not None:
tool_call_delta_any.extra_content = buffered_call.extra_content

return tool_call_delta

@classmethod
def _buffered_tool_calls_chunk(
cls,
template_chunk: ChatCompletionChunk,
buffered_calls: dict[int, _BufferedToolCall],
) -> ChatCompletionChunk:
tool_call_deltas = [
cls._buffered_tool_call_delta(buffered_call)
for _, buffered_call in sorted(buffered_calls.items())
]
choice = Choice(
index=0,
delta=ChoiceDelta(tool_calls=tool_call_deltas),
finish_reason="tool_calls",
)
return template_chunk.model_copy(update={"choices": [choice], "usage": None})

@classmethod
async def buffer_tool_call_stream(
cls,
stream: AsyncIterator[ChatCompletionChunk],
) -> AsyncIterator[ChatCompletionChunk]:
"""Buffer streamed function tool-call deltas until they are complete."""
buffered_calls: dict[int, _BufferedToolCall] = {}
passthrough_tool_call_indexes: set[int] = set()
saw_passthrough_tool_call = False
last_chunk: ChatCompletionChunk | None = None

async for chunk in stream:
last_chunk = chunk

if not chunk.choices:
yield chunk
continue

passthrough_choices: list[Choice] = []
for choice in chunk.choices:
if choice.index != 0:
if choice.delta and choice.delta.tool_calls:
saw_passthrough_tool_call = True
passthrough_choices.append(choice)
continue

delta = choice.delta

if tool_call_deltas := (delta.tool_calls if delta and delta.tool_calls else None):
remaining_tool_calls: list[ChoiceDeltaToolCall] = []
for tool_call_delta in tool_call_deltas:
if tool_call_delta.index in passthrough_tool_call_indexes:
saw_passthrough_tool_call = True
remaining_tool_calls.append(tool_call_delta)
elif cls._should_buffer_tool_call_delta(tool_call_delta):
cls._accumulate_tool_call_delta(buffered_calls, tool_call_delta)
else:
passthrough_tool_call_indexes.add(tool_call_delta.index)
saw_passthrough_tool_call = True
remaining_tool_calls.append(tool_call_delta)

delta = delta.model_copy(update={"tool_calls": remaining_tool_calls or None})
choice = choice.model_copy(update={"delta": delta})

has_passthrough_output = cls._delta_has_passthrough_output(choice.delta)
if (
cls._choice_finished_tool_calls(choice)
and not buffered_calls
and not saw_passthrough_tool_call
and not has_passthrough_output
Comment thread
incoffeemonster marked this conversation as resolved.
):
raise ModelBehaviorError(
"Chat Completions stream finished with finish_reason='tool_calls' "
"but did not include any streamed tool call deltas."
)

if has_passthrough_output:
passthrough_choices.append(choice)

if passthrough_choices or chunk.usage is not None:
yield chunk.model_copy(update={"choices": passthrough_choices})

if buffered_calls:
if last_chunk is None:
return
yield cls._buffered_tool_calls_chunk(last_chunk, buffered_calls)

@staticmethod
def _merged_provider_data(
state: StreamingState,
Expand Down
5 changes: 5 additions & 0 deletions src/agents/models/multi_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
unknown_prefix_mode: MultiProviderUnknownPrefixMode = "error",
openai_agent_registration: OpenAIAgentRegistrationConfig | None = None,
openai_responses_websocket_options: OpenAIResponsesWebSocketOptions | None = None,
openai_buffer_streamed_tool_calls: bool = False,
) -> None:
"""Create a new OpenAI provider.

Expand Down Expand Up @@ -126,6 +127,9 @@ def __init__(
provider.
openai_responses_websocket_options: Optional low-level websocket keepalive options for
the OpenAI Responses websocket transport.
openai_buffer_streamed_tool_calls: Whether OpenAI Chat Completions models should buffer
streamed function tool-call deltas and emit them to the SDK only after the provider
stream finishes.
"""
self.provider_map = provider_map
self.openai_provider = OpenAIProvider(
Expand All @@ -140,6 +144,7 @@ def __init__(
strict_feature_validation=openai_strict_feature_validation,
agent_registration=openai_agent_registration,
responses_websocket_options=openai_responses_websocket_options,
buffer_streamed_tool_calls=openai_buffer_streamed_tool_calls,
)
self._openai_prefix_mode = self._validate_openai_prefix_mode(openai_prefix_mode)
self._unknown_prefix_mode = self._validate_unknown_prefix_mode(unknown_prefix_mode)
Expand Down
10 changes: 9 additions & 1 deletion src/agents/models/openai_chatcompletions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,13 @@ def __init__(
openai_client: AsyncOpenAI,
should_replay_reasoning_content: ShouldReplayReasoningContent | None = None,
strict_feature_validation: bool = False,
buffer_streamed_tool_calls: bool = False,
) -> None:
self.model = model
self._client = openai_client
self.should_replay_reasoning_content = should_replay_reasoning_content
self._strict_feature_validation = strict_feature_validation
self._buffer_streamed_tool_calls = buffer_streamed_tool_calls
self._has_warned_unsupported_prompt = False
self._has_warned_unsupported_conversation_state = False

Expand Down Expand Up @@ -299,9 +301,15 @@ async def stream_response(
)

final_response: Response | None = None
stream_for_handler: AsyncIterator[ChatCompletionChunk]
if self._buffer_streamed_tool_calls:
stream_for_handler = ChatCmplStreamHandler.buffer_tool_call_stream(stream)
else:
stream_for_handler = stream

async for chunk in ChatCmplStreamHandler.handle_stream(
response,
stream,
cast(AsyncStream[ChatCompletionChunk], stream_for_handler),
model=self.model,
strict_feature_validation=self._strict_feature_validation,
):
Expand Down
7 changes: 7 additions & 0 deletions src/agents/models/openai_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
strict_feature_validation: bool = False,
agent_registration: OpenAIAgentRegistrationConfig | None = None,
responses_websocket_options: OpenAIResponsesWebSocketOptions | None = None,
buffer_streamed_tool_calls: bool = False,
) -> None:
"""Create a new OpenAI provider.

Expand All @@ -79,6 +80,10 @@ def __init__(
agent_registration: Optional agent registration configuration.
responses_websocket_options: Optional low-level websocket keepalive options for the
OpenAI Responses websocket transport.
buffer_streamed_tool_calls: Whether Chat Completions models should buffer streamed
function tool-call deltas and emit them to the SDK only after the provider stream
finishes. This is useful for OpenAI-compatible providers whose streamed tool-call
chunk semantics are not reliable enough for incremental processing.
"""
if openai_client is not None:
assert api_key is None and base_url is None and websocket_base_url is None, (
Expand Down Expand Up @@ -109,6 +114,7 @@ def __init__(
self._use_responses_websocket = self._responses_transport == "websocket"
self._strict_feature_validation = strict_feature_validation
self._responses_websocket_options = responses_websocket_options
self._buffer_streamed_tool_calls = buffer_streamed_tool_calls

# Reuse websocket model wrappers so websocket transport can keep a persistent connection
# when callers pass model names as strings through a shared provider.
Expand Down Expand Up @@ -230,6 +236,7 @@ def get_model(self, model_name: str | None) -> Model:
model=resolved_model_name,
openai_client=client,
strict_feature_validation=self._strict_feature_validation,
buffer_streamed_tool_calls=self._buffer_streamed_tool_calls,
)

if use_websocket_transport:
Expand Down
Loading