Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/strands/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@

T = TypeVar("T", bound=BaseModel)

# Alternative context overflow error messages
# These are commonly returned by OpenAI-compatible endpoints wrapping other providers
# (e.g., Databricks serving Bedrock models)
CONTEXT_OVERFLOW_MESSAGES = [
"Input is too long for requested model",
"input length and `max_tokens` exceed context limit",
"too many total text bytes",
]


class Client(Protocol):
"""Protocol defining the OpenAI-compatible interface for the underlying provider client."""
Expand Down Expand Up @@ -594,6 +603,14 @@ async def stream(
# Rate limits (including TPM) require waiting/retrying, not context reduction
logger.warning("OpenAI threw rate limit error")
raise ModelThrottledException(str(e)) from e
except openai.APIError as e:
# Check for alternative context overflow error messages
error_message = str(e)
if any(overflow_msg in error_message for overflow_msg in CONTEXT_OVERFLOW_MESSAGES):
logger.warning("context window overflow error detected")
raise ContextWindowOverflowException(error_message) from e
# Re-raise other APIError exceptions
raise

logger.debug("got response from model")
yield self.format_chunk({"chunk_type": "message_start"})
Expand Down Expand Up @@ -717,6 +734,14 @@ async def structured_output(
# Rate limits (including TPM) require waiting/retrying, not context reduction
logger.warning("OpenAI threw rate limit error")
raise ModelThrottledException(str(e)) from e
except openai.APIError as e:
# Check for alternative context overflow error messages
error_message = str(e)
if any(overflow_msg in error_message for overflow_msg in CONTEXT_OVERFLOW_MESSAGES):
logger.warning("context window overflow error detected")
raise ContextWindowOverflowException(error_message) from e
# Re-raise other APIError exceptions
raise

parsed: T | None = None
# Find the first choice with tool_calls
Expand Down
86 changes: 86 additions & 0 deletions tests/strands/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,92 @@ async def test_stream_context_overflow_exception(openai_client, model, messages)
assert exc_info.value.__cause__ == mock_error


@pytest.mark.asyncio
@pytest.mark.parametrize(
"error_message",
[
"Input is too long for requested model",
"input length and `max_tokens` exceed context limit",
"too many total text bytes",
],
)
async def test_stream_alternative_context_overflow_messages(openai_client, model, messages, error_message):
"""Test that alternative context overflow messages in APIError are properly converted."""
# Create a mock OpenAI APIError with alternative context overflow message
mock_error = openai.APIError(
message=error_message,
request=unittest.mock.MagicMock(),
body={"error": {"message": error_message}},
)

# Configure the mock client to raise the APIError
openai_client.chat.completions.create.side_effect = mock_error

# Test that the stream method converts the error properly
with pytest.raises(ContextWindowOverflowException) as exc_info:
async for _ in model.stream(messages):
pass

# Verify the exception message contains the original error
assert error_message in str(exc_info.value)
assert exc_info.value.__cause__ == mock_error


@pytest.mark.asyncio
@pytest.mark.parametrize(
"error_message",
[
"Input is too long for requested model",
"input length and `max_tokens` exceed context limit",
"too many total text bytes",
],
)
async def test_structured_output_alternative_context_overflow_messages(
openai_client, model, messages, test_output_model_cls, error_message
):
"""Test that alternative context overflow messages in APIError are properly converted in structured output."""
# Create a mock OpenAI APIError with alternative context overflow message
mock_error = openai.APIError(
message=error_message,
request=unittest.mock.MagicMock(),
body={"error": {"message": error_message}},
)

# Configure the mock client to raise the APIError
openai_client.beta.chat.completions.parse.side_effect = mock_error

# Test that the structured_output method converts the error properly
with pytest.raises(ContextWindowOverflowException) as exc_info:
async for _ in model.structured_output(test_output_model_cls, messages):
pass

# Verify the exception message contains the original error
assert error_message in str(exc_info.value)
assert exc_info.value.__cause__ == mock_error


@pytest.mark.asyncio
async def test_stream_api_error_passthrough(openai_client, model, messages):
"""Test that APIError without overflow messages passes through unchanged."""
# Create a mock OpenAI APIError without overflow message
mock_error = openai.APIError(
message="Some other API error",
request=unittest.mock.MagicMock(),
body={"error": {"message": "Some other API error"}},
)

# Configure the mock client to raise the APIError
openai_client.chat.completions.create.side_effect = mock_error

# Test that APIError without overflow messages passes through
with pytest.raises(openai.APIError) as exc_info:
async for _ in model.stream(messages):
pass

# Verify the original exception is raised, not ContextWindowOverflowException
assert exc_info.value == mock_error


@pytest.mark.asyncio
async def test_stream_other_bad_request_errors_passthrough(openai_client, model, messages):
"""Test that other BadRequestError exceptions are not converted to ContextWindowOverflowException."""
Expand Down
Loading