diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index b4bc94412..cbd06e0e7 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -32,7 +32,6 @@ PromptTooLongResponse, QueryResponse, QuotaExceededResponse, - ReferencedDocument, ServiceUnavailableResponse, UnauthorizedResponse, UnprocessableEntityResponse, @@ -65,7 +64,6 @@ ) from utils.suid import normalize_conversation_id from utils.types import ( - RAGChunk, ResponsesApiParams, TurnSummary, ) @@ -157,11 +155,8 @@ async def query_endpoint_handler( client = AsyncLlamaStackClientHolder().get_client() - doc_ids_from_chunks: list[ReferencedDocument] = [] - pre_rag_chunks: list[RAGChunk] = [] - _, _, doc_ids_from_chunks, pre_rag_chunks = await perform_vector_search( - client, query_request, configuration + client, query_request.query, query_request.solr ) rag_context = format_rag_context_for_injection(pre_rag_chunks) @@ -209,7 +204,7 @@ async def query_endpoint_handler( if doc_ids_from_chunks: turn_summary.referenced_documents = deduplicate_referenced_documents( - doc_ids_from_chunks + (turn_summary.referenced_documents or []) + doc_ids_from_chunks + turn_summary.referenced_documents ) # Get topic summary for new conversation diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 45998754e..81f0c1dcf 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -55,7 +55,7 @@ UnauthorizedResponse, UnprocessableEntityResponse, ) -from utils.types import RAGChunk, ReferencedDocument +from utils.types import ReferencedDocument from utils.endpoints import ( check_configuration_loaded, validate_and_retrieve_conversation, @@ -75,6 +75,7 @@ build_mcp_tool_call_from_arguments_done, build_tool_call_summary, build_tool_result_from_mcp_output_item_done, + deduplicate_referenced_documents, extract_token_usage, extract_vector_store_ids_from_tools, get_topic_summary, @@ -184,11 +185,8 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals client = AsyncLlamaStackClientHolder().get_client() - pre_rag_chunks: list[RAGChunk] = [] - doc_ids_from_chunks: list[ReferencedDocument] = [] - _, _, doc_ids_from_chunks, pre_rag_chunks = await perform_vector_search( - client, query_request, configuration + client, query_request.query, query_request.solr ) rag_context = format_rag_context_for_injection(pre_rag_chunks) @@ -277,6 +275,7 @@ async def retrieve_response_generator( Args: responses_params: The Responses API parameters context: The response generator context + doc_ids_from_chunks: List of ReferencedDocument objects extracted from static RAG Returns: tuple[AsyncIterator[str], TurnSummary]: The response generator and turn summary @@ -748,19 +747,9 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat rag_id_mapping=context.rag_id_mapping, ) - # Merge pre-RAG documents with tool-based documents (similar to query.py) - if turn_summary.pre_rag_documents: - all_documents = turn_summary.pre_rag_documents + tool_based_documents - seen = set() - deduplicated_documents = [] - for doc in all_documents: - key = (doc.doc_url, doc.doc_title) - if key not in seen: - seen.add(key) - deduplicated_documents.append(doc) - turn_summary.referenced_documents = deduplicated_documents - else: - turn_summary.referenced_documents = tool_based_documents + turn_summary.referenced_documents = deduplicate_referenced_documents( + tool_based_documents + turn_summary.pre_rag_documents + ) def stream_http_error_event( diff --git a/src/utils/vector_search.py b/src/utils/vector_search.py index ff7e73138..e39e9ec04 100644 --- a/src/utils/vector_search.py +++ b/src/utils/vector_search.py @@ -8,21 +8,20 @@ from typing import Any, Optional from urllib.parse import urljoin -from pydantic import AnyUrl - from llama_stack_client import AsyncLlamaStackClient +from llama_stack_client.types.query_chunks_response import Chunk +from pydantic import AnyUrl import constants -from configuration import AppConfig +from configuration import configuration from log import get_logger -from models.requests import QueryRequest from models.responses import ReferencedDocument from utils.types import RAGChunk logger = get_logger(__name__) -def _is_solr_enabled(configuration: AppConfig) -> bool: +def _is_solr_enabled() -> bool: """Check if Solr is enabled in configuration.""" return bool(configuration.solr and configuration.solr.enabled) @@ -39,7 +38,7 @@ def _get_vector_store_ids(solr_enabled: bool) -> list[str]: return [] -def _build_query_params(query_request: QueryRequest) -> dict: +def _build_query_params(solr: Optional[dict[str, Any]] = None) -> dict[str, Any]: """Build query parameters for vector search.""" params = { "k": constants.VECTOR_SEARCH_DEFAULT_K, @@ -47,10 +46,10 @@ def _build_query_params(query_request: QueryRequest) -> dict: "mode": constants.VECTOR_SEARCH_DEFAULT_MODE, } logger.info("Initial params: %s", params) - logger.info("query_request.solr: %s", query_request.solr) + logger.info("solr: %s", solr) - if query_request.solr: - params["solr"] = query_request.solr + if solr: + params["solr"] = solr logger.info("Final params with solr filters: %s", params) else: logger.info("No solr filters provided") @@ -131,16 +130,16 @@ def _process_chunks_for_documents( async def perform_vector_search( client: AsyncLlamaStackClient, - query_request: QueryRequest, - configuration: AppConfig, + query: str, + solr: Optional[dict[str, Any]] = None, ) -> tuple[list[Any], list[float], list[ReferencedDocument], list[RAGChunk]]: """ Perform vector search and extract RAG chunks and referenced documents. Args: client: The AsyncLlamaStackClient to use for the request - query_request: The user's query request - configuration: Application configuration + query: The user's query + solr: Solr query parameters Returns: Tuple containing: @@ -149,13 +148,13 @@ async def perform_vector_search( - doc_ids_from_chunks: Referenced documents extracted from chunks - rag_chunks: Processed RAG chunks ready for use """ - retrieved_chunks: list[Any] = [] + retrieved_chunks: list[Chunk] = [] retrieved_scores: list[float] = [] doc_ids_from_chunks: list[ReferencedDocument] = [] rag_chunks: list[RAGChunk] = [] # Check if Solr is enabled in configuration - if not _is_solr_enabled(configuration): + if not _is_solr_enabled(): logger.info("Solr vector IO is disabled, skipping vector search") return retrieved_chunks, retrieved_scores, doc_ids_from_chunks, rag_chunks @@ -167,11 +166,11 @@ async def perform_vector_search( if vector_store_ids: vector_store_id = vector_store_ids[0] - params = _build_query_params(query_request) + params = _build_query_params(solr) query_response = await client.vector_io.query( vector_store_id=vector_store_id, - query=query_request.query, + query=query, params=params, )