Python: Add Cosmos DB NoSQL Checkpoint Storage for Python Workflows#4916
Python: Add Cosmos DB NoSQL Checkpoint Storage for Python Workflows#4916aayush3011 wants to merge 3 commits intomicrosoft:mainfrom
Conversation
Add native Cosmos DB NoSQL support for workflow checkpoint storage in the Python agent-framework-azure-cosmos package, achieving parity with the existing .NET CosmosCheckpointStore. New files: - _checkpoint_storage.py: CosmosCheckpointStorage implementing the CheckpointStorage protocol with 6 methods (save, load, list_checkpoints, delete, get_latest, list_checkpoint_ids) - test_cosmos_checkpoint_storage.py: Unit and integration tests - workflow_checkpointing.py: Sample demonstrating Cosmos DB-backed workflow checkpoint/resume Auth support: - Managed identity / RBAC via Azure credential objects (DefaultAzureCredential, ManagedIdentityCredential, etc.) - Key-based auth via account key string or AZURE_COSMOS_KEY env var - Pre-created CosmosClient or ContainerProxy Key design decisions: - Partition key: /workflow_name for efficient per-workflow queries - Serialization: Reuses encode/decode_checkpoint_value for full Python object fidelity (hybrid JSON + pickle approach) - Container auto-creation via create_container_if_not_exists Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Adds a Cosmos DB (NoSQL) checkpoint storage backend to the Python agent-framework-azure-cosmos package to enable durable workflow pause/resume (feature-parity with the .NET Cosmos checkpoint store).
Changes:
- Introduces
CosmosCheckpointStorageimplementing workflow checkpoint persistence in Cosmos DB (auto-creates DB/container, partitions byworkflow_name). - Adds unit + integration tests covering the checkpoint storage behavior.
- Adds runnable samples + README updates showing Cosmos-backed workflow checkpointing (standalone and Azure AI Foundry).
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py | Implements Cosmos-backed checkpoint storage (save/load/list/delete/latest/ids). |
| python/packages/azure-cosmos/agent_framework_azure_cosmos/init.py | Exposes CosmosCheckpointStorage from the package. |
| python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py | Adds unit tests and an integration round-trip test for the new storage. |
| python/packages/azure-cosmos/samples/cosmos_workflow_checkpointing.py | Standalone workflow sample using Cosmos-backed checkpointing. |
| python/packages/azure-cosmos/samples/cosmos_workflow_checkpointing_foundry.py | Foundry multi-agent workflow sample using Cosmos checkpoint storage. |
| python/packages/azure-cosmos/samples/README.md | Documents the new samples. |
| python/packages/azure-cosmos/README.md | Documents CosmosCheckpointStorage usage and configuration. |
| python/packages/azure-cosmos/pyproject.toml | Extends the integration test task to include the new integration test. |
| document: dict[str, Any] = { | ||
| "id": checkpoint.checkpoint_id, | ||
| "workflow_name": checkpoint.workflow_name, | ||
| **encoded, | ||
| } |
There was a problem hiding this comment.
Cosmos id uniqueness is scoped to a partition key value. Since the container partitions on workflow_name, the same checkpoint.checkpoint_id can exist in multiple workflows without overwriting, which can make load() / delete() ambiguous (they query by id across partitions and return the first match). Consider enforcing global uniqueness by storing a composite id (e.g., include workflow_name), or by proactively deleting any existing documents with the same checkpoint_id across partitions before upserting, and/or detecting multiple matches and raising a clear error.
| "id": checkpoint.checkpoint_id, | ||
| "workflow_name": checkpoint.workflow_name, |
There was a problem hiding this comment.
Cosmos DB places constraints on the id field (e.g., certain characters are not allowed and there are length limits). Since checkpoint.checkpoint_id is written directly into document['id'], a caller-supplied checkpoint_id that violates Cosmos constraints will fail at runtime with a Cosmos SDK error. Consider validating/sanitizing checkpoint IDs up-front and raising a WorkflowCheckpointException with a helpful message to match the framework’s checkpoint error semantics.
| # Authentication: use key if available, otherwise fall back to Azure credential (RBAC) | ||
| credential: Any | ||
| if cosmos_key: | ||
| credential = cosmos_key | ||
| else: | ||
| credential = AzureCliCredential() | ||
|
|
||
| async with CosmosCheckpointStorage( | ||
| endpoint=cosmos_endpoint, | ||
| credential=credential, | ||
| database_name=cosmos_database_name, | ||
| container_name=cosmos_container_name, | ||
| ) as checkpoint_storage: | ||
| # Create Azure AI Foundry agents | ||
| client = AzureOpenAIResponsesClient( | ||
| project_endpoint=project_endpoint, | ||
| deployment_name=deployment_name, | ||
| credential=AzureCliCredential(), | ||
| ) | ||
|
|
||
| assistant = client.as_agent( | ||
| name="assistant", | ||
| instructions="You are a helpful assistant. Keep responses brief.", | ||
| ) |
There was a problem hiding this comment.
AzureCliCredential (aio) should be closed to avoid unclosed session warnings/leaks. This sample constructs AzureCliCredential() instances without an async with (and creates two separate credentials). Consider using async with AzureCliCredential() as credential: and reusing that instance for both Cosmos checkpoint storage (when no key is provided) and the AzureOpenAIResponsesClient.
| # Authentication: use key if available, otherwise fall back to Azure credential (RBAC) | |
| credential: Any | |
| if cosmos_key: | |
| credential = cosmos_key | |
| else: | |
| credential = AzureCliCredential() | |
| async with CosmosCheckpointStorage( | |
| endpoint=cosmos_endpoint, | |
| credential=credential, | |
| database_name=cosmos_database_name, | |
| container_name=cosmos_container_name, | |
| ) as checkpoint_storage: | |
| # Create Azure AI Foundry agents | |
| client = AzureOpenAIResponsesClient( | |
| project_endpoint=project_endpoint, | |
| deployment_name=deployment_name, | |
| credential=AzureCliCredential(), | |
| ) | |
| assistant = client.as_agent( | |
| name="assistant", | |
| instructions="You are a helpful assistant. Keep responses brief.", | |
| ) | |
| # Authentication: use key for Cosmos if available, otherwise fall back to Azure credential (RBAC) | |
| async with AzureCliCredential() as azure_credential: | |
| cosmos_credential: Any | |
| if cosmos_key: | |
| cosmos_credential = cosmos_key | |
| else: | |
| cosmos_credential = azure_credential | |
| async with CosmosCheckpointStorage( | |
| endpoint=cosmos_endpoint, | |
| credential=cosmos_credential, | |
| database_name=cosmos_database_name, | |
| container_name=cosmos_container_name, | |
| ) as checkpoint_storage: | |
| # Create Azure AI Foundry agents | |
| client = AzureOpenAIResponsesClient( | |
| project_endpoint=project_endpoint, | |
| deployment_name=deployment_name, | |
| credential=azure_credential, | |
| ) | |
| assistant = client.as_agent( | |
| name="assistant", | |
| instructions="You are a helpful assistant. Keep responses brief.", | |
| ) |
| credential: Any | ||
| if cosmos_key: | ||
| credential = cosmos_key | ||
| else: | ||
| from azure.identity.aio import DefaultAzureCredential | ||
|
|
||
| credential = DefaultAzureCredential() | ||
|
|
||
| async with CosmosCheckpointStorage( | ||
| endpoint=cosmos_endpoint, | ||
| credential=credential, | ||
| database_name=cosmos_database_name, | ||
| container_name=cosmos_container_name, | ||
| ) as checkpoint_storage: | ||
| # Build workflow with Cosmos DB checkpointing | ||
| start = StartExecutor(id="start") | ||
| worker = WorkerExecutor(id="worker") | ||
| workflow_builder = ( | ||
| WorkflowBuilder(start_executor=start, checkpoint_storage=checkpoint_storage) | ||
| .add_edge(start, worker) | ||
| .add_edge(worker, worker) | ||
| ) | ||
|
|
||
| # --- First run: execute the workflow --- | ||
| print("\n=== First Run ===\n") | ||
| workflow = workflow_builder.build() | ||
|
|
||
| output = None | ||
| async for event in workflow.run(message=8, stream=True): | ||
| if event.type == "output": | ||
| output = event.data | ||
|
|
||
| print(f"Factor pairs computed: {output}") | ||
|
|
||
| # List checkpoints saved in Cosmos DB | ||
| checkpoint_ids = await checkpoint_storage.list_checkpoint_ids( | ||
| workflow_name=workflow.name, | ||
| ) | ||
| print(f"\nCheckpoints in Cosmos DB: {len(checkpoint_ids)}") | ||
| for cid in checkpoint_ids: | ||
| print(f" - {cid}") | ||
|
|
||
| # Get the latest checkpoint | ||
| latest: WorkflowCheckpoint | None = await checkpoint_storage.get_latest( | ||
| workflow_name=workflow.name, | ||
| ) | ||
|
|
||
| if latest is None: | ||
| print("No checkpoint found to resume from.") | ||
| return | ||
|
|
||
| print(f"\nLatest checkpoint: {latest.checkpoint_id}") | ||
| print(f" iteration_count: {latest.iteration_count}") | ||
| print(f" timestamp: {latest.timestamp}") | ||
|
|
||
| # --- Second run: resume from the latest checkpoint --- | ||
| print("\n=== Resuming from Checkpoint ===\n") | ||
| workflow2 = workflow_builder.build() | ||
|
|
||
| output2 = None | ||
| async for event in workflow2.run(checkpoint_id=latest.checkpoint_id, stream=True): | ||
| if event.type == "output": | ||
| output2 = event.data | ||
|
|
||
| if output2: | ||
| print(f"Resumed workflow produced: {output2}") | ||
| else: | ||
| print("Resumed workflow completed (no remaining work — already finished).") | ||
|
|
||
|
|
There was a problem hiding this comment.
DefaultAzureCredential from azure.identity.aio holds network resources that should be closed. This sample creates DefaultAzureCredential() and passes it into CosmosCheckpointStorage without ensuring it’s closed at the end of the run. Consider wrapping the credential in async with DefaultAzureCredential() as credential: (similar to the cosmos_history_provider.py sample) so the underlying transport is cleaned up reliably.
| credential: Any | |
| if cosmos_key: | |
| credential = cosmos_key | |
| else: | |
| from azure.identity.aio import DefaultAzureCredential | |
| credential = DefaultAzureCredential() | |
| async with CosmosCheckpointStorage( | |
| endpoint=cosmos_endpoint, | |
| credential=credential, | |
| database_name=cosmos_database_name, | |
| container_name=cosmos_container_name, | |
| ) as checkpoint_storage: | |
| # Build workflow with Cosmos DB checkpointing | |
| start = StartExecutor(id="start") | |
| worker = WorkerExecutor(id="worker") | |
| workflow_builder = ( | |
| WorkflowBuilder(start_executor=start, checkpoint_storage=checkpoint_storage) | |
| .add_edge(start, worker) | |
| .add_edge(worker, worker) | |
| ) | |
| # --- First run: execute the workflow --- | |
| print("\n=== First Run ===\n") | |
| workflow = workflow_builder.build() | |
| output = None | |
| async for event in workflow.run(message=8, stream=True): | |
| if event.type == "output": | |
| output = event.data | |
| print(f"Factor pairs computed: {output}") | |
| # List checkpoints saved in Cosmos DB | |
| checkpoint_ids = await checkpoint_storage.list_checkpoint_ids( | |
| workflow_name=workflow.name, | |
| ) | |
| print(f"\nCheckpoints in Cosmos DB: {len(checkpoint_ids)}") | |
| for cid in checkpoint_ids: | |
| print(f" - {cid}") | |
| # Get the latest checkpoint | |
| latest: WorkflowCheckpoint | None = await checkpoint_storage.get_latest( | |
| workflow_name=workflow.name, | |
| ) | |
| if latest is None: | |
| print("No checkpoint found to resume from.") | |
| return | |
| print(f"\nLatest checkpoint: {latest.checkpoint_id}") | |
| print(f" iteration_count: {latest.iteration_count}") | |
| print(f" timestamp: {latest.timestamp}") | |
| # --- Second run: resume from the latest checkpoint --- | |
| print("\n=== Resuming from Checkpoint ===\n") | |
| workflow2 = workflow_builder.build() | |
| output2 = None | |
| async for event in workflow2.run(checkpoint_id=latest.checkpoint_id, stream=True): | |
| if event.type == "output": | |
| output2 = event.data | |
| if output2: | |
| print(f"Resumed workflow produced: {output2}") | |
| else: | |
| print("Resumed workflow completed (no remaining work — already finished).") | |
| if cosmos_key: | |
| async with CosmosCheckpointStorage( | |
| endpoint=cosmos_endpoint, | |
| credential=cosmos_key, | |
| database_name=cosmos_database_name, | |
| container_name=cosmos_container_name, | |
| ) as checkpoint_storage: | |
| # Build workflow with Cosmos DB checkpointing | |
| start = StartExecutor(id="start") | |
| worker = WorkerExecutor(id="worker") | |
| workflow_builder = ( | |
| WorkflowBuilder(start_executor=start, checkpoint_storage=checkpoint_storage) | |
| .add_edge(start, worker) | |
| .add_edge(worker, worker) | |
| ) | |
| # --- First run: execute the workflow --- | |
| print("\n=== First Run ===\n") | |
| workflow = workflow_builder.build() | |
| output = None | |
| async for event in workflow.run(message=8, stream=True): | |
| if event.type == "output": | |
| output = event.data | |
| print(f"Factor pairs computed: {output}") | |
| # List checkpoints saved in Cosmos DB | |
| checkpoint_ids = await checkpoint_storage.list_checkpoint_ids( | |
| workflow_name=workflow.name, | |
| ) | |
| print(f"\nCheckpoints in Cosmos DB: {len(checkpoint_ids)}") | |
| for cid in checkpoint_ids: | |
| print(f" - {cid}") | |
| # Get the latest checkpoint | |
| latest: WorkflowCheckpoint | None = await checkpoint_storage.get_latest( | |
| workflow_name=workflow.name, | |
| ) | |
| if latest is None: | |
| print("No checkpoint found to resume from.") | |
| return | |
| print(f"\nLatest checkpoint: {latest.checkpoint_id}") | |
| print(f" iteration_count: {latest.iteration_count}") | |
| print(f" timestamp: {latest.timestamp}") | |
| # --- Second run: resume from the latest checkpoint --- | |
| print("\n=== Resuming from Checkpoint ===\n") | |
| workflow2 = workflow_builder.build() | |
| output2 = None | |
| async for event in workflow2.run( | |
| checkpoint_id=latest.checkpoint_id, | |
| stream=True, | |
| ): | |
| if event.type == "output": | |
| output2 = event.data | |
| if output2: | |
| print(f"Resumed workflow produced: {output2}") | |
| else: | |
| print("Resumed workflow completed (no remaining work — already finished).") | |
| else: | |
| from azure.identity.aio import DefaultAzureCredential | |
| async with DefaultAzureCredential() as credential: | |
| async with CosmosCheckpointStorage( | |
| endpoint=cosmos_endpoint, | |
| credential=credential, | |
| database_name=cosmos_database_name, | |
| container_name=cosmos_container_name, | |
| ) as checkpoint_storage: | |
| # Build workflow with Cosmos DB checkpointing | |
| start = StartExecutor(id="start") | |
| worker = WorkerExecutor(id="worker") | |
| workflow_builder = ( | |
| WorkflowBuilder(start_executor=start, checkpoint_storage=checkpoint_storage) | |
| .add_edge(start, worker) | |
| .add_edge(worker, worker) | |
| ) | |
| # --- First run: execute the workflow --- | |
| print("\n=== First Run ===\n") | |
| workflow = workflow_builder.build() | |
| output = None | |
| async for event in workflow.run(message=8, stream=True): | |
| if event.type == "output": | |
| output = event.data | |
| print(f"Factor pairs computed: {output}") | |
| # List checkpoints saved in Cosmos DB | |
| checkpoint_ids = await checkpoint_storage.list_checkpoint_ids( | |
| workflow_name=workflow.name, | |
| ) | |
| print(f"\nCheckpoints in Cosmos DB: {len(checkpoint_ids)}") | |
| for cid in checkpoint_ids: | |
| print(f" - {cid}") | |
| # Get the latest checkpoint | |
| latest: WorkflowCheckpoint | None = await checkpoint_storage.get_latest( | |
| workflow_name=workflow.name, | |
| ) | |
| if latest is None: | |
| print("No checkpoint found to resume from.") | |
| return | |
| print(f"\nLatest checkpoint: {latest.checkpoint_id}") | |
| print(f" iteration_count: {latest.iteration_count}") | |
| print(f" timestamp: {latest.timestamp}") | |
| # --- Second run: resume from the latest checkpoint --- | |
| print("\n=== Resuming from Checkpoint ===\n") | |
| workflow2 = workflow_builder.build() | |
| output2 = None | |
| async for event in workflow2.run( | |
| checkpoint_id=latest.checkpoint_id, | |
| stream=True, | |
| ): | |
| if event.type == "output": | |
| output2 = event.data | |
| if output2: | |
| print(f"Resumed workflow produced: {output2}") | |
| else: | |
| print("Resumed workflow completed (no remaining work — already finished).") |
| with pytest.raises(WorkflowCheckpointException, match="No checkpoint found"): | ||
| await storage.load("nonexistent-id") | ||
|
|
||
|
|
There was a problem hiding this comment.
There’s no test covering the Cosmos-specific edge case where the same checkpoint_id is saved under different workflow_name partition keys. Because Cosmos id uniqueness is per-partition, this can lead to multiple documents matching a load(checkpoint_id) query. Consider adding a unit test that saves two checkpoints with the same checkpoint_id but different workflow_name and asserts the desired behavior (overwrite, deterministic selection, or explicit error).
| async def test_load_multiple_workflows_same_checkpoint_id_raises( | |
| mock_container: MagicMock, | |
| ) -> None: | |
| checkpoint_id = "shared-id" | |
| cp1 = _make_checkpoint(checkpoint_id=checkpoint_id, workflow_name="workflow-a") | |
| cp2 = _make_checkpoint(checkpoint_id=checkpoint_id, workflow_name="workflow-b") | |
| doc1 = _checkpoint_to_cosmos_document(cp1) | |
| doc2 = _checkpoint_to_cosmos_document(cp2) | |
| mock_container.query_items.return_value = _to_async_iter([doc1, doc2]) | |
| storage = CosmosCheckpointStorage(container_client=mock_container) | |
| # When multiple documents share the same checkpoint_id across different | |
| # workflow_name partitions, load should not silently pick one; it should | |
| # raise a WorkflowCheckpointException. | |
| with pytest.raises(WorkflowCheckpointException): | |
| await storage.load(checkpoint_id) |
Motivation and Context
The .NET implementation of the Agent Framework already ships a native CosmosCheckpointStore for workflow checkpointing, but the Python side only supports in-memory and file-based storage. Cosmos DB customers building agents on Azure AI Foundry have been asking for native Cosmos DB checkpoint support so they can durably pause and resume workflows across process restarts without writing custom storage adapters.
This PR adds CosmosCheckpointStorage to the existing agent-framework-azure-cosmos Python package, achieving feature parity with .NET and enabling Cosmos DB customers to use workflow checkpointing out-of-the-box.
Description
Core implementation (_checkpoint_storage.py):
Tests (test_cosmos_checkpoint_storage.py):
Samples:
Contribution Checklist