diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index d9266212b..df55265e8 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -27,6 +27,15 @@ T = TypeVar("T", bound=BaseModel) +# Alternative context overflow error messages +# These are commonly returned by OpenAI-compatible endpoints wrapping other providers +# (e.g., Databricks serving Bedrock models) +CONTEXT_OVERFLOW_MESSAGES = [ + "Input is too long for requested model", + "input length and `max_tokens` exceed context limit", + "too many total text bytes", +] + class Client(Protocol): """Protocol defining the OpenAI-compatible interface for the underlying provider client.""" @@ -594,6 +603,14 @@ async def stream( # Rate limits (including TPM) require waiting/retrying, not context reduction logger.warning("OpenAI threw rate limit error") raise ModelThrottledException(str(e)) from e + except openai.APIError as e: + # Check for alternative context overflow error messages + error_message = str(e) + if any(overflow_msg in error_message for overflow_msg in CONTEXT_OVERFLOW_MESSAGES): + logger.warning("context window overflow error detected") + raise ContextWindowOverflowException(error_message) from e + # Re-raise other APIError exceptions + raise logger.debug("got response from model") yield self.format_chunk({"chunk_type": "message_start"}) @@ -717,6 +734,14 @@ async def structured_output( # Rate limits (including TPM) require waiting/retrying, not context reduction logger.warning("OpenAI threw rate limit error") raise ModelThrottledException(str(e)) from e + except openai.APIError as e: + # Check for alternative context overflow error messages + error_message = str(e) + if any(overflow_msg in error_message for overflow_msg in CONTEXT_OVERFLOW_MESSAGES): + logger.warning("context window overflow error detected") + raise ContextWindowOverflowException(error_message) from e + # Re-raise other APIError exceptions + raise parsed: T | None = None # Find the first choice with tool_calls diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 7c1d18998..849672af0 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -1034,6 +1034,92 @@ async def test_stream_context_overflow_exception(openai_client, model, messages) assert exc_info.value.__cause__ == mock_error +@pytest.mark.asyncio +@pytest.mark.parametrize( + "error_message", + [ + "Input is too long for requested model", + "input length and `max_tokens` exceed context limit", + "too many total text bytes", + ], +) +async def test_stream_alternative_context_overflow_messages(openai_client, model, messages, error_message): + """Test that alternative context overflow messages in APIError are properly converted.""" + # Create a mock OpenAI APIError with alternative context overflow message + mock_error = openai.APIError( + message=error_message, + request=unittest.mock.MagicMock(), + body={"error": {"message": error_message}}, + ) + + # Configure the mock client to raise the APIError + openai_client.chat.completions.create.side_effect = mock_error + + # Test that the stream method converts the error properly + with pytest.raises(ContextWindowOverflowException) as exc_info: + async for _ in model.stream(messages): + pass + + # Verify the exception message contains the original error + assert error_message in str(exc_info.value) + assert exc_info.value.__cause__ == mock_error + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "error_message", + [ + "Input is too long for requested model", + "input length and `max_tokens` exceed context limit", + "too many total text bytes", + ], +) +async def test_structured_output_alternative_context_overflow_messages( + openai_client, model, messages, test_output_model_cls, error_message +): + """Test that alternative context overflow messages in APIError are properly converted in structured output.""" + # Create a mock OpenAI APIError with alternative context overflow message + mock_error = openai.APIError( + message=error_message, + request=unittest.mock.MagicMock(), + body={"error": {"message": error_message}}, + ) + + # Configure the mock client to raise the APIError + openai_client.beta.chat.completions.parse.side_effect = mock_error + + # Test that the structured_output method converts the error properly + with pytest.raises(ContextWindowOverflowException) as exc_info: + async for _ in model.structured_output(test_output_model_cls, messages): + pass + + # Verify the exception message contains the original error + assert error_message in str(exc_info.value) + assert exc_info.value.__cause__ == mock_error + + +@pytest.mark.asyncio +async def test_stream_api_error_passthrough(openai_client, model, messages): + """Test that APIError without overflow messages passes through unchanged.""" + # Create a mock OpenAI APIError without overflow message + mock_error = openai.APIError( + message="Some other API error", + request=unittest.mock.MagicMock(), + body={"error": {"message": "Some other API error"}}, + ) + + # Configure the mock client to raise the APIError + openai_client.chat.completions.create.side_effect = mock_error + + # Test that APIError without overflow messages passes through + with pytest.raises(openai.APIError) as exc_info: + async for _ in model.stream(messages): + pass + + # Verify the original exception is raised, not ContextWindowOverflowException + assert exc_info.value == mock_error + + @pytest.mark.asyncio async def test_stream_other_bad_request_errors_passthrough(openai_client, model, messages): """Test that other BadRequestError exceptions are not converted to ContextWindowOverflowException."""