From dd5e82ffd1489d556bc12f79048ed6af146ea7ee Mon Sep 17 00:00:00 2001 From: michaelwalker Date: Thu, 4 Jan 2024 21:41:59 -0800 Subject: [PATCH] feat: add optional long context window approach feat: add RAG shortcut --- .../question_answering/src/qa_agent/chain.py | 52 +++++++++++-------- .../embeddings_job/src/lambda.py | 29 +++++++---- .../schema.graphql | 6 +++ 3 files changed, 57 insertions(+), 30 deletions(-) diff --git a/lambda/aws-qa-appsync-opensearch/question_answering/src/qa_agent/chain.py b/lambda/aws-qa-appsync-opensearch/question_answering/src/qa_agent/chain.py index 7d94944e..2ffd6532 100644 --- a/lambda/aws-qa-appsync-opensearch/question_answering/src/qa_agent/chain.py +++ b/lambda/aws-qa-appsync-opensearch/question_answering/src/qa_agent/chain.py @@ -30,6 +30,7 @@ tracer = Tracer(service="QUESTION_ANSWERING") metrics = Metrics(namespace="question_answering", service="QUESTION_ANSWERING") + class StreamingCallbackHandler(BaseCallbackHandler): def __init__(self, status_variables: Dict): self.status_variables = status_variables @@ -58,7 +59,6 @@ def on_llm_new_token(self, token: str, **kwargs: Any) -> None: self.status_variables['answer'] = error.decode("utf-8") send_job_status(self.status_variables) - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Run when LLM ends running.""" logger.info(f"[StreamingCallbackHandler::on_llm_end] Streaming ended. Response: {response}") @@ -85,23 +85,25 @@ def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any @tracer.capture_method def run_question_answering(arguments): + response_generation_method = arguments.get('responseGenerationMethod', 'LONG_CONTEXT') try: filename = arguments['filename'] except: filename = '' arguments['filename'] = '' - if not filename: # user didn't provide a specific file as input, we use the RAG source against the entire knowledge base + if not filename: # user didn't provide a specific file as input, we use the RAG source against the entire knowledge base llm_response = run_qa_agent_rag_no_memory(arguments) return llm_response - + bucket_name = os.environ['INPUT_BUCKET'] # select the methodology based on the input size document_number_of_tokens = S3FileLoaderInMemory(bucket_name, filename).get_document_tokens() if document_number_of_tokens is None: - logger.exception(f'Failed to compute the number of tokens for file {filename} in bucket {bucket_name}, returning') + logger.exception( + f'Failed to compute the number of tokens for file {filename} in bucket {bucket_name}, returning') error = JobStatus.ERROR_LOAD_INFO.get_message() status_variables = { 'jobstatus': JobStatus.ERROR_LOAD_INFO.status, @@ -115,21 +117,27 @@ def run_question_answering(arguments): return '' model_max_tokens = get_max_tokens() - logger.info(f'For the current question, we have a max model length of {model_max_tokens} and a document containing {document_number_of_tokens} tokens') + logger.info( + f'For the current question, we have a max model length of {model_max_tokens} and a document containing {document_number_of_tokens} tokens') - # why add 500 ? on top of the document content, we add the prompt. So we keep an extra 500 tokens of space just in case - if (document_number_of_tokens + 250) < model_max_tokens: - logger.info('Running qa agent with full document in context') - llm_response = run_qa_agent_from_single_document_no_memory(arguments) - else: + if response_generation_method == 'RAG': logger.info('Running qa agent with a RAG approach') llm_response = run_qa_agent_rag_no_memory(arguments) - + else: + # why add 500 ? on top of the document content, we add the prompt. So we keep an extra 500 tokens of space just in case + if (document_number_of_tokens + 250) < model_max_tokens: + logger.info('Running qa agent with full document in context') + llm_response = run_qa_agent_from_single_document_no_memory(arguments) + else: + logger.info('Running qa agent with a RAG approach due to document size') + llm_response = run_qa_agent_rag_no_memory(arguments) return llm_response + _doc_index = None _current_doc_index = None + def run_qa_agent_rag_no_memory(input_params): logger.info("starting qa agent with rag approach without memory") @@ -157,11 +165,11 @@ def run_qa_agent_rag_no_memory(input_params): if _doc_index is None: logger.info("loading opensearch retriever") doc_index = load_vector_db_opensearch(boto3.Session().region_name, - os.environ.get('OPENSEARCH_DOMAIN_ENDPOINT'), - os.environ.get('OPENSEARCH_INDEX'), - os.environ.get('OPENSEARCH_SECRET_ID')) + os.environ.get('OPENSEARCH_DOMAIN_ENDPOINT'), + os.environ.get('OPENSEARCH_INDEX'), + os.environ.get('OPENSEARCH_SECRET_ID')) - else : + else: logger.info("_retriever already exists") _current_doc_index = _doc_index @@ -171,7 +179,7 @@ def run_qa_agent_rag_no_memory(input_params): output_file_name = input_params['filename'] source_documents = doc_index.similarity_search(decoded_question, k=max_docs) - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- # If an output file is specified, filter the response to only include chunks # related to that file. The source metadata is added when embeddings are # created in the ingestion pipeline. @@ -179,8 +187,8 @@ def run_qa_agent_rag_no_memory(input_params): # TODO: Evaluate if this filter can be optimized by using the # OpenSearchVectorSearch.max_marginal_relevance_search() method instead. # See https://github.com/langchain-ai/langchain/issues/10524 - #-------------------------------------------------------------------------- - if output_file_name: + # -------------------------------------------------------------------------- + if output_file_name: source_documents = [doc for doc in source_documents if doc.metadata['source'] == output_file_name] logger.info(source_documents) status_variables['sources'] = list(set(doc.metadata['source'] for doc in source_documents)) @@ -208,7 +216,7 @@ def run_qa_agent_rag_no_memory(input_params): try: tmp = chain.predict(context=source_documents, question=decoded_question) answer = tmp.removeprefix(' ') - + logger.info(f'answer is: {answer}') llm_answer_bytes = answer.encode("utf-8") base64_bytes = base64.b64encode(llm_answer_bytes) @@ -224,12 +232,14 @@ def run_qa_agent_rag_no_memory(input_params): error = JobStatus.ERROR_PREDICTION.get_message() status_variables['answer'] = error.decode("utf-8") send_job_status(status_variables) - + return status_variables + _file_content = None _current_file_name = None + def run_qa_agent_from_single_document_no_memory(input_params): logger.info("starting qa agent without memory single document") @@ -267,7 +277,7 @@ def run_qa_agent_from_single_document_no_memory(input_params): else: logger.info('same file as before, but nothing cached') _file_content = S3FileLoaderInMemory(bucket_name, filename).load() - + _current_file_name = filename status_variables['sources'] = [filename] if _file_content is None: diff --git a/lambda/aws-rag-appsync-stepfn-opensearch/embeddings_job/src/lambda.py b/lambda/aws-rag-appsync-stepfn-opensearch/embeddings_job/src/lambda.py index 0afd05a7..da85cb5e 100644 --- a/lambda/aws-rag-appsync-stepfn-opensearch/embeddings_job/src/lambda.py +++ b/lambda/aws-rag-appsync-stepfn-opensearch/embeddings_job/src/lambda.py @@ -71,6 +71,8 @@ def get_bedrock_client(service_name="bedrock-runtime"): return boto3.client(**bedrock_config_data) + + opensearch_secret_id = os.environ['OPENSEARCH_SECRET_ID'] bucket_name = os.environ['OUTPUT_BUCKET'] opensearch_index = os.environ['OPENSEARCH_INDEX'] @@ -92,7 +94,7 @@ def handler(event, context: LambdaContext) -> dict: # if the secret id is not provided # uses username password - if opensearch_secret_id != 'NONE': # nosec + if opensearch_secret_id != 'NONE': # nosec creds = get_credentials(opensearch_secret_id, aws_region) http_auth = (creds['username'], creds['password']) else: # @@ -127,10 +129,12 @@ def handler(event, context: LambdaContext) -> dict: doc.metadata['source'] = filename docs.extend(sub_docs) + response_generation_method = event.get('responseGenerationMethod', 'LONG_CONTEXT') + if not docs: return { 'status':'nothing to ingest' - } + } text_splitter = RecursiveCharacterTextSplitter( # Set a really small chunk size, just to show. @@ -149,14 +153,14 @@ def handler(event, context: LambdaContext) -> dict: chunks = text_splitter.create_documents([doc.page_content for doc in docs], metadatas=[doc.metadata for doc in docs]) db_shards = (len(chunks) // MAX_OS_DOCS_PER_PUT) + 1 - print(f'Loading chunks into vector store ... using {db_shards} shards') + print(f'Loading chunks into vector store ... using {db_shards} shards') shards = np.array_split(chunks, db_shards) # first check if index exists, if it does then call the add_documents function # otherwise call the from_documents function which would first create the index # and then do a bulk add. Both add_documents and from_documents do a bulk add # but it is important to call from_documents first so that the index is created - # correctly for K-NN + # correctly for K-NN try: index_exists = check_if_index_exists(opensearch_index, aws_region, @@ -174,6 +178,13 @@ def handler(event, context: LambdaContext) -> dict: bedrock_client = get_bedrock_client() embeddings = BedrockEmbeddings(client=bedrock_client) + if response_generation_method == 'RAG': + # Call a function to handle RAG logic + question = event.get('question', '') # Assuming the question is also part of the event + return handle_rag_approach(docs, question, embeddings) + else: + # Existing logic for Long Context Window approach + if index_exists is False: # create an index if the create index hint file exists path = os.path.join(DATA_DIR, INDEX_FILE) @@ -186,7 +197,7 @@ def handler(event, context: LambdaContext) -> dict: opensearch_url=opensearch_domain, http_auth=http_auth) # we now need to start the loop below for the second shard - shard_start_index = 1 + shard_start_index = 1 else: print(f"index {opensearch_index} does not exist and {path} file is not present, " f"will wait for some other node to create the index") @@ -207,24 +218,24 @@ def handler(event, context: LambdaContext) -> dict: if time_slept >= TOTAL_INDEX_CREATION_WAIT_TIME: print(f"time_slept={time_slept} >= {TOTAL_INDEX_CREATION_WAIT_TIME}, not waiting anymore for index creation") break - + else: print(f"index={opensearch_index} does exists, going to call add_documents") shard_start_index = 0 for shard in shards[shard_start_index:]: - results = process_shard(shard=shard, + results = process_shard(shard=shard, os_index_name=opensearch_index, os_domain_ep=opensearch_domain, os_http_auth=http_auth) - + for file in files: if file['status'] == 'File transformed': file['status'] = 'Ingested' else: file['status'] = 'Error_'+file['status'] updateIngestionJobStatus({'jobid': job_id, 'files': files}) - + return { 'status':'succeed' } \ No newline at end of file diff --git a/resources/gen-ai/aws-rag-appsync-stepfn-opensearch/schema.graphql b/resources/gen-ai/aws-rag-appsync-stepfn-opensearch/schema.graphql index 38d257fa..85045d91 100644 --- a/resources/gen-ai/aws-rag-appsync-stepfn-opensearch/schema.graphql +++ b/resources/gen-ai/aws-rag-appsync-stepfn-opensearch/schema.graphql @@ -1,3 +1,8 @@ +enum ResponseGenerationMethod { + LONG_CONTEXT + RAG +} + type FileStatus @aws_iam @aws_cognito_user_pools { name: String status: String @@ -17,6 +22,7 @@ input IngestionDocsInput { ingestionjobid: ID files: [FileStatusInput] ignore_existing: Boolean + responseGenerationMethod: ResponseGenerationMethod } type Mutation @aws_iam @aws_cognito_user_pools {