Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/openai/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,8 +1021,8 @@ def request(

log.debug("Raising timeout error")
raise APITimeoutError(request=request) from err
except Exception as err:
log.debug("Encountered Exception", exc_info=True)
except httpx.HTTPError as err:
log.debug("Encountered httpx.HTTPError", exc_info=True)

if remaining_retries > 0:
self._sleep_for_retry(
Expand Down Expand Up @@ -1620,8 +1620,8 @@ async def request(

log.debug("Raising timeout error")
raise APITimeoutError(request=request) from err
except Exception as err:
log.debug("Encountered Exception", exc_info=True)
except httpx.HTTPError as err:
log.debug("Encountered httpx.HTTPError", exc_info=True)

if remaining_retries > 0:
await self._sleep_for_retry(
Expand Down
179 changes: 112 additions & 67 deletions src/openai/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import json
import logging
import inspect
from types import TracebackType
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, Optional, AsyncIterator, cast
Expand All @@ -16,10 +17,79 @@
from ._client import OpenAI, AsyncOpenAI
from ._models import FinalRequestOptions

log = logging.getLogger(__name__)

_T = TypeVar("_T")


def _check_stream_error(data: object, sse: ServerSentEvent, request: httpx.Request) -> None:
"""Check if an SSE event represents a stream error and raise APIError if so.

Handles both explicit error events (sse.event == "error") and events
where the data payload contains an "error" field.
"""
is_error_event = sse.event == "error"
has_error_field = is_mapping(data) and data.get("error")

if not is_error_event and not has_error_field:
return

# An explicit error event with no payload (or empty payload) should still
# surface as an error rather than being silently skipped.
if data is None and is_error_event:
raise APIError(
message="An error occurred during streaming",
request=request,
body=None,
)

message: str | None = None

if is_mapping(data):
error = data.get("error")
if is_mapping(error):
msg = error.get("message")
if msg and isinstance(msg, str):
message = msg
# Also check for top-level message field per API spec
if not message:
top_msg = data.get("message")
if top_msg and isinstance(top_msg, str):
message = top_msg

if not message:
message = "An error occurred during streaming"

body = data.get("error") if is_mapping(data) and data.get("error") else data

raise APIError(
message=message,
request=request,
body=body,
)


def _parse_sse_data(sse: ServerSentEvent) -> Any:
"""Parse the JSON data from an SSE event with proper error handling.

Returns the parsed data, or raises a more informative error if parsing fails.
"""
if not sse.data:
log.debug("Received SSE event with empty data (event=%s)", sse.event)
return None

try:
return json.loads(sse.data)
except json.JSONDecodeError as exc:
data_preview = sse.data[:200]
raise APIError(
message=f"Failed to parse streaming response data as JSON: {exc.msg} "
f"(event={sse.event!r}, data={data_preview!r})",
request=httpx.Request("POST", ""),
body=None,
) from exc


class Stream(Generic[_T]):
"""Provides the core interface to iterate over a synchronous stream response."""

Expand Down Expand Up @@ -63,41 +133,20 @@ def __stream__(self) -> Iterator[_T]:
if sse.data.startswith("[DONE]"):
break

# we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
if sse.event and sse.event.startswith("thread."):
data = sse.json()

if sse.event == "error" and is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)
data = _parse_sse_data(sse)

# Check for error events before processing - handles both explicit
# error events (including those with empty payloads) and data
# payloads containing an "error" field
_check_stream_error(data, sse, self.response.request)

if data is None:
continue

# Assistants `thread.` events need special handling since we synthesize the event key
if sse.event and sse.event.startswith("thread."):
yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
else:
data = sse.json()
if is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)

yield process_data(
data={"data": data, "event": sse.event}
if self._options is not None and self._options.synthesize_event_and_data
Expand Down Expand Up @@ -173,41 +222,20 @@ async def __stream__(self) -> AsyncIterator[_T]:
if sse.data.startswith("[DONE]"):
break

# we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
if sse.event and sse.event.startswith("thread."):
data = sse.json()

if sse.event == "error" and is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)
data = _parse_sse_data(sse)

# Check for error events before processing - handles both explicit
# error events (including those with empty payloads) and data
# payloads containing an "error" field
_check_stream_error(data, sse, self.response.request)

if data is None:
continue

# Assistants `thread.` events need special handling since we synthesize the event key
if sse.event and sse.event.startswith("thread."):
yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
else:
data = sse.json()
if is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)

yield process_data(
data={"data": data, "event": sse.event}
if self._options is not None and self._options.synthesize_event_and_data
Expand Down Expand Up @@ -273,7 +301,16 @@ def data(self) -> str:
return self._data

def json(self) -> Any:
return json.loads(self.data)
try:
return json.loads(self.data)
except json.JSONDecodeError as exc:
data_preview = self.data[:200] if self.data else ""
raise json.JSONDecodeError(
f"Failed to parse SSE event data (event={self.event!r}, "
f"data_preview={data_preview!r}): {exc.msg}",
exc.doc,
exc.pos,
) from None

@override
def __repr__(self) -> str:
Expand All @@ -297,7 +334,11 @@ def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
for chunk in self._iter_chunks(iterator):
# Split before decoding so splitlines() only uses \r and \n
for raw_line in chunk.splitlines():
line = raw_line.decode("utf-8")
try:
line = raw_line.decode("utf-8")
except UnicodeDecodeError:
log.debug("Skipping SSE line with invalid UTF-8: %r", raw_line[:100])
continue
sse = self.decode(line)
if sse:
yield sse
Expand All @@ -319,7 +360,11 @@ async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[Ser
async for chunk in self._aiter_chunks(iterator):
# Split before decoding so splitlines() only uses \r and \n
for raw_line in chunk.splitlines():
line = raw_line.decode("utf-8")
try:
line = raw_line.decode("utf-8")
except UnicodeDecodeError:
log.debug("Skipping SSE line with invalid UTF-8: %r", raw_line[:100])
continue
sse = self.decode(line)
if sse:
yield sse
Expand Down
Loading