Skip to content

Commit

Permalink
Merge pull request sambanova#374 from sambanova/petrojm/ekr_decouple_…
Browse files Browse the repository at this point in the history
…streamlit

Decouple parse_doc() in EKR src from Streamlit dependencies
  • Loading branch information
snova-petrojm authored Oct 7, 2024
2 parents 41c8a2e + caadc17 commit 3d933b9
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 68 deletions.
4 changes: 2 additions & 2 deletions benchmarking/src/llmperf/sambanova_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,8 +539,8 @@ def llm_request(request_config: RequestConfig, tokenizer: AutoTokenizer) -> Tupl
error_code = getattr(
e,
'code',
'''Error while running LLM API requests.
Check your model name, LLM API type, env variables and endpoint status.'''
"""Error while running LLM API requests.
Check your model name, LLM API type, env variables and endpoint status.""",
)
error_message = str(e)
metrics[common_metrics.ERROR_MSG] = error_message
Expand Down
4 changes: 2 additions & 2 deletions benchmarking/src/performance_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,8 +880,8 @@ def get_token_throughput_latencies(
)
nl = '\n'
raise Exception(
f"""Unexpected error happened when executing requests: {f'{nl}-'.join(unique_error_codes)}{nl}"""+
f"""Additional messages: {f'{nl}-'.join(unique_error_msgs)}"""
f"""Unexpected error happened when executing requests: {f'{nl}-'.join(unique_error_codes)}{nl}"""
+ f"""Additional messages: {f'{nl}-'.join(unique_error_msgs)}"""
)

# Capture end time and notify user
Expand Down
37 changes: 4 additions & 33 deletions enterprise_knowledge_retriever/src/document_retrieval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import shutil
import sys
from typing import Any, Dict, List, Optional, Tuple

Expand All @@ -17,7 +16,6 @@
from langchain_core.output_parsers import StrOutputParser
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores.base import VectorStoreRetriever
from streamlit.runtime.uploaded_file_manager import UploadedFile
from transformers import AutoModelForSequenceClassification, AutoTokenizer

current_dir = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -184,14 +182,12 @@ def set_llm(self) -> LLM:
)
return llm

def parse_doc(
self, docs: List[UploadedFile], additional_metadata: Optional[Dict[str, Any]] = None
) -> List[Document]:
def parse_doc(self, doc_folder: str, additional_metadata: Optional[Dict[str, Any]] = None) -> List[Document]:
"""
Parse the uploaded documents and return a list of LangChain documents.
Parse specified documents and return a list of LangChain documents.
Args:
docs (List[UploadFile]): A list of uploaded files.
doc_folder (str): Path to the documents.
additional_metadata (Optional[Dict], optional): Additional metadata to include in the processed documents.
Defaults to an empty dictionary.
Expand All @@ -201,33 +197,8 @@ def parse_doc(
if additional_metadata is None:
additional_metadata = {}

# Create the data/tmp folder if it doesn't exist
temp_folder = os.path.join(kit_dir, 'data/tmp')
if not os.path.exists(temp_folder):
os.makedirs(temp_folder)
else:
# If there are already files there, delete them
for filename in os.listdir(temp_folder):
file_path = os.path.join(temp_folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')

# Save all selected files to the tmp dir with their file names
for doc in docs:
assert hasattr(doc, 'name'), 'doc has no attribute name.'
assert callable(doc.getvalue), 'doc has no method getvalue.'
temp_file = os.path.join(temp_folder, doc.name)
with open(temp_file, 'wb') as f:
f.write(doc.getvalue())

# Pass in the temp folder for processing into the parse_doc_universal function
_, _, langchain_docs = parse_doc_universal(
doc=temp_folder, additional_metadata=additional_metadata, lite_mode=self.pdf_only_mode
doc=doc_folder, additional_metadata=additional_metadata, lite_mode=self.pdf_only_mode
)

return langchain_docs
Expand Down
83 changes: 64 additions & 19 deletions enterprise_knowledge_retriever/streamlit/app.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import logging
import os
import shutil
import sys
from typing import List, Optional

import streamlit as st
import yaml
from streamlit.runtime.uploaded_file_manager import UploadedFile

current_dir = os.path.dirname(os.path.abspath(__file__))
kit_dir = os.path.abspath(os.path.join(current_dir, '..'))
Expand All @@ -24,6 +27,44 @@
logging.info('URL: http://localhost:8501')


def save_files_user(docs: List[UploadedFile]) -> str:
"""
Save all user uploaded files in Streamlit to the tmp dir with their file names
Args:
docs (List[UploadFile]): A list of uploaded files in Streamlit
Returns:
str: path where the files are saved.
"""

# Create the data/tmp folder if it doesn't exist
temp_folder = os.path.join(kit_dir, 'data/tmp')
if not os.path.exists(temp_folder):
os.makedirs(temp_folder)
else:
# If there are already files there, delete them
for filename in os.listdir(temp_folder):
file_path = os.path.join(temp_folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')

# Save all selected files to the tmp dir with their file names
for doc in docs:
assert hasattr(doc, 'name'), 'doc has no attribute name.'
assert callable(doc.getvalue), 'doc has no method getvalue.'
temp_file = os.path.join(temp_folder, doc.name)
with open(temp_file, 'wb') as f:
f.write(doc.getvalue())

return temp_folder


def handle_userinput(user_question: str) -> None:
if user_question:
try:
Expand Down Expand Up @@ -177,33 +218,37 @@ def main() -> None:
st.markdown('Create database')
if st.button('Process'):
with st.spinner('Processing'):
# try:
text_chunks = st.session_state.document_retrieval.parse_doc(docs)
if len(text_chunks) == 0:
st.error(
"""No able to get text from the documents. check your docs or try setting
pdf_only_mode to False"""
try:
if docs is not None:
temp_folder = save_files_user(docs)
text_chunks = st.session_state.document_retrieval.parse_doc(temp_folder)
if len(text_chunks) == 0:
st.error(
"""No able to get text from the documents. check your docs or try setting
pdf_only_mode to False"""
)
embeddings = st.session_state.document_retrieval.load_embedding_model()
collection_name = default_collection if not prod_mode else None
vectorstore = st.session_state.document_retrieval.create_vector_store(
text_chunks, embeddings, output_db=None, collection_name=collection_name
)
embeddings = st.session_state.document_retrieval.load_embedding_model()
collection_name = default_collection if not prod_mode else None
vectorstore = st.session_state.document_retrieval.create_vector_store(
text_chunks, embeddings, output_db=None, collection_name=collection_name
)
st.session_state.vectorstore = vectorstore
st.session_state.document_retrieval.init_retriever(vectorstore)
st.session_state.conversation = st.session_state.document_retrieval.get_qa_retrieval_chain()
st.toast(f'File uploaded! Go ahead and ask some questions', icon='🎉')
st.session_state.input_disabled = False
# except Exception as e:
# st.error(f'An error occurred while processing: {str(e)}')
st.session_state.vectorstore = vectorstore
st.session_state.document_retrieval.init_retriever(vectorstore)
st.session_state.conversation = st.session_state.document_retrieval.get_qa_retrieval_chain()
st.toast(f'File uploaded! Go ahead and ask some questions', icon='🎉')
st.session_state.input_disabled = False
except Exception as e:
st.error(f'An error occurred while processing: {str(e)}')

if not prod_mode:
st.markdown('[Optional] Save database for reuse')
save_location = st.text_input('Save location', './data/my-vector-db').strip()
if st.button('Process and Save database'):
with st.spinner('Processing'):
try:
text_chunks = st.session_state.document_retrieval.parse_doc(docs)
if docs is not None:
temp_folder = save_files_user(docs)
text_chunks = st.session_state.document_retrieval.parse_doc(temp_folder)
embeddings = st.session_state.document_retrieval.load_embedding_model()
vectorstore = st.session_state.document_retrieval.create_vector_store(
text_chunks, embeddings, output_db=save_location, collection_name=default_collection
Expand Down
12 changes: 1 addition & 11 deletions enterprise_knowledge_retriever/tests/ekr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import unittest
from typing import Any, Dict, List, Type

import yaml

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
Expand All @@ -40,16 +38,10 @@
from langchain_core.embeddings import Embeddings

from enterprise_knowledge_retriever.src.document_retrieval import DocumentRetrieval, RetrievalQAChain
from utils.parsing.sambaparse import parse_doc_universal

CONFIG_PATH = os.path.join(kit_dir, 'config.yaml')
PERSIST_DIRECTORY = os.path.join(kit_dir, 'tests', 'vectordata', 'my-vector-db')
TEST_DATA_PATH = os.path.join(kit_dir, 'tests', 'data', 'test')

with open(CONFIG_PATH, 'r') as yaml_file:
config = yaml.safe_load(yaml_file)
pdf_only_mode = config['pdf_only_mode']


# Let's use this as a template for further CLI tests. setup, tests, teardown and assert at the end.
class EKRTestCase(unittest.TestCase):
Expand All @@ -73,9 +65,7 @@ def setUpClass(cls: Type['EKRTestCase']) -> None:

@classmethod
def parse_documents(cls: Type['EKRTestCase']) -> List[Document]:
_, _, text_chunks = parse_doc_universal(
doc=TEST_DATA_PATH, additional_metadata=cls.additional_metadata, lite_mode=pdf_only_mode
)
text_chunks = cls.document_retrieval.parse_doc(doc_folder=TEST_DATA_PATH)
logger.info(f'Number of chunks: {len(text_chunks)}')
return text_chunks

Expand Down
1 change: 0 additions & 1 deletion multimodal_knowledge_retriever/streamlit/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
'meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo',
'meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo',
'llava-v1.5-7b-4096-preview',
'Llama-3.2-11B-Vision-Instruct'
]
# Available models in dropdown menu
LLM_MODELS = ['Meta-Llama-3.1-70B-Instruct', 'Meta-Llama-3.1-405B-Instruct', 'Meta-Llama-3.1-8B-Instruct']
Expand Down
2 changes: 2 additions & 0 deletions utils/vectordb/vector_db.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

# Define the script's usage example
USAGE_EXAMPLE = """
Example usage:
Expand Down

0 comments on commit 3d933b9

Please sign in to comment.