Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add planning mode for AI Assistant #719

Merged
merged 9 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,9 +955,10 @@ async def answer(self, messages: list | str) -> hvPlotUIView:
response_model=self._get_model(schema),
)
spec = await self._extract_spec(output)
chain_of_thought = spec.pop("chain_of_thought")
with self.interface.add_step(title="Generating view...") as step:
step.stream(chain_of_thought)
chain_of_thought = spec.pop("chain_of_thought", None)
if chain_of_thought:
with self.interface.add_step(title="Generating view...") as step:
step.stream(chain_of_thought)
print(f"{self.name} settled on {spec=!r}.")
memory["current_view"] = dict(spec, type=self.view_type)
return self.view_type(pipeline=pipeline, **spec)
Expand Down
172 changes: 99 additions & 73 deletions lumen/ai/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import re

from io import StringIO
from typing import Literal

import param
import yaml
Expand All @@ -15,8 +14,6 @@
from panel.pane import HTML, Markdown
from panel.viewable import Viewer
from panel.widgets import Button, FileDownload
from pydantic import create_model
from pydantic.fields import FieldInfo

from .agents import (
Agent, AnalysisAgent, ChatAgent, SQLAgent,
Expand All @@ -26,7 +23,7 @@
from .llm import Llama, Llm
from .logs import ChatLogs
from .memory import memory
from .models import Validity
from .models import Validity, make_agent_model, make_plan_model
from .utils import get_schema, render_template, retry_llm_output


Expand All @@ -51,6 +48,8 @@ class Assistant(Viewer):

logs_filename = param.String()

planning = param.Boolean(default=False)

philippjfr marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
llm: Llm | None = None,
Expand Down Expand Up @@ -114,6 +113,13 @@ def download_notebook():
instantiated = []
self._analyses = []
for agent in agents or self.agents:
if isinstance(agent, AnalysisAgent):
analyses = "\n".join(
f"- `{analysis.__name__}`: {(analysis.__doc__ or '').strip()}"
for analysis in agent.analyses if analysis._callable_by_llm
)
agent.__doc__ = f"Available analyses include:\n{analyses}\nSelect this agent to perform one of these analyses."
break
if not isinstance(agent, Agent):
kwargs = {"llm": llm} if agent.llm is None else {}
agent = agent(interface=interface, **kwargs)
Expand Down Expand Up @@ -312,25 +318,8 @@ async def _chat_invoke(self, contents: list | str, user: str, instance: ChatInte
print("\033[94mNEW\033[0m" + "-" * 100)
await self.invoke(contents)

@staticmethod
def _create_agent_model(agent_names):
agent_model = create_model(
"RelevantAgent",
chain_of_thought=(
str,
FieldInfo(
description="Explain in your own words, what the user wants."
),
),
agent=(
Literal[agent_names],
FieldInfo(default=..., description="The most relevant agent to use.")
),
)
return agent_model

@retry_llm_output()
async def _create_valid_agent(self, messages, system, agent_model, errors=None):
async def _fill_model(self, messages, system, agent_model, errors=None):
if errors:
errors = '\n'.join(errors)
messages += [{"role": "user", "content": f"\nExpertly resolve these issues:\n{errors}"}]
Expand All @@ -342,78 +331,115 @@ async def _create_valid_agent(self, messages, system, agent_model, errors=None):
)
return out

async def _choose_agent(self, messages: list | str, agents: list[Agent]):
async def _choose_agent(self, messages: list | str, agents: list[Agent] | None = None, primary: bool = False, unmet_dependencies: tuple[str] | None = None):
if agents is None:
agents = self.agents
agents = [agent for agent in agents if await agent.applies()]
agent_names = tuple(sagent.name[:-5] for sagent in agents)
agent_model = make_agent_model(agent_names, primary=primary)
if len(agent_names) == 0:
raise ValueError("No agents available to choose from.")
if len(agent_names) == 1:
return agent_names[0]
return agent_model(agent=agent_names[0], chain_of_thought='')
self._current_agent.object = "## **Current Agent**: [Lumen.ai](https://lumen.holoviz.org/)"
agent_model = self._create_agent_model(agent_names)

for agent in agents:
if isinstance(agent, AnalysisAgent):
analyses = "\n".join(
f"- `{analysis.__name__}`: {(analysis.__doc__ or '').strip()}"
for analysis in agent.analyses if analysis._callable_by_llm
)
agent.__doc__ = f"Available analyses include:\n{analyses}\nSelect this agent to perform one of these analyses."
break

system = render_template(
"pick_agent.jinja2", agents=agents, current_agent=self._current_agent.object
'pick_agent.jinja2', agents=agents, current_agent=self._current_agent.object,
primary=primary, unmet_dependencies=unmet_dependencies
)
return await self._create_valid_agent(messages, system, agent_model)

async def _get_agent(self, messages: list | str):
if len(self.agents) == 1:
return self.agents[0]
agent_types = tuple(agent.name[:-5] for agent in self.agents)
agents = {agent.name[:-5]: agent for agent in self.agents}

if len(agent_types) == 1:
agent = agent_types[0]
return await self._fill_model(messages, system, agent_model)

async def _make_plan(self, messages: list | str, agents: dict[str, Agent]) -> list[tuple(Agent, any)]:
agent_names = tuple(sagent.name[:-5] for sagent in agents.values())
plan_model = make_plan_model(agent_names)
planned = False
unmet_dependencies = ()
with self.interface.add_step(title="Planning how to solve user query...", user="Assistant") as istep:
while not planned or unmet_dependencies:
system = render_template(
'plan_agent.jinja2', agents=list(agents.values()), current_agent=self._current_agent.object,
unmet_dependencies=unmet_dependencies, memory=memory
)
agent_chain = []
plan = await self._fill_model(messages, system, plan_model)
istep.stream(plan.chain_of_thought, replace=True)
step = plan.steps[-1]
subagent = agents[step.expert]
unmet_dependencies = tuple(
r for r in await subagent.requirements(messages) if r not in memory
)
agent_chain.append((subagent, unmet_dependencies, step.instruction))
for step in plan.steps[:-1][::-1]:
requires = tuple(await subagent.requirements(messages))
subagent = agents[step.expert]
unmet_dependencies = tuple(
dep for dep in set(unmet_dependencies+requires) if dep not in subagent.provides and dep not in memory
)
agent_chain.append((subagent, subagent.provides, step.instruction))
if unmet_dependencies:
istep.stream(f"The plan didn't account for {unmet_dependencies!r}", replace=True)
planned = True
istep.stream(f'{plan.chain_of_thought}\n\nHere are the steps:\n\n', replace=True)
for i, step in enumerate(plan.steps):
istep.stream(f"{i+1}. {step.instruction}\n")
istep.success_title = "Successfully came up with a plan."
return agent_chain[::-1]

async def _resolve_dependencies(self, messages, agents: dict[str, Agent]) -> list[tuple(Agent, any)]:
if len(agents) == 1:
agent = next(iter(agents.values()))
else:
with self.interface.add_step(title="Selecting relevant agent...", user="Assistant") as step:
output = await self._choose_agent(messages, self.agents)
with self.interface.add_step(title="Selecting primary agent...", user="Assistant") as step:
output = await self._choose_agent(messages, self.agents, primary=True)
step.stream(output.chain_of_thought, replace=True)
agent = output.agent
step.success_title = f"Selected {agent}"
step.success_title = f"Selected {output.agent}"
agent = agents[output.agent]

if agent is None:
return None
return []

print(
f"Assistant decided on \033[95m{agent!r}\033[0m"
)
selected = subagent = agents[agent]
subagent = agent
agent_chain = []
while unmet_dependencies := tuple(
while (unmet_dependencies := tuple(
r for r in await subagent.requirements(messages) if r not in memory
):
with self.interface.add_step(title="Solving dependency chain...") as step:
)):
with self.interface.add_step(title="Resolving dependencies...", user="Assistant") as step:
step.stream(f"Found {len(unmet_dependencies)} unmet dependencies: {', '.join(unmet_dependencies)}")
print(f"\033[91m### Unmet dependencies: {unmet_dependencies}\033[0m")
subagents = [
agent
for agent in self.agents
if any(ur in agent.provides for ur in unmet_dependencies)
]
output = await self._choose_agent(messages, subagents)
subagent_name = output.agent if not isinstance(output, str) else output
if subagent_name is None:
output = await self._choose_agent(messages, subagents, unmet_dependencies)
if output.agent is None:
continue
subagent = agents[subagent_name]
agent_chain.append((subagent, unmet_dependencies))
step.success_title = f"Solved a dependency with {subagent_name}"
for subagent, deps in agent_chain[::-1]:
with self.interface.add_step(title="Choosing subagent...") as step:
step.stream(f"Assistant decided the {subagent.name[:-5]!r} will provide {', '.join(deps)}.")
self._current_agent.object = f"## **Current Agent**: {subagent.name[:-5]}"
subagent = agents[output.agent]
agent_chain.append((subagent, unmet_dependencies, output.chain_of_thought))
step.success_title = f"Solved a dependency with {output.agent}"
return agent_chain[::-1]+[(agent, (), None)]

async def _get_agent(self, messages: list | str):
if len(self.agents) == 1:
return self.agents[0]

agents = {agent.name[:-5]: agent for agent in self.agents}
if self.planning:
agent_chain = await self._make_plan(messages, agents)
else:
agent_chain = await self._resolve_dependencies(messages, agents)

if not agent_chain:
return

selected = agent = agent_chain[-1][0]
print(f"Assistant decided on \033[95m{agent!r}\033[0m")
for subagent, deps, instruction in agent_chain[:-1]:
agent_name = type(subagent).name.replace('Agent', '')
with self.interface.add_step(title=f"Querying {agent_name} agent...") as step:
step.stream(f"Assistant decided the {agent_name!r} will provide {', '.join(deps)}.")
self._current_agent.object = f"## **Current Agent**: {agent_name}"
custom_messages = messages.copy()
if isinstance(subagent, SQLAgent):
custom_messages = messages.copy()
custom_agent = next((agent for agent in self.agents if isinstance(agent, AnalysisAgent)), None)
if custom_agent:
custom_analysis_doc = custom_agent.__doc__.replace("Available analyses include:\n", "")
Expand All @@ -422,10 +448,10 @@ async def _get_agent(self, messages: list | str):
f"Most likely, you'll just need to do a simple SELECT * FROM {{table}};"
)
custom_messages.append({"role": "user", "content": custom_message})
await subagent.answer(custom_messages)
else:
await subagent.answer(messages)
step.success_title = f"Selected {subagent.name[:-5]}"
if instruction:
custom_messages.append({"role": "user", "content": instruction})
await subagent.answer(custom_messages)
step.success_title = f"{agent_name} agent responded"
return selected

def _serialize(self, obj, exclude_passwords=True):
Expand Down
46 changes: 45 additions & 1 deletion lumen/ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from typing import Literal

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, create_model
from pydantic.fields import FieldInfo


class FuzzyTable(BaseModel):
Expand Down Expand Up @@ -99,3 +100,46 @@ class Topic(BaseModel):
class VegaLiteSpec(BaseModel):

json_spec: str = Field(description="A vega-lite JSON specification WITHOUT the data field, which will be added automatically.")


def make_plan_model(agent_names: tuple[str]):
step = create_model(
"Step",
expert=(Literal[agent_names], FieldInfo(description="The name of the expert to assign a task to.")),
instruction=(str, FieldInfo(description="Instructions to the expert to assist in the task."))
)
return create_model(
"Plan",
chain_of_thought=(
str,
FieldInfo(
description="Describe at a high-level how the actions of each expert will solve the user query."
),
),
steps=(
list[step],
FieldInfo(
description="A list of tuples of the experts name and instructions to that expert to help him solve the overall task"
)
),
)
philippjfr marked this conversation as resolved.
Show resolved Hide resolved


def make_agent_model(agent_names: tuple[str], primary: bool = False):
if primary:
description = "The agent that will provide the output the user requested, e.g. a plot or a table. This should be the FINAL step in your chain of thought."
else:
description = "The most relevant agent to use."
return create_model(
"Agent",
chain_of_thought=(
str,
FieldInfo(
description="Describe what this agent should do."
),
),
agent=(
Literal[agent_names],
FieldInfo(default=..., description=description)
),
)
20 changes: 12 additions & 8 deletions lumen/ai/prompts/pick_agent.jinja2
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
Select the most relevant agent for the user's query.

Each agent can request other agents to fill in the blanks, so pick the agent that can best answer the entire query.

Here's the choice of agents and their uses:
```
{% if primary %}
Select the agent that will provide the agent with the output they requested.
{% endif %}
Here's the choice of experts and their uses:
{% for agent in agents %}
- `{{ agent.name[:-5] }}`: {{ agent.__doc__.strip().split() | join(' ') }}
- `{{ agent.name[:-5] }}`
Provides: {{ agent.provides }}
Description: {{ agent.__doc__.strip().split() | join(' ') }}
{% endfor %}
```
{% if primary %}
If the request requires multiple steps, pick the agent that can will perform the final step and provide the user with the output the user asked for, e.g. if the request requires performing some calculation and a plot, pick the plotting agent. The agent you select can request other agents to fill in the blanks.
{% else %}
The agent is only responsible for answering part of the query and should provide one (or more) of the following pieces of information {{ unmet_dependencies }}.
{% endif %}
Loading