From 1d713fd98648fb3075316b5bafb1060f081e0929 Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Wed, 16 Oct 2024 06:26:48 -0400 Subject: [PATCH] Add planning mode for AI Assistant (#719) --- lumen/ai/__init__.py | 2 +- lumen/ai/agents.py | 20 +-- lumen/ai/assistant.py | 229 ++++++++++++++++++++--------- lumen/ai/models.py | 59 +++++++- lumen/ai/prompts/pick_agent.jinja2 | 20 ++- lumen/ai/prompts/plan_agent.jinja2 | 40 +++++ lumen/ai/utils.py | 2 + 7 files changed, 281 insertions(+), 91 deletions(-) create mode 100644 lumen/ai/prompts/plan_agent.jinja2 diff --git a/lumen/ai/__init__.py b/lumen/ai/__init__.py index 26e2ff82d..1d27fd2cb 100644 --- a/lumen/ai/__init__.py +++ b/lumen/ai/__init__.py @@ -2,7 +2,7 @@ from . import agents, embeddings, llm # noqa from .agents import Analysis # noqa -from .assistant import Assistant # noqa +from .assistant import Assistant, PlanningAssistant # noqa from .memory import memory # noqa pn.chat.message.DEFAULT_AVATARS.update({ diff --git a/lumen/ai/agents.py b/lumen/ai/agents.py index ec8d47d0e..80443bfd2 100644 --- a/lumen/ai/agents.py +++ b/lumen/ai/agents.py @@ -198,10 +198,10 @@ async def invoke(self, messages: list | str): class SourceAgent(Agent): """ - The SourceAgent allows a user to provide an input source. + The SourceAgent allows a user to upload datasets. - Should only be used if the user explicitly requests adding a source - or no source is in memory. + Use this if the user is requesting to add a dataset or you think + additional information is required to solve the user query. """ requires = param.List(default=[], readonly=True) @@ -456,8 +456,7 @@ async def invoke(self, messages: list | str): class TableListAgent(LumenBaseAgent): """ - List all of the available tables or datasets inventory. Not useful - if the user requests a specific table. + Provides a list of all availables tables/datasets. """ system_prompt = param.String( @@ -507,7 +506,9 @@ async def invoke(self, messages: list | str): class SQLAgent(LumenBaseAgent): """ Responsible for generating and modifying SQL queries to answer user queries about the data, - such querying subsets of the data, aggregating the data and calculating results. + such querying subsets of the data, aggregating the data and calculating results. If the + current table does not contain all the available data the SQL agent is also capable of + joining it with other tables. """ system_prompt = param.String( @@ -959,9 +960,10 @@ async def answer(self, messages: list | str) -> hvPlotUIView: response_model=self._get_model(schema), ) spec = await self._extract_spec(output) - chain_of_thought = spec.pop("chain_of_thought") - with self.interface.add_step(title="Generating view...") as step: - step.stream(chain_of_thought) + chain_of_thought = spec.pop("chain_of_thought", None) + if chain_of_thought: + with self.interface.add_step(title="Generating view...") as step: + step.stream(chain_of_thought) print(f"{self.name} settled on {spec=!r}.") memory["current_view"] = dict(spec, type=self.view_type) return self.view_type(pipeline=pipeline, **spec) diff --git a/lumen/ai/assistant.py b/lumen/ai/assistant.py index 38398d444..610a7853f 100644 --- a/lumen/ai/assistant.py +++ b/lumen/ai/assistant.py @@ -4,7 +4,7 @@ import re from io import StringIO -from typing import Literal +from typing import TYPE_CHECKING, Any import param import yaml @@ -15,8 +15,6 @@ from panel.pane import HTML, Markdown from panel.viewable import Viewer from panel.widgets import Button, FileDownload -from pydantic import create_model -from pydantic.fields import FieldInfo from .agents import ( Agent, AnalysisAgent, ChatAgent, SQLAgent, @@ -26,9 +24,15 @@ from .llm import Llama, Llm from .logs import ChatLogs from .memory import memory -from .models import Validity +from .models import Validity, make_agent_model, make_plan_models from .utils import get_schema, render_template, retry_llm_output +if TYPE_CHECKING: + from panel.chat.step import ChatStep + from pydantic import BaseModel + + from ..sources import Source + class Assistant(Viewer): """ @@ -114,6 +118,13 @@ def download_notebook(): instantiated = [] self._analyses = [] for agent in agents or self.agents: + if isinstance(agent, AnalysisAgent): + analyses = "\n".join( + f"- `{analysis.__name__}`: {(analysis.__doc__ or '').strip()}" + for analysis in agent.analyses if analysis._callable_by_llm + ) + agent.__doc__ = f"Available analyses include:\n{analyses}\nSelect this agent to perform one of these analyses." + break if not isinstance(agent, Agent): kwargs = {"llm": llm} if agent.llm is None else {} agent = agent(interface=interface, **kwargs) @@ -312,25 +323,8 @@ async def _chat_invoke(self, contents: list | str, user: str, instance: ChatInte print("\033[94mNEW\033[0m" + "-" * 100) await self.invoke(contents) - @staticmethod - def _create_agent_model(agent_names): - agent_model = create_model( - "RelevantAgent", - chain_of_thought=( - str, - FieldInfo( - description="Explain in your own words, what the user wants." - ), - ), - agent=( - Literal[agent_names], - FieldInfo(default=..., description="The most relevant agent to use.") - ), - ) - return agent_model - @retry_llm_output() - async def _create_valid_agent(self, messages, system, agent_model, errors=None): + async def _fill_model(self, messages, system, agent_model, errors=None): if errors: errors = '\n'.join(errors) messages += [{"role": "user", "content": f"\nExpertly resolve these issues:\n{errors}"}] @@ -342,57 +336,42 @@ async def _create_valid_agent(self, messages, system, agent_model, errors=None): ) return out - async def _choose_agent(self, messages: list | str, agents: list[Agent]): + async def _choose_agent(self, messages: list | str, agents: list[Agent] | None = None, primary: bool = False, unmet_dependencies: tuple[str] | None = None): + if agents is None: + agents = self.agents agents = [agent for agent in agents if await agent.applies()] agent_names = tuple(sagent.name[:-5] for sagent in agents) + agent_model = make_agent_model(agent_names, primary=primary) if len(agent_names) == 0: raise ValueError("No agents available to choose from.") if len(agent_names) == 1: - return agent_names[0] + return agent_model(agent=agent_names[0], chain_of_thought='') self._current_agent.object = "## **Current Agent**: [Lumen.ai](https://lumen.holoviz.org/)" - agent_model = self._create_agent_model(agent_names) - - for agent in agents: - if isinstance(agent, AnalysisAgent): - analyses = "\n".join( - f"- `{analysis.__name__}`: {(analysis.__doc__ or '').strip()}" - for analysis in agent.analyses if analysis._callable_by_llm - ) - agent.__doc__ = f"Available analyses include:\n{analyses}\nSelect this agent to perform one of these analyses." - break - system = render_template( - "pick_agent.jinja2", agents=agents, current_agent=self._current_agent.object + 'pick_agent.jinja2', agents=agents, current_agent=self._current_agent.object, + primary=primary, unmet_dependencies=unmet_dependencies ) - return await self._create_valid_agent(messages, system, agent_model) + return await self._fill_model(messages, system, agent_model) - async def _get_agent(self, messages: list | str): - if len(self.agents) == 1: - return self.agents[0] - agent_types = tuple(agent.name[:-5] for agent in self.agents) - agents = {agent.name[:-5]: agent for agent in self.agents} - - if len(agent_types) == 1: - agent = agent_types[0] + async def _resolve_dependencies(self, messages, agents: dict[str, Agent]) -> list[tuple(Agent, any)]: + if len(agents) == 1: + agent = next(iter(agents.values())) else: - with self.interface.add_step(title="Selecting relevant agent...", user="Assistant") as step: - output = await self._choose_agent(messages, self.agents) + with self.interface.add_step(title="Selecting primary agent...", user="Assistant") as step: + output = await self._choose_agent(messages, self.agents, primary=True) step.stream(output.chain_of_thought, replace=True) - agent = output.agent - step.success_title = f"Selected {agent}" + step.success_title = f"Selected {output.agent}" + agent = agents[output.agent] if agent is None: - return None + return [] - print( - f"Assistant decided on \033[95m{agent!r}\033[0m" - ) - selected = subagent = agents[agent] + subagent = agent agent_chain = [] - while unmet_dependencies := tuple( + while (unmet_dependencies := tuple( r for r in await subagent.requirements(messages) if r not in memory - ): - with self.interface.add_step(title="Solving dependency chain...") as step: + )): + with self.interface.add_step(title="Resolving dependencies...", user="Assistant") as step: step.stream(f"Found {len(unmet_dependencies)} unmet dependencies: {', '.join(unmet_dependencies)}") print(f"\033[91m### Unmet dependencies: {unmet_dependencies}\033[0m") subagents = [ @@ -400,20 +379,33 @@ async def _get_agent(self, messages: list | str): for agent in self.agents if any(ur in agent.provides for ur in unmet_dependencies) ] - output = await self._choose_agent(messages, subagents) - subagent_name = output.agent if not isinstance(output, str) else output - if subagent_name is None: + output = await self._choose_agent(messages, subagents, unmet_dependencies) + if output.agent is None: continue - subagent = agents[subagent_name] - agent_chain.append((subagent, unmet_dependencies)) - step.success_title = f"Solved a dependency with {subagent_name}" - for subagent, deps in agent_chain[::-1]: - with self.interface.add_step(title="Choosing subagent...") as step: - step.stream(f"Assistant decided the {subagent.name[:-5]!r} will provide {', '.join(deps)}.") - self._current_agent.object = f"## **Current Agent**: {subagent.name[:-5]}" + subagent = agents[output.agent] + agent_chain.append((subagent, unmet_dependencies, output.chain_of_thought)) + step.success_title = f"Solved a dependency with {output.agent}" + return agent_chain[::-1]+[(agent, (), None)] + + async def _get_agent(self, messages: list | str): + if len(self.agents) == 1: + return self.agents[0] + agents = {agent.name[:-5]: agent for agent in self.agents} + agent_chain = await self._resolve_dependencies(messages, agents) + + if not agent_chain: + return + + selected = agent = agent_chain[-1][0] + print(f"Assistant decided on \033[95m{agent!r}\033[0m") + for subagent, deps, instruction in agent_chain[:-1]: + agent_name = type(subagent).name.replace('Agent', '') + with self.interface.add_step(title=f"Querying {agent_name} agent...") as step: + step.stream(f"Assistant decided the {agent_name!r} will provide {', '.join(deps)}.") + self._current_agent.object = f"## **Current Agent**: {agent_name}" + custom_messages = messages.copy() if isinstance(subagent, SQLAgent): - custom_messages = messages.copy() custom_agent = next((agent for agent in self.agents if isinstance(agent, AnalysisAgent)), None) if custom_agent: custom_analysis_doc = custom_agent.__doc__.replace("Available analyses include:\n", "") @@ -422,10 +414,10 @@ async def _get_agent(self, messages: list | str): f"Most likely, you'll just need to do a simple SELECT * FROM {{table}};" ) custom_messages.append({"role": "user", "content": custom_message}) - await subagent.answer(custom_messages) - else: - await subagent.answer(messages) - step.success_title = f"Selected {subagent.name[:-5]}" + if instruction: + custom_messages.append({"role": "user", "content": instruction}) + await subagent.answer(custom_messages) + step.success_title = f"{agent_name} agent responded" return selected def _serialize(self, obj, exclude_passwords=True): @@ -490,3 +482,96 @@ def controls(self): def __panel__(self): return self.interface + + +class PlanningAssistant(Assistant): + """ + The PlanningAssistant develops a plan and then executes it + instead of simply resolving the dependencies step-by-step. + """ + + async def _make_plan( + self, + user_msg: dict[str, Any], + messages: list, + agents: dict[str, Agent], + tables: dict[str, Source], + unmet_dependencies: set[str], + reason_model: type[BaseModel], + plan_model: type[BaseModel], + step: ChatStep + ): + info = '' + reasoning = None + requested_tables, provided_tables = [], [] + if 'current_table' in memory: + requested_tables.append(memory['current_table']) + elif len(tables) == 1: + requested_tables.append(next(iter(tables))) + while reasoning is None or requested_tables: + # Add context of table schemas + schemas = [] + requested = getattr(reasoning, 'tables', requested_tables) + for table in requested: + if table in provided_tables: + continue + provided_tables.append(table) + schemas.append(get_schema(tables[table], table, limit=3)) + for table, schema in zip(requested, await asyncio.gather(*schemas)): + info += f'- {table}: {schema}\n\n' + system = render_template( + 'plan_agent.jinja2', agents=list(agents.values()), current_agent=self._current_agent.object, + unmet_dependencies=unmet_dependencies, memory=memory, table_info=info, tables=list(tables) + ) + async for reasoning in self.llm.stream( + messages=messages, + system=system, + response_model=reason_model, + ): + step.stream(reasoning.chain_of_thought, replace=True) + requested_tables = [t for t in reasoning.tables if t and t not in provided_tables] + if requested_tables: + continue + new_msg = dict(role=user_msg['role'], content=f"{user_msg['content']} {reasoning.chain_of_thought}") + messages = messages[:-1] + [new_msg] + plan = await self._fill_model(messages, system, plan_model) + return plan + + async def _resolve_dependencies(self, messages: list, agents: dict[str, Agent]) -> list[tuple(Agent, any)]: + agent_names = tuple(sagent.name[:-5] for sagent in agents.values()) + tables = {} + for src in memory['available_sources']: + for table in src.get_tables(): + tables[table] = src + + reason_model, plan_model = make_plan_models(agent_names, list(tables)) + planned = False + unmet_dependencies = set() + user_msg = messages[-1] + with self.interface.add_step(title="Planning how to solve user query...", user="Assistant") as istep: + while not planned or unmet_dependencies: + plan = await self._make_plan( + user_msg, messages, agents, tables, unmet_dependencies, reason_model, plan_model, istep + ) + step = plan.steps[-1] + subagent = agents[step.expert] + unmet_dependencies = { + r for r in await subagent.requirements(messages) if r not in memory + } + agent_chain = [(subagent, unmet_dependencies, step.instruction)] + for step in plan.steps[:-1][::-1]: + subagent = agents[step.expert] + requires = set(await subagent.requirements(messages)) + unmet_dependencies = { + dep for dep in (unmet_dependencies | requires) + if dep not in subagent.provides and dep not in memory + } + agent_chain.append((subagent, subagent.provides, step.instruction)) + if unmet_dependencies: + istep.stream(f"The plan didn't account for {unmet_dependencies!r}", replace=True) + planned = True + istep.stream('\n\nHere are the steps:\n\n') + for i, step in enumerate(plan.steps): + istep.stream(f"{i+1}. {step.expert}: {step.instruction}\n") + istep.success_title = "Successfully came up with a plan." + return agent_chain[::-1] diff --git a/lumen/ai/models.py b/lumen/ai/models.py index 892dca554..abb16dac5 100644 --- a/lumen/ai/models.py +++ b/lumen/ai/models.py @@ -2,7 +2,8 @@ from typing import Literal -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, create_model +from pydantic.fields import FieldInfo class FuzzyTable(BaseModel): @@ -99,3 +100,59 @@ class Topic(BaseModel): class VegaLiteSpec(BaseModel): json_spec: str = Field(description="A vega-lite JSON specification WITHOUT the data field, which will be added automatically.") + + +def make_plan_models(agent_names: list[str], tables: list[str]): + step = create_model( + "Step", + expert=(Literal[agent_names], FieldInfo(description="The name of the expert to assign a task to.")), + instruction=(str, FieldInfo(description="Instructions to the expert to assist in the task.")) + ) + extras = {} + if tables: + extras['tables'] = ( + list[Literal[tuple(tables)]], + FieldInfo( + description="A list of tables you want to inspect before coming up with a plan." + ) + ) + reasoning = create_model( + 'Reasoning', + chain_of_thought=( + str, + FieldInfo( + description="Describe at a high-level how the actions of each expert will solve the user query." + ), + ), + **extras + ) + plan = create_model( + "Plan", + steps=( + list[step], + FieldInfo( + description="A list of steps to perform that will solve user query. Ensure you include ALL the steps needed to solve the task, matching the chain of thought." + ) + ) + ) + return reasoning, plan + + +def make_agent_model(agent_names: list[str], primary: bool = False): + if primary: + description = "The agent that will provide the output the user requested, e.g. a plot or a table. This should be the FINAL step in your chain of thought." + else: + description = "The most relevant agent to use." + return create_model( + "Agent", + chain_of_thought=( + str, + FieldInfo( + description="Describe what this agent should do." + ), + ), + agent=( + Literal[tuple(agent_names)], + FieldInfo(default=..., description=description) + ), + ) diff --git a/lumen/ai/prompts/pick_agent.jinja2 b/lumen/ai/prompts/pick_agent.jinja2 index 1d877606f..17d5c534f 100644 --- a/lumen/ai/prompts/pick_agent.jinja2 +++ b/lumen/ai/prompts/pick_agent.jinja2 @@ -1,10 +1,14 @@ -Select the most relevant agent for the user's query. - -Each agent can request other agents to fill in the blanks, so pick the agent that can best answer the entire query. - -Here's the choice of agents and their uses: -``` +{% if primary %} +Select the agent that will provide the agent with the output they requested. +{% endif %} +Here's the choice of experts and their uses: {% for agent in agents %} -- `{{ agent.name[:-5] }}`: {{ agent.__doc__.strip().split() | join(' ') }} +- `{{ agent.name[:-5] }}` + Provides: {{ agent.provides }} + Description: {{ agent.__doc__.strip().split() | join(' ') }} {% endfor %} -``` +{% if primary %} +If the request requires multiple steps, pick the agent that can will perform the final step and provide the user with the output the user asked for, e.g. if the request requires performing some calculation and a plot, pick the plotting agent. The agent you select can request other agents to fill in the blanks. +{% else %} +The agent is only responsible for answering part of the query and should provide one (or more) of the following pieces of information {{ unmet_dependencies }}. +{% endif %} diff --git a/lumen/ai/prompts/plan_agent.jinja2 b/lumen/ai/prompts/plan_agent.jinja2 new file mode 100644 index 000000000..2ab1fa572 --- /dev/null +++ b/lumen/ai/prompts/plan_agent.jinja2 @@ -0,0 +1,40 @@ +You are team lead and have to make a plan to solve how to address the user query. + +Ensure that the plan solves the entire problem, step-by-step and ensure all steps listed in the chain of thought are listed! + +If some piece of information is already available to you only call an agent to provide the same piece of information if absolutely necessary, e.g. if 'current_table' is avaible do not call the TableAgent again. + +You have to choose which of the experts at your disposal should address the problem. + +Each of these experts requires certain information and has the ability to provide certain information. + +Ensure that you provide each expert some context to ensure they do not repeat previous steps. + +Currently you have the following information available to you: +{% for item in memory.keys() %} +- {{ item }} +{% endfor %} +And have access to the following data tables: +{% for table in tables %} +- {{ table }} +{% endfor %} +{% if table_info %} +In order to make an informed decision here are schemas for the most relevant tables: +{{ table_info }} +Do not request any additional tables. +{% endif %} +Here's the choice of experts and their uses: +{% for agent in agents %} +- `{{ agent.name[:-5] }}` + Requires: {{ agent.requires }} + Provides: {{ agent.provides }} + Description: {{ agent.__doc__.strip().split() | join(' ') }} +{% endfor %} + +{% if not table_info %} +If you do not think you can solve the problem given the current information provide a list of requested tables. +{% endif %} + +{% if unmet_dependencies %} +Note that a previous plan was unsuccessful because it did not satisfy the following required pieces of information: {unmet_dependencies!r} +{% endif %} diff --git a/lumen/ai/utils.py b/lumen/ai/utils.py index c4fcb169a..1322eef23 100644 --- a/lumen/ai/utils.py +++ b/lumen/ai/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import inspect import time