Skip to content

Commit b028ab7

Browse files
authored
upgrade: add web search! (#40)
1 parent cd0d0c6 commit b028ab7

File tree

6 files changed

+230
-28
lines changed

6 files changed

+230
-28
lines changed

1_🏠_Home.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
"To build a new agent, please make sure that 'Create a new agent' is selected.",
3232
icon="ℹ️",
3333
)
34+
if "metaphor_key" in st.secrets:
35+
st.info("**NOTE**: The ability to add web search is enabled.")
3436

3537

3638
add_sidebar()

agent_utils.py

Lines changed: 129 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
from constants import AGENT_CACHE_DIR
3333
import shutil
3434

35+
from llama_index.callbacks import CallbackManager
36+
from callback_manager import StreamlitFunctionsCallbackHandler
37+
3538

3639
def _resolve_llm(llm_str: str) -> LLM:
3740
"""Resolve LLM."""
@@ -153,9 +156,25 @@ def load_agent(
153156
"""Load agent."""
154157
extra_kwargs = extra_kwargs or {}
155158
if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
159+
# TODO: use default msg handler
160+
# TODO: separate this from agent_utils.py...
161+
def _msg_handler(msg: str) -> None:
162+
"""Message handler."""
163+
st.info(msg)
164+
st.session_state.agent_messages.append(
165+
{"role": "assistant", "content": msg, "msg_type": "info"}
166+
)
167+
168+
# add streamlit callbacks (to inject events)
169+
handler = StreamlitFunctionsCallbackHandler(_msg_handler)
170+
callback_manager = CallbackManager([handler])
156171
# get OpenAI Agent
157172
agent: BaseChatEngine = OpenAIAgent.from_tools(
158-
tools=tools, llm=llm, system_prompt=system_prompt, **kwargs
173+
tools=tools,
174+
llm=llm,
175+
system_prompt=system_prompt,
176+
**kwargs,
177+
callback_manager=callback_manager,
159178
)
160179
else:
161180
if "vector_index" not in extra_kwargs:
@@ -189,8 +208,12 @@ def load_meta_agent(
189208
extra_kwargs = extra_kwargs or {}
190209
if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
191210
# get OpenAI Agent
211+
192212
agent: BaseAgent = OpenAIAgent.from_tools(
193-
tools=tools, llm=llm, system_prompt=system_prompt, **kwargs
213+
tools=tools,
214+
llm=llm,
215+
system_prompt=system_prompt,
216+
**kwargs,
194217
)
195218
else:
196219
agent = ReActAgent.from_tools(
@@ -285,6 +308,66 @@ def construct_agent(
285308
return agent, extra_info
286309

287310

311+
def get_web_agent_tool() -> QueryEngineTool:
312+
"""Get web agent tool.
313+
314+
Wrap with our load and search tool spec.
315+
316+
"""
317+
from llama_hub.tools.metaphor.base import MetaphorToolSpec
318+
319+
# TODO: set metaphor API key
320+
metaphor_tool = MetaphorToolSpec(
321+
api_key=st.secrets.metaphor_key,
322+
)
323+
metaphor_tool_list = metaphor_tool.to_tool_list()
324+
325+
# TODO: LoadAndSearch doesn't work yet
326+
# The search_and_retrieve_documents tool is the third in the tool list,
327+
# as seen above
328+
# wrapped_retrieve = LoadAndSearchToolSpec.from_defaults(
329+
# metaphor_tool_list[2],
330+
# )
331+
332+
# NOTE: requires openai right now
333+
# We don't give the Agent our unwrapped retrieve document tools
334+
# instead passing the wrapped tools
335+
web_agent = OpenAIAgent.from_tools(
336+
# [*wrapped_retrieve.to_tool_list(), metaphor_tool_list[4]],
337+
metaphor_tool_list,
338+
llm=BUILDER_LLM,
339+
verbose=True,
340+
)
341+
342+
# return agent as a tool
343+
# TODO: tune description
344+
web_agent_tool = QueryEngineTool.from_defaults(
345+
web_agent,
346+
name="web_agent",
347+
description="""
348+
This agent can answer questions by searching the web. \
349+
Use this tool if the answer is ONLY likely to be found by searching \
350+
the internet, especially for queries about recent events.
351+
""",
352+
)
353+
354+
return web_agent_tool
355+
356+
357+
def get_tool_objects(tool_names: List[str]) -> List:
358+
"""Get tool objects from tool names."""
359+
# construct additional tools
360+
tool_objs = []
361+
for tool_name in tool_names:
362+
if tool_name == "web_search":
363+
# build web agent
364+
tool_objs.append(get_web_agent_tool())
365+
else:
366+
raise ValueError(f"Tool {tool_name} not recognized.")
367+
368+
return tool_objs
369+
370+
288371
class ParamCache(BaseModel):
289372
"""Cache for RAG agent builder.
290373
@@ -338,7 +421,7 @@ def save_to_disk(self, save_dir: str) -> None:
338421
"file_names": self.file_names,
339422
"urls": self.urls,
340423
# TODO: figure out tools
341-
# "tools": [],
424+
"tools": self.tools,
342425
"rag_params": self.rag_params.dict(),
343426
"agent_id": self.agent_id,
344427
}
@@ -376,11 +459,13 @@ def load_from_disk(
376459
file_names=cache_dict["file_names"], urls=cache_dict["urls"]
377460
)
378461
# load agent from index
462+
additional_tools = get_tool_objects(cache_dict["tools"])
379463
agent, _ = construct_agent(
380464
cache_dict["system_prompt"],
381465
cache_dict["rag_params"],
382466
cache_dict["docs"],
383467
vector_index=vector_index,
468+
additional_tools=additional_tools,
384469
# TODO: figure out tools
385470
)
386471
cache_dict["vector_index"] = vector_index
@@ -505,20 +590,14 @@ def load_data(
505590
self._cache.urls = urls
506591
return "Data loaded successfully."
507592

508-
# NOTE: unused
509593
def add_web_tool(self) -> str:
510594
"""Add a web tool to enable agent to solve a task."""
511595
# TODO: make this not hardcoded to a web tool
512596
# Set up Metaphor tool
513-
from llama_hub.tools.metaphor.base import MetaphorToolSpec
514-
515-
# TODO: set metaphor API key
516-
metaphor_tool = MetaphorToolSpec(
517-
api_key=os.environ["METAPHOR_API_KEY"],
518-
)
519-
metaphor_tool_list = metaphor_tool.to_tool_list()
520-
521-
self._cache.tools.extend(metaphor_tool_list)
597+
if "web_search" in self._cache.tools:
598+
return "Web tool already added."
599+
else:
600+
self._cache.tools.append("web_search")
522601
return "Web tool added successfully."
523602

524603
def get_rag_params(self) -> Dict:
@@ -557,11 +636,13 @@ def create_agent(self, agent_id: Optional[str] = None) -> str:
557636
if self._cache.system_prompt is None:
558637
raise ValueError("Must set system prompt before creating agent.")
559638

639+
# construct additional tools
640+
additional_tools = get_tool_objects(self.cache.tools)
560641
agent, extra_info = construct_agent(
561642
cast(str, self._cache.system_prompt),
562643
cast(RAGParams, self._cache.rag_params),
563644
self._cache.docs,
564-
additional_tools=self._cache.tools,
645+
additional_tools=additional_tools,
565646
)
566647

567648
# if agent_id not specified, randomly generate one
@@ -587,6 +668,7 @@ def update_agent(
587668
chunk_size: Optional[int] = None,
588669
embed_model: Optional[str] = None,
589670
llm: Optional[str] = None,
671+
additional_tools: Optional[List] = None,
590672
) -> None:
591673
"""Update agent.
592674
@@ -609,7 +691,6 @@ def update_agent(
609691
# We call set_rag_params and create_agent, which will
610692
# update the cache
611693
# TODO: decouple functions from tool functions exposed to the agent
612-
613694
rag_params_dict: Dict[str, Any] = {}
614695
if include_summarization is not None:
615696
rag_params_dict["include_summarization"] = include_summarization
@@ -623,6 +704,11 @@ def update_agent(
623704
rag_params_dict["llm"] = llm
624705

625706
self.set_rag_params(**rag_params_dict)
707+
708+
# update tools
709+
if additional_tools is not None:
710+
self.cache.tools = additional_tools
711+
626712
# this will update the agent in the cache
627713
self.create_agent()
628714

@@ -655,6 +741,33 @@ def update_agent(
655741
# please make sure to update the LLM above if you change the function below
656742

657743

744+
def _get_builder_agent_tools(agent_builder: RAGAgentBuilder) -> List[FunctionTool]:
745+
"""Get list of builder agent tools to pass to the builder agent."""
746+
# see if metaphor api key is set, otherwise don't add web tool
747+
# TODO: refactor this later
748+
749+
if "metaphor_key" in st.secrets:
750+
fns: List[Callable] = [
751+
agent_builder.create_system_prompt,
752+
agent_builder.load_data,
753+
agent_builder.add_web_tool,
754+
agent_builder.get_rag_params,
755+
agent_builder.set_rag_params,
756+
agent_builder.create_agent,
757+
]
758+
else:
759+
fns = [
760+
agent_builder.create_system_prompt,
761+
agent_builder.load_data,
762+
agent_builder.get_rag_params,
763+
agent_builder.set_rag_params,
764+
agent_builder.create_agent,
765+
]
766+
767+
fn_tools: List[FunctionTool] = [FunctionTool.from_defaults(fn=fn) for fn in fns]
768+
return fn_tools
769+
770+
658771
# define agent
659772
# @st.cache_resource
660773
def load_meta_agent_and_tools(
@@ -664,15 +777,7 @@ def load_meta_agent_and_tools(
664777
# think of this as tools for the agent to use
665778
agent_builder = RAGAgentBuilder(cache)
666779

667-
fns: List[Callable] = [
668-
agent_builder.create_system_prompt,
669-
agent_builder.load_data,
670-
# add_web_tool,
671-
agent_builder.get_rag_params,
672-
agent_builder.set_rag_params,
673-
agent_builder.create_agent,
674-
]
675-
fn_tools = [FunctionTool.from_defaults(fn=fn) for fn in fns]
780+
fn_tools = _get_builder_agent_tools(agent_builder)
676781

677782
builder_agent = load_meta_agent(
678783
fn_tools, llm=BUILDER_LLM, system_prompt=RAG_BUILDER_SYS_STR, verbose=True

callback_manager.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""Streaming callback manager."""
2+
from llama_index.callbacks.base_handler import BaseCallbackHandler
3+
from llama_index.callbacks.schema import CBEventType
4+
5+
from typing import Optional, Dict, Any, List, Callable
6+
7+
STORAGE_DIR = "./storage" # directory to cache the generated index
8+
DATA_DIR = "./data" # directory containing the documents to index
9+
10+
11+
class StreamlitFunctionsCallbackHandler(BaseCallbackHandler):
12+
"""Callback handler that outputs streamlit components given events."""
13+
14+
def __init__(self, msg_handler: Callable[[str], Any]) -> None:
15+
"""Initialize the base callback handler."""
16+
self.msg_handler = msg_handler
17+
super().__init__([], [])
18+
19+
def on_event_start(
20+
self,
21+
event_type: CBEventType,
22+
payload: Optional[Dict[str, Any]] = None,
23+
event_id: str = "",
24+
parent_id: str = "",
25+
**kwargs: Any,
26+
) -> str:
27+
"""Run when an event starts and return id of event."""
28+
if event_type == CBEventType.FUNCTION_CALL:
29+
if payload is None:
30+
raise ValueError("Payload cannot be None")
31+
arguments_str = payload["function_call"]
32+
tool_str = payload["tool"].name
33+
print_str = f"Calling function: {tool_str} with args: {arguments_str}\n\n"
34+
self.msg_handler(print_str)
35+
else:
36+
pass
37+
return event_id
38+
39+
def on_event_end(
40+
self,
41+
event_type: CBEventType,
42+
payload: Optional[Dict[str, Any]] = None,
43+
event_id: str = "",
44+
**kwargs: Any,
45+
) -> None:
46+
"""Run when an event ends."""
47+
pass
48+
# TODO: currently we don't need to do anything here
49+
# if event_type == CBEventType.FUNCTION_CALL:
50+
# response = payload["function_call_response"]
51+
# # Add this to queue
52+
# print_str = (
53+
# f"\n\nGot output: {response}\n"
54+
# "========================\n\n"
55+
# )
56+
# elif event_type == CBEventType.AGENT_STEP:
57+
# # put response into queue
58+
# self._queue.put(payload["response"])
59+
60+
def start_trace(self, trace_id: Optional[str] = None) -> None:
61+
"""Run when an overall trace is launched."""
62+
pass
63+
64+
def end_trace(
65+
self,
66+
trace_id: Optional[str] = None,
67+
trace_map: Optional[Dict[str, List[str]]] = None,
68+
) -> None:
69+
"""Run when an overall trace is exited."""
70+
pass

pages/2_⚙️_RAG_Config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def update_agent() -> None:
2424
"config_agent_builder" in st.session_state.keys()
2525
and st.session_state.config_agent_builder is not None
2626
):
27+
additional_tools = st.session_state.additional_tools_st.split(",")
2728
agent_builder = cast(RAGAgentBuilder, st.session_state.config_agent_builder)
2829
### Update the agent
2930
agent_builder.update_agent(
@@ -34,6 +35,7 @@ def update_agent() -> None:
3435
chunk_size=st.session_state.chunk_size_st,
3536
embed_model=st.session_state.embed_model_st,
3637
llm=st.session_state.llm_st,
38+
additional_tools=additional_tools,
3739
)
3840

3941
# Update Radio Buttons: update selected agent to the new id
@@ -114,6 +116,14 @@ def delete_agent() -> None:
114116
value=rag_params.include_summarization,
115117
key="include_summarization_st",
116118
)
119+
120+
# add web tool
121+
additional_tools_st = st.text_input(
122+
"Additional tools (currently only supports 'web_search')",
123+
value=",".join(agent_builder.cache.tools),
124+
key="additional_tools_st",
125+
)
126+
117127
top_k_st = st.number_input("Top K", value=rag_params.top_k, key="top_k_st")
118128
chunk_size_st = st.number_input(
119129
"Chunk Size", value=rag_params.chunk_size, key="chunk_size_st"

0 commit comments

Comments
 (0)