-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlanggraph-advanced-rag-flow.py
325 lines (254 loc) · 10.7 KB
/
langgraph-advanced-rag-flow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
# Load environment variables
import os
from dotenv import load_dotenv
from typing import List
from typing_extensions import TypedDict
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.graph import END, StateGraph
from langchain.schema import Document
import json
# Helper function for environment variables
def get_env_variable(var_name):
value = os.getenv(var_name)
if not value:
raise ValueError(f"Missing environment variable: {var_name}")
return value
load_dotenv() # Load environment variables from .env file
qdrant_instance_url = get_env_variable('QDRANT_INSTANCE_URL')
qdrant_api_key = get_env_variable('QDRANT_API_KEY')
tavily_api_key = get_env_variable('TAVILY_API_KEY')
# Prepare LLM
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.1, max_tokens=2500)
llm_json_mode = llm.bind(response_format={"type": "json_object"})
# Prepare Embeddings - use the same embedding model as for ingestion
from langchain_mistralai import MistralAIEmbeddings
embed_model = MistralAIEmbeddings()
# let's attach our Qdrant Vector store
from langchain_qdrant import QdrantVectorStore
store_wiki = QdrantVectorStore.from_existing_collection(
collection_name = "wiki",
embedding = embed_model,
url=qdrant_instance_url,
api_key = qdrant_api_key,
)
# create retriever
wiki_retriever = store_wiki.as_retriever(search_kwargs={"k":1,})
# setup graph
class GraphState(TypedDict):
"""
Graph state is a dictionary that contains information we want to propagate to, and modify in, each graph node.
"""
question : str # User question
generation : str # LLM generation
answer_grade : str # Retrieved docs good for generation relevant/not_relevant
documents : List[str] # List of retrieved documents
### Nodes
def retrieve(state):
"""
Retrieve documents from vectorstore
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
print("---RETRIEVE---")
question = state["question"]
# Write retrieved documents to documents key in state
documents = wiki_retriever.invoke(question) or [Document(page_content="No content found")]
return {"documents": documents}
def grade(state):
"""
Grade retrieved documents
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, answer_grade, that contains grade as relevant or not_relevant
"""
print("---GRADE---")
question = state["question"]
documents = state["documents"]
# Doc grader instructions
doc_grader_instructions = """You are a grader assessing carefully and objectively the relevance of a retrieved document to a user question.
If the document contains keywords or semantic meaning related to the question, grade it as relevant."""
# Grader prompt
doc_grader_prompt = """Here is the retrieved document: \n\n {document} \n\n Here is the user question: \n\n {question}.
Return JSON with single key, binary_score, that is 'relevant' or 'not_relevant' score to indicate whether the document contains at least some information that is relevant to the question."""
# Prepare prompt and run grader
doc_grader_prompt_formatted = doc_grader_prompt.format(document=documents[0].page_content, question=question)
result = llm_json_mode.invoke(
[SystemMessage(content=doc_grader_instructions)]
+ [HumanMessage(content=doc_grader_prompt_formatted)]
)
return {"answer_grade": json.loads(result.content)['binary_score']}
def web_search_angular(state):
"""
Run web search for Angular content on angular.dev
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
print("---WEB SEARCH ANGULAR---")
question = state["question"]
# Instantiating your TavilyClient
from tavily import TavilyClient
search_client = TavilyClient(api_key=tavily_api_key)
# Run open web search
from langchain.schema import Document
results = search_client.search(question, search_depth="advanced", max_results=3, include_domains=["angular.dev"], include_raw_content=True)
# List to store the generated Document objects
documents = []
# Iterate over each entry in the feed
for entry in results["results"]:
# Extract the page content: prefer raw_content, fall back to content, and use default if both are empty
page_content = entry.get('raw_content') or entry.get('content') or "No content found"
# Extract metadata
metadata = {
"title": entry.get('title', 'No Title'),
"link": entry.get('url', 'No Link'),
"score": entry.get('score', '0'),
}
# Create a Document object for this entry
document = Document(page_content=page_content, metadata=metadata)
# Append the document to the list
documents.append(document)
# Write retrieved documents to documents key in state
return{"documents": documents}
def web_search_full(state):
"""
Run web search for any content for given question
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
print("---WEB SEARCH FULL---")
question = state["question"]
# Instantiating your TavilyClient
from tavily import TavilyClient
search_client = TavilyClient(api_key=tavily_api_key)
# Run open web search
from langchain.schema import Document
results = search_client.search(question, search_depth="advanced", max_results=2)
# List to store the generated Document objects
documents = []
# Iterate over each entry in the feed
for entry in results["results"]:
# Extract the page content
page_content = entry.get('content', 'No content')
# Extract metadata
metadata = {
"title": entry.get('title', 'No Title'),
"link": entry.get('url', 'No Link'),
"score": entry.get('score', '0'),
}
# Create a Document object for this entry
document = Document(page_content=page_content, metadata=metadata)
# Append the document to the list
documents.append(document)
# Write retrieved documents to documents key in state
return{"documents": documents}
def generate(state):
"""
Generate answer using RAG on retrieved documents
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, generation, that contains LLM generation
"""
print("---GENERATE---")
question = state["question"]
documents = state["documents"]
context = "\n\n".join(doc.page_content for doc in documents) if documents else "No content found"
# define answer prompt
prompt_template = """You are an assistant for question-answering tasks at ACME GmbH.
Think carefully about the context.
Just say 'Diese Frage kann ich nicht beantworten' if there is not enough or no context given.
Provide an answer to the user question using only the given context.
Use three sentences maximum and keep the answer concise.
If the context mentions ACME guidelines, try to include it in the answer.
Here is the context to use to answer the question:
{context}
Now, review the user question:
{question}
Write the answer in German. Don't output an English translation.
Answer:"""
# RAG generation
rag_prompt_formatted = prompt_template.format(context=context, question=question)
generation = llm.invoke([HumanMessage(content=rag_prompt_formatted)])
return {"generation": generation}
### Conditional nodes
def route_question(state):
"""
Route question to web search or RAG
Args:
state (dict): The current graph state
Returns:
str: Next node to call
"""
print("---ROUTE QUESTION---")
router_instructions = """You are an expert at routing a user question to a vectorstore or websearch.
The vectorstore contains documents related to coding, programming, development practices, single page applications, the Angular framework, and coding guidelines for the company ACME.
Use the vectorstore for any questions containing coding terms, code snippets, programming languages, or technologies relevant to development practices (even in other languages like German).
If the question is related to coding or development but not specifically covered by the vectorstore, still return 'vectorstore'. Use 'websearch' for non-coding questions.
Return JSON with a single key, "datasource," that is 'websearch' or 'vectorstore' depending on the question."""
route_question_result = llm_json_mode.invoke(
[SystemMessage(content=router_instructions)]
+ [HumanMessage(content=state["question"])]
)
source = json.loads(route_question_result.content)["datasource"]
if source == "websearch":
print("---ROUTE QUESTION TO WEB_SEARCH_FULL---")
return "websearch"
elif source == "vectorstore":
print("---ROUTE QUESTION TO RETRIEVER---")
return "vectorstore"
def decide_retriever_ok(state):
"""
Determines whether retrieved content is good to generate an answer, or run web search
Args:
state (dict): The current graph state
Returns:
str: Binary decision for next node to call
"""
print("---DECIDE RETRIEVER OK---")
answer_grade = state["answer_grade"]
if answer_grade.lower() == "not_relevant":
print(
"---DECISION: NOT ALL DOCUMENTS ARE RELEVANT TO QUESTION, INCLUDE WEB SEARCH---"
)
return "websearch"
else:
# We have relevant documents, so generate answer
print("---DECISION: GENERATE---")
return "generate"
workflow = StateGraph(GraphState)
# Define the nodes
workflow.add_node("retrieve", retrieve) # retrieve
workflow.add_node("generate", generate) # generate
workflow.add_node("grade", grade) # grade
workflow.add_node("web_search_angular", web_search_angular) # websearch angular.dev
workflow.add_node("web_search_full", web_search_full) # full websearch
# Define the edges
workflow.set_conditional_entry_point(
route_question,
{
"websearch": "web_search_full",
"vectorstore": "retrieve",
},
)
workflow.add_edge("web_search_full", "generate")
workflow.add_edge("retrieve", "grade")
workflow.add_conditional_edges(
"grade",
decide_retriever_ok,
{
"websearch": "web_search_angular",
"generate": "generate",
},
)
workflow.add_edge("web_search_angular", "generate")
workflow.add_edge("generate", END)
# Compile
graph = workflow.compile()