32
32
from constants import AGENT_CACHE_DIR
33
33
import shutil
34
34
35
+ from llama_index .callbacks import CallbackManager
36
+ from callback_manager import StreamlitFunctionsCallbackHandler
37
+
35
38
36
39
def _resolve_llm (llm_str : str ) -> LLM :
37
40
"""Resolve LLM."""
@@ -153,9 +156,25 @@ def load_agent(
153
156
"""Load agent."""
154
157
extra_kwargs = extra_kwargs or {}
155
158
if isinstance (llm , OpenAI ) and is_function_calling_model (llm .model ):
159
+ # TODO: use default msg handler
160
+ # TODO: separate this from agent_utils.py...
161
+ def _msg_handler (msg : str ) -> None :
162
+ """Message handler."""
163
+ st .info (msg )
164
+ st .session_state .agent_messages .append (
165
+ {"role" : "assistant" , "content" : msg , "msg_type" : "info" }
166
+ )
167
+
168
+ # add streamlit callbacks (to inject events)
169
+ handler = StreamlitFunctionsCallbackHandler (_msg_handler )
170
+ callback_manager = CallbackManager ([handler ])
156
171
# get OpenAI Agent
157
172
agent : BaseChatEngine = OpenAIAgent .from_tools (
158
- tools = tools , llm = llm , system_prompt = system_prompt , ** kwargs
173
+ tools = tools ,
174
+ llm = llm ,
175
+ system_prompt = system_prompt ,
176
+ ** kwargs ,
177
+ callback_manager = callback_manager ,
159
178
)
160
179
else :
161
180
if "vector_index" not in extra_kwargs :
@@ -189,8 +208,12 @@ def load_meta_agent(
189
208
extra_kwargs = extra_kwargs or {}
190
209
if isinstance (llm , OpenAI ) and is_function_calling_model (llm .model ):
191
210
# get OpenAI Agent
211
+
192
212
agent : BaseAgent = OpenAIAgent .from_tools (
193
- tools = tools , llm = llm , system_prompt = system_prompt , ** kwargs
213
+ tools = tools ,
214
+ llm = llm ,
215
+ system_prompt = system_prompt ,
216
+ ** kwargs ,
194
217
)
195
218
else :
196
219
agent = ReActAgent .from_tools (
@@ -285,6 +308,66 @@ def construct_agent(
285
308
return agent , extra_info
286
309
287
310
311
+ def get_web_agent_tool () -> QueryEngineTool :
312
+ """Get web agent tool.
313
+
314
+ Wrap with our load and search tool spec.
315
+
316
+ """
317
+ from llama_hub .tools .metaphor .base import MetaphorToolSpec
318
+
319
+ # TODO: set metaphor API key
320
+ metaphor_tool = MetaphorToolSpec (
321
+ api_key = st .secrets .metaphor_key ,
322
+ )
323
+ metaphor_tool_list = metaphor_tool .to_tool_list ()
324
+
325
+ # TODO: LoadAndSearch doesn't work yet
326
+ # The search_and_retrieve_documents tool is the third in the tool list,
327
+ # as seen above
328
+ # wrapped_retrieve = LoadAndSearchToolSpec.from_defaults(
329
+ # metaphor_tool_list[2],
330
+ # )
331
+
332
+ # NOTE: requires openai right now
333
+ # We don't give the Agent our unwrapped retrieve document tools
334
+ # instead passing the wrapped tools
335
+ web_agent = OpenAIAgent .from_tools (
336
+ # [*wrapped_retrieve.to_tool_list(), metaphor_tool_list[4]],
337
+ metaphor_tool_list ,
338
+ llm = BUILDER_LLM ,
339
+ verbose = True ,
340
+ )
341
+
342
+ # return agent as a tool
343
+ # TODO: tune description
344
+ web_agent_tool = QueryEngineTool .from_defaults (
345
+ web_agent ,
346
+ name = "web_agent" ,
347
+ description = """
348
+ This agent can answer questions by searching the web. \
349
+ Use this tool if the answer is ONLY likely to be found by searching \
350
+ the internet, especially for queries about recent events.
351
+ """ ,
352
+ )
353
+
354
+ return web_agent_tool
355
+
356
+
357
+ def get_tool_objects (tool_names : List [str ]) -> List :
358
+ """Get tool objects from tool names."""
359
+ # construct additional tools
360
+ tool_objs = []
361
+ for tool_name in tool_names :
362
+ if tool_name == "web_search" :
363
+ # build web agent
364
+ tool_objs .append (get_web_agent_tool ())
365
+ else :
366
+ raise ValueError (f"Tool { tool_name } not recognized." )
367
+
368
+ return tool_objs
369
+
370
+
288
371
class ParamCache (BaseModel ):
289
372
"""Cache for RAG agent builder.
290
373
@@ -338,7 +421,7 @@ def save_to_disk(self, save_dir: str) -> None:
338
421
"file_names" : self .file_names ,
339
422
"urls" : self .urls ,
340
423
# TODO: figure out tools
341
- # "tools": [] ,
424
+ "tools" : self . tools ,
342
425
"rag_params" : self .rag_params .dict (),
343
426
"agent_id" : self .agent_id ,
344
427
}
@@ -376,11 +459,13 @@ def load_from_disk(
376
459
file_names = cache_dict ["file_names" ], urls = cache_dict ["urls" ]
377
460
)
378
461
# load agent from index
462
+ additional_tools = get_tool_objects (cache_dict ["tools" ])
379
463
agent , _ = construct_agent (
380
464
cache_dict ["system_prompt" ],
381
465
cache_dict ["rag_params" ],
382
466
cache_dict ["docs" ],
383
467
vector_index = vector_index ,
468
+ additional_tools = additional_tools ,
384
469
# TODO: figure out tools
385
470
)
386
471
cache_dict ["vector_index" ] = vector_index
@@ -505,20 +590,14 @@ def load_data(
505
590
self ._cache .urls = urls
506
591
return "Data loaded successfully."
507
592
508
- # NOTE: unused
509
593
def add_web_tool (self ) -> str :
510
594
"""Add a web tool to enable agent to solve a task."""
511
595
# TODO: make this not hardcoded to a web tool
512
596
# Set up Metaphor tool
513
- from llama_hub .tools .metaphor .base import MetaphorToolSpec
514
-
515
- # TODO: set metaphor API key
516
- metaphor_tool = MetaphorToolSpec (
517
- api_key = os .environ ["METAPHOR_API_KEY" ],
518
- )
519
- metaphor_tool_list = metaphor_tool .to_tool_list ()
520
-
521
- self ._cache .tools .extend (metaphor_tool_list )
597
+ if "web_search" in self ._cache .tools :
598
+ return "Web tool already added."
599
+ else :
600
+ self ._cache .tools .append ("web_search" )
522
601
return "Web tool added successfully."
523
602
524
603
def get_rag_params (self ) -> Dict :
@@ -557,11 +636,13 @@ def create_agent(self, agent_id: Optional[str] = None) -> str:
557
636
if self ._cache .system_prompt is None :
558
637
raise ValueError ("Must set system prompt before creating agent." )
559
638
639
+ # construct additional tools
640
+ additional_tools = get_tool_objects (self .cache .tools )
560
641
agent , extra_info = construct_agent (
561
642
cast (str , self ._cache .system_prompt ),
562
643
cast (RAGParams , self ._cache .rag_params ),
563
644
self ._cache .docs ,
564
- additional_tools = self . _cache . tools ,
645
+ additional_tools = additional_tools ,
565
646
)
566
647
567
648
# if agent_id not specified, randomly generate one
@@ -587,6 +668,7 @@ def update_agent(
587
668
chunk_size : Optional [int ] = None ,
588
669
embed_model : Optional [str ] = None ,
589
670
llm : Optional [str ] = None ,
671
+ additional_tools : Optional [List ] = None ,
590
672
) -> None :
591
673
"""Update agent.
592
674
@@ -609,7 +691,6 @@ def update_agent(
609
691
# We call set_rag_params and create_agent, which will
610
692
# update the cache
611
693
# TODO: decouple functions from tool functions exposed to the agent
612
-
613
694
rag_params_dict : Dict [str , Any ] = {}
614
695
if include_summarization is not None :
615
696
rag_params_dict ["include_summarization" ] = include_summarization
@@ -623,6 +704,11 @@ def update_agent(
623
704
rag_params_dict ["llm" ] = llm
624
705
625
706
self .set_rag_params (** rag_params_dict )
707
+
708
+ # update tools
709
+ if additional_tools is not None :
710
+ self .cache .tools = additional_tools
711
+
626
712
# this will update the agent in the cache
627
713
self .create_agent ()
628
714
@@ -655,6 +741,33 @@ def update_agent(
655
741
# please make sure to update the LLM above if you change the function below
656
742
657
743
744
+ def _get_builder_agent_tools (agent_builder : RAGAgentBuilder ) -> List [FunctionTool ]:
745
+ """Get list of builder agent tools to pass to the builder agent."""
746
+ # see if metaphor api key is set, otherwise don't add web tool
747
+ # TODO: refactor this later
748
+
749
+ if "metaphor_key" in st .secrets :
750
+ fns : List [Callable ] = [
751
+ agent_builder .create_system_prompt ,
752
+ agent_builder .load_data ,
753
+ agent_builder .add_web_tool ,
754
+ agent_builder .get_rag_params ,
755
+ agent_builder .set_rag_params ,
756
+ agent_builder .create_agent ,
757
+ ]
758
+ else :
759
+ fns = [
760
+ agent_builder .create_system_prompt ,
761
+ agent_builder .load_data ,
762
+ agent_builder .get_rag_params ,
763
+ agent_builder .set_rag_params ,
764
+ agent_builder .create_agent ,
765
+ ]
766
+
767
+ fn_tools : List [FunctionTool ] = [FunctionTool .from_defaults (fn = fn ) for fn in fns ]
768
+ return fn_tools
769
+
770
+
658
771
# define agent
659
772
# @st.cache_resource
660
773
def load_meta_agent_and_tools (
@@ -664,15 +777,7 @@ def load_meta_agent_and_tools(
664
777
# think of this as tools for the agent to use
665
778
agent_builder = RAGAgentBuilder (cache )
666
779
667
- fns : List [Callable ] = [
668
- agent_builder .create_system_prompt ,
669
- agent_builder .load_data ,
670
- # add_web_tool,
671
- agent_builder .get_rag_params ,
672
- agent_builder .set_rag_params ,
673
- agent_builder .create_agent ,
674
- ]
675
- fn_tools = [FunctionTool .from_defaults (fn = fn ) for fn in fns ]
780
+ fn_tools = _get_builder_agent_tools (agent_builder )
676
781
677
782
builder_agent = load_meta_agent (
678
783
fn_tools , llm = BUILDER_LLM , system_prompt = RAG_BUILDER_SYS_STR , verbose = True
0 commit comments