Skip to content

Commit 96dabb2

Browse files
authored
Merge pull request #20 from gjyotin305/to_check
Update TigerRag Langchain Folder to Make it Match Newest Movie_Recs Logic #13
2 parents 8eaadaa + 027f1d8 commit 96dabb2

File tree

2 files changed

+30
-55
lines changed

2 files changed

+30
-55
lines changed

TigerRag/demos/langchain/demo.py

Lines changed: 28 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,70 +3,25 @@
33
from langchain.text_splitter import RecursiveCharacterTextSplitter
44
from langchain.vectorstores import Chroma
55
from langchain.embeddings import GPT4AllEmbeddings
6-
import openai
6+
from tigerrag.rag.retrieval_augmenters import OpenAIRetrievalAugmenter
7+
from tigerrag.gar.query_augmenters import OpenAIQueryAugmenter
78
import os
89
import sys
910

11+
1012
# Please set env var OPENAI_API_KEY for GAR and RAG.
1113

1214
# Sample usage:
1315
# python demo.py
1416
# python demo.py -number_of_run 4
1517

16-
17-
def get_documents_embeddings(documents):
18-
# Load documents
19-
loader = WebBaseLoader(documents)
20-
21-
# Split documents
22-
text_splitter = RecursiveCharacterTextSplitter(
23-
chunk_size=500, chunk_overlap=0)
24-
splits = text_splitter.split_documents(loader.load())
25-
26-
# Embed and store splits
27-
vectorstore = Chroma.from_documents(
28-
documents=splits, embedding=GPT4AllEmbeddings())
29-
30-
return vectorstore
31-
32-
33-
# EBR
3418
def ebr(question, vectorstore):
3519
# Perform similarity search
3620
docs = vectorstore.similarity_search(question)
3721

3822
return docs[0]
3923

4024

41-
# RAG
42-
def generate_answer_with_rag_gpt3(question, context, openai_text_model):
43-
# Retrivel Augmented Generation
44-
prompt = f"Context: {context} Question: {question}. Provide a summary or answer:"
45-
46-
# Generation using GPT-3
47-
response = openai.Completion.create(
48-
engine=openai_text_model, prompt=prompt, max_tokens=100)
49-
answer = response.choices[0].text.strip()
50-
51-
return answer
52-
53-
54-
# GAR
55-
def generate_answer_with_gar_gpt3(question, context, openai_text_model, vectorstore):
56-
# Generation Augmented Retrieval
57-
prompt = f"Expand on the query: {question}"
58-
59-
# Generation using GPT-3
60-
response = openai.Completion.create(
61-
engine=openai_text_model, prompt=prompt, max_tokens=100)
62-
augmented_query = response.choices[0].text.strip()
63-
64-
# Retrieval
65-
answer = ebr(augmented_query, vectorstore)
66-
67-
return answer
68-
69-
7025
def is_intstring(s):
7126
try:
7227
int(s)
@@ -98,20 +53,39 @@ def main():
9853
for index in range(min(num_of_run, len(data["queries"]))):
9954
question = data["queries"][index]
10055
# Example usage of EBR
101-
vectorstore = get_documents_embeddings(documents)
56+
loader = WebBaseLoader(documents)
57+
text_splitter = RecursiveCharacterTextSplitter(
58+
chunk_size=500, chunk_overlap=0)
59+
splits = text_splitter.split_documents(loader.load())
60+
61+
vectorstore = Chroma.from_documents(
62+
documents=splits, embedding=GPT4AllEmbeddings())
10263
print("The following is EBR output for question: "+question)
10364
retrieved_context = ebr(question, vectorstore)
10465
print(retrieved_context)
10566

10667
# Example usage of RAG
107-
print("The following is RAG output for question: "+question)
108-
print(generate_answer_with_rag_gpt3(
109-
question, retrieved_context, 'text-davinci-003'))
11068

69+
print("The following is RAG output for question: "+question)
70+
# Retrivel Augmented Generation
71+
prompt_rag = f"""Context: {retrieved_context} Question: {question}.
72+
Provide a summary or answer:"""
73+
openai_generative_retrieval_augmenter = OpenAIRetrievalAugmenter(
74+
"text-davinci-003")
75+
answer_rag = openai_generative_retrieval_augmenter.get_augmented_retrieval(
76+
prompt_rag)
77+
print(answer_rag)
11178
# Example usage of GAR
79+
11280
print("The following is GAR output for question: "+question)
113-
print(generate_answer_with_gar_gpt3(
114-
question, retrieved_context, 'text-davinci-003', vectorstore))
81+
# print(generate_answer_with_gar_gpt3(
82+
# question, retrieved_context, 'text-davinci-003', vectorstore))
83+
prompt_gar = f"Expand on the query: {question}"
84+
openai_generative_query_augmenter = OpenAIQueryAugmenter(
85+
"text-davinci-003")
86+
augmented_query = openai_generative_query_augmenter.get_augmented_query(
87+
prompt_gar)
88+
print(ebr(augmented_query, vectorstore))
11589

11690

11791
if __name__ == "__main__":

TigerRag/tigerrag/base/models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from enum import Enum
2-
2+
from typing import List
33
import numpy as np
44
import numpy.typing as npt
55
import pandas as pd
66
import torch
7+
78
from transformers import (BertModel, BertTokenizer, RobertaModel,
89
RobertaTokenizer, XLNetModel, XLNetTokenizer)
910

0 commit comments

Comments
 (0)