-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathTech_Agent.py
347 lines (282 loc) · 10.1 KB
/
Tech_Agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
import os
from semantic_router.encoders import OpenAIEncoder
from datasets import load_dataset
from pinecone import Pinecone
from pinecone import ServerlessSpec
import time
from tqdm.auto import tqdm
from typing import TypedDict, Annotated
from langchain_core.agents import AgentAction
from langchain_core.messages import BaseMessage
import operator
import re
from langchain_core.tools import tool
from serpapi import GoogleSearch
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import StateGraph, END
from langchain_openai import ChatOpenAI
encoder = OpenAIEncoder(name="text-embedding-3-small")
dataset = load_dataset("vencortex/TechNews", split="train")
pc = Pinecone(api_key="your_api_key")
spec = ServerlessSpec(
cloud="aws", region="us-east-1"
)
dims = len(encoder(["some random text"])[0])
index_name = "gpt-4o-research-agent"
if index_name not in pc.list_indexes().names():
pc.create_index(
index_name,
dimension=dims,
metric='dotproduct',
spec=spec
)
while not pc.describe_index(index_name).status['ready']:
time.sleep(1)
index = pc.Index(index_name)
time.sleep(1)
index.describe_index_stats()
data = dataset.to_pandas().iloc[:10000]
batch_size = 128
for i in tqdm(range(0, len(data), batch_size)):
i_end = min(len(data), i+batch_size)
batch = data[i:i_end].to_dict(orient="records")
metadata = [{
"title": r["title"],
"content": r["text"],
"arxiv_id": [r["symbol"]]
} for r in batch]
ids = [r["symbol"] for r in batch]
content = [r["text"] for r in batch]
embeds = encoder(content)
index.upsert(vectors=zip(ids, embeds, metadata))
class AgentState(TypedDict):
input: str
chat_history: list[BaseMessage]
intermediate_steps: Annotated[list[tuple[AgentAction, str]], operator.add]
symbol = "AAPL"
articles_by_symbol = [x for x in dataset if x['symbol'] == symbol]
symbol_pattern = re.compile(r"\bAAPL\b")
articles_with_symbol = [x for x in dataset if re.search(symbol_pattern, x['text']) is not None]
@tool("fetch_technews")
def fetch_technews(symbol: str):
"""Fetches tech news articles related to a specific stock/company symbol from the TechNews dataset."""
if len(articles_with_symbol) == 0:
return f"No articles found for symbol: {symbol}"
first_article = articles_with_symbol[0]
return {
"title": first_article['title'],
"text": first_article['text'],
"symbol": symbol
}
serpapi_params = {
"engine": "google",
"api_key": os.getenv("SERPAPI_KEY"),
"num": 5
}
@tool("web_search")
def web_search(query: str):
"""Finds tech knowledge information using Google search. Can also be used
to augment more 'tech' knowledge to a previous specialist query."""
sites = ["https://techcrunch.com", "https://gizmodo.com", "https://www.engadget.com", "https://www.wired.com", "https://www.techradar.com", "https://www.cnet.com/", "https://www.digitaltrends.com/", "https://lifehacker.com/"]
search = GoogleSearch({
**serpapi_params,
"q": " OR ".join([f"site:{site}" for site in sites]) + query,
})
results = search.get_dict().get("organic_results", [])
contexts = "\n---\n".join(
["\n".join([x["title"], x["snippet"], x["link"]]) for x in results]
)
return contexts
def format_rag_contexts(matches: list):
contexts = []
for x in matches:
revelant = (
f"Title: {x['metadata']['title']}\n"
)
contexts.append(revelant)
context_str = "\n---\n".join(contexts)
return context_str
@tool("rag_search_filter")
def rag_search_filter(query: str, symbol: str):
"""Finds information from our tech database using a natural language query
and a specific symbol. Allows us to learn more details about a specific news."""
xq = encoder([query])
xc = index.query(vector=xq, top_k=6, include_metadata=True, filter={"symbol": symbol})
context_str = format_rag_contexts(xc["matches"])
return context_str
@tool("rag_search")
def rag_search(query: str):
"""Finds specialist information on AI using a natural language query."""
xq = encoder([query])
xc = index.query(vector=xq, top_k=2, include_metadata=True)
context_str = format_rag_contexts(xc["matches"])
return context_str
@tool("final_answer")
def final_answer(
introduction: str,
research_steps: str,
main_body: str,
conclusion: str,
sources: str
):
"""Returns a natural language response to the user in the form of a research
report. There are several sections to this report, those are:
- `introduction`: a short paragraph introducing the user's question and the
topic we are researching.
- `research_steps`: a few bullet points explaining the steps that were taken
to research your report.
- `main_body`: this is where the bulk of high quality and concise
information that answers the user's question belongs. It is 3-4 paragraphs
long in length.
- `conclusion`: this is a short single paragraph conclusion providing a
concise but sophisticated view on what was found.
- `sources`: a bulletpoint list provided detailed sources for all information
referenced during the research process
"""
if type(research_steps) is list:
research_steps = "\n".join([f"- {r}" for r in research_steps])
if type(sources) is list:
sources = "\n".join([f"- {s}" for s in sources])
return ""
system_prompt = """You are the oracle, the great AI decision maker.
Given the user's query you must decide what to do with it based on the
list of tools provided to you.
If you see that a tool has been used (in the scratchpad) with a particular
query, do NOT use that same tool with the same query again. Also, do NOT use
any tool more than twice (ie, if the tool appears in the scratchpad twice, do
not use it again).
You should aim to collect information from a diverse range of sources before
providing the answer to the user. Once you have collected plenty of information
to answer the user's question (stored in the scratchpad) use the final_answer
tool."""
prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}"),
("assistant", "scratchpad: {scratchpad}"),
])
llm = ChatOpenAI(
model="gpt-4o",
openai_api_key="your_api_key",
temperature=0
)
tools=[
rag_search_filter,
rag_search,
fetch_technews,
web_search,
final_answer
]
def create_scratchpad(intermediate_steps: list[AgentAction]):
research_steps = []
for i, action in enumerate(intermediate_steps):
if action.log != "TBD":
research_steps.append(
f"Tool: {action.tool}, input: {action.tool_input}\n"
f"Output: {action.log}"
)
return "\n---\n".join(research_steps)
oracle = (
{
"input": lambda x: x["input"],
"chat_history": lambda x: x["chat_history"],
"scratchpad": lambda x: create_scratchpad(
intermediate_steps=x["intermediate_steps"]
),
}
| prompt
| llm.bind_tools(tools, tool_choice="any")
)
inputs = {
"input": "what is tesla?",
"chat_history": [],
"intermediate_steps": [],
}
out = oracle.invoke(inputs)
out
def run_oracle(state: list):
print("run_oracle")
print(f"intermediate_steps: {state['intermediate_steps']}")
out = oracle.invoke(state)
tool_name = out.tool_calls[0]["name"]
tool_args = out.tool_calls[0]["args"]
action_out = AgentAction(
tool=tool_name,
tool_input=tool_args,
log="TBD"
)
return {
"intermediate_steps": [action_out]
}
def router(state: list):
if isinstance(state["intermediate_steps"], list):
return state["intermediate_steps"][-1].tool
else:
print("Router invalid format")
return "final_answer"
tool_str_to_func = {
"rag_search_filter": rag_search_filter,
"rag_search": rag_search,
"fetch_technews": fetch_technews,
"web_search": web_search,
"final_answer": final_answer
}
def run_tool(state: list):
tool_name = state["intermediate_steps"][-1].tool
tool_args = state["intermediate_steps"][-1].tool_input
print(f"{tool_name}.invoke(input={tool_args})")
out = tool_str_to_func[tool_name].invoke(input=tool_args)
action_out = AgentAction(
tool=tool_name,
tool_input=tool_args,
log=str(out)
)
return {"intermediate_steps": [action_out]}
graph = StateGraph(AgentState)
graph.add_node("oracle", run_oracle)
graph.add_node("rag_search_filter", run_tool)
graph.add_node("rag_search", run_tool)
graph.add_node("fetch_technews", run_tool)
graph.add_node("web_search", run_tool)
graph.add_node("final_answer", run_tool)
graph.set_entry_point("oracle")
graph.add_conditional_edges(
source="oracle",
path=router,
)
for tool_obj in tools:
if tool_obj.name != "final_answer":
graph.add_edge(tool_obj.name, "oracle")
graph.add_edge("final_answer", END)
runnable = graph.compile()
out = runnable.invoke({
"input": "tell me something latest about tesla",
"chat_history": [],
})
def build_report(output: dict):
research_steps = output["research_steps"]
if type(research_steps) is list:
research_steps = "\n".join([f"- {r}" for r in research_steps])
sources = output["sources"]
if type(sources) is list:
sources = "\n".join([f"- {s}" for s in sources])
return f"""
INTRODUCTION
------------
{output["introduction"]}
RESEARCH STEPS
--------------
{research_steps}
REPORT
------
{output["main_body"]}
CONCLUSION
----------
{output["conclusion"]}
SOURCES
-------
{sources}
"""
print(build_report(
output=out["intermediate_steps"][-1].tool_input
))