From ecc5ebde22352a57533d6fc9b5bb98d0d890b8f3 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Mon, 23 Sep 2024 18:45:23 -0400 Subject: [PATCH] Add tests --- pyproject.toml | 9 +++-- src/controlflow/memory/providers/chroma.py | 4 +- tests/fixtures/controlflow.py | 17 ++++++++ tests/memory/__init__.py | 0 tests/memory/test_memory.py | 47 ++++++++++++++++++++++ 5 files changed, 71 insertions(+), 6 deletions(-) create mode 100644 tests/memory/__init__.py create mode 100644 tests/memory/test_memory.py diff --git a/pyproject.toml b/pyproject.toml index abf5c9b5..a4863ee2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,11 @@ Code = "https://github.com/PrefectHQ/ControlFlow" [project.optional-dependencies] tests = [ + "chromadb", + "duckduckgo-search", + "langchain_community", + "langchain_google_genai", + "langchain_groq", "pytest-asyncio>=0.18.2,!=0.22.0,<0.23.0", "pytest-env>=0.8,<2.0", "pytest-rerunfailures>=10,<14", @@ -51,10 +56,6 @@ tests = [ "pytest>=7.0", "pytest-timeout", "pytest-xdist", - "langchain_community", - "langchain_google_genai", - "langchain_groq", - "duckduckgo-search", ] dev = [ "controlflow[tests]", diff --git a/src/controlflow/memory/providers/chroma.py b/src/controlflow/memory/providers/chroma.py index 139cb590..17fd29bb 100644 --- a/src/controlflow/memory/providers/chroma.py +++ b/src/controlflow/memory/providers/chroma.py @@ -52,8 +52,8 @@ def search(self, memory_key: str, query: str, n: int = 20) -> Dict[str, str]: return dict(zip(results["ids"][0], results["documents"][0])) -def EphemeralChromaMemory() -> ChromaMemory: - return ChromaMemory(client=chromadb.EphemeralClient()) +def EphemeralChromaMemory(**kwargs) -> ChromaMemory: + return ChromaMemory(client=chromadb.EphemeralClient(**kwargs)) def PersistentChromaMemory(path: str = None, **kwargs) -> ChromaMemory: diff --git a/tests/fixtures/controlflow.py b/tests/fixtures/controlflow.py index e9c2f9bc..b3922ece 100644 --- a/tests/fixtures/controlflow.py +++ b/tests/fixtures/controlflow.py @@ -1,11 +1,17 @@ +import chromadb import pytest import controlflow from controlflow.events.history import InMemoryHistory from controlflow.llm.messages import BaseMessage +from controlflow.memory.providers.chroma import ChromaMemory, EphemeralChromaMemory from controlflow.settings import temporary_settings from controlflow.utilities.testing import FakeLLM +EPHEMERAL_CHROMA_CLIENT = chromadb.EphemeralClient( + settings=chromadb.Settings(allow_reset=True) +) + @pytest.fixture(autouse=True, scope="session") def temp_controlflow_settings(): @@ -25,6 +31,7 @@ def reset_settings_after_each_test(): yield +@pytest.fixture(autouse=True) def temp_controlflow_defaults(monkeypatch): # use in-memory history monkeypatch.setattr( @@ -33,6 +40,16 @@ def temp_controlflow_defaults(monkeypatch): InMemoryHistory(), ) + monkeypatch.setattr( + controlflow.defaults, + "memory_provider", + ChromaMemory(client=EPHEMERAL_CHROMA_CLIENT), + ) + + yield + + EPHEMERAL_CHROMA_CLIENT.reset() + @pytest.fixture(autouse=True) def reset_defaults_after_each_test(monkeypatch): diff --git a/tests/memory/__init__.py b/tests/memory/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/memory/test_memory.py b/tests/memory/test_memory.py new file mode 100644 index 00000000..e5e08aad --- /dev/null +++ b/tests/memory/test_memory.py @@ -0,0 +1,47 @@ +import chromadb +import pytest + +import controlflow +from controlflow.memory.providers.chroma import ChromaMemory + + +class TestMemory: + def test_store_and_retrieve(self): + m = controlflow.Memory(key="test", instructions="test") + m.add("The number is 42") + result = m.search("numbers") + assert len(result) == 1 + assert "The number is 42" in result.values() + + def test_delete(self): + m = controlflow.Memory(key="test", instructions="test") + m_id = m.add("The number is 42") + m.delete(m_id) + result = m.search("numbers") + assert len(result) == 0 + + def test_search(self): + m = controlflow.Memory(key="test", instructions="test") + m.add("The number is 42") + m.add("The number is 43") + result = m.search("numbers") + assert len(result) == 2 + assert "The number is 42" in result.values() + assert "The number is 43" in result.values() + + +class TestMemoryProvider: + def test_load_from_string_invalid(self): + with pytest.raises(ValueError): + controlflow.Memory(key="test", instructions="test", provider="invalid") + + def test_load_from_string_chroma_db(self): + m = controlflow.Memory(key="test", instructions="test", provider="chroma-db") + assert isinstance(m.provider, ChromaMemory) + assert m.provider.client.path == str( + controlflow.settings.home_path / "memory/chroma" + ) + + def test_load_from_instance(self): + mp = ChromaMemory(client=chromadb.PersistentClient(path="test_path")) + m = controlflow.Memory(key="test", instructions="test", provider=mp)