Skip to content

Commit

Permalink
Merge pull request #331 from PrefectHQ/chroma-cloud
Browse files Browse the repository at this point in the history
Add chroma cloud configs
  • Loading branch information
jlowin authored Sep 24, 2024
2 parents 176a4bd + c3408b0 commit ea1696e
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 6 deletions.
5 changes: 3 additions & 2 deletions docs/patterns/memory.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,11 @@ You can install the dependencies for a provider with pip, for example `pip insta

For straightforward provider configurations, you can pass a string value to the `provider` parameter that will instantiate a provider with default settings. The following strings are recognized:

|Provider | Provider string | Description |
|Provider | Provider string | Description |
| -------- | -------- | ----------------- |
| Chroma | `chroma-ephemeral` | An ephemeral (in-memory) database. |
| Chroma | `chroma-db` | A persistent, local-file-based database, with a default path of `~/.controlflow/memory/chroma`. |
| Chroma | `chroma-db` | Uses a persistent, local-file-based database, with a default path of `~/.controlflow/memory/chroma`. |
| Chroma | `chroma-cloud` | Uses the Chroma Cloud service. The `CONTROLFLOW_CHROMA_CLOUD_API_KEY`, `CONTROLFLOW_CHROMA_CLOUD_TENANT`, and `CONTROLFLOW_CHROMA_CLOUD_DATABASE` settings are required. |

For example, if `chromadb` is installed, the following code will create a memory module that uses an ephemeral Chroma database:

Expand Down
6 changes: 4 additions & 2 deletions src/controlflow/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,10 @@ def get_memory_provider(provider: str) -> MemoryProvider:
import controlflow.memory.providers.chroma as chroma_providers

if provider == "chroma-ephemeral":
return chroma_providers.EphemeralChromaMemory()
return chroma_providers.ChromaEphemeralMemory()
elif provider == "chroma-db":
return chroma_providers.PersistentChromaMemory()
return chroma_providers.ChromaPersistentMemory()
elif provider == "chroma-cloud":
return chroma_providers.ChromaCloudMemory()

raise ValueError(f'Memory provider "{provider}" could not be loaded from a string.')
20 changes: 18 additions & 2 deletions src/controlflow/memory/providers/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,30 @@ def search(self, memory_key: str, query: str, n: int = 20) -> Dict[str, str]:
return dict(zip(results["ids"][0], results["documents"][0]))


def EphemeralChromaMemory(**kwargs) -> ChromaMemory:
def ChromaEphemeralMemory(**kwargs) -> ChromaMemory:
return ChromaMemory(client=chromadb.EphemeralClient(**kwargs))


def PersistentChromaMemory(path: str = None, **kwargs) -> ChromaMemory:
def ChromaPersistentMemory(path: str = None, **kwargs) -> ChromaMemory:
return ChromaMemory(
client=chromadb.PersistentClient(
path=path or str(controlflow.settings.home_path / "memory/chroma"),
**kwargs,
)
)


def ChromaCloudMemory(
tenant: Optional[str] = None,
database: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs,
) -> ChromaMemory:
return ChromaMemory(
client=chromadb.CloudClient(
api_key=api_key or controlflow.settings.chroma_cloud_api_key,
tenant=tenant or controlflow.settings.chroma_cloud_tenant,
database=database or controlflow.settings.chroma_cloud_database,
**kwargs,
)
)
14 changes: 14 additions & 0 deletions src/controlflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,20 @@ class Settings(ControlFlowSettings):
description="The default memory provider for agents.",
)

# ------------ Memory settings: ChromaDB ------------
chroma_cloud_tenant: Optional[str] = Field(
None,
description="The tenant for Chroma Cloud.",
)
chroma_cloud_database: Optional[str] = Field(
None,
description="The database for Chroma Cloud.",
)
chroma_cloud_api_key: Optional[str] = Field(
None,
description="The API key for Chroma Cloud.",
)

# ------------ Debug settings ------------

debug_messages: bool = Field(
Expand Down

0 comments on commit ea1696e

Please sign in to comment.