From 2810a71aa77b9536e0de314e9b55c0d94490e002 Mon Sep 17 00:00:00 2001 From: instantcoffeemonster Date: Tue, 26 May 2026 18:33:50 +0800 Subject: [PATCH 1/3] Add buffered Chat Completions tool-call streaming --- src/agents/models/chatcmpl_stream_handler.py | 177 +++++++++++- src/agents/models/multi_provider.py | 5 + src/agents/models/openai_chatcompletions.py | 10 +- src/agents/models/openai_provider.py | 7 + .../test_openai_chatcompletions_stream.py | 258 +++++++++++++++++- 5 files changed, 453 insertions(+), 4 deletions(-) diff --git a/src/agents/models/chatcmpl_stream_handler.py b/src/agents/models/chatcmpl_stream_handler.py index b8a1272e0b..71df75f169 100644 --- a/src/agents/models/chatcmpl_stream_handler.py +++ b/src/agents/models/chatcmpl_stream_handler.py @@ -2,10 +2,16 @@ from collections.abc import AsyncIterator, Iterator from dataclasses import dataclass, field -from typing import Any +from typing import Any, cast from openai import AsyncStream from openai.types.chat import ChatCompletionChunk +from openai.types.chat.chat_completion_chunk import ( + Choice, + ChoiceDelta, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, +) from openai.types.completion_usage import CompletionUsage from openai.types.responses import ( Response, @@ -42,7 +48,7 @@ ) from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails -from ..exceptions import UserError +from ..exceptions import ModelBehaviorError, UserError from ..items import TResponseStreamEvent from ..logger import logger from .chatcmpl_helpers import ChatCmplHelpers @@ -76,6 +82,18 @@ class StreamingState: has_warned_unsupported_choice: bool = False +@dataclass +class _BufferedToolCall: + """Accumulates a streamed Chat Completions function tool call.""" + + index: int + call_id: str | None = None + name: str | None = None + arguments: str = "" + provider_specific_fields: dict[str, Any] | None = None + extra_content: dict[str, Any] | None = None + + class SequenceNumber: def __init__(self): self._sequence_number = 0 @@ -163,6 +181,161 @@ def function_calls_after_message( class ChatCmplStreamHandler: + @staticmethod + def _choice_finished_tool_calls(choice: Choice) -> bool: + return choice.finish_reason == "tool_calls" + + @staticmethod + def _should_buffer_tool_call_delta(tool_call_delta: ChoiceDeltaToolCall) -> bool: + tool_call_type = getattr(tool_call_delta, "type", None) + return tool_call_type in (None, "function") + + @staticmethod + def _delta_has_passthrough_output(delta: ChoiceDelta | None) -> bool: + if delta is None: + return False + + if delta.content is not None or delta.tool_calls: + return True + + if hasattr(delta, "refusal") and delta.refusal: + return True + + if hasattr(delta, "reasoning_content") and delta.reasoning_content: + return True + + if hasattr(delta, "reasoning") and delta.reasoning: + return True + + if hasattr(delta, "thinking_blocks") and delta.thinking_blocks: + return True + + return False + + @staticmethod + def _accumulate_tool_call_delta( + buffered_calls: dict[int, _BufferedToolCall], + tool_call_delta: ChoiceDeltaToolCall, + ) -> None: + buffered_call = buffered_calls.setdefault( + tool_call_delta.index, + _BufferedToolCall(index=tool_call_delta.index), + ) + + if tool_call_delta.id: + buffered_call.call_id = tool_call_delta.id + + if tool_call_delta.function: + if tool_call_delta.function.name: + buffered_call.name = tool_call_delta.function.name + if tool_call_delta.function.arguments: + buffered_call.arguments += tool_call_delta.function.arguments + + provider_specific_fields = getattr(tool_call_delta, "provider_specific_fields", None) + if isinstance(provider_specific_fields, dict): + buffered_call.provider_specific_fields = provider_specific_fields + + extra_content = getattr(tool_call_delta, "extra_content", None) + if isinstance(extra_content, dict): + buffered_call.extra_content = extra_content + + @staticmethod + def _buffered_tool_call_delta( + buffered_call: _BufferedToolCall, + ) -> ChoiceDeltaToolCall: + if not buffered_call.call_id: + raise ModelBehaviorError( + "Buffered Chat Completions tool call stream ended without a tool call id." + ) + + if not buffered_call.name: + raise ModelBehaviorError( + "Buffered Chat Completions tool call stream ended without a function name." + ) + + tool_call_delta = ChoiceDeltaToolCall( + index=buffered_call.index, + id=buffered_call.call_id, + function=ChoiceDeltaToolCallFunction( + name=buffered_call.name, + arguments=buffered_call.arguments, + ), + type="function", + ) + + tool_call_delta_any = cast(Any, tool_call_delta) + if buffered_call.provider_specific_fields is not None: + tool_call_delta_any.provider_specific_fields = buffered_call.provider_specific_fields + if buffered_call.extra_content is not None: + tool_call_delta_any.extra_content = buffered_call.extra_content + + return tool_call_delta + + @classmethod + def _buffered_tool_calls_chunk( + cls, + template_chunk: ChatCompletionChunk, + buffered_calls: dict[int, _BufferedToolCall], + ) -> ChatCompletionChunk: + tool_call_deltas = [ + cls._buffered_tool_call_delta(buffered_call) + for _, buffered_call in sorted(buffered_calls.items()) + ] + choice = Choice( + index=0, + delta=ChoiceDelta(tool_calls=tool_call_deltas), + finish_reason="tool_calls", + ) + return template_chunk.model_copy(update={"choices": [choice], "usage": None}) + + @classmethod + async def buffer_tool_call_stream( + cls, + stream: AsyncIterator[ChatCompletionChunk], + ) -> AsyncIterator[ChatCompletionChunk]: + """Buffer streamed function tool-call deltas until they are complete.""" + buffered_calls: dict[int, _BufferedToolCall] = {} + last_chunk: ChatCompletionChunk | None = None + + async for chunk in stream: + last_chunk = chunk + + if not chunk.choices: + yield chunk + continue + + passthrough_choices: list[Choice] = [] + for choice in chunk.choices: + delta = choice.delta + + if tool_call_deltas := (delta.tool_calls if delta and delta.tool_calls else None): + remaining_tool_calls: list[ChoiceDeltaToolCall] = [] + for tool_call_delta in tool_call_deltas: + if cls._should_buffer_tool_call_delta(tool_call_delta): + cls._accumulate_tool_call_delta(buffered_calls, tool_call_delta) + else: + remaining_tool_calls.append(tool_call_delta) + + delta = delta.model_copy(update={"tool_calls": remaining_tool_calls or None}) + choice = choice.model_copy(update={"delta": delta}) + + if cls._choice_finished_tool_calls(choice) and not buffered_calls: + raise ModelBehaviorError( + "Chat Completions stream finished with finish_reason='tool_calls' " + "but did not include any streamed tool call deltas." + ) + + if cls._delta_has_passthrough_output(choice.delta): + passthrough_choices.append(choice) + + if passthrough_choices or chunk.usage is not None: + yield chunk.model_copy(update={"choices": passthrough_choices}) + + if buffered_calls: + if last_chunk is None: + return + yield cls._buffered_tool_calls_chunk(last_chunk, buffered_calls) + @staticmethod def _merged_provider_data( state: StreamingState, diff --git a/src/agents/models/multi_provider.py b/src/agents/models/multi_provider.py index 2dd5af2013..4737bb8c0c 100644 --- a/src/agents/models/multi_provider.py +++ b/src/agents/models/multi_provider.py @@ -89,6 +89,7 @@ def __init__( unknown_prefix_mode: MultiProviderUnknownPrefixMode = "error", openai_agent_registration: OpenAIAgentRegistrationConfig | None = None, openai_responses_websocket_options: OpenAIResponsesWebSocketOptions | None = None, + openai_buffer_streamed_tool_calls: bool = False, ) -> None: """Create a new OpenAI provider. @@ -126,6 +127,9 @@ def __init__( provider. openai_responses_websocket_options: Optional low-level websocket keepalive options for the OpenAI Responses websocket transport. + openai_buffer_streamed_tool_calls: Whether OpenAI Chat Completions models should buffer + streamed function tool-call deltas and emit them to the SDK only after the provider + stream finishes. """ self.provider_map = provider_map self.openai_provider = OpenAIProvider( @@ -140,6 +144,7 @@ def __init__( strict_feature_validation=openai_strict_feature_validation, agent_registration=openai_agent_registration, responses_websocket_options=openai_responses_websocket_options, + buffer_streamed_tool_calls=openai_buffer_streamed_tool_calls, ) self._openai_prefix_mode = self._validate_openai_prefix_mode(openai_prefix_mode) self._unknown_prefix_mode = self._validate_unknown_prefix_mode(unknown_prefix_mode) diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index cba01163e9..4da04f0f83 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -57,11 +57,13 @@ def __init__( openai_client: AsyncOpenAI, should_replay_reasoning_content: ShouldReplayReasoningContent | None = None, strict_feature_validation: bool = False, + buffer_streamed_tool_calls: bool = False, ) -> None: self.model = model self._client = openai_client self.should_replay_reasoning_content = should_replay_reasoning_content self._strict_feature_validation = strict_feature_validation + self._buffer_streamed_tool_calls = buffer_streamed_tool_calls self._has_warned_unsupported_prompt = False self._has_warned_unsupported_conversation_state = False @@ -299,9 +301,15 @@ async def stream_response( ) final_response: Response | None = None + stream_for_handler: AsyncIterator[ChatCompletionChunk] + if self._buffer_streamed_tool_calls: + stream_for_handler = ChatCmplStreamHandler.buffer_tool_call_stream(stream) + else: + stream_for_handler = stream + async for chunk in ChatCmplStreamHandler.handle_stream( response, - stream, + cast(AsyncStream[ChatCompletionChunk], stream_for_handler), model=self.model, strict_feature_validation=self._strict_feature_validation, ): diff --git a/src/agents/models/openai_provider.py b/src/agents/models/openai_provider.py index 4153e659f5..59a04071c6 100644 --- a/src/agents/models/openai_provider.py +++ b/src/agents/models/openai_provider.py @@ -55,6 +55,7 @@ def __init__( strict_feature_validation: bool = False, agent_registration: OpenAIAgentRegistrationConfig | None = None, responses_websocket_options: OpenAIResponsesWebSocketOptions | None = None, + buffer_streamed_tool_calls: bool = False, ) -> None: """Create a new OpenAI provider. @@ -79,6 +80,10 @@ def __init__( agent_registration: Optional agent registration configuration. responses_websocket_options: Optional low-level websocket keepalive options for the OpenAI Responses websocket transport. + buffer_streamed_tool_calls: Whether Chat Completions models should buffer streamed + function tool-call deltas and emit them to the SDK only after the provider stream + finishes. This is useful for OpenAI-compatible providers whose streamed tool-call + chunk semantics are not reliable enough for incremental processing. """ if openai_client is not None: assert api_key is None and base_url is None and websocket_base_url is None, ( @@ -109,6 +114,7 @@ def __init__( self._use_responses_websocket = self._responses_transport == "websocket" self._strict_feature_validation = strict_feature_validation self._responses_websocket_options = responses_websocket_options + self._buffer_streamed_tool_calls = buffer_streamed_tool_calls # Reuse websocket model wrappers so websocket transport can keep a persistent connection # when callers pass model names as strings through a shared provider. @@ -230,6 +236,7 @@ def get_model(self, model_name: str | None) -> Model: model=resolved_model_name, openai_client=client, strict_feature_validation=self._strict_feature_validation, + buffer_streamed_tool_calls=self._buffer_streamed_tool_calls, ) if use_websocket_transport: diff --git a/tests/models/test_openai_chatcompletions_stream.py b/tests/models/test_openai_chatcompletions_stream.py index 10cc5f9b84..b6da7e915d 100644 --- a/tests/models/test_openai_chatcompletions_stream.py +++ b/tests/models/test_openai_chatcompletions_stream.py @@ -29,7 +29,7 @@ ResponseOutputText, ) -from agents.exceptions import UserError +from agents.exceptions import ModelBehaviorError, UserError from agents.model_settings import ModelSettings from agents.models.chatcmpl_stream_handler import ChatCmplStreamHandler from agents.models.interface import ModelTracing @@ -769,6 +769,262 @@ async def patched_fetch_response(self, *args, **kwargs): assert final_fn.arguments == "arg1arg2" +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_stream_response_buffers_tool_call_deltas_when_enabled(monkeypatch) -> None: + tool_call_delta1 = ChoiceDeltaToolCall( + index=0, + id="tool-id", + function=ChoiceDeltaToolCallFunction(name="my_func", arguments="arg1"), + type="function", + ) + tool_call_delta2 = ChoiceDeltaToolCall( + index=0, + id=None, + function=ChoiceDeltaToolCallFunction(name=None, arguments="arg2"), + type="function", + ) + chunk1 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta1]))], + ) + chunk2 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta2]))], + usage=CompletionUsage(completion_tokens=1, prompt_tokens=1, total_tokens=2), + ) + + async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: + for chunk in (chunk1, chunk2): + yield chunk + + async def patched_fetch_response(self, *args, **kwargs): + return _empty_response(), fake_stream() + + monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response) + model = OpenAIProvider( + use_responses=False, + buffer_streamed_tool_calls=True, + ).get_model("gpt-4") + + output_events = [] + async for event in model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + conversation_id=None, + prompt=None, + ): + output_events.append(event) + + argument_delta_events = [ + event for event in output_events if event.type == "response.function_call_arguments.delta" + ] + assert len(argument_delta_events) == 1 + assert argument_delta_events[0].delta == "arg1arg2" + + done_event = next(event for event in output_events if event.type == "response.output_item.done") + final_fn = done_event.item + assert isinstance(final_fn, ResponseFunctionToolCall) + assert final_fn.call_id == "tool-id" + assert final_fn.name == "my_func" + assert final_fn.arguments == "arg1arg2" + + completed_event = next(event for event in output_events if event.type == "response.completed") + assert isinstance(completed_event, ResponseCompletedEvent) + assert completed_event.response.usage + assert completed_event.response.usage.total_tokens == 2 + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_stream_response_buffers_tool_call_usage_chunk_without_replay( + monkeypatch, +) -> None: + tool_call_delta = ChoiceDeltaToolCall( + index=0, + id="tool-id", + function=ChoiceDeltaToolCallFunction(name="my_func", arguments="arg1"), + type="function", + ) + chunk = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta]))], + usage=CompletionUsage(completion_tokens=1, prompt_tokens=1, total_tokens=2), + ) + + async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: + yield chunk + + async def patched_fetch_response(self, *args, **kwargs): + return _empty_response(), fake_stream() + + monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response) + model = OpenAIProvider( + use_responses=False, + buffer_streamed_tool_calls=True, + ).get_model("gpt-4") + + output_events = [] + async for event in model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + conversation_id=None, + prompt=None, + ): + output_events.append(event) + + argument_delta_events = [ + event for event in output_events if event.type == "response.function_call_arguments.delta" + ] + assert len(argument_delta_events) == 1 + assert argument_delta_events[0].delta == "arg1" + + function_done_events = [ + event + for event in output_events + if event.type == "response.output_item.done" + and isinstance(event.item, ResponseFunctionToolCall) + ] + assert len(function_done_events) == 1 + + completed_event = next(event for event in output_events if event.type == "response.completed") + assert isinstance(completed_event, ResponseCompletedEvent) + assert completed_event.response.usage + assert completed_event.response.usage.total_tokens == 2 + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_stream_response_buffers_tool_call_provider_fields(monkeypatch) -> None: + tool_call_delta1 = ChoiceDeltaToolCall( + index=0, + id="tool-id", + function=ChoiceDeltaToolCallFunction(name="my_func", arguments=None), + type="function", + ) + cast(Any, tool_call_delta1).provider_specific_fields = {"thought_signature": "thought-sig"} + tool_call_delta2 = ChoiceDeltaToolCall( + index=0, + id=None, + function=ChoiceDeltaToolCallFunction(name=None, arguments="arg1"), + type="function", + ) + chunk1 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="gemini/gemini-3-pro", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta1]))], + ) + chunk2 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="gemini/gemini-3-pro", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta2]))], + ) + + async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: + for chunk in (chunk1, chunk2): + yield chunk + + async def patched_fetch_response(self, *args, **kwargs): + return _empty_response(), fake_stream() + + monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response) + model = OpenAIProvider( + use_responses=False, + buffer_streamed_tool_calls=True, + ).get_model("gemini/gemini-3-pro") + + output_events = [] + async for event in model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + conversation_id=None, + prompt=None, + ): + output_events.append(event) + + function_done_events = [ + event + for event in output_events + if event.type == "response.output_item.done" + and isinstance(event.item, ResponseFunctionToolCall) + ] + assert len(function_done_events) == 1 + provider_data = function_done_events[0].item.model_dump().get("provider_data", {}) + assert provider_data["thought_signature"] == "thought-sig" + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_stream_response_buffered_tool_calls_raise_for_missing_tool_call_delta( + monkeypatch, +) -> None: + chunk = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(), finish_reason="tool_calls")], + ) + + async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: + yield chunk + + async def patched_fetch_response(self, *args, **kwargs): + return _empty_response(), fake_stream() + + monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response) + model = OpenAIProvider( + use_responses=False, + buffer_streamed_tool_calls=True, + ).get_model("gpt-4") + + with pytest.raises(ModelBehaviorError, match="finish_reason='tool_calls'"): + async for _event in model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + conversation_id=None, + prompt=None, + ): + pass + + @pytest.mark.allow_call_model_methods @pytest.mark.asyncio async def test_stream_response_with_custom_tool_call_raises_in_strict_mode(monkeypatch) -> None: From 09033e9b6cdf0d8f9a1ff37daa903407c25273c1 Mon Sep 17 00:00:00 2001 From: instantcoffeemonster Date: Tue, 26 May 2026 19:40:01 +0800 Subject: [PATCH 2/3] fix: preserve buffered chat completion tool-call validation --- src/agents/models/chatcmpl_stream_handler.py | 19 +- .../test_openai_chatcompletions_stream.py | 222 ++++++++++++++++++ 2 files changed, 238 insertions(+), 3 deletions(-) diff --git a/src/agents/models/chatcmpl_stream_handler.py b/src/agents/models/chatcmpl_stream_handler.py index 71df75f169..7f13dce71a 100644 --- a/src/agents/models/chatcmpl_stream_handler.py +++ b/src/agents/models/chatcmpl_stream_handler.py @@ -295,6 +295,7 @@ async def buffer_tool_call_stream( ) -> AsyncIterator[ChatCompletionChunk]: """Buffer streamed function tool-call deltas until they are complete.""" buffered_calls: dict[int, _BufferedToolCall] = {} + passthrough_tool_call_indexes: set[int] = set() last_chunk: ChatCompletionChunk | None = None async for chunk in stream: @@ -306,26 +307,38 @@ async def buffer_tool_call_stream( passthrough_choices: list[Choice] = [] for choice in chunk.choices: + if choice.index != 0: + passthrough_choices.append(choice) + continue + delta = choice.delta if tool_call_deltas := (delta.tool_calls if delta and delta.tool_calls else None): remaining_tool_calls: list[ChoiceDeltaToolCall] = [] for tool_call_delta in tool_call_deltas: - if cls._should_buffer_tool_call_delta(tool_call_delta): + if tool_call_delta.index in passthrough_tool_call_indexes: + remaining_tool_calls.append(tool_call_delta) + elif cls._should_buffer_tool_call_delta(tool_call_delta): cls._accumulate_tool_call_delta(buffered_calls, tool_call_delta) else: + passthrough_tool_call_indexes.add(tool_call_delta.index) remaining_tool_calls.append(tool_call_delta) delta = delta.model_copy(update={"tool_calls": remaining_tool_calls or None}) choice = choice.model_copy(update={"delta": delta}) - if cls._choice_finished_tool_calls(choice) and not buffered_calls: + has_passthrough_output = cls._delta_has_passthrough_output(choice.delta) + if ( + cls._choice_finished_tool_calls(choice) + and not buffered_calls + and not has_passthrough_output + ): raise ModelBehaviorError( "Chat Completions stream finished with finish_reason='tool_calls' " "but did not include any streamed tool call deltas." ) - if cls._delta_has_passthrough_output(choice.delta): + if has_passthrough_output: passthrough_choices.append(choice) if passthrough_choices or chunk.usage is not None: diff --git a/tests/models/test_openai_chatcompletions_stream.py b/tests/models/test_openai_chatcompletions_stream.py index b6da7e915d..ff2557cd28 100644 --- a/tests/models/test_openai_chatcompletions_stream.py +++ b/tests/models/test_openai_chatcompletions_stream.py @@ -1025,6 +1025,228 @@ async def patched_fetch_response(self, *args, **kwargs): pass +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_buffered_tool_calls_preserve_nonzero_choice_validation(monkeypatch) -> None: + tool_call_delta = ChoiceDeltaToolCall( + index=0, + id="tool-id", + function=ChoiceDeltaToolCallFunction(name="my_func", arguments="arg"), + type="function", + ) + chunk = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=1, delta=ChoiceDelta(tool_calls=[tool_call_delta]))], + ) + + async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: + yield chunk + + async def patched_fetch_response(self, *args, **kwargs): + return _empty_response(), fake_stream() + + monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response) + model = OpenAIProvider( + use_responses=False, + strict_feature_validation=True, + buffer_streamed_tool_calls=True, + ).get_model("gpt-4") + + with pytest.raises(UserError, match="multiple choices or nonzero choice indexes"): + async for _event in model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + conversation_id=None, + prompt=None, + ): + pass + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_buffered_tool_calls_do_not_merge_nonzero_choice_tool_call_indexes( + monkeypatch, +) -> None: + choice_zero_tool_call = ChoiceDeltaToolCall( + index=0, + id="choice-zero-tool-id", + function=ChoiceDeltaToolCallFunction(name="choice_zero_func", arguments="choice-zero"), + type="function", + ) + choice_one_tool_call = ChoiceDeltaToolCall( + index=0, + id="choice-one-tool-id", + function=ChoiceDeltaToolCallFunction(name="choice_one_func", arguments="choice-one"), + type="function", + ) + chunk = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[ + Choice(index=0, delta=ChoiceDelta(tool_calls=[choice_zero_tool_call])), + Choice(index=1, delta=ChoiceDelta(tool_calls=[choice_one_tool_call])), + ], + ) + + async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: + yield chunk + + async def patched_fetch_response(self, *args, **kwargs): + return _empty_response(), fake_stream() + + monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response) + model = OpenAIProvider( + use_responses=False, + buffer_streamed_tool_calls=True, + ).get_model("gpt-4") + + output_events = [] + async for event in model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + conversation_id=None, + prompt=None, + ): + output_events.append(event) + + function_done_events = [ + event + for event in output_events + if event.type == "response.output_item.done" + and isinstance(event.item, ResponseFunctionToolCall) + ] + assert len(function_done_events) == 1 + final_fn = function_done_events[0].item + assert isinstance(final_fn, ResponseFunctionToolCall) + assert final_fn.call_id == "choice-zero-tool-id" + assert final_fn.name == "choice_zero_func" + assert final_fn.arguments == "choice-zero" + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_buffered_tool_calls_preserve_custom_tool_call_strict_error( + monkeypatch, +) -> None: + custom_tool_call_delta = ChoiceDeltaToolCall.model_construct( + index=0, + id="tool-call-123", + type="custom", + ) + chunk = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[ + Choice( + index=0, + delta=ChoiceDelta(tool_calls=[custom_tool_call_delta]), + finish_reason="tool_calls", + ) + ], + ) + + async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: + yield chunk + + async def patched_fetch_response(self, *args, **kwargs): + return _empty_response(), fake_stream() + + monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response) + model = OpenAIProvider( + use_responses=False, + strict_feature_validation=True, + buffer_streamed_tool_calls=True, + ).get_model("gpt-4") + + with pytest.raises(UserError, match="Custom tool calls are not supported"): + async for _event in model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + conversation_id=None, + prompt=None, + ): + pass + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_buffered_tool_calls_ignore_custom_tool_call_by_default(monkeypatch) -> None: + custom_tool_call_delta = ChoiceDeltaToolCall.model_construct( + index=0, + id="tool-call-123", + type="custom", + ) + chunk = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[ + Choice( + index=0, + delta=ChoiceDelta(tool_calls=[custom_tool_call_delta]), + finish_reason="tool_calls", + ) + ], + ) + + async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: + yield chunk + + async def patched_fetch_response(self, *args, **kwargs): + return _empty_response(), fake_stream() + + monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response) + model = OpenAIProvider( + use_responses=False, + buffer_streamed_tool_calls=True, + ).get_model("gpt-4") + + output_events = [] + async for event in model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + conversation_id=None, + prompt=None, + ): + output_events.append(event) + + completed_event = next(event for event in output_events if event.type == "response.completed") + assert isinstance(completed_event, ResponseCompletedEvent) + assert completed_event.response.output == [] + + @pytest.mark.allow_call_model_methods @pytest.mark.asyncio async def test_stream_response_with_custom_tool_call_raises_in_strict_mode(monkeypatch) -> None: From 8e620da28f96d4f39beb9fdfc75ce09db72b5082 Mon Sep 17 00:00:00 2001 From: instantcoffeemonster Date: Tue, 26 May 2026 20:08:40 +0800 Subject: [PATCH 3/3] fix: preserve buffered tool-call stream metadata --- src/agents/models/chatcmpl_stream_handler.py | 40 +++++++++++++++++++- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/src/agents/models/chatcmpl_stream_handler.py b/src/agents/models/chatcmpl_stream_handler.py index 7f13dce71a..052ecfc568 100644 --- a/src/agents/models/chatcmpl_stream_handler.py +++ b/src/agents/models/chatcmpl_stream_handler.py @@ -94,6 +94,30 @@ class _BufferedToolCall: extra_content: dict[str, Any] | None = None +def _merge_buffered_metadata( + current: dict[str, Any] | None, + incoming: dict[str, Any], +) -> dict[str, Any] | None: + """Merge provider metadata without letting empty chunks erase earlier fields.""" + if not incoming: + return current + + if current is None: + return incoming.copy() + + merged = current.copy() + for key, value in incoming.items(): + current_value = merged.get(key) + if isinstance(current_value, dict) and isinstance(value, dict): + merged[key] = _merge_buffered_metadata(current_value, value) or {} + elif isinstance(value, dict) and not value and key in merged: + continue + else: + merged[key] = value + + return merged + + class SequenceNumber: def __init__(self): self._sequence_number = 0 @@ -233,11 +257,17 @@ def _accumulate_tool_call_delta( provider_specific_fields = getattr(tool_call_delta, "provider_specific_fields", None) if isinstance(provider_specific_fields, dict): - buffered_call.provider_specific_fields = provider_specific_fields + buffered_call.provider_specific_fields = _merge_buffered_metadata( + buffered_call.provider_specific_fields, + provider_specific_fields, + ) extra_content = getattr(tool_call_delta, "extra_content", None) if isinstance(extra_content, dict): - buffered_call.extra_content = extra_content + buffered_call.extra_content = _merge_buffered_metadata( + buffered_call.extra_content, + extra_content, + ) @staticmethod def _buffered_tool_call_delta( @@ -296,6 +326,7 @@ async def buffer_tool_call_stream( """Buffer streamed function tool-call deltas until they are complete.""" buffered_calls: dict[int, _BufferedToolCall] = {} passthrough_tool_call_indexes: set[int] = set() + saw_passthrough_tool_call = False last_chunk: ChatCompletionChunk | None = None async for chunk in stream: @@ -308,6 +339,8 @@ async def buffer_tool_call_stream( passthrough_choices: list[Choice] = [] for choice in chunk.choices: if choice.index != 0: + if choice.delta and choice.delta.tool_calls: + saw_passthrough_tool_call = True passthrough_choices.append(choice) continue @@ -317,11 +350,13 @@ async def buffer_tool_call_stream( remaining_tool_calls: list[ChoiceDeltaToolCall] = [] for tool_call_delta in tool_call_deltas: if tool_call_delta.index in passthrough_tool_call_indexes: + saw_passthrough_tool_call = True remaining_tool_calls.append(tool_call_delta) elif cls._should_buffer_tool_call_delta(tool_call_delta): cls._accumulate_tool_call_delta(buffered_calls, tool_call_delta) else: passthrough_tool_call_indexes.add(tool_call_delta.index) + saw_passthrough_tool_call = True remaining_tool_calls.append(tool_call_delta) delta = delta.model_copy(update={"tool_calls": remaining_tool_calls or None}) @@ -331,6 +366,7 @@ async def buffer_tool_call_stream( if ( cls._choice_finished_tool_calls(choice) and not buffered_calls + and not saw_passthrough_tool_call and not has_passthrough_output ): raise ModelBehaviorError(