From 6cb1259bc210dafd736e09d1dbd5f5e075a9dd6b Mon Sep 17 00:00:00 2001 From: Holger Vogel Date: Fri, 17 May 2024 14:52:04 +0100 Subject: [PATCH] fix --- src/dspygen/rm/chatgpt_chromadb_retriever.py | 219 +++++++++++-------- 1 file changed, 130 insertions(+), 89 deletions(-) diff --git a/src/dspygen/rm/chatgpt_chromadb_retriever.py b/src/dspygen/rm/chatgpt_chromadb_retriever.py index 82e09c5..2696635 100644 --- a/src/dspygen/rm/chatgpt_chromadb_retriever.py +++ b/src/dspygen/rm/chatgpt_chromadb_retriever.py @@ -1,30 +1,29 @@ import hashlib - import dspy import ijson from pathlib import Path -from typing import List, Optional, Union +from typing import List, Optional, Union, Any from loguru import logger - import chromadb import chromadb.utils.embedding_functions as embedding_functions from munch import Munch -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel, ValidationError, Field from dspygen.modules.python_source_code_module import python_source_code_call from dspygen.utils.file_tools import data_dir, count_tokens - # Configure loguru logger -# logger.add("chatgpt_chromadb_retriever.log", rotation="10 MB", level="ERROR") +#logger.add("chatgpt_chromadb_retriever.log", rotation="10 MB", level="ERROR") def calculate_file_checksum(file_path: str) -> str: hash_md5 = hashlib.md5() + print("Chromadb path: ", file_path) with open(file_path, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): hash_md5.update(chunk) + print(chunk) return hash_md5.hexdigest() @@ -37,7 +36,7 @@ class Author(BaseModel): class ContentPart(BaseModel): content_type: str - parts: List[str] | None + parts: Optional[List[Union[str, dict]]] = None # Allow parts to be either strings or dicts class Message(BaseModel): @@ -50,8 +49,8 @@ class Message(BaseModel): class Data(BaseModel): id: str - message: Message | None - parent: str | None + message: Optional[Message] = None # Allow message to be None + parent: Optional[str] = None children: List[str] @@ -61,18 +60,21 @@ class Conversation(BaseModel): default_embed_fn = embedding_functions.OllamaEmbeddingFunction( - url="http://localhost:11434/api/embeddings", - model_name="llama3",) + url="http://localhost:11434/api/embeddings", + model_name="llama3", +) class ChatGPTChromaDBRetriever(dspy.Retrieve): - def __init__(self, - json_file_path: str = data_dir() / "chatgpt_logs" / "conversations.json", - collection_name: str = "chatgpt", - persist_directory: str = data_dir(), - check_for_updates: bool = True, - embed_fn=default_embed_fn, - k=5): + def __init__( + self, + json_file_path: str = data_dir() / "chatgpt_logs" / "conversations.json", + collection_name: str = "chatgpt", + persist_directory: str = data_dir(), + check_for_updates: bool = True, + embed_fn=default_embed_fn, + k=5, + ): """Initialize the ChatGPTChromaDBRetriever.""" super().__init__(k) self.json_file_path = json_file_path @@ -81,8 +83,10 @@ def __init__(self, self.persist_directory = Path(persist_directory) self.client = chromadb.PersistentClient(path=str(self.persist_directory)) self.embedding_function = embed_fn - self.collection = self.client.get_or_create_collection(name=self.collection_name, - embedding_function=self.embedding_function) + self.collection = self.client.get_or_create_collection( + name=self.collection_name, + embedding_function=self.embedding_function, + ) self.persist_directory.mkdir(parents=True, exist_ok=True) if not check_for_updates: @@ -110,61 +114,82 @@ def _save_last_processed_checksum(self): def _process_and_store_conversations(self): with open(self.json_file_path, "rb") as json_file: count = -1 - - for conversation in ijson.items(json_file, "item"): - count += 1 - print(f"Processing conversation #{count} {conversation['title']}") + while True: try: - validated_conversation = Conversation(**conversation) - for _, data in validated_conversation.mapping.items(): - validated_data = Data(**data) - - # Search if document already exists - search_results = self.collection.get(ids=[validated_data.id]) - if len(search_results["ids"]) > 0: - logger.info(f"Skipping already existing document #{count} with ID: {validated_data.id}") - continue - - if validated_data.message: - document_text = ' '.join(part for part in validated_data.message.content.parts if part) - - if len(document_text) < 200: - continue - - self.collection.add(documents=[document_text], metadatas=[{"id": validated_data.id}], - ids=[validated_data.id]) - logger.debug(f"Added document with ID: {validated_data.id}") - - except ValidationError as e: - logger.error(f"Validation error: {e}") + for conversation in ijson.items(json_file, "item"): + count += 1 + print(f"Processing conversation #{count} {conversation['title']}") + try: + validated_conversation = Conversation(**conversation) + for _, data in validated_conversation.mapping.items(): + validated_data = Data(**data) + + # Search if document already exists + search_results = self.collection.get(ids=[validated_data.id]) + if len(search_results["ids"]) > 0: + logger.info(f"Skipping already existing document #{count} with ID: {validated_data.id}") + continue + + if validated_data.message and validated_data.message.content.parts: + # Filter and process text parts only + document_text = ' '.join( + part for part in validated_data.message.content.parts if isinstance(part, str) + ) + + if len(document_text) < 200: + continue + + self.collection.add( + documents=[document_text], + metadatas=[{"id": validated_data.id}], + ids=[validated_data.id], + ) + logger.debug(f"Added document with ID: {validated_data.id}") + + except ValidationError as e: + logger.error(f"Validation error: {e}") + break + except ijson.JSONError as e: + logger.error(f"JSON parsing error: {e}") + break # Exit the loop if we encounter a JSON parsing error def _update_collection_metadata(self): with open(self.json_file_path, "rb") as json_file: - for conversation in ijson.items(json_file, "item"): + while True: try: - validated_conversation = Conversation(**conversation) - for _, data in validated_conversation.mapping.items(): - validated_data = Data(**data) - - if validated_data.message: - document_text = ' '.join(part for part in validated_data.message.content.parts if part) - - meta = Munch() - meta.id = validated_data.id - meta.role = validated_data.message.author.role - meta.title = validated_conversation.title - - self.collection.update(metadatas=[meta], ids=[validated_data.id]) - logger.debug(f"Updated document with ID: {validated_data.id}") - - except ValidationError as e: - logger.error(f"Validation error: {e}") + for conversation in ijson.items(json_file, "item"): + try: + validated_conversation = Conversation(**conversation) + for _, data in validated_conversation.mapping.items(): + validated_data = Data(**data) + + if validated_data.message and validated_data.message.content.parts: + # Filter and process text parts only + document_text = ' '.join( + part for part in validated_data.message.content.parts if isinstance(part, str) + ) + + meta = Munch() + meta.id = validated_data.id + meta.role = validated_data.message.author.role + meta.title = validated_conversation.title + + self.collection.update(metadatas=[meta], ids=[validated_data.id]) + logger.debug(f"Updated document with ID: {validated_data.id}") + + except ValidationError as e: + logger.error(f"Validation error: {e}") + break + except ijson.JSONError as e: + logger.error(f"JSON parsing error: {e}") + break # Exit the loop if we encounter a JSON parsing error def forward( - self, query_or_queries: Union[str, List[str]], - k: Optional[int] = None, - contains: Optional[str] = None, - role: str = "assistant" + self, + query_or_queries: Union[str, List[str]], + k: Optional[int] = None, + contains: Optional[str] = None, + role: str = "assistant", ) -> list[str]: """Search with ChromaDB for top passages for the provided query/queries. @@ -173,35 +198,51 @@ def forward( k (Optional[int], optional): The number of top passages to retrieve. Defaults to None, which will use the value in self.k. contains (Optional[str], optional): The string that the retrieved passages must contain. Defaults to None. role: The role of the author of the message. Defaults to "assistant". + Returns: dspy.Prediction: An object containing the retrieved passages. """ - queries = ( - [query_or_queries] - if isinstance(query_or_queries, str) - else query_or_queries - ) + queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries queries = [q for q in queries if q] # Filter empty queries - embeddings = self.embedding_function(queries) - k = self.k if k is None else k + # Check if queries is empty after filtering + if not queries: + logger.error("No valid queries provided") + return [] - if contains is not None: - results = self.collection.query( - query_embeddings=embeddings, - n_results=k, - where={"role": role}, - where_document={"$contains": contains} - ) - else: - results = self.collection.query(query_embeddings=embeddings, - where={"role": role}, - n_results=k) + try: + embeddings = self.embedding_function(queries) + except Exception as e: + logger.error(f"Error generating embeddings: {e}") + return [] - # super().forward(query_or_queries) + # Ensure embeddings are not empty + if not embeddings or not embeddings[0]: + logger.error("No embeddings generated") + return [] - return results["documents"][0] + k = self.k if k is None else k + + try: + if contains is not None: + results = self.collection.query( + query_embeddings=embeddings, + n_results=k, + where={"role": role}, + where_document={"$contains": contains}, + ) + else: + results = self.collection.query( + query_embeddings=embeddings, + where={"role": role}, + n_results=k, + ) + except Exception as e: + logger.error(f"Error querying the collection: {e}") + return [] + + return results.get("documents", [[]])[0] def main(): @@ -212,7 +253,7 @@ def main(): retriever = ChatGPTChromaDBRetriever(check_for_updates=True) retriever._update_collection_metadata() - query = "" + query = "Fixed and running Tetris pygame" matched_conversations = retriever.forward(query, k=5) # print(count_tokens(str(matched_conversations) + "\nI want a DSPy module that generates Python source code.")) for conversation in matched_conversations: