Skip to content

Commit

Permalink
Merge pull request #129 from mit-submit/feature/remove-lock-from-crit…
Browse files Browse the repository at this point in the history
…ical-path

move lock off of critical path containing call to OpenAI
  • Loading branch information
julius-heitkoetter authored Oct 31, 2023
2 parents 6816cb6 + 39eb0a6 commit 2457ebc
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 57 deletions.
8 changes: 1 addition & 7 deletions A2rchi/chains/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from chromadb.config import Settings
from langchain.vectorstores import Chroma
from threading import Lock, Thread

import chromadb
import time
Expand All @@ -16,7 +15,6 @@ def __init__(self):
into a format that the chain can use. Then, it creates the chain using
those documents.
"""
self.lock = Lock()
self.kill = False

from A2rchi.utils.config_loader import Config_Loader
Expand Down Expand Up @@ -55,8 +53,7 @@ def update_vectorstore_and_create_chain(self):
settings=Settings(allow_reset=True, anonymized_telemetry=False), # NOTE: anonymized_telemetry doesn't actually do anything; need to build Chroma on our own without it
)

# acquire lock and construct chain
self.lock.acquire()
# construct chain
vectorstore = Chroma(
client=client,
collection_name=self.collection_name,
Expand All @@ -65,7 +62,6 @@ def update_vectorstore_and_create_chain(self):
chain = BaseChain.from_llm(self.llm, vectorstore.as_retriever(), return_source_documents=True)
print(f"N entries: {client.get_collection(self.collection_name).count()}")
print("Updated chain with new vectorstore")
self.lock.release()

return chain

Expand Down Expand Up @@ -135,10 +131,8 @@ def __call__(self, history):
chat_history = history[:-1] if history is not None else None

# make the request to the chain
self.lock.acquire()
answer = chain({"question": question, "chat_history": chat_history})
print(f" INFO - answer: {answer}")
self.lock.release()

# delete chain object to release chain, vectorstore, and client for garbage collection
del chain
Expand Down
114 changes: 64 additions & 50 deletions A2rchi/interfaces/chat_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
import time

# DEFINITIONS
QUERY_LIMIT = 1000 # max number of queries
# TODO: remove this logic and eventually replace with per-user (or per-conversation) rate limits
QUERY_LIMIT = 10000 # max number of queries


class ChatWrapper:
Expand Down Expand Up @@ -80,7 +81,9 @@ def update_or_add_discussion(data_path, json_file, discussion_id, discussion_con
print(" INFO - json_file found.")

except FileNotFoundError:
# create data path if it doesn't exist
print(" ERROR - json_file not found. Creating a new one")
os.makedirs(data_path, exist_ok=True)

# update or add discussion
discussion_dict = data.get(str(discussion_id), {})
Expand All @@ -100,9 +103,6 @@ def update_or_add_discussion(data_path, json_file, discussion_id, discussion_con

data[str(discussion_id)] = discussion_dict

# create data path if it doesn't exist
os.makedirs(data_path, exist_ok=True)

# write the updated JSON data back to the file
with open(os.path.join(data_path, json_file), 'w') as f:
json.dump(data, f)
Expand All @@ -128,62 +128,76 @@ def __call__(self, history: Optional[List[Tuple[str, str]]], discussion_id: Opti
Execute the chat functionality.
"""
self.lock.acquire()
print("INFO - acquired lock file")
try:
# update vector store through data manager; will only do something if new files have been added
print("INFO - acquired lock file update vectorstore")

self.data_manager.update_vectorstore()

# convert the history to native A2rchi form (because javascript does not have tuples)
history = self.convert_to_chain_history(history)

# get discussion ID so that the conversation can be saved (It seems that random is no good... TODO)
discussion_id = discussion_id or np.random.randint(100000, 999999)

# run chain to get result
if self.number_of_queries < QUERY_LIMIT:
result = self.chain(history)
else:
# the case where we have exceeded the QUERY LIMIT (built so that we do not overuse the chain)
output = "Sorry, our service is currently down due to exceptional demand. Please come again later."
return output, discussion_id
self.number_of_queries += 1
print(f"number of queries is: {self.number_of_queries}")

# get similarity score to see how close the input is to the source
# - low score means very close (it's a distance between embedding vectors approximated
# by an approximate k-nearest neighbors algorithm called HNSW)
inp = history[-1][1]
score = self.chain.similarity_search(inp)

# load the present list of sources
try:
with open(os.path.join(self.data_path, 'sources.yml'), 'r') as file:
sources = yaml.load(file, Loader=yaml.FullLoader)
except FileNotFoundError:
sources = dict()

# get the closest source to the document
source = None
if len(result['source_documents']) > 0:
source_hash = result['source_documents'][0].metadata['source']
if '/' in source_hash and '.' in source_hash:
source = source_hash.split('/')[-1].split('.')[0]

# if the score is low enough, include the source as a link, otherwise give just the answer
embedding_name = self.config["utils"]["embeddings"]["EMBEDDING_NAME"]
similarity_score_reference = self.config["utils"]["embeddings"]["EMBEDDING_CLASS_MAP"][embedding_name]["similarity_score_reference"]
if score < similarity_score_reference and source in sources.keys():
output = "<p>" + self.format_code_in_text(result["answer"]) + "</p>" + "\n\n<br /><br /><p><a href= " + sources[source] + ">Click here to read more</a></p>"
else:
output = "<p>" + self.format_code_in_text(result["answer"]) + "</p>"
except Exception as e:
print(f"ERROR - {str(e)}")

finally:
self.lock.release()
print("INFO - released lock file update vectorstore")

# convert the history to native A2rchi form (because javascript does not have tuples)
history = self.convert_to_chain_history(history)

# get discussion ID so that the conversation can be saved (It seems that random is no good... TODO)
discussion_id = discussion_id or np.random.randint(100000, 999999)

# run chain to get result
if self.number_of_queries < QUERY_LIMIT:
result = self.chain(history)
else:
# the case where we have exceeded the QUERY LIMIT (built so that we do not overuse the chain)
output = "Sorry, our service is currently down due to exceptional demand. Please come again later."
return output, discussion_id
self.number_of_queries += 1
print(f"number of queries is: {self.number_of_queries}")

# get similarity score to see how close the input is to the source
# - low score means very close (it's a distance between embedding vectors approximated
# by an approximate k-nearest neighbors algorithm called HNSW)
inp = history[-1][1]
score = self.chain.similarity_search(inp)

# load the present list of sources
try:
with open(os.path.join(self.data_path, 'sources.yml'), 'r') as file:
sources = yaml.load(file, Loader=yaml.FullLoader)
except FileNotFoundError:
sources = dict()

# get the closest source to the document
source = None
if len(result['source_documents']) > 0:
source_hash = result['source_documents'][0].metadata['source']
if '/' in source_hash and '.' in source_hash:
source = source_hash.split('/')[-1].split('.')[0]

# if the score is low enough, include the source as a link, otherwise give just the answer
embedding_name = self.config["utils"]["embeddings"]["EMBEDDING_NAME"]
similarity_score_reference = self.config["utils"]["embeddings"]["EMBEDDING_CLASS_MAP"][embedding_name]["similarity_score_reference"]
if score < similarity_score_reference and source in sources.keys():
output = "<p>" + result["answer"] + "</p>" + "\n\n<br /><br /><p><a href= " + sources[source] + ">Click here to read more</a></p>"
else:
output = "<p>" + result["answer"] + "</p>"

self.lock.acquire()
try:
print("INFO - acquired lock file write json")

ChatWrapper.update_or_add_discussion(self.data_path, "conversations_test.json", discussion_id, discussion_contents = history + [("A2rchi", output)])

except Exception as e:
raise e
print(f"ERROR - {str(e)}")

finally:
self.lock.release()
print("INFO - released lock file")
print("INFO - released lock file write json")

return output, discussion_id


Expand Down

0 comments on commit 2457ebc

Please sign in to comment.