diff --git a/src/openai/_streaming.py b/src/openai/_streaming.py index 45c13cc11d..3c43821066 100644 --- a/src/openai/_streaming.py +++ b/src/openai/_streaming.py @@ -63,10 +63,20 @@ def __stream__(self) -> Iterator[_T]: if sse.data.startswith("[DONE]"): break + # Skip SSE meta-only events that carry no data (e.g. standalone + # `retry:` or `id:` directives). Per the SSE spec these are valid + # but contain an empty data field; calling sse.json() on them + # raises JSONDecodeError. + if not sse.data or not sse.data.strip(): + continue + # we have to special case the Assistants `thread.` events since we won't have an "event" key in the data if sse.event and sse.event.startswith("thread."): data = sse.json() - + yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response) + else: + data = sse.json() + # Handle error events that carry an "error" event type (outside thread.* scope) if sse.event == "error" and is_mapping(data) and data.get("error"): message = None error = data.get("error") @@ -81,9 +91,6 @@ def __stream__(self) -> Iterator[_T]: body=data["error"], ) - yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response) - else: - data = sse.json() if is_mapping(data) and data.get("error"): message = None error = data.get("error") @@ -173,10 +180,20 @@ async def __stream__(self) -> AsyncIterator[_T]: if sse.data.startswith("[DONE]"): break + # Skip SSE meta-only events that carry no data (e.g. standalone + # `retry:` or `id:` directives). Per the SSE spec these are valid + # but contain an empty data field; calling sse.json() on them + # raises JSONDecodeError. + if not sse.data or not sse.data.strip(): + continue + # we have to special case the Assistants `thread.` events since we won't have an "event" key in the data if sse.event and sse.event.startswith("thread."): data = sse.json() - + yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response) + else: + data = sse.json() + # Handle error events that carry an "error" event type (outside thread.* scope) if sse.event == "error" and is_mapping(data) and data.get("error"): message = None error = data.get("error") @@ -191,9 +208,6 @@ async def __stream__(self) -> AsyncIterator[_T]: body=data["error"], ) - yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response) - else: - data = sse.json() if is_mapping(data) and data.get("error"): message = None error = data.get("error") @@ -238,6 +252,17 @@ async def close(self) -> None: """ await self.response.aclose() + async def aclose(self) -> None: + """Async-convention alias for :meth:`close`. + + Follows the standard Python async cleanup protocol used by + ``asyncio.StreamWriter``, ``httpx.AsyncByteStream``, and async + generators (PEP 525), allowing callers and instrumentation libraries + to call ``await stream.aclose()`` uniformly without special-casing + this class. + """ + await self.close() + class ServerSentEvent: def __init__(