diff --git a/README.md b/README.md index 294056a..7e81ddd 100644 --- a/README.md +++ b/README.md @@ -35,8 +35,9 @@ for chunk in result.text_chunks: ## Retrieval and document lifecycle New documents are published into a retrieval namespace. The server returns a -stable `document_id` when you create a job; persist that value if you need to -update or archive the same document later. +stable `document_id` after the job is published. `client.jobs.create(...)` +does not return a usable `document_id`; persist `job_result.document_id` if you +need to update or archive the same document later. ```python job = client.jobs.create( @@ -45,7 +46,11 @@ job = client.jobs.create( namespace="support-center", ) -print(job.document_id) # "doc_..." +job_result = client.jobs.wait(job.job_id) +document_id = job_result.document_id + +if document_id is None: + raise RuntimeError("Expected document_id after successful publication.") ``` After the job is done and published, query the canonical document content: @@ -55,8 +60,13 @@ response = client.retrieval.query( namespace="support-center", query="How do I reset Bluetooth pairing?", top_k=5, + channels=["path", "term"], + filter_mode="keep", + signal_paths=["Bluetooth", "Pairing"], ) +print(response.router_used) + for result in response.results: print(result.content) print(result.score) @@ -69,13 +79,13 @@ Use `document_id` to update or archive a document: update_job = client.jobs.create( source_type="url", source_url="https://example.com/manual-v2.pdf", - document_id=job.document_id, + document_id=document_id, ) -document = client.documents.get(job.document_id) +document = client.documents.get(document_id) print(document.status) -client.documents.archive(job.document_id) +client.documents.archive(document_id) ``` You can also list documents in a namespace: @@ -177,14 +187,14 @@ job = client.jobs.create( parsing_params={"model": "advanced", "ocr_enabled": True}, ) -print(job.document_id) # Persist this to update/archive the document later. - # Step 2: Upload file to presigned URL client.jobs.upload(job, file=Path("report.pdf")) # Step 3: Poll until done (adaptive backoff) job_result = client.jobs.wait(job.job_id, poll_interval=10.0, poll_timeout=1800.0) +print(job_result.document_id) # Persist this to update/archive the document later. + # Step 4: Download and parse results result = client.jobs.load(job_result) print(result.statistics) diff --git a/src/knowhere/__init__.py b/src/knowhere/__init__.py index ba2b37c..b136805 100644 --- a/src/knowhere/__init__.py +++ b/src/knowhere/__init__.py @@ -39,6 +39,9 @@ from knowhere.types.job import Job, JobError, JobProgress, JobResult from knowhere.types.params import ParsingParams, WebhookConfig from knowhere.types.retrieval import ( + RetrievalChannel, + RetrievalFilterMode, + RetrievalSectionExclusion, RetrievalSource, RetrievalQueryResponse, RetrievalResult, @@ -97,6 +100,9 @@ "Document", "DocumentListResponse", # Retrieval types + "RetrievalChannel", + "RetrievalFilterMode", + "RetrievalSectionExclusion", "RetrievalSource", "RetrievalQueryResponse", "RetrievalResult", diff --git a/src/knowhere/resources/jobs.py b/src/knowhere/resources/jobs.py index e0920c5..f8e184b 100644 --- a/src/knowhere/resources/jobs.py +++ b/src/knowhere/resources/jobs.py @@ -145,8 +145,12 @@ def load( if not job_result.result_url: raise InvalidStateError("JobResult does not have a result_url.") result_url: str = job_result.result_url + namespace: Optional[str] = job_result.namespace + document_id: Optional[str] = job_result.document_id else: result_url = job_result + namespace = None + document_id = None response: httpx.Response = self._client._client.get( result_url, timeout=self._client.upload_timeout @@ -154,7 +158,10 @@ def load( response.raise_for_status() zip_bytes: bytes = response.content - return parseResultZip(zip_bytes, verify_checksum=verify_checksum) + parsed_result = parseResultZip(zip_bytes, verify_checksum=verify_checksum) + parsed_result.namespace = namespace + parsed_result.document_id = document_id + return parsed_result class AsyncJobs(AsyncAPIResource): @@ -251,8 +258,12 @@ async def load( if not job_result.result_url: raise InvalidStateError("JobResult does not have a result_url.") result_url: str = job_result.result_url + namespace: Optional[str] = job_result.namespace + document_id: Optional[str] = job_result.document_id else: result_url = job_result + namespace = None + document_id = None response: httpx.Response = await self._client._client.get( result_url, timeout=self._client.upload_timeout @@ -260,4 +271,7 @@ async def load( response.raise_for_status() zip_bytes: bytes = response.content - return parseResultZip(zip_bytes, verify_checksum=verify_checksum) + parsed_result = parseResultZip(zip_bytes, verify_checksum=verify_checksum) + parsed_result.namespace = namespace + parsed_result.document_id = document_id + return parsed_result diff --git a/src/knowhere/resources/retrieval.py b/src/knowhere/resources/retrieval.py index f702ab5..3b6b36c 100644 --- a/src/knowhere/resources/retrieval.py +++ b/src/knowhere/resources/retrieval.py @@ -5,7 +5,12 @@ from typing import Any, Dict, Optional from knowhere.resources._base import AsyncAPIResource, SyncAPIResource -from knowhere.types.retrieval import RetrievalQueryResponse +from knowhere.types.retrieval import ( + RetrievalChannel, + RetrievalFilterMode, + RetrievalQueryResponse, + RetrievalSectionExclusion, +) class Retrieval(SyncAPIResource): @@ -17,8 +22,16 @@ def query( query: str, namespace: Optional[str] = None, top_k: Optional[int] = None, + data_type: Optional[int] = None, + signal_paths: Optional[list[str]] = None, + filter_mode: Optional[RetrievalFilterMode] = None, + channels: Optional[list[RetrievalChannel]] = None, + channel_weights: Optional[dict[RetrievalChannel, float]] = None, + rerank: Optional[bool] = None, + threshold: Optional[float] = None, + internal_recall_k: Optional[int] = None, exclude_document_ids: Optional[list[str]] = None, - exclude_sections: Optional[list[dict[str, str]]] = None, + exclude_sections: Optional[list[RetrievalSectionExclusion]] = None, ) -> RetrievalQueryResponse: """Query published documents in a namespace.""" body: Dict[str, Any] = {"query": query} @@ -26,6 +39,22 @@ def query( body["namespace"] = namespace if top_k is not None: body["top_k"] = top_k + if data_type is not None: + body["data_type"] = data_type + if signal_paths is not None: + body["signal_paths"] = signal_paths + if filter_mode is not None: + body["filter_mode"] = filter_mode + if channels is not None: + body["channels"] = channels + if channel_weights is not None: + body["channel_weights"] = channel_weights + if rerank is not None: + body["rerank"] = rerank + if threshold is not None: + body["threshold"] = threshold + if internal_recall_k is not None: + body["internal_recall_k"] = internal_recall_k if exclude_document_ids is not None: body["exclude_document_ids"] = exclude_document_ids if exclude_sections is not None: @@ -48,8 +77,16 @@ async def query( query: str, namespace: Optional[str] = None, top_k: Optional[int] = None, + data_type: Optional[int] = None, + signal_paths: Optional[list[str]] = None, + filter_mode: Optional[RetrievalFilterMode] = None, + channels: Optional[list[RetrievalChannel]] = None, + channel_weights: Optional[dict[RetrievalChannel, float]] = None, + rerank: Optional[bool] = None, + threshold: Optional[float] = None, + internal_recall_k: Optional[int] = None, exclude_document_ids: Optional[list[str]] = None, - exclude_sections: Optional[list[dict[str, str]]] = None, + exclude_sections: Optional[list[RetrievalSectionExclusion]] = None, ) -> RetrievalQueryResponse: """Query published documents in a namespace.""" body: Dict[str, Any] = {"query": query} @@ -57,6 +94,22 @@ async def query( body["namespace"] = namespace if top_k is not None: body["top_k"] = top_k + if data_type is not None: + body["data_type"] = data_type + if signal_paths is not None: + body["signal_paths"] = signal_paths + if filter_mode is not None: + body["filter_mode"] = filter_mode + if channels is not None: + body["channels"] = channels + if channel_weights is not None: + body["channel_weights"] = channel_weights + if rerank is not None: + body["rerank"] = rerank + if threshold is not None: + body["threshold"] = threshold + if internal_recall_k is not None: + body["internal_recall_k"] = internal_recall_k if exclude_document_ids is not None: body["exclude_document_ids"] = exclude_document_ids if exclude_sections is not None: diff --git a/src/knowhere/types/__init__.py b/src/knowhere/types/__init__.py index 7b14617..a492955 100644 --- a/src/knowhere/types/__init__.py +++ b/src/knowhere/types/__init__.py @@ -6,6 +6,9 @@ from knowhere.types.job import Job, JobError, JobResult from knowhere.types.params import ParsingParams, WebhookConfig from knowhere.types.retrieval import ( + RetrievalChannel, + RetrievalFilterMode, + RetrievalSectionExclusion, RetrievalSource, RetrievalQueryResponse, RetrievalResult, @@ -38,6 +41,9 @@ "Document", "DocumentListResponse", # retrieval + "RetrievalChannel", + "RetrievalFilterMode", + "RetrievalSectionExclusion", "RetrievalSource", "RetrievalQueryResponse", "RetrievalResult", diff --git a/src/knowhere/types/job.py b/src/knowhere/types/job.py index 260c786..ea86556 100644 --- a/src/knowhere/types/job.py +++ b/src/knowhere/types/job.py @@ -41,7 +41,6 @@ class Job(BaseModel): status: str source_type: str namespace: Optional[str] = None - document_id: Optional[str] = None data_id: Optional[str] = None created_at: Optional[datetime] = None upload_url: Optional[str] = None diff --git a/src/knowhere/types/result.py b/src/knowhere/types/result.py index f58cdab..df83c19 100644 --- a/src/knowhere/types/result.py +++ b/src/knowhere/types/result.py @@ -272,6 +272,8 @@ class ParseResult: kb_csv: Optional[str] hierarchy_view_html: Optional[str] raw_zip: bytes + namespace: Optional[str] + document_id: Optional[str] def __init__( self, @@ -285,6 +287,8 @@ def __init__( kb_csv: Optional[str], hierarchy_view_html: Optional[str], raw_zip: bytes, + namespace: Optional[str] = None, + document_id: Optional[str] = None, ) -> None: self.manifest = manifest self.chunks = chunks @@ -295,6 +299,8 @@ def __init__( self.kb_csv = kb_csv self.hierarchy_view_html = hierarchy_view_html self.raw_zip = raw_zip + self.namespace = namespace + self.document_id = document_id # -- convenience properties -- diff --git a/src/knowhere/types/retrieval.py b/src/knowhere/types/retrieval.py index c13b9d8..47b07a8 100644 --- a/src/knowhere/types/retrieval.py +++ b/src/knowhere/types/retrieval.py @@ -2,11 +2,22 @@ from __future__ import annotations -from typing import Optional +from typing import Literal, Optional, TypedDict from pydantic import BaseModel +RetrievalChannel = Literal["path", "content", "term"] +RetrievalFilterMode = Literal["delete", "keep"] + + +class RetrievalSectionExclusion(TypedDict): + """Section exclusion for follow-up retrieval queries.""" + + document_id: str + section_path: str + + class RetrievalSource(BaseModel): """Caller-facing source reference attached to a retrieval result.""" @@ -30,4 +41,5 @@ class RetrievalQueryResponse(BaseModel): namespace: str query: str + router_used: Optional[str] = None results: list[RetrievalResult] diff --git a/tests/conftest.py b/tests/conftest.py index 82bf58a..c742f61 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -72,7 +72,6 @@ def mock_job_response() -> Dict[str, Any]: "status": "waiting-file", "source_type": "file", "namespace": "default", - "document_id": "doc_test123", "data_id": None, "created_at": "2025-01-01T00:00:00Z", "upload_url": "https://storage.example.com/upload?token=abc", diff --git a/tests/test_jobs.py b/tests/test_jobs.py index a43d425..85669f1 100644 --- a/tests/test_jobs.py +++ b/tests/test_jobs.py @@ -36,7 +36,6 @@ def test_create_with_url_source( "status": "pending", "source_type": "url", "namespace": "support-center", - "document_id": "doc_123", } route = respx.post(JOBS_URL).mock( @@ -53,7 +52,7 @@ def test_create_with_url_source( assert job.source_type == "url" assert job.status == "pending" assert job.namespace == "support-center" - assert job.document_id == "doc_123" + assert not hasattr(job, "document_id") @respx.mock def test_create_with_file_source( @@ -87,7 +86,6 @@ def test_create_sends_correct_body( "status": "pending", "source_type": "url", "namespace": "support-center", - "document_id": "doc_123", } route = respx.post(JOBS_URL).mock( @@ -284,6 +282,8 @@ def test_load_with_job_result_object( job_id="job_load", status="done", source_type="url", + namespace="support-center", + document_id="doc_123", result_url=result_url, ) @@ -293,3 +293,5 @@ def test_load_with_job_result_object( assert route.called assert parse_result.manifest is not None + assert parse_result.namespace == "support-center" + assert parse_result.document_id == "doc_123" diff --git a/tests/test_models.py b/tests/test_models.py index c5989e5..92b9732 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -55,7 +55,7 @@ def test_from_dict_with_document_scope(self) -> None: } job: Job = Job(**data) assert job.namespace == "support-center" - assert job.document_id == "doc_123" + assert "document_id" not in job.model_dump() def test_from_dict_with_upload(self) -> None: data: Dict[str, Any] = { @@ -717,6 +717,11 @@ def test_statistics_shortcut(self) -> None: assert stats.total_chunks == 3 assert stats.text_chunks == 1 + def test_document_scope_defaults_to_none(self) -> None: + result: ParseResult = _build_parse_result() + assert result.namespace is None + assert result.document_id is None + def test_raw_zip_accessible(self) -> None: result: ParseResult = _build_parse_result() assert result.raw_zip == b"fake zip bytes" diff --git a/tests/test_parse.py b/tests/test_parse.py index 8d545b0..c2f6c84 100644 --- a/tests/test_parse.py +++ b/tests/test_parse.py @@ -42,6 +42,8 @@ def _make_done_response(job_id: str, result_url: str) -> Dict[str, Any]: "job_id": job_id, "status": "done", "source_type": "url", + "namespace": "support-center", + "document_id": "doc_123", "result_url": result_url, } @@ -96,6 +98,8 @@ def test_parse_url_full_flow( assert parse_result.manifest is not None assert parse_result.manifest.job_id == "job_test123" + assert parse_result.namespace == "support-center" + assert parse_result.document_id == "doc_123" # --------------------------------------------------------------------------- diff --git a/tests/test_retrieval.py b/tests/test_retrieval.py index 400d77f..4925e30 100644 --- a/tests/test_retrieval.py +++ b/tests/test_retrieval.py @@ -19,6 +19,7 @@ def _make_retrieval_response() -> Dict[str, Any]: return { "namespace": "support-center", "query": "refund policy", + "router_used": "discovery+agent", "results": [ { "chunk_type": "text", @@ -47,6 +48,14 @@ def test_query_sends_request_and_returns_results(self, sync_client: Any) -> None query="refund policy", namespace="support-center", top_k=5, + data_type=6, + signal_paths=["Billing", "Refunds"], + filter_mode="keep", + channels=["path", "term"], + channel_weights={"path": 2.0, "term": 0.5}, + rerank=True, + threshold=0.2, + internal_recall_k=25, exclude_document_ids=["doc_old"], exclude_sections=[ { @@ -62,6 +71,14 @@ def test_query_sends_request_and_returns_results(self, sync_client: Any) -> None "query": "refund policy", "namespace": "support-center", "top_k": 5, + "data_type": 6, + "signal_paths": ["Billing", "Refunds"], + "filter_mode": "keep", + "channels": ["path", "term"], + "channel_weights": {"path": 2.0, "term": 0.5}, + "rerank": True, + "threshold": 0.2, + "internal_recall_k": 25, "exclude_document_ids": ["doc_old"], "exclude_sections": [ { @@ -71,6 +88,7 @@ def test_query_sends_request_and_returns_results(self, sync_client: Any) -> None ], } assert response.namespace == "support-center" + assert response.router_used == "discovery+agent" assert response.results[0].content == "Annual plans may be refunded within 30 days." assert response.results[0].source.document_id == "doc_123" assert response.results[0].source.source_file_name == "refund-policy.md" @@ -107,4 +125,5 @@ async def test_async_query_sends_request_and_returns_results( ) assert route.called + assert response.router_used == "discovery+agent" assert response.results[0].source.document_id == "doc_123"