Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed Sep 23, 2024
1 parent a802baa commit ecc5ebd
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 6 deletions.
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,18 @@ Code = "https://github.com/PrefectHQ/ControlFlow"

[project.optional-dependencies]
tests = [
"chromadb",
"duckduckgo-search",
"langchain_community",
"langchain_google_genai",
"langchain_groq",
"pytest-asyncio>=0.18.2,!=0.22.0,<0.23.0",
"pytest-env>=0.8,<2.0",
"pytest-rerunfailures>=10,<14",
"pytest-sugar>=0.9,<2.0",
"pytest>=7.0",
"pytest-timeout",
"pytest-xdist",
"langchain_community",
"langchain_google_genai",
"langchain_groq",
"duckduckgo-search",
]
dev = [
"controlflow[tests]",
Expand Down
4 changes: 2 additions & 2 deletions src/controlflow/memory/providers/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def search(self, memory_key: str, query: str, n: int = 20) -> Dict[str, str]:
return dict(zip(results["ids"][0], results["documents"][0]))


def EphemeralChromaMemory() -> ChromaMemory:
return ChromaMemory(client=chromadb.EphemeralClient())
def EphemeralChromaMemory(**kwargs) -> ChromaMemory:
return ChromaMemory(client=chromadb.EphemeralClient(**kwargs))


def PersistentChromaMemory(path: str = None, **kwargs) -> ChromaMemory:
Expand Down
17 changes: 17 additions & 0 deletions tests/fixtures/controlflow.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import chromadb
import pytest

import controlflow
from controlflow.events.history import InMemoryHistory
from controlflow.llm.messages import BaseMessage
from controlflow.memory.providers.chroma import ChromaMemory, EphemeralChromaMemory
from controlflow.settings import temporary_settings
from controlflow.utilities.testing import FakeLLM

EPHEMERAL_CHROMA_CLIENT = chromadb.EphemeralClient(
settings=chromadb.Settings(allow_reset=True)
)


@pytest.fixture(autouse=True, scope="session")
def temp_controlflow_settings():
Expand All @@ -25,6 +31,7 @@ def reset_settings_after_each_test():
yield


@pytest.fixture(autouse=True)
def temp_controlflow_defaults(monkeypatch):
# use in-memory history
monkeypatch.setattr(
Expand All @@ -33,6 +40,16 @@ def temp_controlflow_defaults(monkeypatch):
InMemoryHistory(),
)

monkeypatch.setattr(
controlflow.defaults,
"memory_provider",
ChromaMemory(client=EPHEMERAL_CHROMA_CLIENT),
)

yield

EPHEMERAL_CHROMA_CLIENT.reset()


@pytest.fixture(autouse=True)
def reset_defaults_after_each_test(monkeypatch):
Expand Down
Empty file added tests/memory/__init__.py
Empty file.
47 changes: 47 additions & 0 deletions tests/memory/test_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import chromadb
import pytest

import controlflow
from controlflow.memory.providers.chroma import ChromaMemory


class TestMemory:
def test_store_and_retrieve(self):
m = controlflow.Memory(key="test", instructions="test")
m.add("The number is 42")
result = m.search("numbers")
assert len(result) == 1
assert "The number is 42" in result.values()

def test_delete(self):
m = controlflow.Memory(key="test", instructions="test")
m_id = m.add("The number is 42")
m.delete(m_id)
result = m.search("numbers")
assert len(result) == 0

def test_search(self):
m = controlflow.Memory(key="test", instructions="test")
m.add("The number is 42")
m.add("The number is 43")
result = m.search("numbers")
assert len(result) == 2
assert "The number is 42" in result.values()
assert "The number is 43" in result.values()


class TestMemoryProvider:
def test_load_from_string_invalid(self):
with pytest.raises(ValueError):
controlflow.Memory(key="test", instructions="test", provider="invalid")

def test_load_from_string_chroma_db(self):
m = controlflow.Memory(key="test", instructions="test", provider="chroma-db")
assert isinstance(m.provider, ChromaMemory)
assert m.provider.client.path == str(
controlflow.settings.home_path / "memory/chroma"
)

def test_load_from_instance(self):
mp = ChromaMemory(client=chromadb.PersistentClient(path="test_path"))
m = controlflow.Memory(key="test", instructions="test", provider=mp)

0 comments on commit ecc5ebd

Please sign in to comment.