From bceef5cf379dba39543244bd6ca86262a536fb9b Mon Sep 17 00:00:00 2001 From: Will Date: Tue, 21 Apr 2026 09:55:45 +0000 Subject: [PATCH 1/2] feat: add retrieval service sdk clients --- README.md | 71 +++++++++++++++++ docs/usage.md | 115 ++++++++++++++++++++++++++++ src/knowhere/__init__.py | 13 ++++ src/knowhere/_client.py | 44 ++++++++++- src/knowhere/resources/__init__.py | 11 ++- src/knowhere/resources/documents.py | 74 ++++++++++++++++++ src/knowhere/resources/jobs.py | 14 ++++ src/knowhere/resources/retrieval.py | 70 +++++++++++++++++ src/knowhere/types/__init__.py | 13 ++++ src/knowhere/types/document.py | 28 +++++++ src/knowhere/types/job.py | 4 + src/knowhere/types/retrieval.py | 39 ++++++++++ tests/conftest.py | 5 +- tests/test_client.py | 34 ++++++++ tests/test_documents.py | 106 +++++++++++++++++++++++++ tests/test_jobs.py | 13 +++- tests/test_models.py | 16 ++++ tests/test_polling.py | 2 +- tests/test_retrieval.py | 112 +++++++++++++++++++++++++++ tests/test_retry.py | 1 - 20 files changed, 778 insertions(+), 7 deletions(-) create mode 100644 src/knowhere/resources/documents.py create mode 100644 src/knowhere/resources/retrieval.py create mode 100644 src/knowhere/types/document.py create mode 100644 src/knowhere/types/retrieval.py create mode 100644 tests/test_documents.py create mode 100644 tests/test_retrieval.py diff --git a/README.md b/README.md index a178de1..0cf6b21 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,74 @@ for chunk in result.text_chunks: print(chunk.content[:80]) ``` +## Retrieval and document lifecycle + +New documents are published into a retrieval namespace. The server returns a +stable `document_id` when you create a job; persist that value if you need to +update or archive the same document later. + +```python +job = client.jobs.create( + source_type="url", + source_url="https://example.com/manual.pdf", + namespace="support-center", +) + +print(job.document_id) # "doc_..." +``` + +After the job is done and published, query the canonical document content: + +```python +response = client.retrieval.query( + namespace="support-center", + query="How do I reset Bluetooth pairing?", + top_k=5, +) + +for result in response.results: + print(result.content) + if result.citation: + print(result.citation.source_file_name, result.citation.section_path) +``` + +Use `document_id` to update or archive a document: + +```python +update_job = client.jobs.create( + source_type="url", + source_url="https://example.com/manual-v2.pdf", + document_id=job.document_id, +) + +document = client.documents.get(job.document_id) +print(document.status) + +client.documents.archive(job.document_id) +``` + +You can also list documents in a namespace: + +```python +documents = client.documents.list(namespace="support-center") +for document in documents.documents: + print(document.document_id, document.status) +``` + +Retrieval supports exclusions when clients want follow-up results that avoid +previously used documents or sections: + +```python +response = client.retrieval.query( + namespace="support-center", + query="battery charging", + exclude_document_ids=["doc_old"], + exclude_sections=[ + {"document_id": "doc_123", "section_path": "Appendix / Legal"} + ], +) +``` + While you can provide an `api_key` keyword argument, we recommend using [python-dotenv](https://pypi.org/project/python-dotenv/) to add `KNOWHERE_API_KEY="sk_..."` to your `.env` file so that your API key is not stored in source control. ### Parse a local file @@ -105,9 +173,12 @@ from pathlib import Path job = client.jobs.create( source_type="file", file_name="report.pdf", + namespace="support-center", parsing_params={"model": "advanced", "ocr_enabled": True}, ) +print(job.document_id) # Persist this to update/archive the document later. + # Step 2: Upload file to presigned URL client.jobs.upload(job, file=Path("report.pdf")) diff --git a/docs/usage.md b/docs/usage.md index cf32420..a10504f 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -12,6 +12,7 @@ Comprehensive reference for every feature, parameter, and pattern in the SDK. - [Working with Results](#working-with-results) - [Chunk Types](#chunk-types) - [Step-by-Step Control (Jobs API)](#step-by-step-control-jobs-api) +- [Retrieval and Document Lifecycle](#retrieval-and-document-lifecycle) - [Async Usage](#async-usage) - [Progress Callbacks](#progress-callbacks) - [Error Handling](#error-handling) @@ -316,8 +317,10 @@ from pathlib import Path job = client.jobs.create( source_type="file", file_name="report.pdf", + namespace="support-center", parsing_params={"model": "advanced", "ocr_enabled": True}, ) +print(job.document_id) # Persist this value for update/archive flows. # Step 2: Upload file to the presigned URL client.jobs.upload(job, file=Path("report.pdf")) @@ -341,6 +344,8 @@ print(result.statistics) | `source_type` | `"url" \| "file"` | — | Required. Whether parsing from URL or uploaded file. | | `source_url` | `str \| None` | `None` | URL to parse (required when `source_type="url"`). | | `file_name` | `str \| None` | `None` | Original filename (used when `source_type="file"`). | +| `namespace` | `str \| None` | `None` | Retrieval namespace. The server defaults to `"default"` when omitted. | +| `document_id` | `str \| None` | `None` | Existing document ID when creating an update job. Omit for a new document. | | `data_id` | `str \| None` | `None` | Your own correlation/idempotency identifier. | | `parsing_params` | `ParsingParams \| None` | `None` | Parsing configuration. | | `webhook` | `WebhookConfig \| None` | `None` | Webhook for completion notification. | @@ -351,6 +356,8 @@ Returns a `Job` object: job.job_id # "abc-123" job.status # "pending" job.source_type # "file" +job.namespace # "support-center" +job.document_id # "doc_..." — persist this for updates and archive calls job.upload_url # presigned URL (for file uploads) job.upload_headers # headers to include in the upload request job.expires_in # seconds until upload URL expires @@ -407,6 +414,107 @@ result = client.jobs.load("https://storage.example.com/result.zip") --- +## Retrieval and Document Lifecycle + +The retrieval APIs operate on canonical documents that are published after a +job completes. For new documents, the server generates `document_id` during +`jobs.create()`. Store that ID in your application if you need to update or +archive the same document later. + +### Create a retrievable document + +```python +job = client.jobs.create( + source_type="url", + source_url="https://example.com/manual.pdf", + namespace="support-center", +) + +print(job.document_id) # "doc_..." +``` + +For file uploads, the flow is the same except that you upload the file before +polling: + +```python +job = client.jobs.create( + source_type="file", + file_name="manual.pdf", + namespace="support-center", +) +client.jobs.upload(job, file=Path("manual.pdf")) +job_result = client.jobs.wait(job.job_id) +``` + +### Update an existing document + +Pass the prior `document_id` to create an update job. If `namespace` is omitted, +the API resolves the namespace from the existing document. + +```python +update_job = client.jobs.create( + source_type="url", + source_url="https://example.com/manual-v2.pdf", + document_id=job.document_id, +) +``` + +The API rejects concurrent non-terminal jobs for the same document with a +retryable `ConflictError` using the server error code `ABORTED`. + +### Query retrieval results + +```python +response = client.retrieval.query( + namespace="support-center", + query="How do I pair a Bluetooth headset?", + top_k=5, +) + +for result in response.results: + print(result.content) + print(result.score) + if result.citation: + print(result.citation.source_file_name) + print(result.citation.section_path) +``` + +Retrieval results expose `content`, not the older parse-result `text` field. +Media results may include `asset_url` when the server can sign the referenced +artifact. + +### Exclude documents or sections + +Use exclusions for follow-up queries that should avoid already-used context. + +```python +response = client.retrieval.query( + namespace="support-center", + query="battery charging", + top_k=10, + exclude_document_ids=["doc_old"], + exclude_sections=[ + {"document_id": "doc_123", "section_path": "Appendix / Legal"} + ], +) +``` + +### List, get, and archive documents + +```python +document_list = client.documents.list(namespace="support-center") +for document in document_list.documents: + print(document.document_id, document.status, document.source_file_name) + +document = client.documents.get("doc_123") +print(document.current_job_result_id) + +archived = client.documents.archive("doc_123") +print(archived.status) # "archived" +``` + +--- + ## Async Usage Every method available on `Knowhere` has an async counterpart on `AsyncKnowhere`: @@ -429,6 +537,13 @@ async def main(): job_result = await client.jobs.wait(job.job_id) result = await client.jobs.load(job_result) + retrieval = await client.retrieval.query( + namespace="support-center", + query="refund policy", + top_k=5, + ) + print(retrieval.results[0].content) + asyncio.run(main()) ``` diff --git a/src/knowhere/__init__.py b/src/knowhere/__init__.py index 12b0360..cc36213 100644 --- a/src/knowhere/__init__.py +++ b/src/knowhere/__init__.py @@ -35,8 +35,14 @@ ) from knowhere._types import PollProgressCallback, UploadProgressCallback from knowhere._version import __version__ +from knowhere.types.document import Document, DocumentListResponse from knowhere.types.job import Job, JobError, JobProgress, JobResult from knowhere.types.params import ParsingParams, WebhookConfig +from knowhere.types.retrieval import ( + RetrievalCitation, + RetrievalQueryResponse, + RetrievalResult, +) from knowhere.types.result import ( BaseChunk, Checksum, @@ -87,6 +93,13 @@ "JobError", "JobProgress", "JobResult", + # Document types + "Document", + "DocumentListResponse", + # Retrieval types + "RetrievalCitation", + "RetrievalQueryResponse", + "RetrievalResult", # Result types "ParseResult", "Manifest", diff --git a/src/knowhere/_client.py b/src/knowhere/_client.py index b2cbc3e..b45bdc1 100644 --- a/src/knowhere/_client.py +++ b/src/knowhere/_client.py @@ -19,7 +19,9 @@ PollProgressCallback, UploadProgressCallback, ) +from knowhere.resources.documents import AsyncDocuments, Documents from knowhere.resources.jobs import AsyncJobs, Jobs +from knowhere.resources.retrieval import AsyncRetrieval, Retrieval from knowhere.types.job import Job, JobResult from knowhere.types.params import ParsingParams, WebhookConfig from knowhere.types.result import ParseResult @@ -42,6 +44,16 @@ def jobs(self) -> Jobs: """Access the jobs resource namespace.""" return Jobs(self) + @cached_property + def retrieval(self) -> Retrieval: + """Access the retrieval resource namespace.""" + return Retrieval(self) + + @cached_property + def documents(self) -> Documents: + """Access the documents resource namespace.""" + return Documents(self) + # -- overloaded parse signatures -- @overload @@ -50,6 +62,8 @@ def parse( *, url: str, data_id: Optional[str] = ..., + namespace: Optional[str] = ..., + document_id: Optional[str] = ..., parsing_params: Optional[ParsingParams] = ..., webhook: Optional[WebhookConfig] = ..., poll_interval: float = ..., @@ -66,6 +80,8 @@ def parse( file: Union[Path, BinaryIO, bytes], file_name: Optional[str] = ..., data_id: Optional[str] = ..., + namespace: Optional[str] = ..., + document_id: Optional[str] = ..., parsing_params: Optional[ParsingParams] = ..., webhook: Optional[WebhookConfig] = ..., poll_interval: float = ..., @@ -82,6 +98,8 @@ def parse( file: Optional[Union[Path, BinaryIO, bytes]] = None, file_name: Optional[str] = None, data_id: Optional[str] = None, + namespace: Optional[str] = None, + document_id: Optional[str] = None, parsing_params: Optional[ParsingParams] = None, webhook: Optional[WebhookConfig] = None, poll_interval: float = DEFAULT_POLL_INTERVAL, @@ -105,6 +123,8 @@ def parse( source_type="url", source_url=url, data_id=data_id, + namespace=namespace, + document_id=document_id, parsing_params=parsing_params, webhook=webhook, ) @@ -116,6 +136,8 @@ def parse( source_type="file", file_name=resolved_name, data_id=data_id, + namespace=namespace, + document_id=document_id, parsing_params=parsing_params, webhook=webhook, ) @@ -149,12 +171,24 @@ def jobs(self) -> AsyncJobs: """Access the async jobs resource namespace.""" return AsyncJobs(self) + @cached_property + def retrieval(self) -> AsyncRetrieval: + """Access the async retrieval resource namespace.""" + return AsyncRetrieval(self) + + @cached_property + def documents(self) -> AsyncDocuments: + """Access the async documents resource namespace.""" + return AsyncDocuments(self) + @overload async def parse( self, *, url: str, data_id: Optional[str] = ..., + namespace: Optional[str] = ..., + document_id: Optional[str] = ..., parsing_params: Optional[ParsingParams] = ..., webhook: Optional[WebhookConfig] = ..., poll_interval: float = ..., @@ -171,6 +205,8 @@ async def parse( file: Union[Path, BinaryIO, bytes], file_name: Optional[str] = ..., data_id: Optional[str] = ..., + namespace: Optional[str] = ..., + document_id: Optional[str] = ..., parsing_params: Optional[ParsingParams] = ..., webhook: Optional[WebhookConfig] = ..., poll_interval: float = ..., @@ -187,6 +223,8 @@ async def parse( file: Optional[Union[Path, BinaryIO, bytes]] = None, file_name: Optional[str] = None, data_id: Optional[str] = None, + namespace: Optional[str] = None, + document_id: Optional[str] = None, parsing_params: Optional[ParsingParams] = None, webhook: Optional[WebhookConfig] = None, poll_interval: float = DEFAULT_POLL_INTERVAL, @@ -206,6 +244,8 @@ async def parse( source_type="url", source_url=url, data_id=data_id, + namespace=namespace, + document_id=document_id, parsing_params=parsing_params, webhook=webhook, ) @@ -217,6 +257,8 @@ async def parse( source_type="file", file_name=resolved_name, data_id=data_id, + namespace=namespace, + document_id=document_id, parsing_params=parsing_params, webhook=webhook, ) @@ -232,4 +274,4 @@ async def parse( return await self.jobs.load( job_result, verify_checksum=verify_checksum - ) \ No newline at end of file + ) diff --git a/src/knowhere/resources/__init__.py b/src/knowhere/resources/__init__.py index be52770..e523c10 100644 --- a/src/knowhere/resources/__init__.py +++ b/src/knowhere/resources/__init__.py @@ -2,6 +2,15 @@ from __future__ import annotations +from knowhere.resources.documents import AsyncDocuments, Documents from knowhere.resources.jobs import AsyncJobs, Jobs +from knowhere.resources.retrieval import AsyncRetrieval, Retrieval -__all__: list[str] = ["Jobs", "AsyncJobs"] +__all__: list[str] = [ + "AsyncDocuments", + "AsyncJobs", + "AsyncRetrieval", + "Documents", + "Jobs", + "Retrieval", +] diff --git a/src/knowhere/resources/documents.py b/src/knowhere/resources/documents.py new file mode 100644 index 0000000..6c04fc9 --- /dev/null +++ b/src/knowhere/resources/documents.py @@ -0,0 +1,74 @@ +"""Documents resource for canonical document lifecycle operations.""" + +from __future__ import annotations + +from typing import Any, Dict, Optional + +from knowhere.resources._base import AsyncAPIResource, SyncAPIResource +from knowhere.types.document import Document, DocumentListResponse + + +class Documents(SyncAPIResource): + """Synchronous interface for ``/v1/documents`` endpoints.""" + + def list(self, *, namespace: Optional[str] = None) -> DocumentListResponse: + """List canonical documents in a namespace.""" + params: Dict[str, Any] = {} + if namespace is not None: + params["namespace"] = namespace + + return self._request( + "GET", + "v1/documents", + params=params or None, + cast_to=DocumentListResponse, + ) + + def get(self, document_id: str) -> Document: + """Get one canonical document by ID.""" + return self._request( + "GET", + f"v1/documents/{document_id}", + cast_to=Document, + ) + + def archive(self, document_id: str) -> Document: + """Archive one canonical document by ID.""" + return self._request( + "POST", + f"v1/documents/{document_id}:archive", + cast_to=Document, + ) + + +class AsyncDocuments(AsyncAPIResource): + """Asynchronous interface for ``/v1/documents`` endpoints.""" + + async def list(self, *, namespace: Optional[str] = None) -> DocumentListResponse: + """List canonical documents in a namespace.""" + params: Dict[str, Any] = {} + if namespace is not None: + params["namespace"] = namespace + + return await self._request( + "GET", + "v1/documents", + params=params or None, + cast_to=DocumentListResponse, + ) + + async def get(self, document_id: str) -> Document: + """Get one canonical document by ID.""" + return await self._request( + "GET", + f"v1/documents/{document_id}", + cast_to=Document, + ) + + async def archive(self, document_id: str) -> Document: + """Archive one canonical document by ID.""" + return await self._request( + "POST", + f"v1/documents/{document_id}:archive", + cast_to=Document, + ) diff --git a/src/knowhere/resources/jobs.py b/src/knowhere/resources/jobs.py index 11fdc21..e0920c5 100644 --- a/src/knowhere/resources/jobs.py +++ b/src/knowhere/resources/jobs.py @@ -34,6 +34,8 @@ def create( source_type: str, source_url: Optional[str] = None, file_name: Optional[str] = None, + namespace: Optional[str] = None, + document_id: Optional[str] = None, data_id: Optional[str] = None, parsing_params: Optional[ParsingParams] = None, webhook: Optional[WebhookConfig] = None, @@ -44,6 +46,8 @@ def create( source_type: ``"url"`` or ``"file"``. source_url: URL to parse (required when ``source_type="url"``). file_name: Original filename (used when ``source_type="file"``). + namespace: Retrieval namespace. Defaults to the server ``default``. + document_id: Existing document ID when creating an update job. data_id: Optional idempotency / correlation identifier. parsing_params: Optional parsing configuration. webhook: Optional webhook configuration. @@ -56,6 +60,10 @@ def create( body["source_url"] = source_url if file_name is not None: body["file_name"] = file_name + if namespace is not None: + body["namespace"] = namespace + if document_id is not None: + body["document_id"] = document_id if data_id is not None: body["data_id"] = data_id if parsing_params is not None: @@ -158,6 +166,8 @@ async def create( source_type: str, source_url: Optional[str] = None, file_name: Optional[str] = None, + namespace: Optional[str] = None, + document_id: Optional[str] = None, data_id: Optional[str] = None, parsing_params: Optional[ParsingParams] = None, webhook: Optional[WebhookConfig] = None, @@ -168,6 +178,10 @@ async def create( body["source_url"] = source_url if file_name is not None: body["file_name"] = file_name + if namespace is not None: + body["namespace"] = namespace + if document_id is not None: + body["document_id"] = document_id if data_id is not None: body["data_id"] = data_id if parsing_params is not None: diff --git a/src/knowhere/resources/retrieval.py b/src/knowhere/resources/retrieval.py new file mode 100644 index 0000000..f702ab5 --- /dev/null +++ b/src/knowhere/resources/retrieval.py @@ -0,0 +1,70 @@ +"""Retrieval resource for querying published documents.""" + +from __future__ import annotations + +from typing import Any, Dict, Optional + +from knowhere.resources._base import AsyncAPIResource, SyncAPIResource +from knowhere.types.retrieval import RetrievalQueryResponse + + +class Retrieval(SyncAPIResource): + """Synchronous interface for ``/v1/retrieval`` endpoints.""" + + def query( + self, + *, + query: str, + namespace: Optional[str] = None, + top_k: Optional[int] = None, + exclude_document_ids: Optional[list[str]] = None, + exclude_sections: Optional[list[dict[str, str]]] = None, + ) -> RetrievalQueryResponse: + """Query published documents in a namespace.""" + body: Dict[str, Any] = {"query": query} + if namespace is not None: + body["namespace"] = namespace + if top_k is not None: + body["top_k"] = top_k + if exclude_document_ids is not None: + body["exclude_document_ids"] = exclude_document_ids + if exclude_sections is not None: + body["exclude_sections"] = exclude_sections + + return self._request( + "POST", + "v1/retrieval/query", + body=body, + cast_to=RetrievalQueryResponse, + ) + + +class AsyncRetrieval(AsyncAPIResource): + """Asynchronous interface for ``/v1/retrieval`` endpoints.""" + + async def query( + self, + *, + query: str, + namespace: Optional[str] = None, + top_k: Optional[int] = None, + exclude_document_ids: Optional[list[str]] = None, + exclude_sections: Optional[list[dict[str, str]]] = None, + ) -> RetrievalQueryResponse: + """Query published documents in a namespace.""" + body: Dict[str, Any] = {"query": query} + if namespace is not None: + body["namespace"] = namespace + if top_k is not None: + body["top_k"] = top_k + if exclude_document_ids is not None: + body["exclude_document_ids"] = exclude_document_ids + if exclude_sections is not None: + body["exclude_sections"] = exclude_sections + + return await self._request( + "POST", + "v1/retrieval/query", + body=body, + cast_to=RetrievalQueryResponse, + ) diff --git a/src/knowhere/types/__init__.py b/src/knowhere/types/__init__.py index 09d5c6a..33e66b8 100644 --- a/src/knowhere/types/__init__.py +++ b/src/knowhere/types/__init__.py @@ -2,8 +2,14 @@ from __future__ import annotations +from knowhere.types.document import Document, DocumentListResponse from knowhere.types.job import Job, JobError, JobResult from knowhere.types.params import ParsingParams, WebhookConfig +from knowhere.types.retrieval import ( + RetrievalCitation, + RetrievalQueryResponse, + RetrievalResult, +) from knowhere.types.result import ( BaseChunk, Checksum, @@ -28,6 +34,13 @@ "Job", "JobError", "JobResult", + # document + "Document", + "DocumentListResponse", + # retrieval + "RetrievalCitation", + "RetrievalQueryResponse", + "RetrievalResult", # params "ParsingParams", "WebhookConfig", diff --git a/src/knowhere/types/document.py b/src/knowhere/types/document.py new file mode 100644 index 0000000..f41a438 --- /dev/null +++ b/src/knowhere/types/document.py @@ -0,0 +1,28 @@ +"""Pydantic models for canonical document lifecycle responses.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel + + +class Document(BaseModel): + """Canonical document state returned by ``/v1/documents`` endpoints.""" + + document_id: str + namespace: str + status: str + current_job_result_id: Optional[str] = None + source_file_name: Optional[str] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + archived_at: Optional[datetime] = None + + +class DocumentListResponse(BaseModel): + """Response from ``GET /v1/documents``.""" + + namespace: str + documents: list[Document] diff --git a/src/knowhere/types/job.py b/src/knowhere/types/job.py index b09a1ea..260c786 100644 --- a/src/knowhere/types/job.py +++ b/src/knowhere/types/job.py @@ -40,6 +40,8 @@ class Job(BaseModel): job_id: str status: str source_type: str + namespace: Optional[str] = None + document_id: Optional[str] = None data_id: Optional[str] = None created_at: Optional[datetime] = None upload_url: Optional[str] = None @@ -53,6 +55,8 @@ class JobResult(BaseModel): job_id: str status: str source_type: str + namespace: Optional[str] = None + document_id: Optional[str] = None data_id: Optional[str] = None created_at: Optional[datetime] = None progress: Optional[Union[float, JobProgress]] = None diff --git a/src/knowhere/types/retrieval.py b/src/knowhere/types/retrieval.py new file mode 100644 index 0000000..b09b946 --- /dev/null +++ b/src/knowhere/types/retrieval.py @@ -0,0 +1,39 @@ +"""Pydantic models for retrieval query responses.""" + +from __future__ import annotations + +from typing import Optional + +from pydantic import BaseModel + + +class RetrievalCitation(BaseModel): + """Source citation attached to a retrieval result.""" + + document_id: Optional[str] = None + chunk_id: Optional[str] = None + source_file_name: Optional[str] = None + section_path: Optional[str] = None + + +class RetrievalResult(BaseModel): + """Canonical chunk result returned by ``POST /v1/retrieval/query``.""" + + document_id: str + chunk_id: str + section_id: Optional[str] = None + section_path: Optional[str] = None + source_file_name: Optional[str] = None + chunk_type: str + content: str + score: float + asset_url: Optional[str] = None + citation: Optional[RetrievalCitation] = None + + +class RetrievalQueryResponse(BaseModel): + """Response from ``POST /v1/retrieval/query``.""" + + namespace: str + query: str + results: list[RetrievalResult] diff --git a/tests/conftest.py b/tests/conftest.py index a03325e..82bf58a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,6 @@ from __future__ import annotations -import hashlib import io import json import zipfile @@ -72,6 +71,8 @@ def mock_job_response() -> Dict[str, Any]: "job_id": "job_test123", "status": "waiting-file", "source_type": "file", + "namespace": "default", + "document_id": "doc_test123", "data_id": None, "created_at": "2025-01-01T00:00:00Z", "upload_url": "https://storage.example.com/upload?token=abc", @@ -87,6 +88,8 @@ def mock_job_result_response() -> Dict[str, Any]: "job_id": "job_test123", "status": "done", "source_type": "file", + "namespace": "default", + "document_id": "doc_test123", "data_id": "data_abc", "created_at": "2025-01-01T00:00:00Z", "progress": 1.0, diff --git a/tests/test_client.py b/tests/test_client.py index 0b074ac..2bef3e4 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -115,6 +115,24 @@ def test_jobs_property_returns_jobs_instance(self) -> None: assert hasattr(jobs, "load") client.close() + def test_retrieval_property_returns_retrieval_instance(self) -> None: + from knowhere import Knowhere + + client: Knowhere = Knowhere(api_key="sk_test") + retrieval: Any = client.retrieval + assert hasattr(retrieval, "query") + client.close() + + def test_documents_property_returns_documents_instance(self) -> None: + from knowhere import Knowhere + + client: Knowhere = Knowhere(api_key="sk_test") + documents: Any = client.documents + assert hasattr(documents, "list") + assert hasattr(documents, "get") + assert hasattr(documents, "archive") + client.close() + def test_base_url_trailing_slash_stripped(self) -> None: from knowhere import Knowhere @@ -200,3 +218,19 @@ def test_jobs_property_returns_async_jobs_instance(self) -> None: assert hasattr(jobs, "upload") assert hasattr(jobs, "wait") assert hasattr(jobs, "load") + + def test_retrieval_property_returns_async_retrieval_instance(self) -> None: + from knowhere import AsyncKnowhere + + client: AsyncKnowhere = AsyncKnowhere(api_key="sk_test") + retrieval: Any = client.retrieval + assert hasattr(retrieval, "query") + + def test_documents_property_returns_async_documents_instance(self) -> None: + from knowhere import AsyncKnowhere + + client: AsyncKnowhere = AsyncKnowhere(api_key="sk_test") + documents: Any = client.documents + assert hasattr(documents, "list") + assert hasattr(documents, "get") + assert hasattr(documents, "archive") diff --git a/tests/test_documents.py b/tests/test_documents.py new file mode 100644 index 0000000..88857b2 --- /dev/null +++ b/tests/test_documents.py @@ -0,0 +1,106 @@ +"""Tests for the documents resource.""" + +from __future__ import annotations + +from typing import Any, Dict + +import httpx +import pytest +import respx + +from tests.conftest import BASE_URL + + +DOCUMENTS_URL: str = f"{BASE_URL}/v1/documents" + + +def _make_document(status: str = "active") -> Dict[str, Any]: + return { + "document_id": "doc_123", + "namespace": "support-center", + "status": status, + "current_job_result_id": "result_123", + "source_file_name": "refund-policy.md", + "created_at": "2026-04-21T08:00:00Z", + "updated_at": "2026-04-21T08:30:00Z", + "archived_at": "2026-04-21T09:00:00Z" if status == "archived" else None, + } + + +class TestDocumentsResource: + """Verify document lifecycle calls.""" + + @respx.mock + def test_list_documents_sends_namespace_query(self, sync_client: Any) -> None: + route = respx.get(DOCUMENTS_URL).mock( + return_value=httpx.Response( + 200, + json={ + "namespace": "support-center", + "documents": [_make_document()], + }, + ) + ) + + response = sync_client.documents.list(namespace="support-center") + + assert route.called + assert route.calls[0].request.url.params["namespace"] == "support-center" + assert response.namespace == "support-center" + assert response.documents[0].document_id == "doc_123" + + @respx.mock + def test_list_documents_omits_namespace_when_defaulted(self, sync_client: Any) -> None: + route = respx.get(DOCUMENTS_URL).mock( + return_value=httpx.Response( + 200, + json={"namespace": "default", "documents": []}, + ) + ) + + response = sync_client.documents.list() + + assert route.called + assert dict(route.calls[0].request.url.params) == {} + assert response.namespace == "default" + assert response.documents == [] + + @respx.mock + def test_get_document_returns_document_state(self, sync_client: Any) -> None: + route = respx.get(f"{DOCUMENTS_URL}/doc_123").mock( + return_value=httpx.Response(200, json=_make_document()) + ) + + document = sync_client.documents.get("doc_123") + + assert route.called + assert document.document_id == "doc_123" + assert document.status == "active" + + @respx.mock + def test_archive_document_returns_archived_state(self, sync_client: Any) -> None: + route = respx.post(f"{DOCUMENTS_URL}/doc_123:archive").mock( + return_value=httpx.Response(200, json=_make_document(status="archived")) + ) + + document = sync_client.documents.archive("doc_123") + + assert route.called + assert document.document_id == "doc_123" + assert document.status == "archived" + assert document.archived_at is not None + + @respx.mock + @pytest.mark.asyncio + async def test_async_archive_document_returns_archived_state( + self, + async_client: Any, + ) -> None: + route = respx.post(f"{DOCUMENTS_URL}/doc_123:archive").mock( + return_value=httpx.Response(200, json=_make_document(status="archived")) + ) + + document = await async_client.documents.archive("doc_123") + + assert route.called + assert document.status == "archived" diff --git a/tests/test_jobs.py b/tests/test_jobs.py index fbcfbe6..a43d425 100644 --- a/tests/test_jobs.py +++ b/tests/test_jobs.py @@ -5,10 +5,9 @@ from typing import Any, Dict import httpx -import pytest import respx -from tests.conftest import API_KEY, BASE_URL +from tests.conftest import BASE_URL # --------------------------------------------------------------------------- @@ -36,6 +35,8 @@ def test_create_with_url_source( "job_id": "job_test123", "status": "pending", "source_type": "url", + "namespace": "support-center", + "document_id": "doc_123", } route = respx.post(JOBS_URL).mock( @@ -51,6 +52,8 @@ def test_create_with_url_source( assert job.job_id == "job_test123" assert job.source_type == "url" assert job.status == "pending" + assert job.namespace == "support-center" + assert job.document_id == "doc_123" @respx.mock def test_create_with_file_source( @@ -83,6 +86,8 @@ def test_create_sends_correct_body( "job_id": "job_body_check", "status": "pending", "source_type": "url", + "namespace": "support-center", + "document_id": "doc_123", } route = respx.post(JOBS_URL).mock( @@ -93,6 +98,8 @@ def test_create_sends_correct_body( source_type="url", source_url="https://example.com/doc.pdf", data_id="my_data_id", + namespace="support-center", + document_id="doc_123", ) assert route.called @@ -102,6 +109,8 @@ def test_create_sends_correct_body( assert body["source_type"] == "url" assert body["source_url"] == "https://example.com/doc.pdf" assert body["data_id"] == "my_data_id" + assert body["namespace"] == "support-center" + assert body["document_id"] == "doc_123" # --------------------------------------------------------------------------- diff --git a/tests/test_models.py b/tests/test_models.py index 66aadf8..c5989e5 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -45,6 +45,18 @@ def test_from_dict_minimal(self) -> None: assert job.source_type == "url" assert job.upload_url is None + def test_from_dict_with_document_scope(self) -> None: + data: Dict[str, Any] = { + "job_id": "job_scoped", + "status": "pending", + "source_type": "url", + "namespace": "support-center", + "document_id": "doc_123", + } + job: Job = Job(**data) + assert job.namespace == "support-center" + assert job.document_id == "doc_123" + def test_from_dict_with_upload(self) -> None: data: Dict[str, Any] = { "job_id": "job_2", @@ -148,10 +160,14 @@ def test_with_result_url(self) -> None: job_id="job_ok", status="done", source_type="file", + namespace="support-center", + document_id="doc_123", result_url="https://storage.example.com/result.zip", duration_seconds=3.5, credits_spent=1.0, ) + assert result.namespace == "support-center" + assert result.document_id == "doc_123" assert result.result_url == "https://storage.example.com/result.zip" assert result.duration_seconds == 3.5 diff --git a/tests/test_polling.py b/tests/test_polling.py index 0c02ffa..647f513 100644 --- a/tests/test_polling.py +++ b/tests/test_polling.py @@ -208,7 +208,7 @@ class TestPollOnProgressCallback: @respx.mock def test_callback_called_on_each_poll(self, sync_client: Any) -> None: job_id: str = "job_progress" - route = respx.get(f"{JOBS_URL}/{job_id}").mock( + respx.get(f"{JOBS_URL}/{job_id}").mock( side_effect=[ httpx.Response( 200, json=_make_status_response(job_id, "running") diff --git a/tests/test_retrieval.py b/tests/test_retrieval.py new file mode 100644 index 0000000..147c554 --- /dev/null +++ b/tests/test_retrieval.py @@ -0,0 +1,112 @@ +"""Tests for the retrieval resource.""" + +from __future__ import annotations + +import json +from typing import Any, Dict + +import httpx +import pytest +import respx + +from tests.conftest import BASE_URL + + +RETRIEVAL_QUERY_URL: str = f"{BASE_URL}/v1/retrieval/query" + + +def _make_retrieval_response() -> Dict[str, Any]: + return { + "namespace": "support-center", + "query": "refund policy", + "results": [ + { + "document_id": "doc_123", + "chunk_id": "chunk_456", + "section_id": "sec_12", + "section_path": "Policies / Billing / Refunds", + "source_file_name": "refund-policy.md", + "chunk_type": "text", + "content": "Annual plans may be refunded within 30 days.", + "score": 1.0, + "citation": { + "document_id": "doc_123", + "chunk_id": "chunk_456", + "source_file_name": "refund-policy.md", + "section_path": "Policies / Billing / Refunds", + }, + } + ], + } + + +class TestRetrievalQuery: + """Verify retrieval.query() sends the public retrieval contract.""" + + @respx.mock + def test_query_sends_request_and_returns_results(self, sync_client: Any) -> None: + route = respx.post(RETRIEVAL_QUERY_URL).mock( + return_value=httpx.Response(200, json=_make_retrieval_response()) + ) + + response = sync_client.retrieval.query( + query="refund policy", + namespace="support-center", + top_k=5, + exclude_document_ids=["doc_old"], + exclude_sections=[ + { + "document_id": "doc_123", + "section_path": "Policies / Draft", + } + ], + ) + + assert route.called + request_body: Dict[str, Any] = json.loads(route.calls[0].request.read()) + assert request_body == { + "query": "refund policy", + "namespace": "support-center", + "top_k": 5, + "exclude_document_ids": ["doc_old"], + "exclude_sections": [ + { + "document_id": "doc_123", + "section_path": "Policies / Draft", + } + ], + } + assert response.namespace == "support-center" + assert response.results[0].content == "Annual plans may be refunded within 30 days." + assert response.results[0].citation is not None + assert response.results[0].citation.section_path == "Policies / Billing / Refunds" + + @respx.mock + def test_query_omits_defaulted_optional_fields(self, sync_client: Any) -> None: + route = respx.post(RETRIEVAL_QUERY_URL).mock( + return_value=httpx.Response(200, json=_make_retrieval_response()) + ) + + sync_client.retrieval.query(query="refund policy") + + request_body: Dict[str, Any] = json.loads(route.calls[0].request.read()) + assert request_body == {"query": "refund policy"} + + @respx.mock + @pytest.mark.asyncio + async def test_async_query_sends_request_and_returns_results( + self, + async_client: Any, + ) -> None: + route = respx.post(RETRIEVAL_QUERY_URL).mock( + return_value=httpx.Response(200, json=_make_retrieval_response()) + ) + + response = await async_client.retrieval.query( + query="refund policy", + namespace="support-center", + top_k=5, + ) + + assert route.called + assert response.results[0].document_id == "doc_123" diff --git a/tests/test_retry.py b/tests/test_retry.py index f2b9329..716e951 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -11,7 +11,6 @@ from knowhere._exceptions import ( AuthenticationError, BadRequestError, - ConflictError, InternalServerError, RateLimitError, ServiceUnavailableError, From ad67576aeebdb48c1dcb7ee7a99cd9daf25616ac Mon Sep 17 00:00:00 2001 From: Will Date: Tue, 21 Apr 2026 15:40:15 +0000 Subject: [PATCH 2/2] Align SDK retrieval contract with API --- README.md | 4 ++-- docs/usage.md | 18 +++++++++++++++--- src/knowhere/__init__.py | 4 ++-- src/knowhere/resources/documents.py | 4 ++-- src/knowhere/types/__init__.py | 4 ++-- src/knowhere/types/retrieval.py | 12 +++--------- tests/test_documents.py | 4 ++-- tests/test_retrieval.py | 18 ++++++++---------- 8 files changed, 36 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 0cf6b21..294056a 100644 --- a/README.md +++ b/README.md @@ -59,8 +59,8 @@ response = client.retrieval.query( for result in response.results: print(result.content) - if result.citation: - print(result.citation.source_file_name, result.citation.section_path) + print(result.score) + print(result.source.source_file_name, result.source.section_path) ``` Use `document_id` to update or archive a document: diff --git a/docs/usage.md b/docs/usage.md index a10504f..507f5f1 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -474,15 +474,27 @@ response = client.retrieval.query( for result in response.results: print(result.content) print(result.score) - if result.citation: - print(result.citation.source_file_name) - print(result.citation.section_path) + print(result.source.document_id) + print(result.source.source_file_name) + print(result.source.section_path) ``` Retrieval results expose `content`, not the older parse-result `text` field. Media results may include `asset_url` when the server can sign the referenced artifact. +Each retrieval result uses one canonical source reference shape: + +```python +result.content +result.chunk_type +result.score +result.asset_url # Optional[str] +result.source.document_id +result.source.source_file_name +result.source.section_path +``` + ### Exclude documents or sections Use exclusions for follow-up queries that should avoid already-used context. diff --git a/src/knowhere/__init__.py b/src/knowhere/__init__.py index cc36213..ba2b37c 100644 --- a/src/knowhere/__init__.py +++ b/src/knowhere/__init__.py @@ -39,7 +39,7 @@ from knowhere.types.job import Job, JobError, JobProgress, JobResult from knowhere.types.params import ParsingParams, WebhookConfig from knowhere.types.retrieval import ( - RetrievalCitation, + RetrievalSource, RetrievalQueryResponse, RetrievalResult, ) @@ -97,7 +97,7 @@ "Document", "DocumentListResponse", # Retrieval types - "RetrievalCitation", + "RetrievalSource", "RetrievalQueryResponse", "RetrievalResult", # Result types diff --git a/src/knowhere/resources/documents.py b/src/knowhere/resources/documents.py index 6c04fc9..c826d64 100644 --- a/src/knowhere/resources/documents.py +++ b/src/knowhere/resources/documents.py @@ -36,7 +36,7 @@ def archive(self, document_id: str) -> Document: """Archive one canonical document by ID.""" return self._request( "POST", - f"v1/documents/{document_id}:archive", + f"v1/documents/{document_id}/archive", cast_to=Document, ) @@ -69,6 +69,6 @@ async def archive(self, document_id: str) -> Document: """Archive one canonical document by ID.""" return await self._request( "POST", - f"v1/documents/{document_id}:archive", + f"v1/documents/{document_id}/archive", cast_to=Document, ) diff --git a/src/knowhere/types/__init__.py b/src/knowhere/types/__init__.py index 33e66b8..7b14617 100644 --- a/src/knowhere/types/__init__.py +++ b/src/knowhere/types/__init__.py @@ -6,7 +6,7 @@ from knowhere.types.job import Job, JobError, JobResult from knowhere.types.params import ParsingParams, WebhookConfig from knowhere.types.retrieval import ( - RetrievalCitation, + RetrievalSource, RetrievalQueryResponse, RetrievalResult, ) @@ -38,7 +38,7 @@ "Document", "DocumentListResponse", # retrieval - "RetrievalCitation", + "RetrievalSource", "RetrievalQueryResponse", "RetrievalResult", # params diff --git a/src/knowhere/types/retrieval.py b/src/knowhere/types/retrieval.py index b09b946..c13b9d8 100644 --- a/src/knowhere/types/retrieval.py +++ b/src/knowhere/types/retrieval.py @@ -7,11 +7,10 @@ from pydantic import BaseModel -class RetrievalCitation(BaseModel): - """Source citation attached to a retrieval result.""" +class RetrievalSource(BaseModel): + """Caller-facing source reference attached to a retrieval result.""" document_id: Optional[str] = None - chunk_id: Optional[str] = None source_file_name: Optional[str] = None section_path: Optional[str] = None @@ -19,16 +18,11 @@ class RetrievalCitation(BaseModel): class RetrievalResult(BaseModel): """Canonical chunk result returned by ``POST /v1/retrieval/query``.""" - document_id: str - chunk_id: str - section_id: Optional[str] = None - section_path: Optional[str] = None - source_file_name: Optional[str] = None chunk_type: str content: str score: float asset_url: Optional[str] = None - citation: Optional[RetrievalCitation] = None + source: RetrievalSource class RetrievalQueryResponse(BaseModel): diff --git a/tests/test_documents.py b/tests/test_documents.py index 88857b2..8869642 100644 --- a/tests/test_documents.py +++ b/tests/test_documents.py @@ -79,7 +79,7 @@ def test_get_document_returns_document_state(self, sync_client: Any) -> None: @respx.mock def test_archive_document_returns_archived_state(self, sync_client: Any) -> None: - route = respx.post(f"{DOCUMENTS_URL}/doc_123:archive").mock( + route = respx.post(f"{DOCUMENTS_URL}/doc_123/archive").mock( return_value=httpx.Response(200, json=_make_document(status="archived")) ) @@ -96,7 +96,7 @@ async def test_async_archive_document_returns_archived_state( self, async_client: Any, ) -> None: - route = respx.post(f"{DOCUMENTS_URL}/doc_123:archive").mock( + route = respx.post(f"{DOCUMENTS_URL}/doc_123/archive").mock( return_value=httpx.Response(200, json=_make_document(status="archived")) ) diff --git a/tests/test_retrieval.py b/tests/test_retrieval.py index 147c554..400d77f 100644 --- a/tests/test_retrieval.py +++ b/tests/test_retrieval.py @@ -21,17 +21,11 @@ def _make_retrieval_response() -> Dict[str, Any]: "query": "refund policy", "results": [ { - "document_id": "doc_123", - "chunk_id": "chunk_456", - "section_id": "sec_12", - "section_path": "Policies / Billing / Refunds", - "source_file_name": "refund-policy.md", "chunk_type": "text", "content": "Annual plans may be refunded within 30 days.", "score": 1.0, - "citation": { + "source": { "document_id": "doc_123", - "chunk_id": "chunk_456", "source_file_name": "refund-policy.md", "section_path": "Policies / Billing / Refunds", }, @@ -78,8 +72,12 @@ def test_query_sends_request_and_returns_results(self, sync_client: Any) -> None } assert response.namespace == "support-center" assert response.results[0].content == "Annual plans may be refunded within 30 days." - assert response.results[0].citation is not None - assert response.results[0].citation.section_path == "Policies / Billing / Refunds" + assert response.results[0].source.document_id == "doc_123" + assert response.results[0].source.source_file_name == "refund-policy.md" + assert response.results[0].source.section_path == "Policies / Billing / Refunds" + assert not hasattr(response.results[0], "citation") + assert not hasattr(response.results[0], "chunk_id") + assert not hasattr(response.results[0], "section_id") @respx.mock def test_query_omits_defaulted_optional_fields(self, sync_client: Any) -> None: @@ -109,4 +107,4 @@ async def test_async_query_sends_request_and_returns_results( ) assert route.called - assert response.results[0].document_id == "doc_123" + assert response.results[0].source.document_id == "doc_123"