Skip to content

Commit

Permalink
Add planning mode for AI Assistant (#719)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr authored Oct 16, 2024
1 parent 2de7246 commit 1d713fd
Show file tree
Hide file tree
Showing 7 changed files with 281 additions and 91 deletions.
2 changes: 1 addition & 1 deletion lumen/ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from . import agents, embeddings, llm # noqa
from .agents import Analysis # noqa
from .assistant import Assistant # noqa
from .assistant import Assistant, PlanningAssistant # noqa
from .memory import memory # noqa

pn.chat.message.DEFAULT_AVATARS.update({
Expand Down
20 changes: 11 additions & 9 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,10 @@ async def invoke(self, messages: list | str):

class SourceAgent(Agent):
"""
The SourceAgent allows a user to provide an input source.
The SourceAgent allows a user to upload datasets.
Should only be used if the user explicitly requests adding a source
or no source is in memory.
Use this if the user is requesting to add a dataset or you think
additional information is required to solve the user query.
"""

requires = param.List(default=[], readonly=True)
Expand Down Expand Up @@ -456,8 +456,7 @@ async def invoke(self, messages: list | str):

class TableListAgent(LumenBaseAgent):
"""
List all of the available tables or datasets inventory. Not useful
if the user requests a specific table.
Provides a list of all availables tables/datasets.
"""

system_prompt = param.String(
Expand Down Expand Up @@ -507,7 +506,9 @@ async def invoke(self, messages: list | str):
class SQLAgent(LumenBaseAgent):
"""
Responsible for generating and modifying SQL queries to answer user queries about the data,
such querying subsets of the data, aggregating the data and calculating results.
such querying subsets of the data, aggregating the data and calculating results. If the
current table does not contain all the available data the SQL agent is also capable of
joining it with other tables.
"""

system_prompt = param.String(
Expand Down Expand Up @@ -959,9 +960,10 @@ async def answer(self, messages: list | str) -> hvPlotUIView:
response_model=self._get_model(schema),
)
spec = await self._extract_spec(output)
chain_of_thought = spec.pop("chain_of_thought")
with self.interface.add_step(title="Generating view...") as step:
step.stream(chain_of_thought)
chain_of_thought = spec.pop("chain_of_thought", None)
if chain_of_thought:
with self.interface.add_step(title="Generating view...") as step:
step.stream(chain_of_thought)
print(f"{self.name} settled on {spec=!r}.")
memory["current_view"] = dict(spec, type=self.view_type)
return self.view_type(pipeline=pipeline, **spec)
Expand Down
229 changes: 157 additions & 72 deletions lumen/ai/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re

from io import StringIO
from typing import Literal
from typing import TYPE_CHECKING, Any

import param
import yaml
Expand All @@ -15,8 +15,6 @@
from panel.pane import HTML, Markdown
from panel.viewable import Viewer
from panel.widgets import Button, FileDownload
from pydantic import create_model
from pydantic.fields import FieldInfo

from .agents import (
Agent, AnalysisAgent, ChatAgent, SQLAgent,
Expand All @@ -26,9 +24,15 @@
from .llm import Llama, Llm
from .logs import ChatLogs
from .memory import memory
from .models import Validity
from .models import Validity, make_agent_model, make_plan_models
from .utils import get_schema, render_template, retry_llm_output

if TYPE_CHECKING:
from panel.chat.step import ChatStep
from pydantic import BaseModel

from ..sources import Source


class Assistant(Viewer):
"""
Expand Down Expand Up @@ -114,6 +118,13 @@ def download_notebook():
instantiated = []
self._analyses = []
for agent in agents or self.agents:
if isinstance(agent, AnalysisAgent):
analyses = "\n".join(
f"- `{analysis.__name__}`: {(analysis.__doc__ or '').strip()}"
for analysis in agent.analyses if analysis._callable_by_llm
)
agent.__doc__ = f"Available analyses include:\n{analyses}\nSelect this agent to perform one of these analyses."
break
if not isinstance(agent, Agent):
kwargs = {"llm": llm} if agent.llm is None else {}
agent = agent(interface=interface, **kwargs)
Expand Down Expand Up @@ -312,25 +323,8 @@ async def _chat_invoke(self, contents: list | str, user: str, instance: ChatInte
print("\033[94mNEW\033[0m" + "-" * 100)
await self.invoke(contents)

@staticmethod
def _create_agent_model(agent_names):
agent_model = create_model(
"RelevantAgent",
chain_of_thought=(
str,
FieldInfo(
description="Explain in your own words, what the user wants."
),
),
agent=(
Literal[agent_names],
FieldInfo(default=..., description="The most relevant agent to use.")
),
)
return agent_model

@retry_llm_output()
async def _create_valid_agent(self, messages, system, agent_model, errors=None):
async def _fill_model(self, messages, system, agent_model, errors=None):
if errors:
errors = '\n'.join(errors)
messages += [{"role": "user", "content": f"\nExpertly resolve these issues:\n{errors}"}]
Expand All @@ -342,78 +336,76 @@ async def _create_valid_agent(self, messages, system, agent_model, errors=None):
)
return out

async def _choose_agent(self, messages: list | str, agents: list[Agent]):
async def _choose_agent(self, messages: list | str, agents: list[Agent] | None = None, primary: bool = False, unmet_dependencies: tuple[str] | None = None):
if agents is None:
agents = self.agents
agents = [agent for agent in agents if await agent.applies()]
agent_names = tuple(sagent.name[:-5] for sagent in agents)
agent_model = make_agent_model(agent_names, primary=primary)
if len(agent_names) == 0:
raise ValueError("No agents available to choose from.")
if len(agent_names) == 1:
return agent_names[0]
return agent_model(agent=agent_names[0], chain_of_thought='')
self._current_agent.object = "## **Current Agent**: [Lumen.ai](https://lumen.holoviz.org/)"
agent_model = self._create_agent_model(agent_names)

for agent in agents:
if isinstance(agent, AnalysisAgent):
analyses = "\n".join(
f"- `{analysis.__name__}`: {(analysis.__doc__ or '').strip()}"
for analysis in agent.analyses if analysis._callable_by_llm
)
agent.__doc__ = f"Available analyses include:\n{analyses}\nSelect this agent to perform one of these analyses."
break

system = render_template(
"pick_agent.jinja2", agents=agents, current_agent=self._current_agent.object
'pick_agent.jinja2', agents=agents, current_agent=self._current_agent.object,
primary=primary, unmet_dependencies=unmet_dependencies
)
return await self._create_valid_agent(messages, system, agent_model)
return await self._fill_model(messages, system, agent_model)

async def _get_agent(self, messages: list | str):
if len(self.agents) == 1:
return self.agents[0]
agent_types = tuple(agent.name[:-5] for agent in self.agents)
agents = {agent.name[:-5]: agent for agent in self.agents}

if len(agent_types) == 1:
agent = agent_types[0]
async def _resolve_dependencies(self, messages, agents: dict[str, Agent]) -> list[tuple(Agent, any)]:
if len(agents) == 1:
agent = next(iter(agents.values()))
else:
with self.interface.add_step(title="Selecting relevant agent...", user="Assistant") as step:
output = await self._choose_agent(messages, self.agents)
with self.interface.add_step(title="Selecting primary agent...", user="Assistant") as step:
output = await self._choose_agent(messages, self.agents, primary=True)
step.stream(output.chain_of_thought, replace=True)
agent = output.agent
step.success_title = f"Selected {agent}"
step.success_title = f"Selected {output.agent}"
agent = agents[output.agent]

if agent is None:
return None
return []

print(
f"Assistant decided on \033[95m{agent!r}\033[0m"
)
selected = subagent = agents[agent]
subagent = agent
agent_chain = []
while unmet_dependencies := tuple(
while (unmet_dependencies := tuple(
r for r in await subagent.requirements(messages) if r not in memory
):
with self.interface.add_step(title="Solving dependency chain...") as step:
)):
with self.interface.add_step(title="Resolving dependencies...", user="Assistant") as step:
step.stream(f"Found {len(unmet_dependencies)} unmet dependencies: {', '.join(unmet_dependencies)}")
print(f"\033[91m### Unmet dependencies: {unmet_dependencies}\033[0m")
subagents = [
agent
for agent in self.agents
if any(ur in agent.provides for ur in unmet_dependencies)
]
output = await self._choose_agent(messages, subagents)
subagent_name = output.agent if not isinstance(output, str) else output
if subagent_name is None:
output = await self._choose_agent(messages, subagents, unmet_dependencies)
if output.agent is None:
continue
subagent = agents[subagent_name]
agent_chain.append((subagent, unmet_dependencies))
step.success_title = f"Solved a dependency with {subagent_name}"
for subagent, deps in agent_chain[::-1]:
with self.interface.add_step(title="Choosing subagent...") as step:
step.stream(f"Assistant decided the {subagent.name[:-5]!r} will provide {', '.join(deps)}.")
self._current_agent.object = f"## **Current Agent**: {subagent.name[:-5]}"
subagent = agents[output.agent]
agent_chain.append((subagent, unmet_dependencies, output.chain_of_thought))
step.success_title = f"Solved a dependency with {output.agent}"
return agent_chain[::-1]+[(agent, (), None)]

async def _get_agent(self, messages: list | str):
if len(self.agents) == 1:
return self.agents[0]

agents = {agent.name[:-5]: agent for agent in self.agents}
agent_chain = await self._resolve_dependencies(messages, agents)

if not agent_chain:
return

selected = agent = agent_chain[-1][0]
print(f"Assistant decided on \033[95m{agent!r}\033[0m")
for subagent, deps, instruction in agent_chain[:-1]:
agent_name = type(subagent).name.replace('Agent', '')
with self.interface.add_step(title=f"Querying {agent_name} agent...") as step:
step.stream(f"Assistant decided the {agent_name!r} will provide {', '.join(deps)}.")
self._current_agent.object = f"## **Current Agent**: {agent_name}"
custom_messages = messages.copy()
if isinstance(subagent, SQLAgent):
custom_messages = messages.copy()
custom_agent = next((agent for agent in self.agents if isinstance(agent, AnalysisAgent)), None)
if custom_agent:
custom_analysis_doc = custom_agent.__doc__.replace("Available analyses include:\n", "")
Expand All @@ -422,10 +414,10 @@ async def _get_agent(self, messages: list | str):
f"Most likely, you'll just need to do a simple SELECT * FROM {{table}};"
)
custom_messages.append({"role": "user", "content": custom_message})
await subagent.answer(custom_messages)
else:
await subagent.answer(messages)
step.success_title = f"Selected {subagent.name[:-5]}"
if instruction:
custom_messages.append({"role": "user", "content": instruction})
await subagent.answer(custom_messages)
step.success_title = f"{agent_name} agent responded"
return selected

def _serialize(self, obj, exclude_passwords=True):
Expand Down Expand Up @@ -490,3 +482,96 @@ def controls(self):

def __panel__(self):
return self.interface


class PlanningAssistant(Assistant):
"""
The PlanningAssistant develops a plan and then executes it
instead of simply resolving the dependencies step-by-step.
"""

async def _make_plan(
self,
user_msg: dict[str, Any],
messages: list,
agents: dict[str, Agent],
tables: dict[str, Source],
unmet_dependencies: set[str],
reason_model: type[BaseModel],
plan_model: type[BaseModel],
step: ChatStep
):
info = ''
reasoning = None
requested_tables, provided_tables = [], []
if 'current_table' in memory:
requested_tables.append(memory['current_table'])
elif len(tables) == 1:
requested_tables.append(next(iter(tables)))
while reasoning is None or requested_tables:
# Add context of table schemas
schemas = []
requested = getattr(reasoning, 'tables', requested_tables)
for table in requested:
if table in provided_tables:
continue
provided_tables.append(table)
schemas.append(get_schema(tables[table], table, limit=3))
for table, schema in zip(requested, await asyncio.gather(*schemas)):
info += f'- {table}: {schema}\n\n'
system = render_template(
'plan_agent.jinja2', agents=list(agents.values()), current_agent=self._current_agent.object,
unmet_dependencies=unmet_dependencies, memory=memory, table_info=info, tables=list(tables)
)
async for reasoning in self.llm.stream(
messages=messages,
system=system,
response_model=reason_model,
):
step.stream(reasoning.chain_of_thought, replace=True)
requested_tables = [t for t in reasoning.tables if t and t not in provided_tables]
if requested_tables:
continue
new_msg = dict(role=user_msg['role'], content=f"<user query>{user_msg['content']}</user query> {reasoning.chain_of_thought}")
messages = messages[:-1] + [new_msg]
plan = await self._fill_model(messages, system, plan_model)
return plan

async def _resolve_dependencies(self, messages: list, agents: dict[str, Agent]) -> list[tuple(Agent, any)]:
agent_names = tuple(sagent.name[:-5] for sagent in agents.values())
tables = {}
for src in memory['available_sources']:
for table in src.get_tables():
tables[table] = src

reason_model, plan_model = make_plan_models(agent_names, list(tables))
planned = False
unmet_dependencies = set()
user_msg = messages[-1]
with self.interface.add_step(title="Planning how to solve user query...", user="Assistant") as istep:
while not planned or unmet_dependencies:
plan = await self._make_plan(
user_msg, messages, agents, tables, unmet_dependencies, reason_model, plan_model, istep
)
step = plan.steps[-1]
subagent = agents[step.expert]
unmet_dependencies = {
r for r in await subagent.requirements(messages) if r not in memory
}
agent_chain = [(subagent, unmet_dependencies, step.instruction)]
for step in plan.steps[:-1][::-1]:
subagent = agents[step.expert]
requires = set(await subagent.requirements(messages))
unmet_dependencies = {
dep for dep in (unmet_dependencies | requires)
if dep not in subagent.provides and dep not in memory
}
agent_chain.append((subagent, subagent.provides, step.instruction))
if unmet_dependencies:
istep.stream(f"The plan didn't account for {unmet_dependencies!r}", replace=True)
planned = True
istep.stream('\n\nHere are the steps:\n\n')
for i, step in enumerate(plan.steps):
istep.stream(f"{i+1}. {step.expert}: {step.instruction}\n")
istep.success_title = "Successfully came up with a plan."
return agent_chain[::-1]
Loading

0 comments on commit 1d713fd

Please sign in to comment.