Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
PromptTooLongResponse,
QueryResponse,
QuotaExceededResponse,
ReferencedDocument,
ServiceUnavailableResponse,
UnauthorizedResponse,
UnprocessableEntityResponse,
Expand Down Expand Up @@ -65,7 +64,6 @@
)
from utils.suid import normalize_conversation_id
from utils.types import (
RAGChunk,
ResponsesApiParams,
TurnSummary,
)
Expand Down Expand Up @@ -157,11 +155,8 @@ async def query_endpoint_handler(

client = AsyncLlamaStackClientHolder().get_client()

doc_ids_from_chunks: list[ReferencedDocument] = []
pre_rag_chunks: list[RAGChunk] = []

_, _, doc_ids_from_chunks, pre_rag_chunks = await perform_vector_search(
client, query_request, configuration
client, query_request.query, query_request.solr
)

rag_context = format_rag_context_for_injection(pre_rag_chunks)
Expand Down Expand Up @@ -209,7 +204,7 @@ async def query_endpoint_handler(

if doc_ids_from_chunks:
turn_summary.referenced_documents = deduplicate_referenced_documents(
doc_ids_from_chunks + (turn_summary.referenced_documents or [])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary, attribute is never None

doc_ids_from_chunks + turn_summary.referenced_documents
)

# Get topic summary for new conversation
Expand Down
25 changes: 7 additions & 18 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
UnauthorizedResponse,
UnprocessableEntityResponse,
)
from utils.types import RAGChunk, ReferencedDocument
from utils.types import ReferencedDocument
from utils.endpoints import (
check_configuration_loaded,
validate_and_retrieve_conversation,
Expand All @@ -75,6 +75,7 @@
build_mcp_tool_call_from_arguments_done,
build_tool_call_summary,
build_tool_result_from_mcp_output_item_done,
deduplicate_referenced_documents,
extract_token_usage,
extract_vector_store_ids_from_tools,
get_topic_summary,
Expand Down Expand Up @@ -184,11 +185,8 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals

client = AsyncLlamaStackClientHolder().get_client()

pre_rag_chunks: list[RAGChunk] = []
doc_ids_from_chunks: list[ReferencedDocument] = []

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary to initialize

_, _, doc_ids_from_chunks, pre_rag_chunks = await perform_vector_search(
client, query_request, configuration
client, query_request.query, query_request.solr
)

rag_context = format_rag_context_for_injection(pre_rag_chunks)
Expand Down Expand Up @@ -277,6 +275,7 @@ async def retrieve_response_generator(
Args:
responses_params: The Responses API parameters
context: The response generator context
doc_ids_from_chunks: List of ReferencedDocument objects extracted from static RAG

Returns:
tuple[AsyncIterator[str], TurnSummary]: The response generator and turn summary
Expand Down Expand Up @@ -748,19 +747,9 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
rag_id_mapping=context.rag_id_mapping,
)

# Merge pre-RAG documents with tool-based documents (similar to query.py)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Utility function for this purpose already exists

if turn_summary.pre_rag_documents:
all_documents = turn_summary.pre_rag_documents + tool_based_documents
seen = set()
deduplicated_documents = []
for doc in all_documents:
key = (doc.doc_url, doc.doc_title)
if key not in seen:
seen.add(key)
deduplicated_documents.append(doc)
turn_summary.referenced_documents = deduplicated_documents
else:
turn_summary.referenced_documents = tool_based_documents
turn_summary.referenced_documents = deduplicate_referenced_documents(
tool_based_documents + turn_summary.pre_rag_documents
)


def stream_http_error_event(
Expand Down
33 changes: 16 additions & 17 deletions src/utils/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,20 @@
from typing import Any, Optional
from urllib.parse import urljoin

from pydantic import AnyUrl

from llama_stack_client import AsyncLlamaStackClient
from llama_stack_client.types.query_chunks_response import Chunk
from pydantic import AnyUrl

import constants
from configuration import AppConfig
from configuration import configuration
from log import get_logger
from models.requests import QueryRequest
from models.responses import ReferencedDocument
from utils.types import RAGChunk

logger = get_logger(__name__)


def _is_solr_enabled(configuration: AppConfig) -> bool:
def _is_solr_enabled() -> bool:
"""Check if Solr is enabled in configuration."""
return bool(configuration.solr and configuration.solr.enabled)

Expand All @@ -39,18 +38,18 @@ def _get_vector_store_ids(solr_enabled: bool) -> list[str]:
return []


def _build_query_params(query_request: QueryRequest) -> dict:
def _build_query_params(solr: Optional[dict[str, Any]] = None) -> dict[str, Any]:
"""Build query parameters for vector search."""
params = {
"k": constants.VECTOR_SEARCH_DEFAULT_K,
"score_threshold": constants.VECTOR_SEARCH_DEFAULT_SCORE_THRESHOLD,
"mode": constants.VECTOR_SEARCH_DEFAULT_MODE,
}
logger.info("Initial params: %s", params)
logger.info("query_request.solr: %s", query_request.solr)
logger.info("solr: %s", solr)

if query_request.solr:
params["solr"] = query_request.solr
if solr:
params["solr"] = solr
logger.info("Final params with solr filters: %s", params)
else:
logger.info("No solr filters provided")
Expand Down Expand Up @@ -131,16 +130,16 @@ def _process_chunks_for_documents(

async def perform_vector_search(
client: AsyncLlamaStackClient,
query_request: QueryRequest,
configuration: AppConfig,
query: str,
solr: Optional[dict[str, Any]] = None,
) -> tuple[list[Any], list[float], list[ReferencedDocument], list[RAGChunk]]:
"""
Perform vector search and extract RAG chunks and referenced documents.

Args:
client: The AsyncLlamaStackClient to use for the request
query_request: The user's query request
configuration: Application configuration
query: The user's query
solr: Solr query parameters

Returns:
Tuple containing:
Expand All @@ -149,13 +148,13 @@ async def perform_vector_search(
- doc_ids_from_chunks: Referenced documents extracted from chunks
- rag_chunks: Processed RAG chunks ready for use
"""
retrieved_chunks: list[Any] = []
retrieved_chunks: list[Chunk] = []
retrieved_scores: list[float] = []
doc_ids_from_chunks: list[ReferencedDocument] = []
rag_chunks: list[RAGChunk] = []

# Check if Solr is enabled in configuration
if not _is_solr_enabled(configuration):
if not _is_solr_enabled():
logger.info("Solr vector IO is disabled, skipping vector search")
return retrieved_chunks, retrieved_scores, doc_ids_from_chunks, rag_chunks

Expand All @@ -167,11 +166,11 @@ async def perform_vector_search(

if vector_store_ids:
vector_store_id = vector_store_ids[0]
params = _build_query_params(query_request)
params = _build_query_params(solr)

query_response = await client.vector_io.query(
vector_store_id=vector_store_id,
query=query_request.query,
query=query,
params=params,
)

Expand Down
Loading