Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions config/system.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ api:
cache_enabled: true # Is lightspeed-stack cache enabled?
# Authentication via API_KEY environment variable only for MCP server

# Retry configuration for 429 Too Many Requests API errors
num_retries: 3 # Number of retry attempts (default 3)

# Default metrics metadata
metrics_metadata:
# Turn-level metrics metadata
Expand Down
102 changes: 72 additions & 30 deletions src/lightspeed_evaluation/core/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,18 @@
import json
import logging
import os
from typing import Optional, cast
from typing import Any, Optional, cast

import httpx
from diskcache import Cache
from tenacity import (
retry,
retry_if_exception,
stop_after_attempt,
wait_exponential,
before_sleep_log,
RetryError,
)

from lightspeed_evaluation.core.api.streaming_parser import parse_streaming_response
from lightspeed_evaluation.core.constants import (
Expand All @@ -19,6 +27,14 @@
logger = logging.getLogger(__name__)


def _is_too_many_requests_error(exception: BaseException) -> bool:
"""Check if exception is a 429 error."""
return (
isinstance(exception, httpx.HTTPStatusError)
and exception.response.status_code == 429
)


class APIClient:
"""API client for actual data generation."""

Expand All @@ -28,10 +44,6 @@ def __init__(
):
"""Initialize the client with configuration."""
self.config = config
self.api_base = config.api_base
self.version = config.version
self.endpoint_type = config.endpoint_type
self.timeout = config.timeout

self.client: Optional[httpx.Client] = None

Expand All @@ -43,11 +55,27 @@ def __init__(
self._validate_endpoint_type()
self._setup_client()

# Wrap methods with retry decorator for handling 429 Too Many Requests errors
retry_decorator = self._create_retry_decorator()
self._standard_query_with_retry = retry_decorator(self._standard_query)
self._streaming_query_with_retry = retry_decorator(self._streaming_query)

def _create_retry_decorator(self) -> Any:
return retry(
retry=retry_if_exception(_is_too_many_requests_error),
stop=stop_after_attempt(
self.config.num_retries + 1
), # +1 to account for the initial attempt
Comment thread
asamal4 marked this conversation as resolved.
wait=wait_exponential(multiplier=1, min=4, max=60), # multiplier * 2^x
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=False, # If all retry attempts are exhausted, RetryError is raised
)

def _validate_endpoint_type(self) -> None:
"""Validate endpoint type is supported."""
if self.endpoint_type not in SUPPORTED_ENDPOINT_TYPES:
if self.config.endpoint_type not in SUPPORTED_ENDPOINT_TYPES:
raise APIError(
f"Unsupported endpoint type: {self.endpoint_type}. "
f"Unsupported endpoint type: {self.config.endpoint_type}. "
f"Must be one of {SUPPORTED_ENDPOINT_TYPES}"
)

Expand All @@ -57,7 +85,9 @@ def _setup_client(self) -> None:
# Enable verify, currently for eval it is set to False
verify = False
self.client = httpx.Client(
base_url=self.api_base, verify=verify, timeout=self.timeout
base_url=self.config.api_base,
verify=verify,
timeout=self.config.timeout,
)
self.client.headers.update({"Content-Type": "application/json"})

Expand Down Expand Up @@ -88,22 +118,28 @@ def query(
if not self.client:
raise APIError("API client not initialized")

api_request = self._prepare_request(query, conversation_id, attachments)
if self.config.cache_enabled:
cached_response = self._get_cached_response(api_request)
if cached_response is not None:
logger.debug("Returning cached response for query: '%s'", query)
return cached_response

if self.endpoint_type == "streaming":
response = self._streaming_query(api_request)
else:
response = self._standard_query(api_request)

if self.config.cache_enabled:
self._add_response_to_cache(api_request, response)

return response
try:
api_request = self._prepare_request(query, conversation_id, attachments)
if self.config.cache_enabled:
cached_response = self._get_cached_response(api_request)
if cached_response is not None:
logger.debug("Returning cached response for query: '%s'", query)
return cached_response

if self.config.endpoint_type == "streaming":
response = self._streaming_query_with_retry(api_request)
else:
response = self._standard_query_with_retry(api_request)

if self.config.cache_enabled:
self._add_response_to_cache(api_request, response)

return response
except RetryError as e:
raise APIError(
f"Maximum retry attempts ({self.config.num_retries}) reached "
"due to persistent rate limiting (HTTP 429)."
) from e

def _prepare_request(
self,
Expand All @@ -123,12 +159,12 @@ def _prepare_request(
)

def _standard_query(self, api_request: APIRequest) -> APIResponse:
"""Query the API using non-streaming endpoint."""
"""Query the API using non-streaming endpoint with retry on 429."""
if not self.client:
raise APIError("HTTP client not initialized")
try:
response = self.client.post(
f"/{self.version}/query",
f"/{self.config.version}/query",
Comment thread
asamal4 marked this conversation as resolved.
json=api_request.model_dump(exclude_none=True),
)
response.raise_for_status()
Expand Down Expand Up @@ -165,8 +201,11 @@ def _standard_query(self, api_request: APIRequest) -> APIResponse:
return APIResponse.from_raw_response(response_data)

except httpx.TimeoutException as e:
raise self._handle_timeout_error("standard", self.timeout) from e
raise self._handle_timeout_error("standard", self.config.timeout) from e
except httpx.HTTPStatusError as e:
# Re-raise 429 errors without conversion to allow retry decorator to handle them
if e.response.status_code == 429:
raise
raise self._handle_http_error(e) from e
except ValueError as e:
raise self._handle_validation_error(e) from e
Expand All @@ -182,17 +221,20 @@ def _streaming_query(self, api_request: APIRequest) -> APIResponse:
try:
with self.client.stream(
"POST",
f"/{self.version}/streaming_query",
f"/{self.config.version}/streaming_query",
json=api_request.model_dump(exclude_none=True),
) as response:
self._handle_response_errors(response)
raw_data = parse_streaming_response(response)
return APIResponse.from_raw_response(raw_data)

except httpx.TimeoutException as e:
raise self._handle_timeout_error("streaming", self.timeout) from e
raise self._handle_timeout_error("streaming", self.config.timeout) from e
except httpx.HTTPStatusError as e:
raise APIError(str(e)) from e
# Re-raise 429 errors without conversion to allow retry decorator to handle them
if e.response.status_code == 429:
raise
raise self._handle_http_error(e) from e
except ValueError as e:
raise self._handle_validation_error(e) from e
except APIError:
Expand Down
2 changes: 2 additions & 0 deletions src/lightspeed_evaluation/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
SUPPORTED_ENDPOINT_TYPES = ["streaming", "query"]
DEFAULT_API_CACHE_DIR = ".caches/api_cache"

DEFAULT_API_NUM_RETRIES = 3

DEFAULT_LLM_PROVIDER = "openai"
DEFAULT_LLM_MODEL = "gpt-4o-mini"
DEFAULT_SSL_VERIFY = True
Expand Down
9 changes: 9 additions & 0 deletions src/lightspeed_evaluation/core/models/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
DEFAULT_API_CACHE_DIR,
DEFAULT_API_TIMEOUT,
DEFAULT_API_VERSION,
DEFAULT_API_NUM_RETRIES,
DEFAULT_BASE_FILENAME,
DEFAULT_EMBEDDING_CACHE_DIR,
DEFAULT_EMBEDDING_MODEL,
Expand Down Expand Up @@ -195,6 +196,14 @@ class APIConfig(BaseModel):
cache_enabled: bool = Field(
default=True, description="Is caching of lightspeed-stack queries enabled?"
)
num_retries: int = Field(
default=DEFAULT_API_NUM_RETRIES,
ge=0,
description=(
"Maximum number of retry attempts for API calls on "
"429 Too Many Requests errors"
),
)

@field_validator("endpoint_type")
@classmethod
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/core/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@


@pytest.fixture
def api_config() -> APIConfig:
"""Create test API config."""
def basic_api_config_query_endpoint() -> APIConfig:
"""Create test API config for query endpoint."""
return APIConfig(
enabled=True,
api_base="http://localhost:8080",
Expand All @@ -22,8 +22,8 @@ def api_config() -> APIConfig:


@pytest.fixture
def basic_api_config() -> APIConfig:
"""Create basic API configuration for streaming."""
def basic_api_config_streaming_endpoint() -> APIConfig:
"""Create basic API configuration for streaming endpoint."""
return APIConfig(
enabled=True,
api_base="http://localhost:8080",
Expand Down
Loading
Loading