diff --git a/python/packages/azure-cosmos/AGENTS.md b/python/packages/azure-cosmos/AGENTS.md index 9bb7f76da9..263ea1a2f4 100644 --- a/python/packages/azure-cosmos/AGENTS.md +++ b/python/packages/azure-cosmos/AGENTS.md @@ -1,30 +1,65 @@ # Azure Cosmos DB Package (agent-framework-azure-cosmos) -Azure Cosmos DB history provider integration for Agent Framework. +Azure Cosmos DB history, retrieval, and workflow checkpointing integration for Agent Framework. ## Main Classes - **`CosmosHistoryProvider`** - Persistent conversation history storage backed by Azure Cosmos DB +- **`CosmosContextProvider`** - Cosmos DB context provider for injecting relevant documents before a model run and writing request/response messages back into the same container after the run +- **`CosmosCheckpointStorage`** - Cosmos DB-backed workflow checkpoint storage for durable workflow execution ## Usage ```python -from agent_framework.azure import CosmosHistoryProvider +from agent_framework_azure_cosmos import ( + CosmosCheckpointStorage, + CosmosContextProvider, + CosmosContextSearchMode, + CosmosHistoryProvider, +) -provider = CosmosHistoryProvider( +history_provider = CosmosHistoryProvider( endpoint="https://.documents.azure.com:443/", credential="", database_name="agent-framework", container_name="chat-history", ) + +context_provider = CosmosContextProvider( + endpoint="https://.documents.azure.com:443/", + credential="", + database_name="agent-framework", + container_name="knowledge", + embedding_function=my_embedding_function, +) + +checkpoint_storage = CosmosCheckpointStorage( + endpoint="https://.documents.azure.com:443/", + credential="", + database_name="agent-framework", + container_name="workflow-checkpoints", +) ``` -Container name is configured on the provider. `session_id` is used as the partition key. +Container name is configured on each provider. `CosmosHistoryProvider` uses `session_id` as the partition key for reads/writes. `CosmosContextProvider` can optionally scope retrieval with `partition_key`. + +`CosmosContextProvider` joins the filtered `user` and `assistant` messages from the current run into one retrieval query string, and writes request/response messages back into the same Cosmos knowledge container after each run. All configuration — including search mode, weights, top_k, scan_limit, and partition key — is set on the constructor. + +The default search mode is `VECTOR`. Full-text and hybrid modes are also supported via the `search_mode` constructor parameter. Optional hybrid RRF weights can be provided through `weights=[...]` on the constructor. + +The application owner is responsible for making sure the Cosmos account, database, container, partitioning strategy, and any required full-text/vector/hybrid indexing configuration already exist. The provider does not create or manage Cosmos resources or search policies. + +`CosmosCheckpointStorage` creates the configured database and container on first use when needed, and stores workflow checkpoints using `/workflow_name` as the partition key. ## Import Path ```python +from agent_framework_azure_cosmos import ( + CosmosCheckpointStorage, + CosmosContextProvider, + CosmosHistoryProvider, +) + +# `CosmosHistoryProvider` is also available from the Azure namespace: from agent_framework.azure import CosmosHistoryProvider -# or directly: -from agent_framework_azure_cosmos import CosmosHistoryProvider ``` diff --git a/python/packages/azure-cosmos/README.md b/python/packages/azure-cosmos/README.md index a03c5c6f93..5274ff7df3 100644 --- a/python/packages/azure-cosmos/README.md +++ b/python/packages/azure-cosmos/README.md @@ -37,6 +37,73 @@ Container naming behavior: See `samples/02-agents/conversations/cosmos_history_provider.py` for a runnable example. +## Azure Cosmos DB Context Provider + +The Azure Cosmos DB integration also provides `CosmosContextProvider` for context injection before model invocation. It also writes input and response messages back into the same Cosmos container after each run so the knowledge container can accumulate additional context over time. + +### Basic Usage Example + +```python +from azure.identity.aio import DefaultAzureCredential +from agent_framework_azure_cosmos import CosmosContextProvider, CosmosContextSearchMode + +provider = CosmosContextProvider( + endpoint="https://.documents.azure.com:443/", + credential=DefaultAzureCredential(), + database_name="agent-framework", + container_name="knowledge", + embedding_function=my_embedding_function, + content_field_names=("content", "text"), +) +``` + +Supported retrieval configuration includes: + +- `search_mode`: `CosmosContextSearchMode.VECTOR` (default), `.FULL_TEXT`, or `.HYBRID` +- `weights` for hybrid RRF runs (optional, omitted by default) +- `top_k` for controlling the number of context messages injected +- `scan_limit` for controlling the number of Cosmos candidate items scanned +- `partition_key` for scoping Cosmos retrieval + +All configuration is set on the constructor. The default search mode is `VECTOR`, which requires an `embedding_function`. For full-text mode, set `search_mode=CosmosContextSearchMode.FULL_TEXT`: + +```python +provider = CosmosContextProvider( + endpoint="https://.documents.azure.com:443/", + credential=DefaultAzureCredential(), + database_name="agent-framework", + container_name="knowledge", + search_mode=CosmosContextSearchMode.FULL_TEXT, +) +``` + +For hybrid retrieval with optional weights: + +```python +provider = CosmosContextProvider( + endpoint="https://.documents.azure.com:443/", + credential=DefaultAzureCredential(), + database_name="agent-framework", + container_name="knowledge", + embedding_function=my_embedding_function, + search_mode=CosmosContextSearchMode.HYBRID, + weights=[2.0, 1.0], + top_k=3, + scan_limit=10, + partition_key="tenant-a", +) +``` + +`CosmosContextProvider` contributes retrieval context in `before_run(...)` and persists input/response messages in `after_run(...)`. + +The provider builds retrieval input by joining the filtered `user` and `assistant` messages from the current run into a single query string. That joined query text is then used for full-text tokenization, vector embedding generation, or hybrid retrieval depending on the configured search mode. + +The provider writes the request/response messages back into the same knowledge container configured by `container_name`. + +The provider assumes the Cosmos account, database, container, partitioning strategy, and any required Cosmos full-text/vector/hybrid indexing policies already exist and are correctly configured by the application owner. It does not create or manage Cosmos resources, schema, or search policies for you. + +See `packages/azure-cosmos/samples/cosmos_context_provider.py` for a package-local context provider example. + ## Cosmos DB Workflow Checkpoint Storage `CosmosCheckpointStorage` implements the `CheckpointStorage` protocol, enabling @@ -84,7 +151,7 @@ workflow = WorkflowBuilder( checkpoint_storage=checkpoint_storage, ).build() -# Run the workflow — checkpoints are automatically saved after each superstep +# Run the workflow - checkpoints are automatically saved after each superstep result = await workflow.run(message="input data") # Resume from a checkpoint @@ -124,3 +191,4 @@ portal with this partition key configuration. See `samples/03-workflows/checkpoint/cosmos_workflow_checkpointing.py` for a standalone example, or `samples/03-workflows/checkpoint/cosmos_workflow_checkpointing_foundry.py` for an end-to-end example with Azure AI Foundry agents. + diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/__init__.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/__init__.py index 66373b0f1d..e0f42ae9a6 100644 --- a/python/packages/azure-cosmos/agent_framework_azure_cosmos/__init__.py +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/__init__.py @@ -1,8 +1,18 @@ # Copyright (c) Microsoft. All rights reserved. +"""Azure Cosmos DB provider exports. + +Supported classes: + - ``CosmosContextProvider`` + - ``CosmosCheckpointStorage`` + - ``CosmosContextSearchMode`` + - ``CosmosHistoryProvider`` +""" + import importlib.metadata from ._checkpoint_storage import CosmosCheckpointStorage +from ._context_provider import CosmosContextProvider, CosmosContextSearchMode from ._history_provider import CosmosHistoryProvider try: @@ -12,6 +22,8 @@ __all__ = [ "CosmosCheckpointStorage", + "CosmosContextProvider", + "CosmosContextSearchMode", "CosmosHistoryProvider", "__version__", ] diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_context_provider.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_context_provider.py new file mode 100644 index 0000000000..12d5f1ae96 --- /dev/null +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_context_provider.py @@ -0,0 +1,363 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Azure Cosmos DB context provider.""" + +from __future__ import annotations + +import logging +import re +import time +import uuid +from collections.abc import Awaitable, Callable, Sequence +from enum import Enum +from typing import TYPE_CHECKING, Any, TypedDict + +from agent_framework import ( + AGENT_FRAMEWORK_USER_AGENT, + AgentSession, + ContextProvider, + Message, + SessionContext, + SupportsGetEmbeddings, +) +from agent_framework._settings import SecretString, load_settings +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential +from azure.cosmos.aio import ContainerProxy, CosmosClient, DatabaseProxy + +if TYPE_CHECKING: + from agent_framework._agents import SupportsAgentRun + + +logger = logging.getLogger(__name__) + +AzureCredentialTypes = TokenCredential | AsyncTokenCredential + +COSMOS_USER_AGENT_SUFFIX = f"{AGENT_FRAMEWORK_USER_AGENT} CosmosContextProvider" + + +class CosmosContextSearchMode(str, Enum): + """Supported Azure Cosmos DB retrieval modes for the context provider.""" + + VECTOR = "vector" + FULL_TEXT = "full_text" + HYBRID = "hybrid" + + +class CosmosContextSettings(TypedDict, total=False): + """Settings for CosmosContextProvider resolved from args and environment.""" + + endpoint: str | None + database_name: str | None + container_name: str | None + key: SecretString | None + top_k: int | None + scan_limit: int | None + + +class CosmosContextProvider(ContextProvider): + """Azure Cosmos DB-backed context provider. + + Queries a Cosmos DB knowledge container for relevant context before + agent model invocation, and writes request/response messages back + into the same container after each run. + """ + + def __init__( + self, + source_id: str = "azure_cosmos_context", + *, + endpoint: str | None = None, + database_name: str | None = None, + container_name: str | None = None, + credential: str | AzureCredentialTypes | None = None, + cosmos_client: CosmosClient | None = None, + container_client: ContainerProxy | None = None, + top_k: int | None = None, + scan_limit: int | None = None, + search_mode: CosmosContextSearchMode = CosmosContextSearchMode.VECTOR, + content_field_names: Sequence[str] = ("content", "text"), + message_field_name: str | None = "message", + vector_field_name: str = "embedding", + embedding_function: Callable[[str], Awaitable[list[float]]] + | SupportsGetEmbeddings[str, list[float], Any] + | None = None, + partition_key: str | None = None, + weights: Sequence[float] | None = None, + context_prompt: str = "Use the following context to answer the question:", + env_file_path: str | None = None, + env_file_encoding: str | None = None, + ) -> None: + super().__init__(source_id) + + self.top_k = top_k or 5 + self.scan_limit = scan_limit or 25 + self.search_mode = search_mode + self.content_field_names = tuple(content_field_names) + self.message_field_name = message_field_name + self.vector_field_name = vector_field_name + self.embedding_function = embedding_function + self.partition_key = partition_key + self.weights = tuple(float(w) for w in weights) if weights is not None else None + self.context_prompt = context_prompt + + self._cosmos_client: CosmosClient | None = cosmos_client + self._container_proxy: ContainerProxy | None = container_client + self._database_client: DatabaseProxy | None = None + self._owns_client = False + + if self._container_proxy is not None: + self.database_name: str = database_name or "" + self.container_name: str = container_name or "" + return + + required_fields: list[str] = ["database_name", "container_name"] + if cosmos_client is None: + required_fields.append("endpoint") + if credential is None: + required_fields.append("key") + + settings = load_settings( + CosmosContextSettings, + env_prefix="AZURE_COSMOS_", + required_fields=required_fields, + endpoint=endpoint, + database_name=database_name, + container_name=container_name, + key=credential if isinstance(credential, str) else None, + top_k=top_k, + scan_limit=scan_limit, + env_file_path=env_file_path, + env_file_encoding=env_file_encoding, + ) + + self.database_name = settings["database_name"] # type: ignore[assignment,reportTypedDictNotRequiredAccess] + self.container_name = settings["container_name"] # type: ignore[assignment,reportTypedDictNotRequiredAccess] + env_top_k = settings.get("top_k") + if env_top_k is not None: + self.top_k = env_top_k + env_scan_limit = settings.get("scan_limit") + if env_scan_limit is not None: + self.scan_limit = env_scan_limit + + if self._cosmos_client is None: + self._cosmos_client = CosmosClient( + url=settings["endpoint"], # type: ignore[arg-type] + credential=credential or settings["key"].get_secret_value(), # type: ignore[arg-type,union-attr] + user_agent_suffix=COSMOS_USER_AGENT_SUFFIX, + ) + self._owns_client = True + + self._database_client = self._cosmos_client.get_database_client(self.database_name) + + async def before_run( + self, + *, + agent: SupportsAgentRun, + session: AgentSession, + context: SessionContext, + state: dict[str, Any], + ) -> None: + """Retrieve relevant context from Cosmos DB before model invocation.""" + filtered = [ + msg + for msg in context.input_messages + if msg and msg.text and msg.text.strip() and msg.role in {"user", "assistant"} + ] + if not filtered: + return + + query_text = "\n".join(msg.text.strip() for msg in filtered).strip() + if not query_text: + return + + query_terms = tuple(dict.fromkeys(m.casefold() for m in re.findall(r"\w+", query_text, flags=re.UNICODE))) + if self.search_mode in {CosmosContextSearchMode.FULL_TEXT, CosmosContextSearchMode.HYBRID} and not query_terms: + return + + state["query_text"] = query_text + + items = await self._execute_retrieval_query(query_text, query_terms) + + result_messages: list[Message] = [] + for item in items: + msg = self._shape_context_message(item) + if msg is not None: + result_messages.append(msg) + if len(result_messages) >= self.top_k: + break + + if result_messages: + context.extend_messages( + self.source_id, + [Message(role="user", contents=[self.context_prompt]), *result_messages], + ) + + async def after_run( + self, + *, + agent: SupportsAgentRun, + session: AgentSession, + context: SessionContext, + state: dict[str, Any], + ) -> None: + """Persist conversation messages to the knowledge container after each run. + + Stores user and assistant messages with embeddings (when available) so + they are retrievable by ``before_run`` on subsequent invocations. + """ + messages_to_store: list[Message] = list(context.input_messages) + if context.response and context.response.messages: + messages_to_store.extend(context.response.messages) + + writeback = [m for m in messages_to_store if m.role in {"user", "assistant"} and m.text and m.text.strip()] + if not writeback: + return + + container = await self._get_container() + session_key = context.session_id or str(uuid.uuid4()) + if not context.session_id: + logger.warning("No session_id; generated '%s' for Cosmos writeback partition key.", session_key) + + agent_name = getattr(agent, "name", None) + user_id = context.metadata.get("user_id") if context.metadata else None + + base_sort_key = time.time_ns() + for index, message in enumerate(writeback): + content_text = message.text.strip() + role_value = str(message.role.value) if hasattr(message.role, "value") else str(message.role) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportAttributeAccessIssue] + document: dict[str, Any] = { + "id": str(uuid.uuid4()), + "session_id": session_key, + "sort_key": base_sort_key + index, + "source_id": self.source_id, + "role": role_value, + "content": content_text, + "message": message.to_dict(), + } + if agent_name: + document["agent_name"] = agent_name + if user_id: + document["user_id"] = user_id + if message.author_name: + document["author_name"] = message.author_name + if self.partition_key is not None: + document["partition_key"] = self.partition_key + + if self.embedding_function is not None: + try: + embedding = await self._get_query_vector(content_text) + document[self.vector_field_name] = embedding + except Exception: + logger.warning("Failed to generate embedding for writeback document; skipping vector field.") + + await container.upsert_item(document) + + async def close(self) -> None: + """Close the underlying Cosmos client when this provider owns it.""" + if self._owns_client and self._cosmos_client is not None: + await self._cosmos_client.close() + + async def __aenter__(self) -> CosmosContextProvider: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: + try: + await self.close() + except Exception: + if exc_type is None: + raise + + # --- Private --- + + async def _get_container(self) -> ContainerProxy: + """Return the Cosmos container proxy, resolving lazily from database client.""" + if self._container_proxy is not None: + return self._container_proxy + if self._database_client is None: + raise RuntimeError("Cosmos database client is not initialized.") + self._container_proxy = self._database_client.get_container_client(self.container_name) + return self._container_proxy + + async def _execute_retrieval_query(self, query_text: str, query_terms: tuple[str, ...]) -> list[dict[str, Any]]: + """Build and execute the Cosmos retrieval query for the configured search mode.""" + fields = list(self.content_field_names) + if self.message_field_name and self.message_field_name not in fields: + fields.append(self.message_field_name) + select = ", ".join(f"c.{f}" for f in fields) + base = f"SELECT TOP {self.scan_limit} {select} FROM c" # noqa: S608 # nosec B608 + + parameters: list[dict[str, object]] = [] + + if self.search_mode is CosmosContextSearchMode.FULL_TEXT: + search_field = self.content_field_names[0] + query = f"{base} ORDER BY RANK FullTextScore(c.{search_field}, @query_text)" + parameters.append({"name": "@query_text", "value": " ".join(query_terms)}) + + elif self.search_mode is CosmosContextSearchMode.VECTOR: + query_vector = await self._get_query_vector(query_text) + query = f"{base} ORDER BY VectorDistance(c.{self.vector_field_name}, @query_vector) ASC" + parameters.append({"name": "@query_vector", "value": query_vector}) + + elif self.search_mode is CosmosContextSearchMode.HYBRID: + query_vector = await self._get_query_vector(query_text) + search_field = self.content_field_names[0] + ft = f"FullTextScore(c.{search_field}, @query_text)" + vd = f"VectorDistance(c.{self.vector_field_name}, @query_vector)" + if self.weights is not None: + wl = "[" + ", ".join(f"{w:g}" for w in self.weights) + "]" + rrf = f"RRF({ft}, {vd}, {wl})" + else: + rrf = f"RRF({ft}, {vd})" + query = f"{base} ORDER BY RANK {rrf}" + parameters.append({"name": "@query_text", "value": " ".join(query_terms)}) + parameters.append({"name": "@query_vector", "value": query_vector}) + + else: + raise ValueError(f"Unsupported search_mode: {self.search_mode}") + + container = await self._get_container() + query_kwargs: dict[str, Any] = {"query": query, "max_item_count": self.scan_limit} + if parameters: + query_kwargs["parameters"] = parameters + if self.partition_key is not None: + query_kwargs["partition_key"] = self.partition_key + return [item async for item in container.query_items(**query_kwargs)] + + def _shape_context_message(self, item: dict[str, Any]) -> Message | None: + """Convert a Cosmos item into a context Message.""" + payload = item.get(self.message_field_name) if self.message_field_name else None + if isinstance(payload, dict): + try: + return Message.from_dict(payload) # pyright: ignore[reportUnknownArgumentType] + except (TypeError, ValueError): + pass + + content = next( + (v.strip() for f in self.content_field_names if isinstance(v := item.get(f), str) and v.strip()), + None, + ) + if not content: + return None + return Message(role="user", contents=[content]) + + async def _get_query_vector(self, query_text: str) -> list[float]: + """Get a query embedding from the configured embedding provider.""" + if self.embedding_function is None: + raise ValueError("embedding_function is required for vector and hybrid retrieval") + + if isinstance(self.embedding_function, SupportsGetEmbeddings): + embeddings = await self.embedding_function.get_embeddings([query_text]) # type: ignore[reportUnknownVariableType] + if not embeddings: + raise ValueError("embedding_function returned no embeddings") + return [float(v) for v in embeddings[0].vector] # type: ignore[reportUnknownVariableType] + + return [float(v) for v in await self.embedding_function(query_text)] + + +__all__ = ["CosmosContextProvider", "CosmosContextSearchMode"] diff --git a/python/packages/azure-cosmos/pyproject.toml b/python/packages/azure-cosmos/pyproject.toml index 4193b07014..295c151ef4 100644 --- a/python/packages/azure-cosmos/pyproject.toml +++ b/python/packages/azure-cosmos/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "agent-framework-azure-cosmos" -description = "Azure Cosmos DB history provider integration for Microsoft Agent Framework." +description = "Azure Cosmos DB history and context provider integration for Microsoft Agent Framework." authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] readme = "README.md" requires-python = ">=3.10" diff --git a/python/packages/azure-cosmos/tests/test_cosmos_context_provider.py b/python/packages/azure-cosmos/tests/test_cosmos_context_provider.py new file mode 100644 index 0000000000..ad8cf45b32 --- /dev/null +++ b/python/packages/azure-cosmos/tests/test_cosmos_context_provider.py @@ -0,0 +1,501 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import uuid +from collections.abc import AsyncIterator +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest +from agent_framework import AgentResponse, Message +from agent_framework._sessions import AgentSession, SessionContext + +import agent_framework_azure_cosmos._context_provider as context_provider_module +from agent_framework_azure_cosmos import CosmosContextProvider, CosmosContextSearchMode + + +def _to_async_iter(items: list[Any]) -> AsyncIterator[Any]: + async def _iterator() -> AsyncIterator[Any]: + for item in items: + yield item + + return _iterator() + + +async def _stub_embed(_: str) -> list[float]: + return [1.0, 0.0] + + +def test_provider_uses_existing_container_client() -> None: + container = MagicMock() + provider = CosmosContextProvider( + source_id="ctx", + container_client=container, + search_mode=CosmosContextSearchMode.FULL_TEXT, + ) + assert provider.source_id == "ctx" + assert provider.search_mode is CosmosContextSearchMode.FULL_TEXT + + +def test_provider_default_search_mode_is_vector() -> None: + provider = CosmosContextProvider(container_client=MagicMock(), embedding_function=_stub_embed) + assert provider.search_mode is CosmosContextSearchMode.VECTOR + assert provider.vector_field_name == "embedding" + + +def test_provider_constructs_client_from_environment(monkeypatch: pytest.MonkeyPatch) -> None: + database_client = MagicMock() + cosmos_client = MagicMock() + cosmos_client.get_database_client.return_value = database_client + cosmos_client_factory = MagicMock(return_value=cosmos_client) + + monkeypatch.setattr(context_provider_module, "CosmosClient", cosmos_client_factory) + monkeypatch.setenv("AZURE_COSMOS_ENDPOINT", "https://account.documents.azure.com:443/") + monkeypatch.setenv("AZURE_COSMOS_DATABASE_NAME", "agent-framework") + monkeypatch.setenv("AZURE_COSMOS_CONTAINER_NAME", "knowledge") + monkeypatch.setenv("AZURE_COSMOS_KEY", "test-key") + monkeypatch.setenv("AZURE_COSMOS_TOP_K", "4") + monkeypatch.setenv("AZURE_COSMOS_SCAN_LIMIT", "9") + + provider = CosmosContextProvider(search_mode=CosmosContextSearchMode.FULL_TEXT) + + cosmos_client_factory.assert_called_once() + kwargs = cosmos_client_factory.call_args.kwargs + assert kwargs["url"] == "https://account.documents.azure.com:443/" + assert kwargs["credential"] == "test-key" + assert "CosmosContextProvider" in kwargs["user_agent_suffix"] + assert provider.database_name == "agent-framework" + assert provider.container_name == "knowledge" + assert provider.top_k == 4 + assert provider.scan_limit == 9 + + +class TestBeforeRun: + async def test_skips_when_no_user_or_assistant_messages(self) -> None: + container = MagicMock() + container.query_items = MagicMock(return_value=_to_async_iter([])) + provider = CosmosContextProvider(container_client=container, search_mode=CosmosContextSearchMode.FULL_TEXT) + session = AgentSession(session_id="s") + context = SessionContext(input_messages=[Message(role="system", contents=["ignore"])], session_id="s1") + + await provider.before_run( + agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + + container.query_items.assert_not_called() + assert context.context_messages.get(provider.source_id) is None + + async def test_full_text_queries_cosmos_and_adds_context(self) -> None: + container = MagicMock() + container.query_items = MagicMock( + return_value=_to_async_iter([ + {"content": "Cosmos DB supports vector search."}, + {"content": "Full text search is also available."}, + ]) + ) + provider = CosmosContextProvider(container_client=container, search_mode=CosmosContextSearchMode.FULL_TEXT) + session = AgentSession(session_id="s") + context = SessionContext( + input_messages=[Message(role="user", contents=["How does search work?"])], session_id="s1" + ) + + await provider.before_run( + agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + + container.query_items.assert_called_once() + provider_messages = context.context_messages[provider.source_id] + assert provider_messages[0].text == provider.context_prompt + assert len(provider_messages) >= 2 + query_kwargs = container.query_items.call_args.kwargs + assert "ORDER BY RANK FullTextScore(" in query_kwargs["query"] + assert "WHERE" not in query_kwargs["query"] + + async def test_vector_mode_builds_vector_distance_query(self) -> None: + container = MagicMock() + container.query_items = MagicMock(return_value=_to_async_iter([{"content": "Vector search result."}])) + provider = CosmosContextProvider(container_client=container, embedding_function=_stub_embed) + session = AgentSession(session_id="s") + context = SessionContext(input_messages=[Message(role="user", contents=["Find similar docs"])], session_id="s1") + + await provider.before_run( + agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + + query_kwargs = container.query_items.call_args.kwargs + assert "ORDER BY VectorDistance(c.embedding, @query_vector) ASC" in query_kwargs["query"] + assert query_kwargs["parameters"] == [{"name": "@query_vector", "value": [1.0, 0.0]}] + + async def test_hybrid_mode_builds_rrf_query_with_weights(self) -> None: + container = MagicMock() + container.query_items = MagicMock(return_value=_to_async_iter([{"content": "Hybrid result."}])) + provider = CosmosContextProvider( + container_client=container, + embedding_function=_stub_embed, + search_mode=CosmosContextSearchMode.HYBRID, + weights=[2.0, 1.0], + ) + session = AgentSession(session_id="s") + context = SessionContext( + input_messages=[Message(role="user", contents=["Explain hybrid search"])], session_id="s1" + ) + + await provider.before_run( + agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + + query_kwargs = container.query_items.call_args.kwargs + assert "ORDER BY RANK RRF(" in query_kwargs["query"] + assert "[2, 1]" in query_kwargs["query"] + + async def test_hybrid_mode_omits_weights_when_none(self) -> None: + container = MagicMock() + container.query_items = MagicMock(return_value=_to_async_iter([{"content": "Hybrid result."}])) + provider = CosmosContextProvider( + container_client=container, + embedding_function=_stub_embed, + search_mode=CosmosContextSearchMode.HYBRID, + ) + session = AgentSession(session_id="s") + context = SessionContext( + input_messages=[Message(role="user", contents=["Explain hybrid ranking"])], session_id="s1" + ) + + await provider.before_run( + agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + + query_kwargs = container.query_items.call_args.kwargs + assert "RRF(FullTextScore(" in query_kwargs["query"] + assert "[" not in query_kwargs["query"].split("RRF(", 1)[1] + + async def test_respects_top_k(self) -> None: + container = MagicMock() + container.query_items = MagicMock( + return_value=_to_async_iter([{"content": "Result 1."}, {"content": "Result 2."}, {"content": "Result 3."}]) + ) + provider = CosmosContextProvider( + container_client=container, top_k=1, search_mode=CosmosContextSearchMode.FULL_TEXT + ) + session = AgentSession(session_id="s") + context = SessionContext(input_messages=[Message(role="user", contents=["search query"])], session_id="s1") + + await provider.before_run( + agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + + # prompt + 1 result + assert len(context.context_messages[provider.source_id]) == 2 + + async def test_respects_partition_key(self) -> None: + container = MagicMock() + container.query_items = MagicMock(return_value=_to_async_iter([{"content": "Result."}])) + provider = CosmosContextProvider( + container_client=container, + partition_key="tenant-a", + search_mode=CosmosContextSearchMode.FULL_TEXT, + ) + session = AgentSession(session_id="s") + context = SessionContext(input_messages=[Message(role="user", contents=["search"])], session_id="s1") + + await provider.before_run( + agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + + assert container.query_items.call_args.kwargs["partition_key"] == "tenant-a" + + async def test_joins_user_and_assistant_messages_for_query(self) -> None: + container = MagicMock() + container.query_items = MagicMock(return_value=_to_async_iter([{"content": "Result."}])) + provider = CosmosContextProvider(container_client=container, search_mode=CosmosContextSearchMode.FULL_TEXT) + session = AgentSession(session_id="s") + state = session.state.setdefault(provider.source_id, {}) + context = SessionContext( + input_messages=[ + Message(role="user", contents=["Tell me about Cosmos"]), + Message(role="system", contents=["ignored"]), + Message(role="assistant", contents=["Vector or hybrid?"]), + Message(role="user", contents=["Hybrid"]), + ], + session_id="s1", + ) + + await provider.before_run(agent=None, session=session, context=context, state=state) # type: ignore[arg-type] + + assert state["query_text"] == "Tell me about Cosmos\nVector or hybrid?\nHybrid" + + async def test_vector_mode_works_with_non_lexical_input(self) -> None: + container = MagicMock() + container.query_items = MagicMock(return_value=_to_async_iter([{"content": "Emoji result"}])) + provider = CosmosContextProvider(container_client=container, embedding_function=_stub_embed) + session = AgentSession(session_id="s") + context = SessionContext(input_messages=[Message(role="user", contents=["🔎"])], session_id="s1") + + await provider.before_run( + agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + + container.query_items.assert_called_once() + + async def test_hybrid_skips_when_no_text_terms(self) -> None: + container = MagicMock() + provider = CosmosContextProvider( + container_client=container, + embedding_function=_stub_embed, + search_mode=CosmosContextSearchMode.HYBRID, + ) + session = AgentSession(session_id="s") + context = SessionContext(input_messages=[Message(role="user", contents=["🔎"])], session_id="s1") + + await provider.before_run( + agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + + container.query_items.assert_not_called() + + async def test_message_field_deserialized_when_valid(self) -> None: + container = MagicMock() + container.query_items = MagicMock( + return_value=_to_async_iter([{"message": {"bad": "payload"}, "content": "Fallback content."}]) + ) + provider = CosmosContextProvider(container_client=container, search_mode=CosmosContextSearchMode.FULL_TEXT) + session = AgentSession(session_id="s") + context = SessionContext(input_messages=[Message(role="user", contents=["find stuff"])], session_id="s1") + + await provider.before_run( + agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + + assert "Fallback content." in context.context_messages[provider.source_id][1].text + + async def test_container_resolved_from_database_client(self) -> None: + container = MagicMock() + container.query_items = MagicMock(return_value=_to_async_iter([{"text": "Result."}])) + database_client = MagicMock() + database_client.get_container_client.return_value = container + cosmos_client = MagicMock() + cosmos_client.get_database_client.return_value = database_client + + provider = CosmosContextProvider( + cosmos_client=cosmos_client, + database_name="db1", + container_name="knowledge", + search_mode=CosmosContextSearchMode.FULL_TEXT, + ) + session = AgentSession(session_id="s") + context = SessionContext(input_messages=[Message(role="user", contents=["search"])], session_id="s1") + + await provider.before_run( + agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + + database_client.get_container_client.assert_called_once_with("knowledge") + + +class TestAfterRun: + async def test_writeback_stores_messages(self) -> None: + container = MagicMock() + container.upsert_item = AsyncMock() + provider = CosmosContextProvider(container_client=container, search_mode=CosmosContextSearchMode.FULL_TEXT) + session = AgentSession(session_id="s") + context = SessionContext(input_messages=[Message(role="user", contents=["hello"])], session_id="s1") + context._response = AgentResponse(messages=[Message(role="assistant", contents=["hi"])]) + + await provider.after_run( + agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + + assert container.upsert_item.await_count == 2 + first = container.upsert_item.await_args_list[0].args[0] + second = container.upsert_item.await_args_list[1].args[0] + assert first["session_id"] == "s1" + assert first["content"] == "hello" + assert second["content"] == "hi" + assert "document_type" not in first + + async def test_excludes_system_messages(self) -> None: + container = MagicMock() + container.upsert_item = AsyncMock() + provider = CosmosContextProvider(container_client=container, search_mode=CosmosContextSearchMode.FULL_TEXT) + session = AgentSession(session_id="s") + context = SessionContext( + input_messages=[ + Message(role="system", contents=["You are helpful."]), + Message(role="user", contents=["hello"]), + ], + session_id="s1", + ) + context._response = AgentResponse(messages=[Message(role="assistant", contents=["hi"])]) + + await provider.after_run( + agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + + assert container.upsert_item.await_count == 2 + roles = [call.args[0]["role"] for call in container.upsert_item.await_args_list] + assert "system" not in roles + assert roles == ["user", "assistant"] + + async def test_writeback_includes_embedding_when_function_available(self) -> None: + container = MagicMock() + container.upsert_item = AsyncMock() + provider = CosmosContextProvider( + container_client=container, + embedding_function=_stub_embed, + search_mode=CosmosContextSearchMode.VECTOR, + ) + session = AgentSession(session_id="s") + context = SessionContext(input_messages=[Message(role="user", contents=["hello"])], session_id="s1") + + await provider.after_run( + agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + + doc = container.upsert_item.await_args.args[0] + assert "embedding" in doc + assert doc["embedding"] == [1.0, 0.0] + + async def test_writeback_skips_embedding_when_no_function(self) -> None: + container = MagicMock() + container.upsert_item = AsyncMock() + provider = CosmosContextProvider(container_client=container, search_mode=CosmosContextSearchMode.FULL_TEXT) + session = AgentSession(session_id="s") + context = SessionContext(input_messages=[Message(role="user", contents=["hello"])], session_id="s1") + + await provider.after_run( + agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + + doc = container.upsert_item.await_args.args[0] + assert "embedding" not in doc + + async def test_writeback_includes_agent_and_user_metadata(self) -> None: + container = MagicMock() + container.upsert_item = AsyncMock() + provider = CosmosContextProvider(container_client=container, search_mode=CosmosContextSearchMode.FULL_TEXT) + session = AgentSession(session_id="s") + agent = MagicMock() + agent.name = "test-agent" + context = SessionContext( + input_messages=[Message(role="user", contents=["hello"])], + session_id="s1", + metadata={"user_id": "user-42"}, + ) + + await provider.after_run( + agent=agent, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) + ) + + doc = container.upsert_item.await_args.args[0] + assert doc["agent_name"] == "test-agent" + assert doc["user_id"] == "user-42" + + async def test_writeback_omits_metadata_when_not_available(self) -> None: + container = MagicMock() + container.upsert_item = AsyncMock() + provider = CosmosContextProvider(container_client=container, search_mode=CosmosContextSearchMode.FULL_TEXT) + session = AgentSession(session_id="s") + + context = SessionContext(input_messages=[Message(role="user", contents=["hello"])], session_id="s1") + + await provider.after_run( + agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + + doc = container.upsert_item.await_args.args[0] + assert "agent_name" not in doc + assert "user_id" not in doc + + async def test_writeback_includes_partition_key(self) -> None: + container = MagicMock() + container.upsert_item = AsyncMock() + provider = CosmosContextProvider( + container_client=container, + partition_key="tenant-a", + search_mode=CosmosContextSearchMode.FULL_TEXT, + ) + session = AgentSession(session_id="s") + context = SessionContext(input_messages=[Message(role="user", contents=["hello"])], session_id="s1") + + await provider.after_run( + agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + + doc = container.upsert_item.await_args.args[0] + assert doc["partition_key"] == "tenant-a" + + async def test_writeback_omits_partition_key_when_not_set(self) -> None: + container = MagicMock() + container.upsert_item = AsyncMock() + provider = CosmosContextProvider(container_client=container, search_mode=CosmosContextSearchMode.FULL_TEXT) + session = AgentSession(session_id="s") + context = SessionContext(input_messages=[Message(role="user", contents=["hello"])], session_id="s1") + + await provider.after_run( + agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + + doc = container.upsert_item.await_args.args[0] + assert "partition_key" not in doc + + async def test_writeback_continues_when_embedding_fails(self) -> None: + async def _failing_embed(_: str) -> list[float]: + raise RuntimeError("embedding service down") + + container = MagicMock() + container.upsert_item = AsyncMock() + provider = CosmosContextProvider( + container_client=container, + embedding_function=_failing_embed, + search_mode=CosmosContextSearchMode.VECTOR, + ) + session = AgentSession(session_id="s") + context = SessionContext(input_messages=[Message(role="user", contents=["hello"])], session_id="s1") + + await provider.after_run( + agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + + assert container.upsert_item.await_count == 1 + doc = container.upsert_item.await_args.args[0] + assert doc["content"] == "hello" + assert "embedding" not in doc + + async def test_generates_session_id_when_missing(self, caplog: pytest.LogCaptureFixture) -> None: + container = MagicMock() + container.upsert_item = AsyncMock() + provider = CosmosContextProvider(container_client=container, search_mode=CosmosContextSearchMode.FULL_TEXT) + session = AgentSession(session_id="s") + context = SessionContext(input_messages=[Message(role="user", contents=["hello"])], session_id=None) + + with caplog.at_level("WARNING"): + await provider.after_run( + agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) + ) # type: ignore[arg-type] + + doc = container.upsert_item.await_args.args[0] + uuid.UUID(doc["session_id"]) + assert "session_id" in caplog.text + + +class TestLifecycle: + async def test_close_closes_owned_client(self, monkeypatch: pytest.MonkeyPatch) -> None: + database_client = MagicMock() + cosmos_client = MagicMock() + cosmos_client.get_database_client.return_value = database_client + cosmos_client.close = AsyncMock() + cosmos_client_factory = MagicMock(return_value=cosmos_client) + + monkeypatch.setattr(context_provider_module, "CosmosClient", cosmos_client_factory) + + provider = CosmosContextProvider( + endpoint="https://account.documents.azure.com:443/", + credential="test-key", + database_name="db1", + container_name="knowledge", + search_mode=CosmosContextSearchMode.FULL_TEXT, + ) + + await provider.close() + cosmos_client.close.assert_awaited_once()