Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions tests/e2e/features/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
from tests.e2e.utils.prow_utils import restore_llama_stack_pod
from behave.runner import Context

from tests.e2e.utils.llama_stack_shields import (
register_shield,
unregister_shield,
)
from tests.e2e.utils.utils import (
create_config_backup,
is_prow_environment,
Expand Down Expand Up @@ -169,6 +173,27 @@ def before_scenario(context: Context, scenario: Scenario) -> None:
scenario.skip("Skipped in library mode (no separate llama-stack container)")
return

# @disable-shields: unregister shield via client.shields.delete("llama-guard").
# Only in server mode: in library mode there is no separate Llama Stack to call,
# and unregistering in the test process would not affect the app's in-process instance.
if "disable-shields" in scenario.effective_tags:
if context.is_library_mode:
scenario.skip(
"Shield unregister/register only applies in server mode (Llama Stack as a "
"separate service). In library mode the app's shields cannot be disabled from e2e."
)
return
try:
saved = unregister_shield("llama-guard")
context.llama_guard_provider_id = saved[0] if saved else None
context.llama_guard_provider_shield_id = saved[1] if saved else None
print("Unregistered shield llama-guard for this scenario")
Comment on lines +176 to +190
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid unconditional shield re-registration; it can corrupt scenario isolation.

At Line 246, register_shield(...) runs for every @disable-shields scenario. But Line 187 can return None (shield absent), and then register_shield falls back to defaults, potentially creating a shield that did not exist before the scenario.

🔧 Suggested fix
 def before_scenario(context: Context, scenario: Scenario) -> None:
+    context.llama_guard_restore_required = False
+
     # `@disable-shields`: unregister shield via client.shields.delete("llama-guard").
     # Only in server mode: in library mode there is no separate Llama Stack to call,
     # and unregistering in the test process would not affect the app's in-process instance.
     if "disable-shields" in scenario.effective_tags:
@@
         try:
             saved = unregister_shield("llama-guard")
-            context.llama_guard_provider_id = saved[0] if saved else None
-            context.llama_guard_provider_shield_id = saved[1] if saved else None
-            print("Unregistered shield llama-guard for this scenario")
+            if saved:
+                context.llama_guard_provider_id = saved[0]
+                context.llama_guard_provider_shield_id = saved[1]
+                context.llama_guard_restore_required = True
+                print("Unregistered shield llama-guard for this scenario")
+            else:
+                context.llama_guard_provider_id = None
+                context.llama_guard_provider_shield_id = None
+                print("Shield llama-guard was not registered; nothing to restore")
         except Exception as e:  # pylint: disable=broad-exception-caught
             scenario.skip(
                 f"Could not unregister shield (is Llama Stack reachable?): {e}"
             )
             return
@@
-    if "disable-shields" in scenario.effective_tags:
+    if (
+        "disable-shields" in scenario.effective_tags
+        and not context.is_library_mode
+        and getattr(context, "llama_guard_restore_required", False)
+    ):
         try:
             provider_id = getattr(context, "llama_guard_provider_id", None)
             provider_shield_id = getattr(
                 context, "llama_guard_provider_shield_id", None
             )
             register_shield(
                 "llama-guard",
                 provider_id=provider_id,
                 provider_shield_id=provider_shield_id,
             )
             print("Re-registered shield llama-guard")
+            context.llama_guard_restore_required = False
         except Exception as e:  # pylint: disable=broad-exception-caught
             print(f"Warning: Could not re-register shield: {e}")

Also applies to: 245-256

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/e2e/features/environment.py` around lines 176 - 190, When restoring
shields after a `@disable-shields` scenario, don't unconditionally call
register_shield with defaults if unregister_shield returned no previous
provider; change the restore logic to check the saved tuple from
unregister_shield (the values stored on context.llama_guard_provider_id and
context.llama_guard_provider_shield_id) and only call register_shield to
re-create the shield when saved is truthy (i.e., the shield existed before). If
saved is falsy/None, skip register_shield so you don't create a new default
shield and thus preserve scenario isolation; update any code paths that assume
register_shield will always run to use this presence check instead.

except Exception as e: # pylint: disable=broad-exception-caught
scenario.skip(
f"Could not unregister shield (is Llama Stack reachable?): {e}"
)
return

mode_dir = "library-mode" if context.is_library_mode else "server-mode"

if "InvalidFeedbackStorageConfig" in scenario.effective_tags:
Expand Down Expand Up @@ -217,6 +242,53 @@ def after_scenario(context: Context, scenario: Scenario) -> None:
switch_config(context.feature_config)
restart_container("lightspeed-stack")

# @disable-shields: re-register shield (server mode only; library mode skipped above)
if "disable-shields" in scenario.effective_tags:
try:
provider_id = getattr(context, "llama_guard_provider_id", None)
provider_shield_id = getattr(
context, "llama_guard_provider_shield_id", None
)
register_shield(
"llama-guard",
provider_id=provider_id,
provider_shield_id=provider_shield_id,
)
print("Re-registered shield llama-guard")
except Exception as e: # pylint: disable=broad-exception-caught
print(f"Warning: Could not re-register shield: {e}")


def _print_llama_stack_diagnostics() -> None:
"""Print container state, health, and recent logs to diagnose why llama-stack did not recover."""
print("--- llama-stack diagnostics ---")
for label, cmd in [
("State", ["docker", "inspect", "--format={{.State}}", "llama-stack"]),
("Health", ["docker", "inspect", "--format={{.State.Health}}", "llama-stack"]),
]:
try:
r = subprocess.run(
cmd, capture_output=True, text=True, timeout=5, check=False
)
print(f" {label}: {r.stdout.strip() if r.stdout else r.stderr or 'N/A'}")
except subprocess.TimeoutExpired:
print(f" {label}: (inspect timed out)")
try:
r = subprocess.run(
["docker", "logs", "--tail", "40", "llama-stack"],
capture_output=True,
text=True,
timeout=10,
check=False,
)
out = (r.stdout or "") + (r.stderr or "")
print(" Logs (last 40 lines):")
for line in out.strip().splitlines():
print(f" {line}")
except subprocess.TimeoutExpired:
print(" Logs: (timed out)")
print("--- end diagnostics ---")


def _restore_llama_stack(context: Context) -> None:
"""Restore Llama Stack connection after disruption."""
Expand Down Expand Up @@ -263,9 +335,15 @@ def _restore_llama_stack(context: Context) -> None:
time.sleep(5)
else:
print("Warning: Llama Stack may not be fully healthy after restoration")
_print_llama_stack_diagnostics()

except subprocess.CalledProcessError as e:
print(f"Warning: Could not restore Llama Stack connection: {e}")
if e.stderr:
print(f" docker start stderr: {e.stderr}")
if e.stdout:
print(f" docker start stdout: {e.stdout}")
_print_llama_stack_diagnostics()


def before_feature(context: Context, feature: Feature) -> None:
Expand Down
17 changes: 17 additions & 0 deletions tests/e2e/features/query.feature
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,20 @@ Scenario: Check if LLM responds for query request with error for missing query
}
"""
Then The status code of the response is 200

Scenario: Check if query with shields returns 413 when question is too long for model context
Given The system is in default state
And I set the Authorization header to Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva
When I use "query" to ask question with too-long query and authorization header
Then The status code of the response is 413
And The body of the response contains Prompt is too long

#https://issues.redhat.com/browse/LCORE-1387
@skip
@disable-shields
Scenario: Check if query without shields returns 413 when question is too long for model context
Given The system is in default state
And I set the Authorization header to Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva
When I use "query" to ask question with too-long query and authorization header
Then The status code of the response is 413
And The body of the response contains Prompt is too long
82 changes: 75 additions & 7 deletions tests/e2e/features/steps/llm_query_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
def wait_for_complete_response(context: Context) -> None:
"""Wait for the response to be complete."""
context.response_data = _parse_streaming_response(context.response.text)
print(context.response_data)
context.response.raise_for_status()
assert context.response_data["finished"] is True

Expand All @@ -31,10 +30,21 @@ def ask_question(context: Context, endpoint: str) -> None:
json_str = replace_placeholders(context, context.text or "{}")

data = json.loads(json_str)
print(f"Request data: {data}")
context.response = requests.post(url, json=data, timeout=DEFAULT_LLM_TIMEOUT)


def _read_streamed_response(response: requests.Response) -> str:
"""Read a streaming response body, tolerating premature close (e.g. after error event)."""
chunks = []
try:
for line in response.iter_lines(decode_unicode=True):
if line is not None:
chunks.append(line + "\n")
except requests.exceptions.ChunkedEncodingError:
pass # Server may close stream after sending an error event
return "".join(chunks)


@step('I use "{endpoint}" to ask question with authorization header')
def ask_question_authorized(context: Context, endpoint: str) -> None:
"""Call the service REST API endpoint with question."""
Expand All @@ -46,10 +56,40 @@ def ask_question_authorized(context: Context, endpoint: str) -> None:
json_str = replace_placeholders(context, context.text or "{}")

data = json.loads(json_str)
print(f"Request data: {data}")
context.response = requests.post(
url, json=data, headers=context.auth_headers, timeout=DEFAULT_LLM_TIMEOUT
)
if endpoint == "streaming_query":
resp = requests.post(
url,
json=data,
headers=context.auth_headers,
timeout=DEFAULT_LLM_TIMEOUT,
stream=True,
)
# Consume stream so server close after error event does not raise
body = _read_streamed_response(resp)
resp._content = body.encode(resp.encoding or "utf-8")
context.response = resp
else:
context.response = requests.post(
url, json=data, headers=context.auth_headers, timeout=DEFAULT_LLM_TIMEOUT
)


# Query length chosen to exceed typical model context windows (e.g. 128k tokens)
_TOO_LONG_QUERY_LENGTH = 80_000


@step('I use "{endpoint}" to ask question with too-long query and authorization header')
def ask_question_too_long_authorized(context: Context, endpoint: str) -> None:
"""Call the query endpoint with a query string that exceeds model context (expect 413)."""
long_query = "what is openshift?" * _TOO_LONG_QUERY_LENGTH
payload = {
"query": long_query,
"model": context.default_model,
"provider": context.default_provider,
}
context.text = json.dumps(payload)
print(f"Request: query length={len(long_query)}, model={context.default_model}")
ask_question_authorized(context, endpoint)


@step("I store conversation details")
Expand All @@ -72,7 +112,6 @@ def ask_question_in_same_conversation(context: Context, endpoint: str) -> None:
headers = context.auth_headers if hasattr(context, "auth_headers") else {}
data["conversation_id"] = context.response_data["conversation_id"]

print(f"Request data: {data}")
context.response = requests.post(
url, json=data, headers=headers, timeout=DEFAULT_LLM_TIMEOUT
)
Expand Down Expand Up @@ -142,6 +181,29 @@ def check_streamed_fragments_in_response(context: Context) -> None:
), f"Fragment '{expected}' not found in LLM response: '{response}'"


@then("The streamed response contains error message {message}")
def check_streamed_response_error_message(context: Context, message: str) -> None:
"""Check that the streamed SSE response contains an error event with the given message.

Parses the response body as SSE, asserts that an event with event type 'error' is
present, and that its 'response' or 'cause' field contains the given message.
Use for streaming endpoints when the error is delivered in the stream (e.g. 200 + error event).
"""
assert context.response is not None, "Request needs to be performed first"
print(context.response.text)
parsed = _parse_streaming_response(context.response.text)
stream_error = parsed.get("stream_error")
assert (
stream_error is not None
), "No error event in stream. Expected an SSE event with event type 'error'."
response_text = str(stream_error.get("response", ""))
cause_text = str(stream_error.get("cause", ""))
assert message in response_text or message in cause_text, (
f"Expected error message '{message}' not found in stream error event: "
f"response={response_text!r}, cause={cause_text!r}"
)


@then("The streamed response is equal to the full response")
def compare_streamed_responses(context: Context) -> None:
"""Check that streamed response is equal to complete response.
Expand Down Expand Up @@ -171,6 +233,9 @@ def _parse_streaming_response(response_text: str) -> dict:
full_response_split = []
finished = False
first_token = True
stream_error = (
None # {"status_code": int, "response": str, "cause": str} if event "error"
)

for line in lines:
if line.startswith("data: "):
Expand All @@ -190,6 +255,8 @@ def _parse_streaming_response(response_text: str) -> dict:
full_response = data["data"]["token"]
elif event == "end":
finished = True
elif event == "error":
stream_error = data.get("data") or {}
except json.JSONDecodeError:
continue # Skip malformed lines

Expand All @@ -198,4 +265,5 @@ def _parse_streaming_response(response_text: str) -> dict:
"response": "".join(full_response_split),
"response_complete": full_response,
"finished": finished,
"stream_error": stream_error,
}
15 changes: 15 additions & 0 deletions tests/e2e/features/streaming_query.feature
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,18 @@ Feature: streaming_query endpoint API tests
}
}
"""

Scenario: Check if streaming_query with shields returns 413 when question is too long for model context
Given The system is in default state
And I set the Authorization header to Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva
When I use "streaming_query" to ask question with too-long query and authorization header
Then The status code of the response is 413
And The body of the response contains Prompt is too long

@disable-shields
Scenario: Check if streaming_query without shields returns 200 and error in stream when question is too long for model context
Given The system is in default state
And I set the Authorization header to Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva
When I use "streaming_query" to ask question with too-long query and authorization header
Then The status code of the response is 200
And The streamed response contains error message Prompt is too long
104 changes: 104 additions & 0 deletions tests/e2e/utils/llama_stack_shields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""E2E helpers to unregister and re-register Llama Stack shields via the client API.

Used by the @disable-shields tag: before the scenario we call client.shields.delete()
to unregister the shield; after the scenario we call client.shields.register()
to restore it. Only applies in server mode (Llama Stack as a separate service).
Requires E2E_LLAMA_STACK_URL or E2E_LLAMA_HOSTNAME/E2E_LLAMA_PORT.
"""

import asyncio
import os
from typing import Optional

from llama_stack_client import (
APIConnectionError,
AsyncLlamaStackClient,
APIStatusError,
)


def _get_llama_stack_client() -> AsyncLlamaStackClient:
"""Build an AsyncLlamaStackClient from env (for e2e test use)."""
base_url = os.getenv("E2E_LLAMA_STACK_URL")
if not base_url:
host = os.getenv("E2E_LLAMA_HOSTNAME", "localhost")
port = os.getenv("E2E_LLAMA_PORT", "8321")
base_url = f"http://{host}:{port}"
api_key = os.getenv("E2E_LLAMA_STACK_API_KEY", "xyzzy")
timeout = int(os.getenv("E2E_LLAMA_STACK_TIMEOUT", "60"))
return AsyncLlamaStackClient(base_url=base_url, api_key=api_key, timeout=timeout)


async def _unregister_shield_async(identifier: str) -> Optional[tuple[str, str]]:
"""Unregister a shield by identifier; return (provider_id, provider_shield_id) for restore."""
client = _get_llama_stack_client()
try:
shields = await client.shields.list()
provider_id = None
provider_shield_id = None
found = False
for shield in shields:
if getattr(shield, "identifier", None) == identifier:
provider_id = getattr(shield, "provider_id", None)
provider_shield_id = getattr(
shield, "provider_resource_id", None
) or getattr(shield, "provider_shield_id", None)
found = True
break
if not found:
# Shield not registered; nothing to delete, scenario can proceed
return None
try:
await client.shields.delete(identifier)
except APIConnectionError:
raise
except APIStatusError as e:
# 400 "not found": shield already absent, scenario can proceed
if e.status_code == 400 and "not found" in str(e).lower():
return None
raise
if provider_id is not None and provider_shield_id is not None:
return (provider_id, provider_shield_id)
return None
finally:
await client.close()


async def _register_shield_async(
shield_id: str,
provider_id: str,
provider_shield_id: str,
) -> None:
"""Register a shield (restore after unregister)."""
client = _get_llama_stack_client()
try:
await client.shields.register(
shield_id=shield_id,
provider_id=provider_id,
provider_shield_id=provider_shield_id,
)
finally:
await client.close()


def unregister_shield(
identifier: str = "llama-guard",
) -> Optional[tuple[str, str]]:
"""Unregister the shield via client.shields.delete(); return (provider_id, provider_shield_id)."""
return asyncio.run(_unregister_shield_async(identifier))


def register_shield(
shield_id: str = "llama-guard",
provider_id: Optional[str] = None,
provider_shield_id: Optional[str] = None,
) -> None:
"""Re-register the shield via client.shields.register()."""
if not provider_id:
provider_id = os.getenv("E2E_LLAMA_GUARD_PROVIDER_ID", "llama-guard")
if not provider_shield_id:
provider_shield_id = os.getenv(
"E2E_LLAMA_GUARD_PROVIDER_SHIELD_ID",
"openai/gpt-4o-mini",
)
asyncio.run(_register_shield_async(shield_id, provider_id, provider_shield_id))
Loading