diff --git a/backend/src/agents/retriever_typing.py b/backend/src/agents/retriever_typing.py index f8c7d5db..60720d3b 100644 --- a/backend/src/agents/retriever_typing.py +++ b/backend/src/agents/retriever_typing.py @@ -6,6 +6,7 @@ class AgentState(TypedDict): messages: Annotated[list[AnyMessage], add_messages] context: Annotated[list[AnyMessage], add_messages] + context_list: Annotated[list[str], add_messages] tools: list[str] sources: Annotated[list[str], add_messages] urls: Annotated[list[str], add_messages] diff --git a/backend/src/api/routers/helpers.py b/backend/src/api/routers/helpers.py index 73c2c983..78c97905 100644 --- a/backend/src/api/routers/helpers.py +++ b/backend/src/api/routers/helpers.py @@ -11,7 +11,22 @@ if not GOOGLE_API_KEY: raise RuntimeError("GOOGLE_API_KEY is not set") -model = "gemini-2.0-flash" +_GEMINI_MODEL_MAP = { + "2.0_flash": "gemini-2.0-flash", + "2.5_flash": "gemini-2.5-flash", + "2.5_pro": "gemini-2.5-pro", +} + + +def _resolve_gemini_model(gemini_version: str | None) -> str: + """Map a GOOGLE_GEMINI version string to a Gemini model name. + + Falls back to ``gemini-2.0-flash`` when the value is unset or unknown. + """ + return _GEMINI_MODEL_MAP.get(gemini_version or "", "gemini-2.0-flash") + + +model = _resolve_gemini_model(os.getenv("GOOGLE_GEMINI")) client = OpenAI( base_url="https://generativelanguage.googleapis.com/v1beta/openai/", api_key=GOOGLE_API_KEY, diff --git a/backend/src/prompts/prompt_templates.py b/backend/src/prompts/prompt_templates.py index 0cd436e4..156628d7 100644 --- a/backend/src/prompts/prompt_templates.py +++ b/backend/src/prompts/prompt_templates.py @@ -11,7 +11,7 @@ You must not ask the user to refer to the context in any part of your answer. You must not ask the user to refer to a link that is not a part of your answer. -If there is nothing in the context relevant to the question, simply say "Sorry its not avaiable in my knowledge base." +If there is nothing in the context relevant to the question, simply say "Sorry, it's not available in my knowledge base." Do not try to make up an answer. Anything between the following `context` html blocks is retrieved from a knowledge bank, not part of the conversation with the user. diff --git a/backend/src/tools/process_pdf.py b/backend/src/tools/process_pdf.py index d5cfeac3..b2c17802 100644 --- a/backend/src/tools/process_pdf.py +++ b/backend/src/tools/process_pdf.py @@ -31,6 +31,7 @@ def process_pdf_docs(file_path: str) -> list[Document]: documents = loader.load_and_split(text_splitter=text_splitter) except PdfStreamError: logging.error(f"Error processing PDF: {file_path} is corrupted or incomplete.") + return [] for doc in documents: try: diff --git a/backend/src/vectorstores/faiss.py b/backend/src/vectorstores/faiss.py index 52cdc15f..858f652f 100644 --- a/backend/src/vectorstores/faiss.py +++ b/backend/src/vectorstores/faiss.py @@ -214,6 +214,9 @@ def add_documents( return None def get_db_path(self) -> str: + env_path = os.getenv("FAISS_DB_PATH") + if env_path: + return os.path.abspath(env_path) cur_path = os.path.abspath(__file__) path = os.path.join(cur_path, "../../../", "faiss_db") path = os.path.abspath(path) # Ensure proper parent directory diff --git a/backend/tests/test_api_helpers.py b/backend/tests/test_api_helpers.py index 25208459..c9c9e077 100644 --- a/backend/tests/test_api_helpers.py +++ b/backend/tests/test_api_helpers.py @@ -81,6 +81,22 @@ def test_constants_defined(self): assert model == "gemini-2.0-flash" # GOOGLE_API_KEY should be set or raise error during module import + def test_resolve_gemini_model_maps_known_versions(self): + """_resolve_gemini_model maps GOOGLE_GEMINI values to model names (issue #259).""" + from src.api.routers.helpers import _resolve_gemini_model + + assert _resolve_gemini_model("2.0_flash") == "gemini-2.0-flash" + assert _resolve_gemini_model("2.5_flash") == "gemini-2.5-flash" + assert _resolve_gemini_model("2.5_pro") == "gemini-2.5-pro" + + def test_resolve_gemini_model_defaults_for_unknown_or_unset(self): + """_resolve_gemini_model falls back to gemini-2.0-flash for unset/unknown values.""" + from src.api.routers.helpers import _resolve_gemini_model + + assert _resolve_gemini_model(None) == "gemini-2.0-flash" + assert _resolve_gemini_model("") == "gemini-2.0-flash" + assert _resolve_gemini_model("nonexistent") == "gemini-2.0-flash" + def test_router_configuration(self): """Test that router is properly configured.""" from src.api.routers.helpers import router diff --git a/backend/tests/test_faiss_vectorstore.py b/backend/tests/test_faiss_vectorstore.py index 5202c83c..08fd194f 100644 --- a/backend/tests/test_faiss_vectorstore.py +++ b/backend/tests/test_faiss_vectorstore.py @@ -184,7 +184,7 @@ def test_add_md_docs_invalid_folder_paths(self): db.add_md_docs(folder_paths="not_a_list") def test_get_db_path(self): - """Test get_db_path returns correct path.""" + """get_db_path falls back to the default ./faiss_db when FAISS_DB_PATH is unset.""" with patch("src.vectorstores.faiss.HuggingFaceEmbeddings") as mock_hf: mock_hf.return_value = Mock() @@ -192,10 +192,27 @@ def test_get_db_path(self): embeddings_type="HF", embeddings_model_name="test-model" ) - path = db.get_db_path() + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("FAISS_DB_PATH", None) + path = db.get_db_path() + assert path.endswith("faiss_db") assert os.path.isabs(path) + def test_get_db_path_respects_env_var(self): + """get_db_path honors the FAISS_DB_PATH environment variable (issue #259).""" + with patch("src.vectorstores.faiss.HuggingFaceEmbeddings") as mock_hf: + mock_hf.return_value = Mock() + + db = FAISSVectorDatabase( + embeddings_type="HF", embeddings_model_name="test-model" + ) + + with patch.dict(os.environ, {"FAISS_DB_PATH": "custom_dir/faiss_index"}): + path = db.get_db_path() + + assert path == os.path.abspath("custom_dir/faiss_index") + def test_save_db_without_documents_raises_error(self): """Test save_db raises error when no documents in database.""" with patch("src.vectorstores.faiss.HuggingFaceEmbeddings") as mock_hf: diff --git a/backend/tests/test_process_pdf.py b/backend/tests/test_process_pdf.py index 144fd212..91eae1a6 100644 --- a/backend/tests/test_process_pdf.py +++ b/backend/tests/test_process_pdf.py @@ -77,15 +77,27 @@ def test_process_pdf_docs_multiple_pages(self, mock_file, mock_loader): assert all(doc.metadata["url"] == "https://example1.com" for doc in result) assert all(doc.metadata["source"] == "doc1.pdf" for doc in result) - # Note: Commented out due to bug in process_pdf_docs function - # The function doesn't properly handle PdfStreamError - it logs but then - # tries to use undefined 'documents' variable - # @patch('src.tools.process_pdf.logging') - # @patch('src.tools.process_pdf.PyPDFLoader') - # @patch('builtins.open', new_callable=mock_open, read_data='{"corrupted.pdf": "https://example.com"}') - # def test_process_pdf_docs_corrupted_file(self, mock_file, mock_loader, mock_logging): - # """Test PDF processing with corrupted file.""" - # pass + @patch("src.tools.process_pdf.PyPDFLoader") + @patch( + "builtins.open", + new_callable=mock_open, + read_data='{"corrupted.pdf": "https://example.com"}', + ) + def test_process_pdf_docs_corrupted_file_returns_empty( + self, mock_file, mock_loader + ): + """A corrupted PDF (PdfStreamError) should return [] instead of crashing.""" + from pypdf.errors import PdfStreamError + + mock_loader_instance = Mock() + mock_loader_instance.load_and_split.side_effect = PdfStreamError( + "corrupted stream" + ) + mock_loader.return_value = mock_loader_instance + + result = process_pdf_docs("./corrupted.pdf") + + assert result == [] @patch("src.tools.process_pdf.logging") @patch("src.tools.process_pdf.PyPDFLoader") diff --git a/backend/tests/test_prompt_templates.py b/backend/tests/test_prompt_templates.py index 9e32d622..d67743c0 100644 --- a/backend/tests/test_prompt_templates.py +++ b/backend/tests/test_prompt_templates.py @@ -48,3 +48,13 @@ def test_suggested_questions_prompt_template_is_string(self): assert isinstance(suggested_questions_prompt_template, str) assert suggested_questions_prompt_template != "" assert suggested_questions_prompt_template is not None + + def test_summarise_prompt_template_has_no_typo(self): + """summarise_prompt_template should use correct spelling (issue #259).""" + from src.prompts.prompt_templates import summarise_prompt_template + + assert "avaiable" not in summarise_prompt_template + assert ( + "Sorry, it's not available in my knowledge base." + in summarise_prompt_template + ) diff --git a/backend/tests/test_retriever_typing.py b/backend/tests/test_retriever_typing.py new file mode 100644 index 00000000..e53aa8f2 --- /dev/null +++ b/backend/tests/test_retriever_typing.py @@ -0,0 +1,14 @@ +from src.agents.retriever_typing import AgentState + + +class TestAgentState: + """Test suite for the AgentState TypedDict.""" + + def test_agent_state_includes_context_list(self): + """AgentState must declare context_list so LangGraph propagates it (issue #259). + + ToolNode.get_node returns a ``context_list`` key, but LangGraph drops any + state key that is not declared on the graph's state schema. Without this + annotation the retrieved context list is silently lost downstream. + """ + assert "context_list" in AgentState.__annotations__