-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrag_engine.py
139 lines (120 loc) · 4.97 KB
/
rag_engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
from typing import List, Dict
from dotenv import load_dotenv
import chromadb
from langchain.embeddings import AzureOpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chat_models import AzureChatOpenAI
from langchain.chains import RetrievalQA
import time
# Load environment variables
load_dotenv()
class RAGEngine:
def __init__(self):
# Verify Azure OpenAI settings are set
required_vars = [
'AZURE_OPENAI_ENDPOINT',
'AZURE_OPENAI_KEY',
'AZURE_OPENAI_DEPLOYMENT_NAME',
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME'
]
# Debug: Print environment variables (without sensitive values)
print("Checking environment variables...")
for var in required_vars:
if not os.getenv(var):
print(f"Missing {var}")
else:
print(f"Found {var}")
missing_vars = [var for var in required_vars if not os.getenv(var)]
if missing_vars:
raise ValueError(f"Missing required Azure OpenAI settings: {', '.join(missing_vars)}")
# Initialize with retry mechanism
max_retries = 3
for attempt in range(max_retries):
try:
self.embeddings = AzureOpenAIEmbeddings(
azure_endpoint=os.getenv('AZURE_OPENAI_ENDPOINT'),
azure_deployment=os.getenv('AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME'),
api_key=os.getenv('AZURE_OPENAI_KEY')
)
self.vector_store = None
self.qa_chain = None
# Test connection
self.embeddings.embed_query("test")
break
except Exception as e:
if attempt == max_retries - 1:
raise ConnectionError(f"Failed to connect to Azure OpenAI API after {max_retries} attempts. Error: {str(e)}")
time.sleep(2) # Wait before retrying
def initialize_vector_store(self, chunks: List[Dict]):
"""
Initialize the vector store with document chunks.
Args:
chunks (List[Dict]): List of dictionaries containing text and metadata
"""
print(f"Initializing vector store with {len(chunks)} chunks")
if not chunks:
raise ValueError("No text chunks provided. PDF processing may have failed.")
texts = [chunk['text'] for chunk in chunks]
metadatas = [chunk['metadata'] for chunk in chunks]
print(f"First chunk preview: {texts[0][:200]}...")
print(f"First chunk metadata: {metadatas[0]}")
# Create vector store
print("Creating Chroma vector store...")
self.vector_store = Chroma.from_texts(
texts=texts,
embedding=self.embeddings,
metadatas=metadatas,
persist_directory="./chroma_db" # Add persistence
)
print("Vector store created successfully")
# Initialize QA chain
print("Initializing QA chain...")
llm = AzureChatOpenAI(
temperature=0,
model_name="gpt-3.5-turbo",
azure_deployment_name=os.getenv('AZURE_OPENAI_DEPLOYMENT_NAME'),
azure_endpoint=os.getenv('AZURE_OPENAI_ENDPOINT'),
api_key=os.getenv('AZURE_OPENAI_KEY')
)
self.qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=self.vector_store.as_retriever(
search_kwargs={"k": 3}
)
)
print("QA chain initialized successfully")
def answer_question(self, question: str) -> Dict:
"""
Answer a question using the RAG system.
Args:
question (str): User's question
Returns:
Dict: Answer and source information
"""
if not self.qa_chain:
raise ValueError("Vector store not initialized. Please process documents first.")
# Create a prompt that emphasizes definition extraction
prompt = f"""
Question: {question}
Please provide a clear and concise answer based on the provided context.
If the question asks for a definition or explanation of a concept,
make sure to provide that specifically. Include relevant examples or
additional context only if they help clarify the concept.
"""
# Get answer from QA chain
result = self.qa_chain({"query": prompt})
# Get source documents
source_docs = self.vector_store.similarity_search(question, k=2)
sources = [
{
'page': doc.metadata['page'],
'text': doc.page_content[:200] + "..." # Preview of source text
}
for doc in source_docs
]
return {
'answer': result['result'],
'sources': sources
}