Skip to content

Commit

Permalink
use new hashing primitives
Browse files Browse the repository at this point in the history
The `st.experimental_memo()` and `st.cache_data` methods are deprecated
so I replaced them with the new caching methods.

However, the `Document` class cannot be hashed, so I added a custom
hash function for it. Also, as a temporary solution, the `search_docs()` method was moved to the `main.py` file since the `VectorStore` class cannot be
hashed.
  • Loading branch information
mmz-001 committed Jul 1, 2023
1 parent 0453b02 commit dd7cf1a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 19 deletions.
3 changes: 1 addition & 2 deletions knowledge_gpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
parse_docx,
parse_pdf,
parse_txt,
search_docs,
text_to_docs,
wrap_text_in_html,
)
Expand Down Expand Up @@ -72,7 +71,7 @@ def clear_submit():
st.session_state["submit"] = True
# Output Columns
answer_col, sources_col = st.columns(2)
sources = search_docs(index, query)
sources = index.similarity_search(query, k=5)

try:
answer = get_answer(sources, query)
Expand Down
31 changes: 14 additions & 17 deletions knowledge_gpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,23 @@
from knowledge_gpt.embeddings import OpenAIEmbeddings
from knowledge_gpt.prompts import STUFF_PROMPT

from hashlib import md5

@st.experimental_memo()

def hash_func(doc: Document) -> str:
"""Hash function for caching Documents"""
return md5(doc.page_content.encode("utf-8")).hexdigest()


@st.cache_data()
def parse_docx(file: BytesIO) -> str:
text = docx2txt.process(file)
# Remove multiple newlines
text = re.sub(r"\n\s*\n", "\n\n", text)
return text


@st.experimental_memo()
@st.cache_data()
def parse_pdf(file: BytesIO) -> List[str]:
pdf = PdfReader(file)
output = []
Expand All @@ -43,15 +50,15 @@ def parse_pdf(file: BytesIO) -> List[str]:
return output


@st.experimental_memo()
@st.cache_data()
def parse_txt(file: BytesIO) -> str:
text = file.read().decode("utf-8")
# Remove multiple newlines
text = re.sub(r"\n\s*\n", "\n\n", text)
return text


@st.cache(allow_output_mutation=True)
@st.cache_data()
def text_to_docs(text: str | List[str]) -> List[Document]:
"""Converts a string or list of strings to a list of Documents
with metadata."""
Expand Down Expand Up @@ -84,7 +91,7 @@ def text_to_docs(text: str | List[str]) -> List[Document]:
return doc_chunks


@st.cache(allow_output_mutation=True, show_spinner=False)
@st.cache_data(show_spinner=False, hash_funcs={Document: hash_func})
def embed_docs(docs: List[Document]) -> VectorStore:
"""Embeds a list of Documents and returns a FAISS index"""

Expand All @@ -103,17 +110,7 @@ def embed_docs(docs: List[Document]) -> VectorStore:
return index


@st.cache(allow_output_mutation=True)
def search_docs(index: VectorStore, query: str) -> List[Document]:
"""Searches a FAISS index for similar chunks to the query
and returns a list of Documents."""

# Search for similar chunks
docs = index.similarity_search(query, k=5)
return docs


@st.cache(allow_output_mutation=True)
@st.cache_data(show_spinner=False, hash_funcs={Document: hash_func})
def get_answer(docs: List[Document], query: str) -> Dict[str, Any]:
"""Gets an answer to a question from a list of Documents."""

Expand All @@ -137,7 +134,7 @@ def get_answer(docs: List[Document], query: str) -> Dict[str, Any]:
return answer


@st.cache(allow_output_mutation=True)
@st.cache_data(show_spinner=False, hash_funcs={Document: hash_func})
def get_sources(answer: Dict[str, Any], docs: List[Document]) -> List[Document]:
"""Gets the source documents for an answer."""

Expand Down

0 comments on commit dd7cf1a

Please sign in to comment.