diff --git a/mindflow/__init__.py b/mindflow/__init__.py index 8a3be2e..dc1bba8 100644 --- a/mindflow/__init__.py +++ b/mindflow/__init__.py @@ -1 +1 @@ -__version__ = "0.3.13" +__version__ = "0.3.14" diff --git a/mindflow/core/index.py b/mindflow/core/index.py index 7bae2f3..2e314f1 100644 --- a/mindflow/core/index.py +++ b/mindflow/core/index.py @@ -49,8 +49,9 @@ def run_index(document_paths: List[str], refresh: bool, verbose: bool = True) -> indexable_document_references: List[DocumentReference] = return_if_indexable( document_references, refresh ) - if not indexable_document_references and verbose: - print("No documents to index") + if not indexable_document_references: + if verbose: + print("No documents to index") return print_total_size(indexable_document_references) diff --git a/mindflow/core/query.py b/mindflow/core/query.py index 928acaf..df84b27 100644 --- a/mindflow/core/query.py +++ b/mindflow/core/query.py @@ -69,18 +69,16 @@ def select_content( """ This function is used to generate a prompt based on a question or summarization task """ - embedding_ranked_document_chunks: List[ - DocumentChunk - ] = rank_document_chunks_by_embedding(query, resolved, embedding_model) - if len(embedding_ranked_document_chunks) == 0: + ranked_document_chunks: List[DocumentChunk] = rank_document_chunks_by_embedding( + query, resolved, embedding_model + ) + if len(ranked_document_chunks) == 0: print( "No index for requested hashes. Please generate index for passed content." ) sys.exit(1) - selected_content = trim_content( - embedding_ranked_document_chunks, completion_model, query - ) + selected_content = trim_content(ranked_document_chunks, completion_model, query) return selected_content @@ -90,13 +88,10 @@ class DocumentChunk: This class is used to store the chunks of a document. """ - def __init__( - self, path: str, start: int, end: int, embedding: Optional[np.ndarray] = None - ): + def __init__(self, path: str, start: int, end: int): self.path = path self.start = start self.end = end - self.embedding = embedding @classmethod def from_search_tree( @@ -109,21 +104,9 @@ def from_search_tree( """ stack = [document.search_tree] - chunks: List["DocumentChunk"] = [ - cls( - document.path, - document.search_tree["start"], - document.search_tree["end"], - ) - ] - embedding_response: Union[ModelError, np.ndarray] = embedding_model( - document.search_tree["summary"] - ) - if isinstance(embedding_response, ModelError): - print(embedding_response.embedding_message) - return [], [] + chunks: List["DocumentChunk"] = [] + embeddings: List[np.ndarray] = [] - embeddings: List[np.ndarray] = [embedding_response] rolling_summary: List[str] = [] while stack: node = stack.pop() @@ -131,17 +114,18 @@ def from_search_tree( if node["leaves"]: for leaf in node["leaves"]: stack.append(leaf) - chunks.append(cls(document.path, leaf["start"], leaf["end"])) - rolling_summary_embedding_response: Union[ - np.ndarray, ModelError - ] = embedding_model( - "\n\n".join(rolling_summary) + "\n\n" + leaf["summary"], - ) - if isinstance(rolling_summary_embedding_response, ModelError): - print(rolling_summary_embedding_response.embedding_message) - continue - embeddings.append(rolling_summary_embedding_response) - rolling_summary.pop() + else: + rolling_summary_embedding_response: Union[ + np.ndarray, ModelError + ] = embedding_model("\n\n".join(rolling_summary)) + if isinstance(rolling_summary_embedding_response, ModelError): + print(rolling_summary_embedding_response.embedding_message) + continue + + chunks.append(cls(document.path, node["start"], node["end"])) + embeddings.append(rolling_summary_embedding_response) + + rolling_summary.pop() return chunks, embeddings @@ -155,28 +139,50 @@ def trim_content( selected_content: str = "" for document_chunk in ranked_document_chunks: - if document_chunk: - with open(document_chunk.path, "r", encoding="utf-8") as file: - file.seek(document_chunk.start) - text = file.read(document_chunk.end - document_chunk.start) - - # Perform a binary search to find the maximum amount of text that fits within the token limit - left, right = 0, len(text) - while left <= right: - mid = (left + right) // 2 - if ( - get_token_count(model, query + selected_content + text[:mid]) - <= model.hard_token_limit - MinimumReservedLength.QUERY.value - ): - left = mid + 1 - else: - right = mid - 1 - - # Add the selected text to the selected content - selected_content += text[:right] + with open(document_chunk.path, "r", encoding="utf-8") as file: + file.seek(document_chunk.start) + text = file.read(document_chunk.end - document_chunk.start) + + selected_content += formated_chunk(document_chunk, text) + + if ( + get_token_count(model, query + selected_content) + > model.hard_token_limit + ): + break + + # Perform a binary search to trim the selected content to fit within the token limit + left, right = 0, len(selected_content) + while left <= right: + mid = (left + right) // 2 + if ( + get_token_count(model, query + selected_content[:mid]) + <= model.hard_token_limit - MinimumReservedLength.QUERY.value + ): + left = mid + 1 + else: + right = mid - 1 + + # Trim the selected content to the new bounds + selected_content = selected_content[:right] + return selected_content +def formated_chunk(document_chunk: DocumentChunk, text: str) -> str: + return ( + "Path: " + + document_chunk.path + + " Start: " + + str(document_chunk.start) + + " End: " + + str(document_chunk.end) + + " Text: " + + text + + "\n\n" + ) + + def rank_document_chunks_by_embedding( query: str, resolved: List[Dict], @@ -205,10 +211,9 @@ def rank_document_chunks_by_embedding( for document in filtered_documents ] for future in as_completed(futures): - document_chunks, document_chunk_embeddings = future.result() - similarities = cosine_similarity( - prompt_embeddings, document_chunk_embeddings - )[0] + # Ordered together + document_chunks, embeddings = future.result() + similarities = cosine_similarity(prompt_embeddings, embeddings)[0] ranked_document_chunks.extend(list(zip(document_chunks, similarities))) ranked_document_chunks.sort(key=lambda x: x[1], reverse=True) diff --git a/mindflow/db/objects/document.py b/mindflow/db/objects/document.py index 602e693..f9877ed 100644 --- a/mindflow/db/objects/document.py +++ b/mindflow/db/objects/document.py @@ -86,7 +86,7 @@ def from_resolved( document_reference.id, document_reference.document_type ) if not document_text: - print(f"Unable to read document text: {document_reference.id}") + ## print(f"Unable to read document text: {document_reference.id}") continue document_text_bytes = document_text.encode("utf-8")