Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions tests/integration/endpoints/test_query_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,111 @@ async def test_query_v2_endpoint_handles_connection_error(
assert "cause" in exc_info.value.detail


@pytest.mark.asyncio
async def test_query_v2_endpoint_returns_401_with_www_authenticate_when_mcp_oauth_required(
test_config: AppConfig,
mock_llama_stack_client: AsyncMockType,
test_request: Request,
test_auth: AuthTuple,
mocker: MockerFixture,
) -> None:
"""Test query endpoint returns 401 with WWW-Authenticate when MCP server requires OAuth.

When prepare_tools calls get_mcp_tools and an MCP server is configured for OAuth
without client-provided headers, get_mcp_tools raises 401 with WWW-Authenticate.
This test verifies the query handler propagates that response to the client.

Parameters:
test_config: Test configuration
mock_llama_stack_client: Mocked Llama Stack client
test_request: FastAPI request
test_auth: noop authentication tuple
mocker: pytest-mock fixture
"""
_ = test_config
_ = mock_llama_stack_client

expected_www_auth = 'Bearer realm="oauth"'
oauth_401 = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={"cause": "MCP server at http://example.com requires OAuth"},
headers={"WWW-Authenticate": expected_www_auth},
)
mocker.patch(
"utils.responses.get_mcp_tools",
new_callable=mocker.AsyncMock,
side_effect=oauth_401,
)

query_request = QueryRequest(query="What is Ansible?")

with pytest.raises(HTTPException) as exc_info:
await query_endpoint_handler(
request=test_request,
query_request=query_request,
auth=test_auth,
mcp_headers={},
)

assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc_info.value.headers is not None
assert exc_info.value.headers.get("WWW-Authenticate") == expected_www_auth


@pytest.mark.asyncio
async def test_query_v2_endpoint_returns_401_when_oauth_probe_times_out(
test_config: AppConfig,
mock_llama_stack_client: AsyncMockType,
test_request: Request,
test_auth: AuthTuple,
mocker: MockerFixture,
) -> None:
"""Test query endpoint returns 401 when OAuth probe times out.

When prepare_responses_params calls get_mcp_tools and the MCP OAuth probe
times out (TimeoutError), get_mcp_tools raises 401 without a
WWW-Authenticate header. This test verifies the query handler propagates
that response.

Parameters:
test_config: Test configuration
mock_llama_stack_client: Mocked Llama Stack client
test_request: FastAPI request
test_auth: noop authentication tuple
mocker: pytest-mock fixture
"""
_ = test_config
_ = mock_llama_stack_client

# Probe timed out: 401 without WWW-Authenticate (same as real probe on TimeoutError)
oauth_probe_timeout_401 = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={"cause": "MCP server at http://example.com requires OAuth"},
headers=None,
)
mocker.patch(
"utils.responses.get_mcp_tools",
new_callable=mocker.AsyncMock,
side_effect=oauth_probe_timeout_401,
)

query_request = QueryRequest(query="What is Ansible?")

with pytest.raises(HTTPException) as exc_info:
await query_endpoint_handler(
request=test_request,
query_request=query_request,
auth=test_auth,
mcp_headers={},
)

assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert (
exc_info.value.headers is None
or exc_info.value.headers.get("WWW-Authenticate") is None
)


@pytest.mark.asyncio
async def test_query_v2_endpoint_empty_query(
test_config: AppConfig,
Expand Down
138 changes: 138 additions & 0 deletions tests/integration/endpoints/test_streaming_query_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Integration tests for the /streaming_query endpoint."""

from typing import Any, Generator

import pytest
from fastapi import HTTPException, Request, status
from pytest_mock import MockerFixture

from app.endpoints.streaming_query import streaming_query_endpoint_handler
from authentication.interface import AuthTuple
from configuration import AppConfig
from models.requests import QueryRequest


@pytest.fixture(name="mock_llama_stack_streaming")
def mock_llama_stack_streaming_fixture(
mocker: MockerFixture,
) -> Generator[Any, None, None]:
"""Mock the Llama Stack client for streaming_query endpoint.

Configures models.list, vector_stores.list, and conversations.create so
prepare_responses_params can run until get_mcp_tools.
"""
mock_holder_class = mocker.patch(
"app.endpoints.streaming_query.AsyncLlamaStackClientHolder"
)
mock_client = mocker.AsyncMock()

mock_model = mocker.MagicMock()
mock_model.id = "test-provider/test-model"
mock_model.custom_metadata = {
"provider_id": "test-provider",
"model_type": "llm",
}
mock_client.models.list.return_value = [mock_model]

mock_vector_stores_response = mocker.MagicMock()
mock_vector_stores_response.data = []
mock_client.vector_stores.list.return_value = mock_vector_stores_response

mock_conversation = mocker.MagicMock()
mock_conversation.id = "conv_" + "a" * 48
mock_client.conversations.create = mocker.AsyncMock(return_value=mock_conversation)

mock_holder_class.return_value.get_client.return_value = mock_client
yield mock_client


@pytest.mark.asyncio
async def test_streaming_query_endpoint_returns_401_with_www_authenticate_when_mcp_oauth_required(
test_config: AppConfig,
mock_llama_stack_streaming: Any,
test_request: Request,
test_auth: AuthTuple,
mocker: MockerFixture,
) -> None:
"""Test streaming_query returns 401 with WWW-Authenticate when MCP server requires OAuth.

When prepare_responses_params calls get_mcp_tools and an MCP server is
configured for OAuth without client-provided headers, get_mcp_tools raises
401 with WWW-Authenticate. This test verifies the streaming handler
propagates that response to the client.
"""
_ = test_config
_ = mock_llama_stack_streaming

expected_www_auth = 'Bearer realm="oauth"'
oauth_401 = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={"cause": "MCP server at http://example.com requires OAuth"},
headers={"WWW-Authenticate": expected_www_auth},
)
mocker.patch(
"utils.responses.get_mcp_tools",
new_callable=mocker.AsyncMock,
side_effect=oauth_401,
)

query_request = QueryRequest(query="What is Ansible?")

with pytest.raises(HTTPException) as exc_info:
await streaming_query_endpoint_handler(
request=test_request,
query_request=query_request,
auth=test_auth,
mcp_headers={},
)

assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc_info.value.headers is not None
assert exc_info.value.headers.get("WWW-Authenticate") == expected_www_auth


@pytest.mark.asyncio
async def test_streaming_query_returns_401_when_oauth_probe_times_out(
test_config: AppConfig,
mock_llama_stack_streaming: Any,
test_request: Request,
test_auth: AuthTuple,
mocker: MockerFixture,
) -> None:
"""Test streaming_query returns 401 when OAuth probe times out.

When prepare_responses_params calls get_mcp_tools and the MCP OAuth probe
times out (TimeoutError), get_mcp_tools raises 401 without a
WWW-Authenticate header. This test verifies the streaming handler
propagates that response.
"""
_ = test_config
_ = mock_llama_stack_streaming

# Probe timed out: 401 without WWW-Authenticate (same as real probe on TimeoutError)
oauth_probe_timeout_401 = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={"cause": "MCP server at http://example.com requires OAuth"},
headers=None,
)
mocker.patch(
"utils.responses.get_mcp_tools",
new_callable=mocker.AsyncMock,
side_effect=oauth_probe_timeout_401,
)

query_request = QueryRequest(query="What is Ansible?")

with pytest.raises(HTTPException) as exc_info:
await streaming_query_endpoint_handler(
request=test_request,
query_request=query_request,
auth=test_auth,
mcp_headers={},
)

assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert (
exc_info.value.headers is None
or exc_info.value.headers.get("WWW-Authenticate") is None
)
134 changes: 134 additions & 0 deletions tests/integration/endpoints/test_tools_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""Integration tests for the /tools endpoint."""

from typing import Any, Generator

import pytest
from fastapi import HTTPException, Request, status
from llama_stack_client import AuthenticationError
from pytest_mock import MockerFixture

from app.endpoints import tools
from authentication.interface import AuthTuple
from configuration import AppConfig


@pytest.fixture(name="mock_llama_stack_tools")
def mock_llama_stack_tools_fixture(
mocker: MockerFixture,
) -> Generator[Any, None, None]:
"""Mock the Llama Stack client for tools endpoint.

Returns:
Mock client with toolgroups.list and tools.list configured.
"""
mock_holder_class = mocker.patch("app.endpoints.tools.AsyncLlamaStackClientHolder")
mock_client = mocker.AsyncMock()
mock_holder_class.return_value.get_client.return_value = mock_client
yield mock_client


@pytest.mark.asyncio
async def test_tools_endpoint_returns_401_with_www_authenticate_when_mcp_oauth_required(
test_config: AppConfig,
mock_llama_stack_tools: Any,
test_request: Request,
test_auth: AuthTuple,
mocker: MockerFixture,
) -> None:
"""Test GET /tools returns 401 with WWW-Authenticate when MCP server requires OAuth.

When tools.list raises AuthenticationError and the toolgroup has an
mcp_endpoint, the handler calls probe_mcp_oauth_and_raise_401 and
raises 401 with WWW-Authenticate so the client can perform OAuth.

Verifies:
- AuthenticationError from first toolgroup triggers OAuth probe
- Response is 401 with WWW-Authenticate header
"""
_ = test_config

mock_toolgroup = mocker.Mock()
mock_toolgroup.identifier = "server1"
mock_toolgroup.mcp_endpoint = mocker.Mock()
mock_toolgroup.mcp_endpoint.uri = "http://url.com:1"
mock_llama_stack_tools.toolgroups.list.return_value = [mock_toolgroup]

auth_error = AuthenticationError(
message="MCP server requires OAuth",
response=mocker.Mock(request=None),
body=None,
)
mock_llama_stack_tools.tools.list.side_effect = auth_error

expected_www_auth = 'Bearer realm="oauth"'
probe_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={"cause": "MCP server at http://url.com:1 requires OAuth"},
headers={"WWW-Authenticate": expected_www_auth},
)
mocker.patch(
"app.endpoints.tools.probe_mcp_oauth_and_raise_401",
new_callable=mocker.AsyncMock,
side_effect=probe_exception,
)

with pytest.raises(HTTPException) as exc_info:
await tools.tools_endpoint_handler(request=test_request, auth=test_auth)

assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc_info.value.headers is not None
assert exc_info.value.headers.get("WWW-Authenticate") == expected_www_auth


@pytest.mark.asyncio
async def test_tools_endpoint_returns_401_when_oauth_probe_times_out(
test_config: AppConfig,
mock_llama_stack_tools: Any,
test_request: Request,
test_auth: AuthTuple,
mocker: MockerFixture,
) -> None:
"""Test GET /tools returns 401 when OAuth probe times out.

When tools.list raises AuthenticationError and the toolgroup has an
mcp_endpoint, the handler calls probe_mcp_oauth_and_raise_401. If the probe
times out (TimeoutError), the probe raises 401 without a WWW-Authenticate
header.

Verifies:
- Real probe runs and hits a timeout (aiohttp session.get raises TimeoutError)
- 401 is returned with no WWW-Authenticate header
"""
_ = test_config

mock_toolgroup = mocker.Mock()
mock_toolgroup.identifier = "server1"
mock_toolgroup.mcp_endpoint = mocker.Mock()
mock_toolgroup.mcp_endpoint.uri = "http://url.com:1"
mock_llama_stack_tools.toolgroups.list.return_value = [mock_toolgroup]

auth_error = AuthenticationError(
message="MCP server requires OAuth",
response=mocker.Mock(request=None),
body=None,
)
mock_llama_stack_tools.tools.list.side_effect = auth_error

# Simulate timeout: session.get() raises TimeoutError; real probe catches it and raises 401.
mock_session = mocker.Mock()
mock_session.get = mocker.Mock(side_effect=TimeoutError("OAuth probe timed out"))
mock_session_cm = mocker.AsyncMock()
mock_session_cm.__aenter__.return_value = mock_session
mock_session_cm.__aexit__.return_value = None
mocker.patch(
"utils.mcp_oauth_probe.aiohttp.ClientSession", return_value=mock_session_cm
)

with pytest.raises(HTTPException) as exc_info:
await tools.tools_endpoint_handler(request=test_request, auth=test_auth)

assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert (
exc_info.value.headers is None
or exc_info.value.headers.get("WWW-Authenticate") is None
)
Loading