Skip to content

Commit

Permalink
Merge pull request #176 from awslabs/feat/issue_137
Browse files Browse the repository at this point in the history
feat(qa construct): add optional long context window approach
  • Loading branch information
MichaelWalker-git authored Jan 16, 2024
2 parents 94dd364 + 3a9597e commit 2e499df
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
logger = Logger(service="QUESTION_ANSWERING")
tracer = Tracer(service="QUESTION_ANSWERING")
metrics = Metrics(namespace="question_answering", service="QUESTION_ANSWERING")

class StreamingCallbackHandler(BaseCallbackHandler):
def __init__(self, status_variables: Dict):
self.status_variables = status_variables
Expand Down Expand Up @@ -58,7 +57,6 @@ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
self.status_variables['answer'] = error.decode("utf-8")
send_job_status(self.status_variables)


def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
logger.info(f"[StreamingCallbackHandler::on_llm_end] Streaming ended. Response: {response}")
Expand All @@ -85,23 +83,31 @@ def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any

@tracer.capture_method
def run_question_answering(arguments):
response_generation_method = arguments.get('responseGenerationMethod', 'LONG_CONTEXT')

try:
filename = arguments['filename']
except:

filename = ''
arguments['filename'] = ''
if not filename: # user didn't provide a specific file as input, we use the RAG source against the entire knowledge base

if not filename: # user didn't provide a specific file as input, we use the RAG source against the entire knowledge base
if response_generation_method == 'LONG_CONTEXT':
error = 'Error: Filename required for LONG_CONTEXT approach, defaulting to RAG.'
logger.error(error)

llm_response = run_qa_agent_rag_no_memory(arguments)
return llm_response

bucket_name = os.environ['INPUT_BUCKET']

# select the methodology based on the input size
document_number_of_tokens = S3FileLoaderInMemory(bucket_name, filename).get_document_tokens()

if document_number_of_tokens is None:
logger.exception(f'Failed to compute the number of tokens for file {filename} in bucket {bucket_name}, returning')
logger.exception(
f'Failed to compute the number of tokens for file {filename} in bucket {bucket_name}, returning')
error = JobStatus.ERROR_LOAD_INFO.get_message()
status_variables = {
'jobstatus': JobStatus.ERROR_LOAD_INFO.status,
Expand All @@ -115,21 +121,24 @@ def run_question_answering(arguments):
return ''

model_max_tokens = get_max_tokens()
logger.info(f'For the current question, we have a max model length of {model_max_tokens} and a document containing {document_number_of_tokens} tokens')
logger.info(
f'For the current question, we have a max model length of {model_max_tokens} and a document containing {document_number_of_tokens} tokens')

# why add 500 ? on top of the document content, we add the prompt. So we keep an extra 500 tokens of space just in case
if (document_number_of_tokens + 250) < model_max_tokens:
logger.info('Running qa agent with full document in context')
llm_response = run_qa_agent_from_single_document_no_memory(arguments)
else:
if response_generation_method == 'RAG':
logger.info('Running qa agent with a RAG approach')
llm_response = run_qa_agent_rag_no_memory(arguments)

else:
# LONG CONTEXT
# why add 500 ? on top of the document content, we add the prompt. So we keep an extra 500 tokens of space just in case
if (document_number_of_tokens + 250) < model_max_tokens:
logger.info('Running qa agent with full document in context')
llm_response = run_qa_agent_from_single_document_no_memory(arguments)
else:
logger.info('Running qa agent with a RAG approach due to document size')
llm_response = run_qa_agent_rag_no_memory(arguments)
return llm_response

_doc_index = None
_current_doc_index = None

def run_qa_agent_rag_no_memory(input_params):
logger.info("starting qa agent with rag approach without memory")

Expand Down Expand Up @@ -157,11 +166,11 @@ def run_qa_agent_rag_no_memory(input_params):
if _doc_index is None:
logger.info("loading opensearch retriever")
doc_index = load_vector_db_opensearch(boto3.Session().region_name,
os.environ.get('OPENSEARCH_DOMAIN_ENDPOINT'),
os.environ.get('OPENSEARCH_INDEX'),
os.environ.get('OPENSEARCH_SECRET_ID'))
os.environ.get('OPENSEARCH_DOMAIN_ENDPOINT'),
os.environ.get('OPENSEARCH_INDEX'),
os.environ.get('OPENSEARCH_SECRET_ID'))

else :
else:
logger.info("_retriever already exists")

_current_doc_index = _doc_index
Expand All @@ -171,16 +180,16 @@ def run_qa_agent_rag_no_memory(input_params):
output_file_name = input_params['filename']

source_documents = doc_index.similarity_search(decoded_question, k=max_docs)
#--------------------------------------------------------------------------
# --------------------------------------------------------------------------
# If an output file is specified, filter the response to only include chunks
# related to that file. The source metadata is added when embeddings are
# created in the ingestion pipeline.
#
# TODO: Evaluate if this filter can be optimized by using the
# OpenSearchVectorSearch.max_marginal_relevance_search() method instead.
# See https://github.com/langchain-ai/langchain/issues/10524
#--------------------------------------------------------------------------
if output_file_name:
# --------------------------------------------------------------------------
if output_file_name:
source_documents = [doc for doc in source_documents if doc.metadata['source'] == output_file_name]
logger.info(source_documents)
status_variables['sources'] = list(set(doc.metadata['source'] for doc in source_documents))
Expand Down Expand Up @@ -208,7 +217,7 @@ def run_qa_agent_rag_no_memory(input_params):
try:
tmp = chain.predict(context=source_documents, question=decoded_question)
answer = tmp.removeprefix(' ')

logger.info(f'answer is: {answer}')
llm_answer_bytes = answer.encode("utf-8")
base64_bytes = base64.b64encode(llm_answer_bytes)
Expand All @@ -224,12 +233,14 @@ def run_qa_agent_rag_no_memory(input_params):
error = JobStatus.ERROR_PREDICTION.get_message()
status_variables['answer'] = error.decode("utf-8")
send_job_status(status_variables)

return status_variables


_file_content = None
_current_file_name = None


def run_qa_agent_from_single_document_no_memory(input_params):
logger.info("starting qa agent without memory single document")

Expand Down Expand Up @@ -267,7 +278,7 @@ def run_qa_agent_from_single_document_no_memory(input_params):
else:
logger.info('same file as before, but nothing cached')
_file_content = S3FileLoaderInMemory(bucket_name, filename).load()

_current_file_name = filename
status_variables['sources'] = [filename]
if _file_content is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def handler(event, context: LambdaContext) -> dict:

# if the secret id is not provided
# uses username password
if opensearch_secret_id != 'NONE': # nosec
if opensearch_secret_id != 'NONE': # nosec
creds = get_credentials(opensearch_secret_id, aws_region)
http_auth = (creds['username'], creds['password'])
else: #
Expand Down Expand Up @@ -130,7 +130,7 @@ def handler(event, context: LambdaContext) -> dict:
if not docs:
return {
'status':'nothing to ingest'
}
}

text_splitter = RecursiveCharacterTextSplitter(
# Set a really small chunk size, just to show.
Expand All @@ -149,14 +149,14 @@ def handler(event, context: LambdaContext) -> dict:
chunks = text_splitter.create_documents([doc.page_content for doc in docs], metadatas=[doc.metadata for doc in docs])

db_shards = (len(chunks) // MAX_OS_DOCS_PER_PUT) + 1
print(f'Loading chunks into vector store ... using {db_shards} shards')
print(f'Loading chunks into vector store ... using {db_shards} shards')
shards = np.array_split(chunks, db_shards)

# first check if index exists, if it does then call the add_documents function
# otherwise call the from_documents function which would first create the index
# and then do a bulk add. Both add_documents and from_documents do a bulk add
# but it is important to call from_documents first so that the index is created
# correctly for K-NN
# correctly for K-NN
try:
index_exists = check_if_index_exists(opensearch_index,
aws_region,
Expand Down Expand Up @@ -186,7 +186,7 @@ def handler(event, context: LambdaContext) -> dict:
opensearch_url=opensearch_domain,
http_auth=http_auth)
# we now need to start the loop below for the second shard
shard_start_index = 1
shard_start_index = 1
else:
print(f"index {opensearch_index} does not exist and {path} file is not present, "
f"will wait for some other node to create the index")
Expand All @@ -207,24 +207,24 @@ def handler(event, context: LambdaContext) -> dict:
if time_slept >= TOTAL_INDEX_CREATION_WAIT_TIME:
print(f"time_slept={time_slept} >= {TOTAL_INDEX_CREATION_WAIT_TIME}, not waiting anymore for index creation")
break

else:
print(f"index={opensearch_index} does exists, going to call add_documents")
shard_start_index = 0

for shard in shards[shard_start_index:]:
results = process_shard(shard=shard,
results = process_shard(shard=shard,
os_index_name=opensearch_index,
os_domain_ep=opensearch_domain,
os_http_auth=http_auth)

for file in files:
if file['status'] == 'File transformed':
file['status'] = 'Ingested'
else:
file['status'] = 'Error_'+file['status']
updateIngestionJobStatus({'jobid': job_id, 'files': files})

return {
'status':'succeed'
}
10 changes: 9 additions & 1 deletion resources/gen-ai/aws-qa-appsync-opensearch/schema.graphql
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
enum ResponseGenerationMethod {
LONG_CONTEXT
RAG
}

type QADocs @aws_iam @aws_cognito_user_pools {
jobid: ID
question: String
Expand All @@ -8,6 +13,7 @@ type QADocs @aws_iam @aws_cognito_user_pools {
answer: String
sources: [String]
jobstatus: String
responseGenerationMethod: ResponseGenerationMethod
}

input QADocsInput {
Expand All @@ -17,6 +23,7 @@ input QADocsInput {
max_docs: Int
filename: String
verbose: Boolean
responseGenerationMethod: ResponseGenerationMethod
}

type Mutation @aws_iam @aws_cognito_user_pools {
Expand All @@ -27,7 +34,8 @@ type Mutation @aws_iam @aws_cognito_user_pools {
question: String,
max_docs: Int,
verbose: Boolean,
streaming: Boolean
streaming: Boolean,
responseGenerationMethod: ResponseGenerationMethod
): QADocs
updateQAJobStatus(
jobid: ID,
Expand Down
4 changes: 3 additions & 1 deletion src/patterns/gen-ai/aws-qa-appsync-opensearch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,15 @@ Mutation call to trigger the question:

```
mutation MyMutation {
postQuestion(jobid: "123", jobstatus: "", max_docs: 10, question: "d2hhdCBpcyB0aGlzIGRvY3VtZW50IGFib3V0ID8=", streaming:true, verbose: true, filename: "document.txt") {
postQuestion(jobid: "123", jobstatus: "", max_docs: 10, question: "d2hhdCBpcyB0aGlzIGRvY3VtZW50IGFib3V0ID8=", streaming:true, verbose: true, filename: "document.txt", responseGenerationMethod: "RAG") {
answer
jobid
jobstatus
max_docs
question
verbose
streaming
responseGenerationMethod
}
}
____________________________________________________________________
Expand Down Expand Up @@ -162,6 +163,7 @@ Where:
- verbose: boolean indicating if the [Langchain chain call verbosity](https://python.langchain.com/docs/guides/debugging#chain-verbosetrue) should be enabled or not
- streaming: boolean indicating if the streaming capability of Bedrock is used. If set to true, tokens will be send back to the subscriber as they are generated. If set to false, the entire response will be sent back to the subscriber once generated.
- filename: optional. Name of the file stored in the input S3 bucket, in txt format.
- responseGenerationMethod: optional. Method used to generate the response. Can be either RAG or LONG_CONTEXT. If not provided, the default value is LONG_CONTEXT.

## Initializer

Expand Down

0 comments on commit 2e499df

Please sign in to comment.