Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
hjvogel committed May 17, 2024
1 parent ae61139 commit 6cb1259
Showing 1 changed file with 130 additions and 89 deletions.
219 changes: 130 additions & 89 deletions src/dspygen/rm/chatgpt_chromadb_retriever.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,29 @@
import hashlib

import dspy
import ijson
from pathlib import Path
from typing import List, Optional, Union
from typing import List, Optional, Union, Any

from loguru import logger

import chromadb
import chromadb.utils.embedding_functions as embedding_functions
from munch import Munch
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel, ValidationError, Field

from dspygen.modules.python_source_code_module import python_source_code_call
from dspygen.utils.file_tools import data_dir, count_tokens


# Configure loguru logger
# logger.add("chatgpt_chromadb_retriever.log", rotation="10 MB", level="ERROR")
#logger.add("chatgpt_chromadb_retriever.log", rotation="10 MB", level="ERROR")


def calculate_file_checksum(file_path: str) -> str:
hash_md5 = hashlib.md5()
print("Chromadb path: ", file_path)
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
print(chunk)
return hash_md5.hexdigest()


Expand All @@ -37,7 +36,7 @@ class Author(BaseModel):

class ContentPart(BaseModel):
content_type: str
parts: List[str] | None
parts: Optional[List[Union[str, dict]]] = None # Allow parts to be either strings or dicts


class Message(BaseModel):
Expand All @@ -50,8 +49,8 @@ class Message(BaseModel):

class Data(BaseModel):
id: str
message: Message | None
parent: str | None
message: Optional[Message] = None # Allow message to be None
parent: Optional[str] = None
children: List[str]


Expand All @@ -61,18 +60,21 @@ class Conversation(BaseModel):


default_embed_fn = embedding_functions.OllamaEmbeddingFunction(
url="http://localhost:11434/api/embeddings",
model_name="llama3",)
url="http://localhost:11434/api/embeddings",
model_name="llama3",
)


class ChatGPTChromaDBRetriever(dspy.Retrieve):
def __init__(self,
json_file_path: str = data_dir() / "chatgpt_logs" / "conversations.json",
collection_name: str = "chatgpt",
persist_directory: str = data_dir(),
check_for_updates: bool = True,
embed_fn=default_embed_fn,
k=5):
def __init__(
self,
json_file_path: str = data_dir() / "chatgpt_logs" / "conversations.json",
collection_name: str = "chatgpt",
persist_directory: str = data_dir(),
check_for_updates: bool = True,
embed_fn=default_embed_fn,
k=5,
):
"""Initialize the ChatGPTChromaDBRetriever."""
super().__init__(k)
self.json_file_path = json_file_path
Expand All @@ -81,8 +83,10 @@ def __init__(self,
self.persist_directory = Path(persist_directory)
self.client = chromadb.PersistentClient(path=str(self.persist_directory))
self.embedding_function = embed_fn
self.collection = self.client.get_or_create_collection(name=self.collection_name,
embedding_function=self.embedding_function)
self.collection = self.client.get_or_create_collection(
name=self.collection_name,
embedding_function=self.embedding_function,
)
self.persist_directory.mkdir(parents=True, exist_ok=True)

if not check_for_updates:
Expand Down Expand Up @@ -110,61 +114,82 @@ def _save_last_processed_checksum(self):
def _process_and_store_conversations(self):
with open(self.json_file_path, "rb") as json_file:
count = -1

for conversation in ijson.items(json_file, "item"):
count += 1
print(f"Processing conversation #{count} {conversation['title']}")
while True:
try:
validated_conversation = Conversation(**conversation)
for _, data in validated_conversation.mapping.items():
validated_data = Data(**data)

# Search if document already exists
search_results = self.collection.get(ids=[validated_data.id])
if len(search_results["ids"]) > 0:
logger.info(f"Skipping already existing document #{count} with ID: {validated_data.id}")
continue

if validated_data.message:
document_text = ' '.join(part for part in validated_data.message.content.parts if part)

if len(document_text) < 200:
continue

self.collection.add(documents=[document_text], metadatas=[{"id": validated_data.id}],
ids=[validated_data.id])
logger.debug(f"Added document with ID: {validated_data.id}")

except ValidationError as e:
logger.error(f"Validation error: {e}")
for conversation in ijson.items(json_file, "item"):
count += 1
print(f"Processing conversation #{count} {conversation['title']}")
try:
validated_conversation = Conversation(**conversation)
for _, data in validated_conversation.mapping.items():
validated_data = Data(**data)

# Search if document already exists
search_results = self.collection.get(ids=[validated_data.id])
if len(search_results["ids"]) > 0:
logger.info(f"Skipping already existing document #{count} with ID: {validated_data.id}")
continue

if validated_data.message and validated_data.message.content.parts:
# Filter and process text parts only
document_text = ' '.join(
part for part in validated_data.message.content.parts if isinstance(part, str)
)

if len(document_text) < 200:
continue

self.collection.add(
documents=[document_text],
metadatas=[{"id": validated_data.id}],
ids=[validated_data.id],
)
logger.debug(f"Added document with ID: {validated_data.id}")

except ValidationError as e:
logger.error(f"Validation error: {e}")
break
except ijson.JSONError as e:
logger.error(f"JSON parsing error: {e}")
break # Exit the loop if we encounter a JSON parsing error

def _update_collection_metadata(self):
with open(self.json_file_path, "rb") as json_file:
for conversation in ijson.items(json_file, "item"):
while True:
try:
validated_conversation = Conversation(**conversation)
for _, data in validated_conversation.mapping.items():
validated_data = Data(**data)

if validated_data.message:
document_text = ' '.join(part for part in validated_data.message.content.parts if part)

meta = Munch()
meta.id = validated_data.id
meta.role = validated_data.message.author.role
meta.title = validated_conversation.title

self.collection.update(metadatas=[meta], ids=[validated_data.id])
logger.debug(f"Updated document with ID: {validated_data.id}")

except ValidationError as e:
logger.error(f"Validation error: {e}")
for conversation in ijson.items(json_file, "item"):
try:
validated_conversation = Conversation(**conversation)
for _, data in validated_conversation.mapping.items():
validated_data = Data(**data)

if validated_data.message and validated_data.message.content.parts:
# Filter and process text parts only
document_text = ' '.join(
part for part in validated_data.message.content.parts if isinstance(part, str)
)

meta = Munch()
meta.id = validated_data.id
meta.role = validated_data.message.author.role
meta.title = validated_conversation.title

self.collection.update(metadatas=[meta], ids=[validated_data.id])
logger.debug(f"Updated document with ID: {validated_data.id}")

except ValidationError as e:
logger.error(f"Validation error: {e}")
break
except ijson.JSONError as e:
logger.error(f"JSON parsing error: {e}")
break # Exit the loop if we encounter a JSON parsing error

def forward(
self, query_or_queries: Union[str, List[str]],
k: Optional[int] = None,
contains: Optional[str] = None,
role: str = "assistant"
self,
query_or_queries: Union[str, List[str]],
k: Optional[int] = None,
contains: Optional[str] = None,
role: str = "assistant",
) -> list[str]:
"""Search with ChromaDB for top passages for the provided query/queries.
Expand All @@ -173,35 +198,51 @@ def forward(
k (Optional[int], optional): The number of top passages to retrieve. Defaults to None, which will use the value in self.k.
contains (Optional[str], optional): The string that the retrieved passages must contain. Defaults to None.
role: The role of the author of the message. Defaults to "assistant".
Returns:
dspy.Prediction: An object containing the retrieved passages.
"""

queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries
queries = [q for q in queries if q] # Filter empty queries
embeddings = self.embedding_function(queries)

k = self.k if k is None else k
# Check if queries is empty after filtering
if not queries:
logger.error("No valid queries provided")
return []

if contains is not None:
results = self.collection.query(
query_embeddings=embeddings,
n_results=k,
where={"role": role},
where_document={"$contains": contains}
)
else:
results = self.collection.query(query_embeddings=embeddings,
where={"role": role},
n_results=k)
try:
embeddings = self.embedding_function(queries)
except Exception as e:
logger.error(f"Error generating embeddings: {e}")
return []

# super().forward(query_or_queries)
# Ensure embeddings are not empty
if not embeddings or not embeddings[0]:
logger.error("No embeddings generated")
return []

return results["documents"][0]
k = self.k if k is None else k

try:
if contains is not None:
results = self.collection.query(
query_embeddings=embeddings,
n_results=k,
where={"role": role},
where_document={"$contains": contains},
)
else:
results = self.collection.query(
query_embeddings=embeddings,
where={"role": role},
n_results=k,
)
except Exception as e:
logger.error(f"Error querying the collection: {e}")
return []

return results.get("documents", [[]])[0]


def main():
Expand All @@ -212,7 +253,7 @@ def main():
retriever = ChatGPTChromaDBRetriever(check_for_updates=True)
retriever._update_collection_metadata()

query = ""
query = "Fixed and running Tetris pygame"
matched_conversations = retriever.forward(query, k=5)
# print(count_tokens(str(matched_conversations) + "\nI want a DSPy module that generates Python source code."))
for conversation in matched_conversations:
Expand Down

0 comments on commit 6cb1259

Please sign in to comment.