Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Addition of Simple Memory System Based on ChromaDB #28

Merged
merged 1 commit into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions forge/.pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ repos:
hooks:
- id: autoflake
name: autoflake
entry: autoflake --in-place --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring autogpt
entry: autoflake --in-place --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring forge/autogpt
language: python
types: [ python ]
- id: pytest-check
name: pytest-check
entry: pytest
language: system
pass_filenames: false
always_run: true
# Mono repo has bronken this TODO: fix
# - id: pytest-check
# name: pytest-check
# entry: pytest
# language: system
# pass_filenames: false
# always_run: true
1 change: 1 addition & 0 deletions forge/autogpt/sdk/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

159 changes: 159 additions & 0 deletions forge/autogpt/sdk/memory/memstore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import hashlib

import chromadb
from chromadb.config import Settings


class MemStore:
"""
A class used to represent a Memory Store
"""

def __init__(self, store_path: str):
"""
Initialize the MemStore with a given store path.

Args:
store_path (str): The path to the store.
"""
self.client = chromadb.PersistentClient(
path=store_path, settings=Settings(anonymized_telemetry=False)
)

def add(self, task_id: str, document: str, metadatas: dict) -> None:
"""
Add a document to the MemStore.

Args:
task_id (str): The ID of the task.
document (str): The document to be added.
metadatas (dict): The metadata of the document.
"""
doc_id = hashlib.sha256(document.encode()).hexdigest()[:20]
collection = self.client.get_or_create_collection(task_id)
collection.add(documents=[document], metadatas=[metadatas], ids=[doc_id])

def query(
self,
task_id: str,
query: str,
filters: dict = None,
document_search: dict = None,
) -> dict:
"""
Query the MemStore.

Args:
task_id (str): The ID of the task.
query (str): The query string.
filters (dict, optional): The filters to be applied. Defaults to None.
search_string (str, optional): The search string. Defaults to None.

Returns:
dict: The query results.
"""
collection = self.client.get_or_create_collection(task_id)

kwargs = {
"query_texts": [query],
"n_results": 10,
}

if filters:
kwargs["where"] = filters

if document_search:
kwargs["where_document"] = document_search

return collection.query(**kwargs)

def get(self, task_id: str, doc_ids: list = None, filters: dict = None) -> dict:
"""
Get documents from the MemStore.

Args:
task_id (str): The ID of the task.
doc_ids (list, optional): The IDs of the documents to be retrieved. Defaults to None.
filters (dict, optional): The filters to be applied. Defaults to None.

Returns:
dict: The retrieved documents.
"""
collection = self.client.get_or_create_collection(task_id)
kwargs = {}
if doc_ids:
kwargs["ids"] = doc_ids
if filters:
kwargs["where"] = filters
return collection.get(**kwargs)

def update(self, task_id: str, doc_ids: list, documents: list, metadatas: list):
"""
Update documents in the MemStore.

Args:
task_id (str): The ID of the task.
doc_ids (list): The IDs of the documents to be updated.
documents (list): The updated documents.
metadatas (list): The updated metadata.
"""
collection = self.client.get_or_create_collection(task_id)
collection.update(ids=doc_ids, documents=documents, metadatas=metadatas)

def delete(self, task_id: str, doc_id: str):
"""
Delete a document from the MemStore.

Args:
task_id (str): The ID of the task.
doc_id (str): The ID of the document to be deleted.
"""
collection = self.client.get_or_create_collection(task_id)
collection.delete(ids=[doc_id])


if __name__ == "__main__":
print("#############################################")
# Initialize MemStore
mem = MemStore(".agent_mem_store")

# Test add function
task_id = "test_task"
document = "This is a another new test document."
metadatas = {"metadata": "test_metadata"}
mem.add(task_id, document, metadatas)

task_id = "test_task"
document = "The quick brown fox jumps over the lazy dog."
metadatas = {"metadata": "test_metadata"}
mem.add(task_id, document, metadatas)

task_id = "test_task"
document = "AI is a new technology that will change the world."
metadatas = {"timestamp": 1623936000}
mem.add(task_id, document, metadatas)

doc_id = hashlib.sha256(document.encode()).hexdigest()[:20]
# Test query function
query = "test"
filters = {"metadata": {"$eq": "test"}}
search_string = {"$contains": "test"}
doc_ids = [doc_id]
documents = ["This is an updated test document."]
updated_metadatas = {"metadata": "updated_test_metadata"}

print("Query:")
print(mem.query(task_id, query))

# Test get function
print("Get:")

print(mem.get(task_id))

# Test update function
print("Update:")
print(mem.update(task_id, doc_ids, documents, updated_metadatas))

print("Delete:")
# Test delete function
print(mem.delete(task_id, doc_ids[0]))
58 changes: 58 additions & 0 deletions forge/autogpt/sdk/memory/memstore_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import hashlib
import shutil

import pytest

from autogpt.sdk.memory.memstore import MemStore


@pytest.fixture
def memstore():
mem = MemStore(".test_mem_store")
yield mem
shutil.rmtree(".test_mem_store")


def test_add(memstore):
task_id = "test_task"
document = "This is a test document."
metadatas = {"metadata": "test_metadata"}
memstore.add(task_id, document, metadatas)
doc_id = hashlib.sha256(document.encode()).hexdigest()[:20]
assert memstore.client.get_or_create_collection(task_id).count() == 1


def test_query(memstore):
task_id = "test_task"
document = "This is a test document."
metadatas = {"metadata": "test_metadata"}
memstore.add(task_id, document, metadatas)
query = "test"
assert len(memstore.query(task_id, query)["documents"]) == 1


def test_update(memstore):
task_id = "test_task"
document = "This is a test document."
metadatas = {"metadata": "test_metadata"}
memstore.add(task_id, document, metadatas)
doc_id = hashlib.sha256(document.encode()).hexdigest()[:20]
updated_document = "This is an updated test document."
updated_metadatas = {"metadata": "updated_test_metadata"}
memstore.update(task_id, [doc_id], [updated_document], [updated_metadatas])
assert memstore.get(task_id, [doc_id]) == {
"documents": [updated_document],
"metadatas": [updated_metadatas],
"embeddings": None,
"ids": [doc_id],
}


def test_delete(memstore):
task_id = "test_task"
document = "This is a test document."
metadatas = {"metadata": "test_metadata"}
memstore.add(task_id, document, metadatas)
doc_id = hashlib.sha256(document.encode()).hexdigest()[:20]
memstore.delete(task_id, doc_id)
assert memstore.client.get_or_create_collection(task_id).count() == 0
Loading
Loading