diff --git a/docs/patterns/memory.mdx b/docs/patterns/memory.mdx index 4942a5d9..ec4c9979 100644 --- a/docs/patterns/memory.mdx +++ b/docs/patterns/memory.mdx @@ -138,10 +138,11 @@ You can install the dependencies for a provider with pip, for example `pip insta For straightforward provider configurations, you can pass a string value to the `provider` parameter that will instantiate a provider with default settings. The following strings are recognized: -|Provider | Provider string | Description | +|Provider | Provider string | Description | | -------- | -------- | ----------------- | | Chroma | `chroma-ephemeral` | An ephemeral (in-memory) database. | -| Chroma | `chroma-db` | A persistent, local-file-based database, with a default path of `~/.controlflow/memory/chroma`. | +| Chroma | `chroma-db` | Uses a persistent, local-file-based database, with a default path of `~/.controlflow/memory/chroma`. | +| Chroma | `chroma-cloud` | Uses the Chroma Cloud service. The `CONTROLFLOW_CHROMA_CLOUD_API_KEY`, `CONTROLFLOW_CHROMA_CLOUD_TENANT`, and `CONTROLFLOW_CHROMA_CLOUD_DATABASE` settings are required. | For example, if `chromadb` is installed, the following code will create a memory module that uses an ephemeral Chroma database: diff --git a/src/controlflow/memory/memory.py b/src/controlflow/memory/memory.py index cccf55fa..78dc4eab 100644 --- a/src/controlflow/memory/memory.py +++ b/src/controlflow/memory/memory.py @@ -141,8 +141,10 @@ def get_memory_provider(provider: str) -> MemoryProvider: import controlflow.memory.providers.chroma as chroma_providers if provider == "chroma-ephemeral": - return chroma_providers.EphemeralChromaMemory() + return chroma_providers.ChromaEphemeralMemory() elif provider == "chroma-db": - return chroma_providers.PersistentChromaMemory() + return chroma_providers.ChromaPersistentMemory() + elif provider == "chroma-cloud": + return chroma_providers.ChromaCloudMemory() raise ValueError(f'Memory provider "{provider}" could not be loaded from a string.') diff --git a/src/controlflow/memory/providers/chroma.py b/src/controlflow/memory/providers/chroma.py index 17fd29bb..875dacb7 100644 --- a/src/controlflow/memory/providers/chroma.py +++ b/src/controlflow/memory/providers/chroma.py @@ -52,14 +52,30 @@ def search(self, memory_key: str, query: str, n: int = 20) -> Dict[str, str]: return dict(zip(results["ids"][0], results["documents"][0])) -def EphemeralChromaMemory(**kwargs) -> ChromaMemory: +def ChromaEphemeralMemory(**kwargs) -> ChromaMemory: return ChromaMemory(client=chromadb.EphemeralClient(**kwargs)) -def PersistentChromaMemory(path: str = None, **kwargs) -> ChromaMemory: +def ChromaPersistentMemory(path: str = None, **kwargs) -> ChromaMemory: return ChromaMemory( client=chromadb.PersistentClient( path=path or str(controlflow.settings.home_path / "memory/chroma"), **kwargs, ) ) + + +def ChromaCloudMemory( + tenant: Optional[str] = None, + database: Optional[str] = None, + api_key: Optional[str] = None, + **kwargs, +) -> ChromaMemory: + return ChromaMemory( + client=chromadb.CloudClient( + api_key=api_key or controlflow.settings.chroma_cloud_api_key, + tenant=tenant or controlflow.settings.chroma_cloud_tenant, + database=database or controlflow.settings.chroma_cloud_database, + **kwargs, + ) + ) diff --git a/src/controlflow/settings.py b/src/controlflow/settings.py index 3f962aa0..08a6450c 100644 --- a/src/controlflow/settings.py +++ b/src/controlflow/settings.py @@ -88,6 +88,20 @@ class Settings(ControlFlowSettings): description="The default memory provider for agents.", ) + # ------------ Memory settings: ChromaDB ------------ + chroma_cloud_tenant: Optional[str] = Field( + None, + description="The tenant for Chroma Cloud.", + ) + chroma_cloud_database: Optional[str] = Field( + None, + description="The database for Chroma Cloud.", + ) + chroma_cloud_api_key: Optional[str] = Field( + None, + description="The API key for Chroma Cloud.", + ) + # ------------ Debug settings ------------ debug_messages: bool = Field(