Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Changed
- **agent-framework-azure-cosmos**: [BREAKING] `CosmosCheckpointStorage` now uses restricted pickle deserialization by default, matching `FileCheckpointStorage` behavior. If your checkpoints contain application-defined types, pass them via `allowed_checkpoint_types=["my_app.models:MyState"]`. ([#5200](https://github.com/microsoft/agent-framework/issues/5200))

## [1.0.1] - 2026-04-09

### Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,34 @@ class CosmosCheckpointStorage:
``FileCheckpointStorage``, allowing full Python object fidelity for
complex workflow state while keeping the document structure human-readable.

SECURITY WARNING: Checkpoints use pickle for data serialization. Only load
checkpoints from trusted sources. Loading a malicious checkpoint can execute
arbitrary code.
Security warning: checkpoints use pickle for non-JSON-native values. Loading
checkpoints from untrusted sources is unsafe and can execute arbitrary code
during deserialization. The built-in deserialization restrictions reduce risk,
but they do not make untrusted checkpoints safe to load. Extending
``allowed_checkpoint_types`` may further increase risk and should only be done
for trusted application types.

By default, checkpoint deserialization is restricted to a built-in set of safe
Python types (primitives, datetime, uuid, ...) and all ``agent_framework``
internal types. To allow additional application-specific types, pass them via
the ``allowed_checkpoint_types`` parameter using ``"module:qualname"`` format.

Example:

.. code-block:: python

from azure.identity.aio import DefaultAzureCredential
from agent_framework_azure_cosmos import CosmosCheckpointStorage

storage = CosmosCheckpointStorage(
endpoint="https://my-account.documents.azure.com:443/",
credential=DefaultAzureCredential(),
database_name="agent-db",
container_name="checkpoints",
allowed_checkpoint_types=[
"my_app.models:MyState",
],
)

The database and container are created automatically on first use
if they do not already exist. The container uses partition key
Expand Down Expand Up @@ -97,6 +122,7 @@ def __init__(
container_client: ContainerProxy | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
allowed_checkpoint_types: list[str] | None = None,
) -> None:
"""Initialize the Azure Cosmos DB checkpoint storage.

Expand Down Expand Up @@ -129,10 +155,15 @@ def __init__(
container_client: Pre-created Cosmos container client.
env_file_path: Path to environment file for loading settings.
env_file_encoding: Encoding of the environment file.
allowed_checkpoint_types: Additional types (beyond the built-in safe set
and framework types) that are permitted during checkpoint
deserialization. Each entry should be a ``"module:qualname"``
string (e.g., ``"my_app.models:MyState"``).
"""
self._cosmos_client: CosmosClient | None = cosmos_client
self._container_proxy: ContainerProxy | None = container_client
self._owns_client = False
self._allowed_types: frozenset[str] = frozenset(allowed_checkpoint_types or [])

if self._container_proxy is not None:
self.database_name: str = database_name or ""
Expand Down Expand Up @@ -401,8 +432,7 @@ async def _ensure_container_proxy(self) -> None:
partition_key=PartitionKey(path="/workflow_name"),
)

@staticmethod
def _document_to_checkpoint(document: dict[str, Any]) -> WorkflowCheckpoint:
def _document_to_checkpoint(self, document: dict[str, Any]) -> WorkflowCheckpoint:
"""Convert a Cosmos DB document back to a WorkflowCheckpoint.

Strips Cosmos DB system properties (``_rid``, ``_self``, ``_etag``,
Expand All @@ -413,7 +443,7 @@ def _document_to_checkpoint(document: dict[str, Any]) -> WorkflowCheckpoint:
cosmos_keys = {"id", "_rid", "_self", "_etag", "_attachments", "_ts"}
cleaned = {k: v for k, v in document.items() if k not in cosmos_keys}

decoded = decode_checkpoint_value(cleaned)
decoded = decode_checkpoint_value(cleaned, allowed_types=self._allowed_types)
return WorkflowCheckpoint.from_dict(decoded)

@staticmethod
Expand Down
140 changes: 140 additions & 0 deletions python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import uuid
from collections.abc import AsyncIterator
from contextlib import suppress
from dataclasses import dataclass
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch

Expand Down Expand Up @@ -595,3 +596,142 @@ async def test_cosmos_checkpoint_storage_roundtrip_with_emulator() -> None:
finally:
with suppress(Exception):
await cosmos_client.delete_database(database_name)


# --- Tests for allowed_checkpoint_types ---


@dataclass
class _AppState:
"""Application-defined state type used to test allowed_checkpoint_types."""

label: str
count: int


_APP_STATE_TYPE_KEY = f"{_AppState.__module__}:{_AppState.__qualname__}"


def _make_checkpoint_with_state(state: dict[str, Any]) -> WorkflowCheckpoint:
"""Create a checkpoint with custom state for serialization tests."""
return WorkflowCheckpoint(
workflow_name="test-workflow",
graph_signature_hash="abc123",
timestamp="2025-01-01T00:00:00+00:00",
state=state,
iteration_count=1,
)


async def test_init_accepts_allowed_checkpoint_types(mock_container: MagicMock) -> None:
"""CosmosCheckpointStorage.__init__ accepts allowed_checkpoint_types."""
storage = CosmosCheckpointStorage(
container_client=mock_container,
allowed_checkpoint_types=["some.module:SomeType"],
)
assert storage is not None


async def test_load_allows_builtin_safe_types(mock_container: MagicMock) -> None:
"""Built-in safe types load without opt-in via allowed_checkpoint_types."""
from datetime import datetime, timezone

checkpoint = _make_checkpoint_with_state({
"ts": datetime(2025, 1, 1, tzinfo=timezone.utc),
"tags": {1, 2, 3},
})
doc = _checkpoint_to_cosmos_document(checkpoint)
mock_container.query_items.return_value = _to_async_iter([doc])

storage = CosmosCheckpointStorage(container_client=mock_container)
loaded = await storage.load(checkpoint.checkpoint_id)

assert loaded.state["ts"] == datetime(2025, 1, 1, tzinfo=timezone.utc)
assert loaded.state["tags"] == {1, 2, 3}


async def test_load_blocks_unlisted_app_type(mock_container: MagicMock) -> None:
"""Application types are blocked when not listed in allowed_checkpoint_types."""
checkpoint = _make_checkpoint_with_state({"data": _AppState(label="x", count=1)})
doc = _checkpoint_to_cosmos_document(checkpoint)
mock_container.query_items.return_value = _to_async_iter([doc])

storage = CosmosCheckpointStorage(container_client=mock_container)

with pytest.raises(WorkflowCheckpointException, match="deserialization blocked"):
await storage.load(checkpoint.checkpoint_id)


async def test_load_allows_listed_app_type(mock_container: MagicMock) -> None:
"""Application types are allowed when listed in allowed_checkpoint_types."""
checkpoint = _make_checkpoint_with_state({"data": _AppState(label="ok", count=7)})
doc = _checkpoint_to_cosmos_document(checkpoint)
mock_container.query_items.return_value = _to_async_iter([doc])

storage = CosmosCheckpointStorage(
container_client=mock_container,
allowed_checkpoint_types=[_APP_STATE_TYPE_KEY],
)
loaded = await storage.load(checkpoint.checkpoint_id)

assert isinstance(loaded.state["data"], _AppState)
assert loaded.state["data"].label == "ok"
assert loaded.state["data"].count == 7


async def test_list_checkpoints_blocks_unlisted_app_type(mock_container: MagicMock) -> None:
"""list_checkpoints skips documents with unlisted application types."""
checkpoint = _make_checkpoint_with_state({"data": _AppState(label="x", count=1)})
doc = _checkpoint_to_cosmos_document(checkpoint)
mock_container.query_items.return_value = _to_async_iter([doc])

storage = CosmosCheckpointStorage(container_client=mock_container)
results = await storage.list_checkpoints(workflow_name="test-workflow")

# The document is skipped (logged as warning) because the type is blocked
assert len(results) == 0


async def test_list_checkpoints_allows_listed_app_type(mock_container: MagicMock) -> None:
"""list_checkpoints decodes documents with listed application types."""
checkpoint = _make_checkpoint_with_state({"data": _AppState(label="ok", count=3)})
doc = _checkpoint_to_cosmos_document(checkpoint)
mock_container.query_items.return_value = _to_async_iter([doc])

storage = CosmosCheckpointStorage(
container_client=mock_container,
allowed_checkpoint_types=[_APP_STATE_TYPE_KEY],
)
results = await storage.list_checkpoints(workflow_name="test-workflow")

assert len(results) == 1
assert isinstance(results[0].state["data"], _AppState)


async def test_get_latest_blocks_unlisted_app_type(mock_container: MagicMock) -> None:
"""get_latest raises when the checkpoint contains an unlisted application type."""
checkpoint = _make_checkpoint_with_state({"data": _AppState(label="x", count=1)})
doc = _checkpoint_to_cosmos_document(checkpoint)
mock_container.query_items.return_value = _to_async_iter([doc])

storage = CosmosCheckpointStorage(container_client=mock_container)

with pytest.raises(WorkflowCheckpointException, match="deserialization blocked"):
await storage.get_latest(workflow_name="test-workflow")


async def test_get_latest_allows_listed_app_type(mock_container: MagicMock) -> None:
"""get_latest decodes checkpoints with listed application types."""
checkpoint = _make_checkpoint_with_state({"data": _AppState(label="latest", count=9)})
doc = _checkpoint_to_cosmos_document(checkpoint)
mock_container.query_items.return_value = _to_async_iter([doc])

storage = CosmosCheckpointStorage(
container_client=mock_container,
allowed_checkpoint_types=[_APP_STATE_TYPE_KEY],
)
result = await storage.get_latest(workflow_name="test-workflow")

assert result is not None
assert isinstance(result.state["data"], _AppState)
assert result.state["data"].label == "latest"
Loading