Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various tweaks for assistant, agents and utilities #728

Merged
merged 5 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,9 @@ async def invoke(self, messages: list | str):

class SourceAgent(Agent):
"""
The SourceAgent allows a user to upload datasets.
The SourceAgent allows a user to upload new datasets.

Use this if the user is requesting to add a dataset or you think
Only use this if the user is requesting to add a dataset or you think
additional information is required to solve the user query.
"""

Expand Down Expand Up @@ -352,7 +352,6 @@ def _render_lumen(self, component: Component, message: pn.chat.ChatMessage = Non
self.interface.stream(message=message, **message_kwargs, replace=True, max_width=self._max_width)



class TableAgent(LumenBaseAgent):
"""
Displays a single table / dataset. Does not discuss.
Expand Down Expand Up @@ -479,8 +478,9 @@ def _use_table(self, event):
self.interface.send(f"Show the table: {table!r}")

async def answer(self, messages: list | str):
source = memory["current_source"]
tables = source.get_tables()
tables = []
for source in memory['available_sources']:
tables += source.get_tables()
if not tables:
return

Expand Down Expand Up @@ -720,7 +720,17 @@ async def answer(self, messages: list | str):
table_schema = schema
else:
table_schema = await get_schema(source, source_table, include_min_max=False)
table_schemas[source_table] = {

# Look up underlying table name
table_name = source_table
if (
'tables' in source.param and
isinstance(source.tables, dict) and
'select ' not in source.tables[table_name].lower()
):
table_name = source.tables[table_name]

table_schemas[table_name] = {
"schema": yaml.dump(table_schema),
"sql": source.get_sql_expr(source_table)
}
Expand Down
109 changes: 67 additions & 42 deletions lumen/ai/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re

from io import StringIO
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

import param
import yaml
Expand Down Expand Up @@ -402,7 +402,7 @@ async def _get_agent(self, messages: list | str):
for subagent, deps, instruction in agent_chain[:-1]:
agent_name = type(subagent).name.replace('Agent', '')
with self.interface.add_step(title=f"Querying {agent_name} agent...") as step:
step.stream(f"Assistant decided the {agent_name!r} will provide {', '.join(deps)}.")
step.stream(f"`{agent_name}` agent is working on the following task:\n\n{instruction}")
self._current_agent.object = f"## **Current Agent**: {agent_name}"
custom_messages = messages.copy()
if isinstance(subagent, SQLAgent):
Expand All @@ -417,7 +417,8 @@ async def _get_agent(self, messages: list | str):
if instruction:
custom_messages.append({"role": "user", "content": instruction})
await subagent.answer(custom_messages)
step.success_title = f"{agent_name} agent responded"
step.stream(f"`{agent_name}` agent successfully completed the following task:\n\n- {instruction}", replace=True)
step.success_title = f"{agent_name} agent successfully responded"
return selected

def _serialize(self, obj, exclude_passwords=True):
Expand Down Expand Up @@ -490,53 +491,89 @@ class PlanningAssistant(Assistant):
instead of simply resolving the dependencies step-by-step.
"""

@classmethod
async def _lookup_schemas(
cls,
tables: dict[str, Source],
requested: list[str],
provided: list[str],
cache: dict[str, dict] | None = None
) -> str:
cache = cache or {}
to_query, queries = [], []
for table in requested:
if table in provided or table in cache:
continue
to_query.append(table)
queries.append(get_schema(tables[table], table, limit=3))
for table, schema in zip(to_query, await asyncio.gather(*queries)):
cache[table] = schema
schema_info = ''
for table in requested:
if table in provided:
continue
provided.append(table)
schema_info += f'- {table}: {cache[table]}\n\n'
return schema_info

async def _make_plan(
self,
user_msg: dict[str, Any],
messages: list,
agents: dict[str, Agent],
tables: dict[str, Source],
unmet_dependencies: set[str],
reason_model: type[BaseModel],
plan_model: type[BaseModel],
step: ChatStep
):
step: ChatStep,
schemas: dict[str, dict] | None = None
) -> BaseModel:
user_msg = messages[-1]
info = ''
reasoning = None
requested_tables, provided_tables = [], []
requested, provided = [], []
if 'current_table' in memory:
requested_tables.append(memory['current_table'])
requested.append(memory['current_table'])
elif len(tables) == 1:
requested_tables.append(next(iter(tables)))
while reasoning is None or requested_tables:
# Add context of table schemas
schemas = []
requested = getattr(reasoning, 'tables', requested_tables)
for table in requested:
if table in provided_tables:
continue
provided_tables.append(table)
schemas.append(get_schema(tables[table], table, limit=3))
for table, schema in zip(requested, await asyncio.gather(*schemas)):
info += f'- {table}: {schema}\n\n'
requested.append(next(iter(tables)))
while reasoning is None or requested:
info += await self._lookup_schemas(tables, requested, provided, cache=schemas)
available = [t for t in tables if t not in provided]
system = render_template(
'plan_agent.jinja2', agents=list(agents.values()), current_agent=self._current_agent.object,
unmet_dependencies=unmet_dependencies, memory=memory, table_info=info, tables=list(tables)
unmet_dependencies=unmet_dependencies, memory=memory, table_info=info, tables=available
)
async for reasoning in self.llm.stream(
messages=messages,
system=system,
response_model=reason_model,
):
step.stream(reasoning.chain_of_thought, replace=True)
requested_tables = [t for t in reasoning.tables if t and t not in provided_tables]
if requested_tables:
continue
requested = [
t for t in getattr(reasoning, 'tables', [])
if t and t not in provided
]
new_msg = dict(role=user_msg['role'], content=f"<user query>{user_msg['content']}</user query> {reasoning.chain_of_thought}")
messages = messages[:-1] + [new_msg]
plan = await self._fill_model(messages, system, plan_model)
return plan

async def _resolve_plan(self, plan, agents, messages):
step = plan.steps[-1]
subagent = agents[step.expert]
unmet_dependencies = {
r for r in await subagent.requirements(messages) if r not in memory
}
agent_chain = [(subagent, unmet_dependencies, step.instruction)]
for step in plan.steps[:-1][::-1]:
subagent = agents[step.expert]
requires = set(await subagent.requirements(messages))
unmet_dependencies = {
dep for dep in (unmet_dependencies | requires)
if dep not in subagent.provides and dep not in memory
}
agent_chain.append((subagent, subagent.provides, step.instruction))
return agent_chain, unmet_dependencies

async def _resolve_dependencies(self, messages: list, agents: dict[str, Agent]) -> list[tuple(Agent, any)]:
agent_names = tuple(sagent.name[:-5] for sagent in agents.values())
tables = {}
Expand All @@ -547,29 +584,17 @@ async def _resolve_dependencies(self, messages: list, agents: dict[str, Agent])
reason_model, plan_model = make_plan_models(agent_names, list(tables))
planned = False
unmet_dependencies = set()
user_msg = messages[-1]
schemas = {}
with self.interface.add_step(title="Planning how to solve user query...", user="Assistant") as istep:
while not planned or unmet_dependencies:
while not planned:
plan = await self._make_plan(
user_msg, messages, agents, tables, unmet_dependencies, reason_model, plan_model, istep
messages, agents, tables, unmet_dependencies, reason_model, plan_model, istep, schemas
)
step = plan.steps[-1]
subagent = agents[step.expert]
unmet_dependencies = {
r for r in await subagent.requirements(messages) if r not in memory
}
agent_chain = [(subagent, unmet_dependencies, step.instruction)]
for step in plan.steps[:-1][::-1]:
subagent = agents[step.expert]
requires = set(await subagent.requirements(messages))
unmet_dependencies = {
dep for dep in (unmet_dependencies | requires)
if dep not in subagent.provides and dep not in memory
}
agent_chain.append((subagent, subagent.provides, step.instruction))
agent_chain, unmet_dependencies = await self._resolve_plan(plan, agents, messages)
if unmet_dependencies:
istep.stream(f"The plan didn't account for {unmet_dependencies!r}", replace=True)
planned = True
else:
planned = True
istep.stream('\n\nHere are the steps:\n\n')
for i, step in enumerate(plan.steps):
istep.stream(f"{i+1}. {step.expert}: {step.instruction}\n")
Expand Down
19 changes: 8 additions & 11 deletions lumen/ai/prompts/plan_agent.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@ Ensure that you provide each expert some context to ensure they do not repeat pr
Currently you have the following information available to you:
{% for item in memory.keys() %}
- {{ item }}
{% endfor %}
And have access to the following data tables:
{% for table in tables %}
- {{ table }}
{% endfor %}
{%- endfor %}
{% if table_info %}
In order to make an informed decision here are schemas for the most relevant tables:
In order to make an informed decision here are schemas for the most relevant tables (note that these schemas are computed on a subset of data):
{{ table_info }}
Do not request any additional tables.
{% endif %}
{%- if tables %}
Additionally the following tables are available and you may request to look at them before revising your plan:
{% for table in tables %}
- {{ table }}
{% endfor %}
{%- endif -%}
Here's the choice of experts and their uses:
{% for agent in agents %}
- `{{ agent.name[:-5] }}`
Expand All @@ -31,10 +32,6 @@ Here's the choice of experts and their uses:
Description: {{ agent.__doc__.strip().split() | join(' ') }}
{% endfor %}

{% if not table_info %}
If you do not think you can solve the problem given the current information provide a list of requested tables.
{% endif %}

{% if unmet_dependencies %}
Note that a previous plan was unsuccessful because it did not satisfy the following required pieces of information: {unmet_dependencies!r}
{% endif %}
19 changes: 13 additions & 6 deletions lumen/ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import inspect
import math
import time

from functools import wraps
Expand Down Expand Up @@ -140,10 +141,13 @@ async def get_schema(
if "max" in spec:
spec.pop("max")

if not include_enum:
for field, spec in schema.items():
if "enum" in spec:
spec.pop("enum")
for field, spec in schema.items():
if "enum" not in spec:
continue
elif not include_enum:
spec.pop("enum")
elif "limit" in get_kwargs:
spec["enum"].append("...")

if count and include_count:
spec["count"] = count
Expand Down Expand Up @@ -174,7 +178,7 @@ def get_data_sync():

async def describe_data(df: pd.DataFrame) -> str:
def format_float(num):
if pd.isna(num):
if pd.isna(num) or math.isinf(num):
return num
# if is integer, round to 0 decimals
if num == int(num):
Expand Down Expand Up @@ -209,7 +213,10 @@ def describe_data_sync(df):
for col in df.select_dtypes(include=["object"]).columns:
if col not in df_describe_dict:
df_describe_dict[col] = {}
df_describe_dict[col]["nunique"] = df[col].nunique()
try:
df_describe_dict[col]["nunique"] = df[col].nunique()
except Exception:
df_describe_dict[col]["nunique"] = 'unknown'
try:
df_describe_dict[col]["lengths"] = {
"max": df[col].str.len().max(),
Expand Down
Loading