Skip to content

Commit

Permalink
Merge pull request #99 from nollied/improve-index
Browse files Browse the repository at this point in the history
Modified indexing and querying functions and updated version number.
  • Loading branch information
steegecs authored Mar 10, 2023
2 parents fd4b1dd + 5b6e4c8 commit 064e22b
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 63 deletions.
2 changes: 1 addition & 1 deletion mindflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.13"
__version__ = "0.3.14"
5 changes: 3 additions & 2 deletions mindflow/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ def run_index(document_paths: List[str], refresh: bool, verbose: bool = True) ->
indexable_document_references: List[DocumentReference] = return_if_indexable(
document_references, refresh
)
if not indexable_document_references and verbose:
print("No documents to index")
if not indexable_document_references:
if verbose:
print("No documents to index")
return

print_total_size(indexable_document_references)
Expand Down
123 changes: 64 additions & 59 deletions mindflow/core/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,16 @@ def select_content(
"""
This function is used to generate a prompt based on a question or summarization task
"""
embedding_ranked_document_chunks: List[
DocumentChunk
] = rank_document_chunks_by_embedding(query, resolved, embedding_model)
if len(embedding_ranked_document_chunks) == 0:
ranked_document_chunks: List[DocumentChunk] = rank_document_chunks_by_embedding(
query, resolved, embedding_model
)
if len(ranked_document_chunks) == 0:
print(
"No index for requested hashes. Please generate index for passed content."
)
sys.exit(1)

selected_content = trim_content(
embedding_ranked_document_chunks, completion_model, query
)
selected_content = trim_content(ranked_document_chunks, completion_model, query)

return selected_content

Expand All @@ -90,13 +88,10 @@ class DocumentChunk:
This class is used to store the chunks of a document.
"""

def __init__(
self, path: str, start: int, end: int, embedding: Optional[np.ndarray] = None
):
def __init__(self, path: str, start: int, end: int):
self.path = path
self.start = start
self.end = end
self.embedding = embedding

@classmethod
def from_search_tree(
Expand All @@ -109,39 +104,28 @@ def from_search_tree(
"""

stack = [document.search_tree]
chunks: List["DocumentChunk"] = [
cls(
document.path,
document.search_tree["start"],
document.search_tree["end"],
)
]
embedding_response: Union[ModelError, np.ndarray] = embedding_model(
document.search_tree["summary"]
)
if isinstance(embedding_response, ModelError):
print(embedding_response.embedding_message)
return [], []
chunks: List["DocumentChunk"] = []
embeddings: List[np.ndarray] = []

embeddings: List[np.ndarray] = [embedding_response]
rolling_summary: List[str] = []
while stack:
node = stack.pop()
rolling_summary.append(node["summary"])
if node["leaves"]:
for leaf in node["leaves"]:
stack.append(leaf)
chunks.append(cls(document.path, leaf["start"], leaf["end"]))
rolling_summary_embedding_response: Union[
np.ndarray, ModelError
] = embedding_model(
"\n\n".join(rolling_summary) + "\n\n" + leaf["summary"],
)
if isinstance(rolling_summary_embedding_response, ModelError):
print(rolling_summary_embedding_response.embedding_message)
continue
embeddings.append(rolling_summary_embedding_response)
rolling_summary.pop()
else:
rolling_summary_embedding_response: Union[
np.ndarray, ModelError
] = embedding_model("\n\n".join(rolling_summary))
if isinstance(rolling_summary_embedding_response, ModelError):
print(rolling_summary_embedding_response.embedding_message)
continue

chunks.append(cls(document.path, node["start"], node["end"]))
embeddings.append(rolling_summary_embedding_response)

rolling_summary.pop()

return chunks, embeddings

Expand All @@ -155,28 +139,50 @@ def trim_content(
selected_content: str = ""

for document_chunk in ranked_document_chunks:
if document_chunk:
with open(document_chunk.path, "r", encoding="utf-8") as file:
file.seek(document_chunk.start)
text = file.read(document_chunk.end - document_chunk.start)

# Perform a binary search to find the maximum amount of text that fits within the token limit
left, right = 0, len(text)
while left <= right:
mid = (left + right) // 2
if (
get_token_count(model, query + selected_content + text[:mid])
<= model.hard_token_limit - MinimumReservedLength.QUERY.value
):
left = mid + 1
else:
right = mid - 1

# Add the selected text to the selected content
selected_content += text[:right]
with open(document_chunk.path, "r", encoding="utf-8") as file:
file.seek(document_chunk.start)
text = file.read(document_chunk.end - document_chunk.start)

selected_content += formated_chunk(document_chunk, text)

if (
get_token_count(model, query + selected_content)
> model.hard_token_limit
):
break

# Perform a binary search to trim the selected content to fit within the token limit
left, right = 0, len(selected_content)
while left <= right:
mid = (left + right) // 2
if (
get_token_count(model, query + selected_content[:mid])
<= model.hard_token_limit - MinimumReservedLength.QUERY.value
):
left = mid + 1
else:
right = mid - 1

# Trim the selected content to the new bounds
selected_content = selected_content[:right]

return selected_content


def formated_chunk(document_chunk: DocumentChunk, text: str) -> str:
return (
"Path: "
+ document_chunk.path
+ " Start: "
+ str(document_chunk.start)
+ " End: "
+ str(document_chunk.end)
+ " Text: "
+ text
+ "\n\n"
)


def rank_document_chunks_by_embedding(
query: str,
resolved: List[Dict],
Expand Down Expand Up @@ -205,10 +211,9 @@ def rank_document_chunks_by_embedding(
for document in filtered_documents
]
for future in as_completed(futures):
document_chunks, document_chunk_embeddings = future.result()
similarities = cosine_similarity(
prompt_embeddings, document_chunk_embeddings
)[0]
# Ordered together
document_chunks, embeddings = future.result()
similarities = cosine_similarity(prompt_embeddings, embeddings)[0]
ranked_document_chunks.extend(list(zip(document_chunks, similarities)))

ranked_document_chunks.sort(key=lambda x: x[1], reverse=True)
Expand Down
2 changes: 1 addition & 1 deletion mindflow/db/objects/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def from_resolved(
document_reference.id, document_reference.document_type
)
if not document_text:
print(f"Unable to read document text: {document_reference.id}")
## print(f"Unable to read document text: {document_reference.id}")
continue

document_text_bytes = document_text.encode("utf-8")
Expand Down

0 comments on commit 064e22b

Please sign in to comment.