|
3 | 3 | from langchain.text_splitter import RecursiveCharacterTextSplitter
|
4 | 4 | from langchain.vectorstores import Chroma
|
5 | 5 | from langchain.embeddings import GPT4AllEmbeddings
|
6 |
| -import openai |
| 6 | +from tigerrag.rag.retrieval_augmenters import OpenAIRetrievalAugmenter |
| 7 | +from tigerrag.gar.query_augmenters import OpenAIQueryAugmenter |
7 | 8 | import os
|
8 | 9 | import sys
|
9 | 10 |
|
| 11 | + |
10 | 12 | # Please set env var OPENAI_API_KEY for GAR and RAG.
|
11 | 13 |
|
12 | 14 | # Sample usage:
|
13 | 15 | # python demo.py
|
14 | 16 | # python demo.py -number_of_run 4
|
15 | 17 |
|
16 |
| - |
17 |
| -def get_documents_embeddings(documents): |
18 |
| - # Load documents |
19 |
| - loader = WebBaseLoader(documents) |
20 |
| - |
21 |
| - # Split documents |
22 |
| - text_splitter = RecursiveCharacterTextSplitter( |
23 |
| - chunk_size=500, chunk_overlap=0) |
24 |
| - splits = text_splitter.split_documents(loader.load()) |
25 |
| - |
26 |
| - # Embed and store splits |
27 |
| - vectorstore = Chroma.from_documents( |
28 |
| - documents=splits, embedding=GPT4AllEmbeddings()) |
29 |
| - |
30 |
| - return vectorstore |
31 |
| - |
32 |
| - |
33 |
| -# EBR |
34 | 18 | def ebr(question, vectorstore):
|
35 | 19 | # Perform similarity search
|
36 | 20 | docs = vectorstore.similarity_search(question)
|
37 | 21 |
|
38 | 22 | return docs[0]
|
39 | 23 |
|
40 | 24 |
|
41 |
| -# RAG |
42 |
| -def generate_answer_with_rag_gpt3(question, context, openai_text_model): |
43 |
| - # Retrivel Augmented Generation |
44 |
| - prompt = f"Context: {context} Question: {question}. Provide a summary or answer:" |
45 |
| - |
46 |
| - # Generation using GPT-3 |
47 |
| - response = openai.Completion.create( |
48 |
| - engine=openai_text_model, prompt=prompt, max_tokens=100) |
49 |
| - answer = response.choices[0].text.strip() |
50 |
| - |
51 |
| - return answer |
52 |
| - |
53 |
| - |
54 |
| -# GAR |
55 |
| -def generate_answer_with_gar_gpt3(question, context, openai_text_model, vectorstore): |
56 |
| - # Generation Augmented Retrieval |
57 |
| - prompt = f"Expand on the query: {question}" |
58 |
| - |
59 |
| - # Generation using GPT-3 |
60 |
| - response = openai.Completion.create( |
61 |
| - engine=openai_text_model, prompt=prompt, max_tokens=100) |
62 |
| - augmented_query = response.choices[0].text.strip() |
63 |
| - |
64 |
| - # Retrieval |
65 |
| - answer = ebr(augmented_query, vectorstore) |
66 |
| - |
67 |
| - return answer |
68 |
| - |
69 |
| - |
70 | 25 | def is_intstring(s):
|
71 | 26 | try:
|
72 | 27 | int(s)
|
@@ -98,20 +53,39 @@ def main():
|
98 | 53 | for index in range(min(num_of_run, len(data["queries"]))):
|
99 | 54 | question = data["queries"][index]
|
100 | 55 | # Example usage of EBR
|
101 |
| - vectorstore = get_documents_embeddings(documents) |
| 56 | + loader = WebBaseLoader(documents) |
| 57 | + text_splitter = RecursiveCharacterTextSplitter( |
| 58 | + chunk_size=500, chunk_overlap=0) |
| 59 | + splits = text_splitter.split_documents(loader.load()) |
| 60 | + |
| 61 | + vectorstore = Chroma.from_documents( |
| 62 | + documents=splits, embedding=GPT4AllEmbeddings()) |
102 | 63 | print("The following is EBR output for question: "+question)
|
103 | 64 | retrieved_context = ebr(question, vectorstore)
|
104 | 65 | print(retrieved_context)
|
105 | 66 |
|
106 | 67 | # Example usage of RAG
|
107 |
| - print("The following is RAG output for question: "+question) |
108 |
| - print(generate_answer_with_rag_gpt3( |
109 |
| - question, retrieved_context, 'text-davinci-003')) |
110 | 68 |
|
| 69 | + print("The following is RAG output for question: "+question) |
| 70 | + # Retrivel Augmented Generation |
| 71 | + prompt_rag = f"""Context: {retrieved_context} Question: {question}. |
| 72 | + Provide a summary or answer:""" |
| 73 | + openai_generative_retrieval_augmenter = OpenAIRetrievalAugmenter( |
| 74 | + "text-davinci-003") |
| 75 | + answer_rag = openai_generative_retrieval_augmenter.get_augmented_retrieval( |
| 76 | + prompt_rag) |
| 77 | + print(answer_rag) |
111 | 78 | # Example usage of GAR
|
| 79 | + |
112 | 80 | print("The following is GAR output for question: "+question)
|
113 |
| - print(generate_answer_with_gar_gpt3( |
114 |
| - question, retrieved_context, 'text-davinci-003', vectorstore)) |
| 81 | + # print(generate_answer_with_gar_gpt3( |
| 82 | + # question, retrieved_context, 'text-davinci-003', vectorstore)) |
| 83 | + prompt_gar = f"Expand on the query: {question}" |
| 84 | + openai_generative_query_augmenter = OpenAIQueryAugmenter( |
| 85 | + "text-davinci-003") |
| 86 | + augmented_query = openai_generative_query_augmenter.get_augmented_query( |
| 87 | + prompt_gar) |
| 88 | + print(ebr(augmented_query, vectorstore)) |
115 | 89 |
|
116 | 90 |
|
117 | 91 | if __name__ == "__main__":
|
|
0 commit comments