Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 50 additions & 7 deletions src/app/endpoints/rlsapi_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,13 @@
from models.rlsapi.requests import RlsapiV1InferRequest, RlsapiV1SystemInfo
from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse
from observability import InferenceEventData, build_inference_event, send_splunk_event
from utils.query import handle_known_apistatus_errors
from utils.query import (
extract_provider_and_model_from_model_id,
handle_known_apistatus_errors,
)
from utils.responses import (
extract_text_from_response_items,
extract_token_usage,
get_mcp_tools,
)
from utils.suid import get_suid
Expand Down Expand Up @@ -191,6 +195,7 @@ async def retrieve_simple_response(
store=False,
)
response = cast(OpenAIResponseObject, response)
extract_token_usage(response.usage, model_id)

return extract_text_from_response_items(response.output)

Expand Down Expand Up @@ -242,6 +247,8 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
request_id: str,
error: Exception,
start_time: float,
model: str,
provider: str,
) -> float:
"""Record metrics and queue Splunk event for an inference failure.

Expand All @@ -257,7 +264,7 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
The total inference time in seconds.
"""
inference_time = time.monotonic() - start_time
metrics.llm_calls_failures_total.inc()
metrics.llm_calls_failures_total.labels(provider, model).inc()
_queue_splunk_event(
background_tasks,
infer_request,
Expand Down Expand Up @@ -307,6 +314,7 @@ async def infer_endpoint(
input_source = infer_request.get_input_source()
instructions = _build_instructions(infer_request.context.systeminfo)
model_id = _get_default_model_id()
model, provider = extract_provider_and_model_from_model_id(model_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Provider/model destructuring is reversed, so failure labels are flipped.

At Line 317, the helper returns (provider, model), but the assignment stores (model, provider). That inverts llm_calls_failures_total labels downstream.

🔧 Proposed fix
-    model, provider = extract_provider_and_model_from_model_id(model_id)
+    provider, model = extract_provider_and_model_from_model_id(model_id)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
model, provider = extract_provider_and_model_from_model_id(model_id)
provider, model = extract_provider_and_model_from_model_id(model_id)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/app/endpoints/rlsapi_v1.py` at line 317, The tuple returned by
extract_provider_and_model_from_model_id is being unpacked in reverse; change
the assignment at the call site so the first value is assigned to provider and
the second to model (i.e., provider, model =
extract_provider_and_model_from_model_id(model_id)) so downstream labels like
llm_calls_failures_total use the correct provider and model values.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might be flipped and it should be provider, model.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AH, I AM SLOWER THAN @CodeRabbit

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Find the implementation of extract_provider_and_model_from_model_id function
ast-grep --pattern $'def extract_provider_and_model_from_model_id($$$) {
  $$$
}'

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Also search with rg to find the function definition
rg -n -A 10 'def extract_provider_and_model_from_model_id'

Length of output: 570


🏁 Script executed:

#!/bin/bash
# Get more lines to see the complete return statement
rg -n -A 15 'def extract_provider_and_model_from_model_id'

Length of output: 780


@major 😄 Great minds think alike! Thanks for double-checking - it's always good to have multiple eyes on these details. The function indeed returns (provider, model) based on how it splits the model_id string.

mcp_tools = await get_mcp_tools(request_headers=request.headers)
logger.debug(
"Request %s: Combined input source length: %d", request_id, len(input_source)
Expand All @@ -321,19 +329,40 @@ async def infer_endpoint(
except RuntimeError as e:
if "context_length" in str(e).lower():
_record_inference_failure(
background_tasks, infer_request, request, request_id, e, start_time
background_tasks,
infer_request,
request,
request_id,
e,
start_time,
model,
provider,
)
logger.error("Prompt too long for request %s: %s", request_id, e)
error_response = PromptTooLongResponse(model=model_id)
raise HTTPException(**error_response.model_dump()) from e
_record_inference_failure(
background_tasks, infer_request, request, request_id, e, start_time
background_tasks,
infer_request,
request,
request_id,
e,
start_time,
model,
provider,
)
logger.error("Unexpected RuntimeError for request %s: %s", request_id, e)
raise
except APIConnectionError as e:
_record_inference_failure(
background_tasks, infer_request, request, request_id, e, start_time
background_tasks,
infer_request,
request,
request_id,
e,
start_time,
model,
provider,
)
logger.error(
"Unable to connect to Llama Stack for request %s: %s", request_id, e
Expand All @@ -345,7 +374,14 @@ async def infer_endpoint(
raise HTTPException(**error_response.model_dump()) from e
except RateLimitError as e:
_record_inference_failure(
background_tasks, infer_request, request, request_id, e, start_time
background_tasks,
infer_request,
request,
request_id,
e,
start_time,
model,
provider,
)
logger.error("Rate limit exceeded for request %s: %s", request_id, e)
error_response = QuotaExceededResponse(
Expand All @@ -355,7 +391,14 @@ async def infer_endpoint(
raise HTTPException(**error_response.model_dump()) from e
except (APIStatusError, OpenAIAPIStatusError) as e:
_record_inference_failure(
background_tasks, infer_request, request, request_id, e, start_time
background_tasks,
infer_request,
request,
request_id,
e,
start_time,
model,
provider,
)
logger.exception("API error for request %s: %s", request_id, e)
error_response = handle_known_apistatus_errors(e, model_id)
Expand Down
4 changes: 3 additions & 1 deletion src/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
)

# Metric that counts how many LLM calls failed
llm_calls_failures_total = Counter("ls_llm_calls_failures_total", "LLM calls failures")
llm_calls_failures_total = Counter(
"ls_llm_calls_failures_total", "LLM calls failures", ["provider", "model"]
)

# Metric that counts how many LLM calls had validation errors
llm_calls_validation_errors_total = Counter(
Expand Down
1 change: 0 additions & 1 deletion tests/unit/app/endpoints/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ async def test_metrics_endpoint(mocker: MockerFixture) -> None:
assert "# TYPE ls_provider_model_configuration gauge" in response_body
assert "# TYPE ls_llm_calls_total counter" in response_body
assert "# TYPE ls_llm_calls_failures_total counter" in response_body
assert "# TYPE ls_llm_calls_failures_created gauge" in response_body
assert "# TYPE ls_llm_validation_errors_total counter" in response_body
assert "# TYPE ls_llm_validation_errors_created gauge" in response_body
assert "# TYPE ls_llm_token_sent_total counter" in response_body
Expand Down
Loading