Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ for chunk in result.text_chunks:
## Retrieval and document lifecycle

New documents are published into a retrieval namespace. The server returns a
stable `document_id` when you create a job; persist that value if you need to
update or archive the same document later.
stable `document_id` after the job is published. `client.jobs.create(...)`
does not return a usable `document_id`; persist `job_result.document_id` if you
need to update or archive the same document later.

```python
job = client.jobs.create(
Expand All @@ -45,7 +46,11 @@ job = client.jobs.create(
namespace="support-center",
)

print(job.document_id) # "doc_..."
job_result = client.jobs.wait(job.job_id)
document_id = job_result.document_id

if document_id is None:
raise RuntimeError("Expected document_id after successful publication.")
```

After the job is done and published, query the canonical document content:
Expand All @@ -55,8 +60,13 @@ response = client.retrieval.query(
namespace="support-center",
query="How do I reset Bluetooth pairing?",
top_k=5,
channels=["path", "term"],
filter_mode="keep",
signal_paths=["Bluetooth", "Pairing"],
)

print(response.router_used)

for result in response.results:
print(result.content)
print(result.score)
Expand All @@ -69,13 +79,13 @@ Use `document_id` to update or archive a document:
update_job = client.jobs.create(
source_type="url",
source_url="https://example.com/manual-v2.pdf",
document_id=job.document_id,
document_id=document_id,
)

document = client.documents.get(job.document_id)
document = client.documents.get(document_id)
print(document.status)

client.documents.archive(job.document_id)
client.documents.archive(document_id)
```

You can also list documents in a namespace:
Expand Down Expand Up @@ -177,14 +187,14 @@ job = client.jobs.create(
parsing_params={"model": "advanced", "ocr_enabled": True},
)

print(job.document_id) # Persist this to update/archive the document later.

# Step 2: Upload file to presigned URL
client.jobs.upload(job, file=Path("report.pdf"))

# Step 3: Poll until done (adaptive backoff)
job_result = client.jobs.wait(job.job_id, poll_interval=10.0, poll_timeout=1800.0)

print(job_result.document_id) # Persist this to update/archive the document later.

# Step 4: Download and parse results
result = client.jobs.load(job_result)
print(result.statistics)
Expand Down
6 changes: 6 additions & 0 deletions src/knowhere/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
from knowhere.types.job import Job, JobError, JobProgress, JobResult
from knowhere.types.params import ParsingParams, WebhookConfig
from knowhere.types.retrieval import (
RetrievalChannel,
RetrievalFilterMode,
RetrievalSectionExclusion,
RetrievalSource,
RetrievalQueryResponse,
RetrievalResult,
Expand Down Expand Up @@ -97,6 +100,9 @@
"Document",
"DocumentListResponse",
# Retrieval types
"RetrievalChannel",
"RetrievalFilterMode",
"RetrievalSectionExclusion",
"RetrievalSource",
"RetrievalQueryResponse",
"RetrievalResult",
Expand Down
18 changes: 16 additions & 2 deletions src/knowhere/resources/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,16 +145,23 @@ def load(
if not job_result.result_url:
raise InvalidStateError("JobResult does not have a result_url.")
result_url: str = job_result.result_url
namespace: Optional[str] = job_result.namespace
document_id: Optional[str] = job_result.document_id
else:
result_url = job_result
namespace = None
document_id = None

response: httpx.Response = self._client._client.get(
result_url, timeout=self._client.upload_timeout
)
response.raise_for_status()
zip_bytes: bytes = response.content

return parseResultZip(zip_bytes, verify_checksum=verify_checksum)
parsed_result = parseResultZip(zip_bytes, verify_checksum=verify_checksum)
parsed_result.namespace = namespace
parsed_result.document_id = document_id
return parsed_result


class AsyncJobs(AsyncAPIResource):
Expand Down Expand Up @@ -251,13 +258,20 @@ async def load(
if not job_result.result_url:
raise InvalidStateError("JobResult does not have a result_url.")
result_url: str = job_result.result_url
namespace: Optional[str] = job_result.namespace
document_id: Optional[str] = job_result.document_id
else:
result_url = job_result
namespace = None
document_id = None

response: httpx.Response = await self._client._client.get(
result_url, timeout=self._client.upload_timeout
)
response.raise_for_status()
zip_bytes: bytes = response.content

return parseResultZip(zip_bytes, verify_checksum=verify_checksum)
parsed_result = parseResultZip(zip_bytes, verify_checksum=verify_checksum)
parsed_result.namespace = namespace
parsed_result.document_id = document_id
return parsed_result
59 changes: 56 additions & 3 deletions src/knowhere/resources/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from typing import Any, Dict, Optional

from knowhere.resources._base import AsyncAPIResource, SyncAPIResource
from knowhere.types.retrieval import RetrievalQueryResponse
from knowhere.types.retrieval import (
RetrievalChannel,
RetrievalFilterMode,
RetrievalQueryResponse,
RetrievalSectionExclusion,
)


class Retrieval(SyncAPIResource):
Expand All @@ -17,15 +22,39 @@ def query(
query: str,
namespace: Optional[str] = None,
top_k: Optional[int] = None,
data_type: Optional[int] = None,
signal_paths: Optional[list[str]] = None,
filter_mode: Optional[RetrievalFilterMode] = None,
channels: Optional[list[RetrievalChannel]] = None,
channel_weights: Optional[dict[RetrievalChannel, float]] = None,
rerank: Optional[bool] = None,
threshold: Optional[float] = None,
internal_recall_k: Optional[int] = None,
exclude_document_ids: Optional[list[str]] = None,
exclude_sections: Optional[list[dict[str, str]]] = None,
exclude_sections: Optional[list[RetrievalSectionExclusion]] = None,
) -> RetrievalQueryResponse:
"""Query published documents in a namespace."""
body: Dict[str, Any] = {"query": query}
if namespace is not None:
body["namespace"] = namespace
if top_k is not None:
body["top_k"] = top_k
if data_type is not None:
body["data_type"] = data_type
if signal_paths is not None:
body["signal_paths"] = signal_paths
if filter_mode is not None:
body["filter_mode"] = filter_mode
if channels is not None:
body["channels"] = channels
if channel_weights is not None:
body["channel_weights"] = channel_weights
if rerank is not None:
body["rerank"] = rerank
if threshold is not None:
body["threshold"] = threshold
if internal_recall_k is not None:
body["internal_recall_k"] = internal_recall_k
if exclude_document_ids is not None:
body["exclude_document_ids"] = exclude_document_ids
if exclude_sections is not None:
Expand All @@ -48,15 +77,39 @@ async def query(
query: str,
namespace: Optional[str] = None,
top_k: Optional[int] = None,
data_type: Optional[int] = None,
signal_paths: Optional[list[str]] = None,
filter_mode: Optional[RetrievalFilterMode] = None,
channels: Optional[list[RetrievalChannel]] = None,
channel_weights: Optional[dict[RetrievalChannel, float]] = None,
rerank: Optional[bool] = None,
threshold: Optional[float] = None,
internal_recall_k: Optional[int] = None,
exclude_document_ids: Optional[list[str]] = None,
exclude_sections: Optional[list[dict[str, str]]] = None,
exclude_sections: Optional[list[RetrievalSectionExclusion]] = None,
) -> RetrievalQueryResponse:
"""Query published documents in a namespace."""
body: Dict[str, Any] = {"query": query}
if namespace is not None:
body["namespace"] = namespace
if top_k is not None:
body["top_k"] = top_k
if data_type is not None:
body["data_type"] = data_type
if signal_paths is not None:
body["signal_paths"] = signal_paths
if filter_mode is not None:
body["filter_mode"] = filter_mode
if channels is not None:
body["channels"] = channels
if channel_weights is not None:
body["channel_weights"] = channel_weights
if rerank is not None:
body["rerank"] = rerank
if threshold is not None:
body["threshold"] = threshold
if internal_recall_k is not None:
body["internal_recall_k"] = internal_recall_k
if exclude_document_ids is not None:
body["exclude_document_ids"] = exclude_document_ids
if exclude_sections is not None:
Expand Down
6 changes: 6 additions & 0 deletions src/knowhere/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from knowhere.types.job import Job, JobError, JobResult
from knowhere.types.params import ParsingParams, WebhookConfig
from knowhere.types.retrieval import (
RetrievalChannel,
RetrievalFilterMode,
RetrievalSectionExclusion,
RetrievalSource,
RetrievalQueryResponse,
RetrievalResult,
Expand Down Expand Up @@ -38,6 +41,9 @@
"Document",
"DocumentListResponse",
# retrieval
"RetrievalChannel",
"RetrievalFilterMode",
"RetrievalSectionExclusion",
"RetrievalSource",
"RetrievalQueryResponse",
"RetrievalResult",
Expand Down
1 change: 0 additions & 1 deletion src/knowhere/types/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class Job(BaseModel):
status: str
source_type: str
namespace: Optional[str] = None
document_id: Optional[str] = None
data_id: Optional[str] = None
created_at: Optional[datetime] = None
upload_url: Optional[str] = None
Expand Down
6 changes: 6 additions & 0 deletions src/knowhere/types/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ class ParseResult:
kb_csv: Optional[str]
hierarchy_view_html: Optional[str]
raw_zip: bytes
namespace: Optional[str]
document_id: Optional[str]

def __init__(
self,
Expand All @@ -285,6 +287,8 @@ def __init__(
kb_csv: Optional[str],
hierarchy_view_html: Optional[str],
raw_zip: bytes,
namespace: Optional[str] = None,
document_id: Optional[str] = None,
) -> None:
self.manifest = manifest
self.chunks = chunks
Expand All @@ -295,6 +299,8 @@ def __init__(
self.kb_csv = kb_csv
self.hierarchy_view_html = hierarchy_view_html
self.raw_zip = raw_zip
self.namespace = namespace
self.document_id = document_id

# -- convenience properties --

Expand Down
14 changes: 13 additions & 1 deletion src/knowhere/types/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,22 @@

from __future__ import annotations

from typing import Optional
from typing import Literal, Optional, TypedDict

from pydantic import BaseModel


RetrievalChannel = Literal["path", "content", "term"]
RetrievalFilterMode = Literal["delete", "keep"]


class RetrievalSectionExclusion(TypedDict):
"""Section exclusion for follow-up retrieval queries."""

document_id: str
section_path: str


class RetrievalSource(BaseModel):
"""Caller-facing source reference attached to a retrieval result."""

Expand All @@ -30,4 +41,5 @@ class RetrievalQueryResponse(BaseModel):

namespace: str
query: str
router_used: Optional[str] = None
results: list[RetrievalResult]
1 change: 0 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def mock_job_response() -> Dict[str, Any]:
"status": "waiting-file",
"source_type": "file",
"namespace": "default",
"document_id": "doc_test123",
"data_id": None,
"created_at": "2025-01-01T00:00:00Z",
"upload_url": "https://storage.example.com/upload?token=abc",
Expand Down
8 changes: 5 additions & 3 deletions tests/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def test_create_with_url_source(
"status": "pending",
"source_type": "url",
"namespace": "support-center",
"document_id": "doc_123",
}

route = respx.post(JOBS_URL).mock(
Expand All @@ -53,7 +52,7 @@ def test_create_with_url_source(
assert job.source_type == "url"
assert job.status == "pending"
assert job.namespace == "support-center"
assert job.document_id == "doc_123"
assert not hasattr(job, "document_id")

@respx.mock
def test_create_with_file_source(
Expand Down Expand Up @@ -87,7 +86,6 @@ def test_create_sends_correct_body(
"status": "pending",
"source_type": "url",
"namespace": "support-center",
"document_id": "doc_123",
}

route = respx.post(JOBS_URL).mock(
Expand Down Expand Up @@ -284,6 +282,8 @@ def test_load_with_job_result_object(
job_id="job_load",
status="done",
source_type="url",
namespace="support-center",
document_id="doc_123",
result_url=result_url,
)

Expand All @@ -293,3 +293,5 @@ def test_load_with_job_result_object(

assert route.called
assert parse_result.manifest is not None
assert parse_result.namespace == "support-center"
assert parse_result.document_id == "doc_123"
7 changes: 6 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_from_dict_with_document_scope(self) -> None:
}
job: Job = Job(**data)
assert job.namespace == "support-center"
assert job.document_id == "doc_123"
assert "document_id" not in job.model_dump()

def test_from_dict_with_upload(self) -> None:
data: Dict[str, Any] = {
Expand Down Expand Up @@ -717,6 +717,11 @@ def test_statistics_shortcut(self) -> None:
assert stats.total_chunks == 3
assert stats.text_chunks == 1

def test_document_scope_defaults_to_none(self) -> None:
result: ParseResult = _build_parse_result()
assert result.namespace is None
assert result.document_id is None

def test_raw_zip_accessible(self) -> None:
result: ParseResult = _build_parse_result()
assert result.raw_zip == b"fake zip bytes"
Expand Down
Loading
Loading