Skip to content

Commit

Permalink
feat: add optional long context window approach
Browse files Browse the repository at this point in the history
feat: add RAG shortcut
  • Loading branch information
MichaelWalker-git committed Jan 5, 2024
1 parent e9b0ac8 commit dd5e82f
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
tracer = Tracer(service="QUESTION_ANSWERING")
metrics = Metrics(namespace="question_answering", service="QUESTION_ANSWERING")


class StreamingCallbackHandler(BaseCallbackHandler):
def __init__(self, status_variables: Dict):
self.status_variables = status_variables
Expand Down Expand Up @@ -58,7 +59,6 @@ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
self.status_variables['answer'] = error.decode("utf-8")
send_job_status(self.status_variables)


def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
logger.info(f"[StreamingCallbackHandler::on_llm_end] Streaming ended. Response: {response}")
Expand All @@ -85,23 +85,25 @@ def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any

@tracer.capture_method
def run_question_answering(arguments):
response_generation_method = arguments.get('responseGenerationMethod', 'LONG_CONTEXT')

try:
filename = arguments['filename']
except:
filename = ''
arguments['filename'] = ''
if not filename: # user didn't provide a specific file as input, we use the RAG source against the entire knowledge base
if not filename: # user didn't provide a specific file as input, we use the RAG source against the entire knowledge base
llm_response = run_qa_agent_rag_no_memory(arguments)
return llm_response

bucket_name = os.environ['INPUT_BUCKET']

# select the methodology based on the input size
document_number_of_tokens = S3FileLoaderInMemory(bucket_name, filename).get_document_tokens()

if document_number_of_tokens is None:
logger.exception(f'Failed to compute the number of tokens for file {filename} in bucket {bucket_name}, returning')
logger.exception(
f'Failed to compute the number of tokens for file {filename} in bucket {bucket_name}, returning')
error = JobStatus.ERROR_LOAD_INFO.get_message()
status_variables = {
'jobstatus': JobStatus.ERROR_LOAD_INFO.status,
Expand All @@ -115,21 +117,27 @@ def run_question_answering(arguments):
return ''

model_max_tokens = get_max_tokens()
logger.info(f'For the current question, we have a max model length of {model_max_tokens} and a document containing {document_number_of_tokens} tokens')
logger.info(
f'For the current question, we have a max model length of {model_max_tokens} and a document containing {document_number_of_tokens} tokens')

# why add 500 ? on top of the document content, we add the prompt. So we keep an extra 500 tokens of space just in case
if (document_number_of_tokens + 250) < model_max_tokens:
logger.info('Running qa agent with full document in context')
llm_response = run_qa_agent_from_single_document_no_memory(arguments)
else:
if response_generation_method == 'RAG':
logger.info('Running qa agent with a RAG approach')
llm_response = run_qa_agent_rag_no_memory(arguments)

else:
# why add 500 ? on top of the document content, we add the prompt. So we keep an extra 500 tokens of space just in case
if (document_number_of_tokens + 250) < model_max_tokens:
logger.info('Running qa agent with full document in context')
llm_response = run_qa_agent_from_single_document_no_memory(arguments)
else:
logger.info('Running qa agent with a RAG approach due to document size')
llm_response = run_qa_agent_rag_no_memory(arguments)
return llm_response


_doc_index = None
_current_doc_index = None


def run_qa_agent_rag_no_memory(input_params):
logger.info("starting qa agent with rag approach without memory")

Expand Down Expand Up @@ -157,11 +165,11 @@ def run_qa_agent_rag_no_memory(input_params):
if _doc_index is None:
logger.info("loading opensearch retriever")
doc_index = load_vector_db_opensearch(boto3.Session().region_name,
os.environ.get('OPENSEARCH_DOMAIN_ENDPOINT'),
os.environ.get('OPENSEARCH_INDEX'),
os.environ.get('OPENSEARCH_SECRET_ID'))
os.environ.get('OPENSEARCH_DOMAIN_ENDPOINT'),
os.environ.get('OPENSEARCH_INDEX'),
os.environ.get('OPENSEARCH_SECRET_ID'))

else :
else:
logger.info("_retriever already exists")

_current_doc_index = _doc_index
Expand All @@ -171,16 +179,16 @@ def run_qa_agent_rag_no_memory(input_params):
output_file_name = input_params['filename']

source_documents = doc_index.similarity_search(decoded_question, k=max_docs)
#--------------------------------------------------------------------------
# --------------------------------------------------------------------------
# If an output file is specified, filter the response to only include chunks
# related to that file. The source metadata is added when embeddings are
# created in the ingestion pipeline.
#
# TODO: Evaluate if this filter can be optimized by using the
# OpenSearchVectorSearch.max_marginal_relevance_search() method instead.
# See https://github.com/langchain-ai/langchain/issues/10524
#--------------------------------------------------------------------------
if output_file_name:
# --------------------------------------------------------------------------
if output_file_name:
source_documents = [doc for doc in source_documents if doc.metadata['source'] == output_file_name]
logger.info(source_documents)
status_variables['sources'] = list(set(doc.metadata['source'] for doc in source_documents))
Expand Down Expand Up @@ -208,7 +216,7 @@ def run_qa_agent_rag_no_memory(input_params):
try:
tmp = chain.predict(context=source_documents, question=decoded_question)
answer = tmp.removeprefix(' ')

logger.info(f'answer is: {answer}')
llm_answer_bytes = answer.encode("utf-8")
base64_bytes = base64.b64encode(llm_answer_bytes)
Expand All @@ -224,12 +232,14 @@ def run_qa_agent_rag_no_memory(input_params):
error = JobStatus.ERROR_PREDICTION.get_message()
status_variables['answer'] = error.decode("utf-8")
send_job_status(status_variables)

return status_variables


_file_content = None
_current_file_name = None


def run_qa_agent_from_single_document_no_memory(input_params):
logger.info("starting qa agent without memory single document")

Expand Down Expand Up @@ -267,7 +277,7 @@ def run_qa_agent_from_single_document_no_memory(input_params):
else:
logger.info('same file as before, but nothing cached')
_file_content = S3FileLoaderInMemory(bucket_name, filename).load()

_current_file_name = filename
status_variables['sources'] = [filename]
if _file_content is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def get_bedrock_client(service_name="bedrock-runtime"):

return boto3.client(**bedrock_config_data)



opensearch_secret_id = os.environ['OPENSEARCH_SECRET_ID']
bucket_name = os.environ['OUTPUT_BUCKET']
opensearch_index = os.environ['OPENSEARCH_INDEX']
Expand All @@ -92,7 +94,7 @@ def handler(event, context: LambdaContext) -> dict:

# if the secret id is not provided
# uses username password
if opensearch_secret_id != 'NONE': # nosec
if opensearch_secret_id != 'NONE': # nosec
creds = get_credentials(opensearch_secret_id, aws_region)
http_auth = (creds['username'], creds['password'])
else: #
Expand Down Expand Up @@ -127,10 +129,12 @@ def handler(event, context: LambdaContext) -> dict:
doc.metadata['source'] = filename
docs.extend(sub_docs)

response_generation_method = event.get('responseGenerationMethod', 'LONG_CONTEXT')

if not docs:
return {
'status':'nothing to ingest'
}
}

text_splitter = RecursiveCharacterTextSplitter(
# Set a really small chunk size, just to show.
Expand All @@ -149,14 +153,14 @@ def handler(event, context: LambdaContext) -> dict:
chunks = text_splitter.create_documents([doc.page_content for doc in docs], metadatas=[doc.metadata for doc in docs])

db_shards = (len(chunks) // MAX_OS_DOCS_PER_PUT) + 1
print(f'Loading chunks into vector store ... using {db_shards} shards')
print(f'Loading chunks into vector store ... using {db_shards} shards')
shards = np.array_split(chunks, db_shards)

# first check if index exists, if it does then call the add_documents function
# otherwise call the from_documents function which would first create the index
# and then do a bulk add. Both add_documents and from_documents do a bulk add
# but it is important to call from_documents first so that the index is created
# correctly for K-NN
# correctly for K-NN
try:
index_exists = check_if_index_exists(opensearch_index,
aws_region,
Expand All @@ -174,6 +178,13 @@ def handler(event, context: LambdaContext) -> dict:
bedrock_client = get_bedrock_client()
embeddings = BedrockEmbeddings(client=bedrock_client)

if response_generation_method == 'RAG':
# Call a function to handle RAG logic
question = event.get('question', '') # Assuming the question is also part of the event
return handle_rag_approach(docs, question, embeddings)
else:
# Existing logic for Long Context Window approach

if index_exists is False:
# create an index if the create index hint file exists
path = os.path.join(DATA_DIR, INDEX_FILE)
Expand All @@ -186,7 +197,7 @@ def handler(event, context: LambdaContext) -> dict:
opensearch_url=opensearch_domain,
http_auth=http_auth)
# we now need to start the loop below for the second shard
shard_start_index = 1
shard_start_index = 1
else:
print(f"index {opensearch_index} does not exist and {path} file is not present, "
f"will wait for some other node to create the index")
Expand All @@ -207,24 +218,24 @@ def handler(event, context: LambdaContext) -> dict:
if time_slept >= TOTAL_INDEX_CREATION_WAIT_TIME:
print(f"time_slept={time_slept} >= {TOTAL_INDEX_CREATION_WAIT_TIME}, not waiting anymore for index creation")
break

else:
print(f"index={opensearch_index} does exists, going to call add_documents")
shard_start_index = 0

for shard in shards[shard_start_index:]:
results = process_shard(shard=shard,
results = process_shard(shard=shard,
os_index_name=opensearch_index,
os_domain_ep=opensearch_domain,
os_http_auth=http_auth)

for file in files:
if file['status'] == 'File transformed':
file['status'] = 'Ingested'
else:
file['status'] = 'Error_'+file['status']
updateIngestionJobStatus({'jobid': job_id, 'files': files})

return {
'status':'succeed'
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
enum ResponseGenerationMethod {
LONG_CONTEXT
RAG
}

type FileStatus @aws_iam @aws_cognito_user_pools {
name: String
status: String
Expand All @@ -17,6 +22,7 @@ input IngestionDocsInput {
ingestionjobid: ID
files: [FileStatusInput]
ignore_existing: Boolean
responseGenerationMethod: ResponseGenerationMethod
}

type Mutation @aws_iam @aws_cognito_user_pools {
Expand Down

0 comments on commit dd5e82f

Please sign in to comment.