Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
Addition of Simple Memory System Based on ChromaDB (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
Swiftyos authored Sep 5, 2023
1 parent 1277e36 commit cf2952f
Show file tree
Hide file tree
Showing 6 changed files with 840 additions and 862 deletions.
15 changes: 8 additions & 7 deletions forge/.pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ repos:
hooks:
- id: autoflake
name: autoflake
entry: autoflake --in-place --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring autogpt
entry: autoflake --in-place --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring forge/autogpt
language: python
types: [ python ]
- id: pytest-check
name: pytest-check
entry: pytest
language: system
pass_filenames: false
always_run: true
# Mono repo has bronken this TODO: fix
# - id: pytest-check
# name: pytest-check
# entry: pytest
# language: system
# pass_filenames: false
# always_run: true
1 change: 1 addition & 0 deletions forge/autogpt/sdk/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

159 changes: 159 additions & 0 deletions forge/autogpt/sdk/memory/memstore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import hashlib

import chromadb
from chromadb.config import Settings


class MemStore:
"""
A class used to represent a Memory Store
"""

def __init__(self, store_path: str):
"""
Initialize the MemStore with a given store path.
Args:
store_path (str): The path to the store.
"""
self.client = chromadb.PersistentClient(
path=store_path, settings=Settings(anonymized_telemetry=False)
)

def add(self, task_id: str, document: str, metadatas: dict) -> None:
"""
Add a document to the MemStore.
Args:
task_id (str): The ID of the task.
document (str): The document to be added.
metadatas (dict): The metadata of the document.
"""
doc_id = hashlib.sha256(document.encode()).hexdigest()[:20]
collection = self.client.get_or_create_collection(task_id)
collection.add(documents=[document], metadatas=[metadatas], ids=[doc_id])

def query(
self,
task_id: str,
query: str,
filters: dict = None,
document_search: dict = None,
) -> dict:
"""
Query the MemStore.
Args:
task_id (str): The ID of the task.
query (str): The query string.
filters (dict, optional): The filters to be applied. Defaults to None.
search_string (str, optional): The search string. Defaults to None.
Returns:
dict: The query results.
"""
collection = self.client.get_or_create_collection(task_id)

kwargs = {
"query_texts": [query],
"n_results": 10,
}

if filters:
kwargs["where"] = filters

if document_search:
kwargs["where_document"] = document_search

return collection.query(**kwargs)

def get(self, task_id: str, doc_ids: list = None, filters: dict = None) -> dict:
"""
Get documents from the MemStore.
Args:
task_id (str): The ID of the task.
doc_ids (list, optional): The IDs of the documents to be retrieved. Defaults to None.
filters (dict, optional): The filters to be applied. Defaults to None.
Returns:
dict: The retrieved documents.
"""
collection = self.client.get_or_create_collection(task_id)
kwargs = {}
if doc_ids:
kwargs["ids"] = doc_ids
if filters:
kwargs["where"] = filters
return collection.get(**kwargs)

def update(self, task_id: str, doc_ids: list, documents: list, metadatas: list):
"""
Update documents in the MemStore.
Args:
task_id (str): The ID of the task.
doc_ids (list): The IDs of the documents to be updated.
documents (list): The updated documents.
metadatas (list): The updated metadata.
"""
collection = self.client.get_or_create_collection(task_id)
collection.update(ids=doc_ids, documents=documents, metadatas=metadatas)

def delete(self, task_id: str, doc_id: str):
"""
Delete a document from the MemStore.
Args:
task_id (str): The ID of the task.
doc_id (str): The ID of the document to be deleted.
"""
collection = self.client.get_or_create_collection(task_id)
collection.delete(ids=[doc_id])


if __name__ == "__main__":
print("#############################################")
# Initialize MemStore
mem = MemStore(".agent_mem_store")

# Test add function
task_id = "test_task"
document = "This is a another new test document."
metadatas = {"metadata": "test_metadata"}
mem.add(task_id, document, metadatas)

task_id = "test_task"
document = "The quick brown fox jumps over the lazy dog."
metadatas = {"metadata": "test_metadata"}
mem.add(task_id, document, metadatas)

task_id = "test_task"
document = "AI is a new technology that will change the world."
metadatas = {"timestamp": 1623936000}
mem.add(task_id, document, metadatas)

doc_id = hashlib.sha256(document.encode()).hexdigest()[:20]
# Test query function
query = "test"
filters = {"metadata": {"$eq": "test"}}
search_string = {"$contains": "test"}
doc_ids = [doc_id]
documents = ["This is an updated test document."]
updated_metadatas = {"metadata": "updated_test_metadata"}

print("Query:")
print(mem.query(task_id, query))

# Test get function
print("Get:")

print(mem.get(task_id))

# Test update function
print("Update:")
print(mem.update(task_id, doc_ids, documents, updated_metadatas))

print("Delete:")
# Test delete function
print(mem.delete(task_id, doc_ids[0]))
58 changes: 58 additions & 0 deletions forge/autogpt/sdk/memory/memstore_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import hashlib
import shutil

import pytest

from autogpt.sdk.memory.memstore import MemStore


@pytest.fixture
def memstore():
mem = MemStore(".test_mem_store")
yield mem
shutil.rmtree(".test_mem_store")


def test_add(memstore):
task_id = "test_task"
document = "This is a test document."
metadatas = {"metadata": "test_metadata"}
memstore.add(task_id, document, metadatas)
doc_id = hashlib.sha256(document.encode()).hexdigest()[:20]
assert memstore.client.get_or_create_collection(task_id).count() == 1


def test_query(memstore):
task_id = "test_task"
document = "This is a test document."
metadatas = {"metadata": "test_metadata"}
memstore.add(task_id, document, metadatas)
query = "test"
assert len(memstore.query(task_id, query)["documents"]) == 1


def test_update(memstore):
task_id = "test_task"
document = "This is a test document."
metadatas = {"metadata": "test_metadata"}
memstore.add(task_id, document, metadatas)
doc_id = hashlib.sha256(document.encode()).hexdigest()[:20]
updated_document = "This is an updated test document."
updated_metadatas = {"metadata": "updated_test_metadata"}
memstore.update(task_id, [doc_id], [updated_document], [updated_metadatas])
assert memstore.get(task_id, [doc_id]) == {
"documents": [updated_document],
"metadatas": [updated_metadatas],
"embeddings": None,
"ids": [doc_id],
}


def test_delete(memstore):
task_id = "test_task"
document = "This is a test document."
metadatas = {"metadata": "test_metadata"}
memstore.add(task_id, document, metadatas)
doc_id = hashlib.sha256(document.encode()).hexdigest()[:20]
memstore.delete(task_id, doc_id)
assert memstore.client.get_or_create_collection(task_id).count() == 0
Loading

0 comments on commit cf2952f

Please sign in to comment.