From 42c612d9abcbe56f64bbf3a4957f58982273743b Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Wed, 23 Oct 2024 12:21:50 +0200 Subject: [PATCH 1/5] Refactor PlanningAssistant --- lumen/ai/assistant.py | 108 ++++++++++++++++++----------- lumen/ai/prompts/plan_agent.jinja2 | 19 +++-- 2 files changed, 75 insertions(+), 52 deletions(-) diff --git a/lumen/ai/assistant.py b/lumen/ai/assistant.py index 610a7853..7cbeb70e 100644 --- a/lumen/ai/assistant.py +++ b/lumen/ai/assistant.py @@ -402,7 +402,7 @@ async def _get_agent(self, messages: list | str): for subagent, deps, instruction in agent_chain[:-1]: agent_name = type(subagent).name.replace('Agent', '') with self.interface.add_step(title=f"Querying {agent_name} agent...") as step: - step.stream(f"Assistant decided the {agent_name!r} will provide {', '.join(deps)}.") + step.stream(f"`{agent_name}` agent is working on the following task:\n\n{instruction}") self._current_agent.object = f"## **Current Agent**: {agent_name}" custom_messages = messages.copy() if isinstance(subagent, SQLAgent): @@ -417,7 +417,8 @@ async def _get_agent(self, messages: list | str): if instruction: custom_messages.append({"role": "user", "content": instruction}) await subagent.answer(custom_messages) - step.success_title = f"{agent_name} agent responded" + step.stream(f"`{agent_name}` agent successfully completed the following task:\n\n- {instruction}", replace=True) + step.success_title = f"{agent_name} agent successfully responded" return selected def _serialize(self, obj, exclude_passwords=True): @@ -490,38 +491,57 @@ class PlanningAssistant(Assistant): instead of simply resolving the dependencies step-by-step. """ + @classmethod + async def _lookup_schemas( + cls, + tables: dict[str, Source], + requested: list[str], + provided: list[str], + cache: dict[str, dict] | None = None + ) -> str: + cache = cache or {} + to_query, queries = [], [] + for table in requested: + if table in provided or table in cache: + continue + to_query.append(table) + queries.append(get_schema(tables[table], table, limit=3)) + for table, schema in zip(to_query, await asyncio.gather(*queries)): + cache[table] = schema + schema_info = '' + for table in requested: + if table in provided: + continue + provided.append(table) + schema_info += f'- {table}: {cache[table]}\n\n' + return schema_info + async def _make_plan( self, - user_msg: dict[str, Any], messages: list, agents: dict[str, Agent], tables: dict[str, Source], unmet_dependencies: set[str], reason_model: type[BaseModel], plan_model: type[BaseModel], - step: ChatStep - ): + step: ChatStep, + schemas: dict[str, dict] | None = None + ) -> BaseModel: + user_msg = messages[-1] info = '' reasoning = None - requested_tables, provided_tables = [], [] + requested, provided = [], [] if 'current_table' in memory: - requested_tables.append(memory['current_table']) + requested.append(memory['current_table']) elif len(tables) == 1: - requested_tables.append(next(iter(tables))) - while reasoning is None or requested_tables: - # Add context of table schemas - schemas = [] - requested = getattr(reasoning, 'tables', requested_tables) - for table in requested: - if table in provided_tables: - continue - provided_tables.append(table) - schemas.append(get_schema(tables[table], table, limit=3)) - for table, schema in zip(requested, await asyncio.gather(*schemas)): - info += f'- {table}: {schema}\n\n' + requested.append(next(iter(tables))) + print(requested) + while reasoning is None or requested: + info += await self._lookup_schemas(tables, requested, provided, cache=schemas) + available = [t for t in tables if t not in provided] system = render_template( 'plan_agent.jinja2', agents=list(agents.values()), current_agent=self._current_agent.object, - unmet_dependencies=unmet_dependencies, memory=memory, table_info=info, tables=list(tables) + unmet_dependencies=unmet_dependencies, memory=memory, table_info=info, tables=available ) async for reasoning in self.llm.stream( messages=messages, @@ -529,14 +549,32 @@ async def _make_plan( response_model=reason_model, ): step.stream(reasoning.chain_of_thought, replace=True) - requested_tables = [t for t in reasoning.tables if t and t not in provided_tables] - if requested_tables: - continue + requested = [ + t for t in getattr(reasoning, 'tables', []) + if t and t not in provided + ] new_msg = dict(role=user_msg['role'], content=f"{user_msg['content']} {reasoning.chain_of_thought}") messages = messages[:-1] + [new_msg] plan = await self._fill_model(messages, system, plan_model) return plan + async def _resolve_plan(self, plan, agents, messages): + step = plan.steps[-1] + subagent = agents[step.expert] + unmet_dependencies = { + r for r in await subagent.requirements(messages) if r not in memory + } + agent_chain = [(subagent, unmet_dependencies, step.instruction)] + for step in plan.steps[:-1][::-1]: + subagent = agents[step.expert] + requires = set(await subagent.requirements(messages)) + unmet_dependencies = { + dep for dep in (unmet_dependencies | requires) + if dep not in subagent.provides and dep not in memory + } + agent_chain.append((subagent, subagent.provides, step.instruction)) + return agent_chain, unmet_dependencies + async def _resolve_dependencies(self, messages: list, agents: dict[str, Agent]) -> list[tuple(Agent, any)]: agent_names = tuple(sagent.name[:-5] for sagent in agents.values()) tables = {} @@ -547,29 +585,17 @@ async def _resolve_dependencies(self, messages: list, agents: dict[str, Agent]) reason_model, plan_model = make_plan_models(agent_names, list(tables)) planned = False unmet_dependencies = set() - user_msg = messages[-1] + schemas = {} with self.interface.add_step(title="Planning how to solve user query...", user="Assistant") as istep: - while not planned or unmet_dependencies: + while not planned: plan = await self._make_plan( - user_msg, messages, agents, tables, unmet_dependencies, reason_model, plan_model, istep + messages, agents, tables, unmet_dependencies, reason_model, plan_model, istep, schemas ) - step = plan.steps[-1] - subagent = agents[step.expert] - unmet_dependencies = { - r for r in await subagent.requirements(messages) if r not in memory - } - agent_chain = [(subagent, unmet_dependencies, step.instruction)] - for step in plan.steps[:-1][::-1]: - subagent = agents[step.expert] - requires = set(await subagent.requirements(messages)) - unmet_dependencies = { - dep for dep in (unmet_dependencies | requires) - if dep not in subagent.provides and dep not in memory - } - agent_chain.append((subagent, subagent.provides, step.instruction)) + agent_chain, unmet_dependencies = await self._resolve_plan(plan, agents, messages) if unmet_dependencies: istep.stream(f"The plan didn't account for {unmet_dependencies!r}", replace=True) - planned = True + else: + planned = True istep.stream('\n\nHere are the steps:\n\n') for i, step in enumerate(plan.steps): istep.stream(f"{i+1}. {step.expert}: {step.instruction}\n") diff --git a/lumen/ai/prompts/plan_agent.jinja2 b/lumen/ai/prompts/plan_agent.jinja2 index 2ab1fa57..b9cd6ec2 100644 --- a/lumen/ai/prompts/plan_agent.jinja2 +++ b/lumen/ai/prompts/plan_agent.jinja2 @@ -13,16 +13,17 @@ Ensure that you provide each expert some context to ensure they do not repeat pr Currently you have the following information available to you: {% for item in memory.keys() %} - {{ item }} -{% endfor %} -And have access to the following data tables: -{% for table in tables %} -- {{ table }} -{% endfor %} +{%- endfor %} {% if table_info %} -In order to make an informed decision here are schemas for the most relevant tables: +In order to make an informed decision here are schemas for the most relevant tables (note that these schemas are computed on a subset of data): {{ table_info }} -Do not request any additional tables. {% endif %} +{%- if tables %} +Additionally the following tables are available and you may request to look at them before revising your plan: +{% for table in tables %} +- {{ table }} +{% endfor %} +{%- endif -%} Here's the choice of experts and their uses: {% for agent in agents %} - `{{ agent.name[:-5] }}` @@ -31,10 +32,6 @@ Here's the choice of experts and their uses: Description: {{ agent.__doc__.strip().split() | join(' ') }} {% endfor %} -{% if not table_info %} -If you do not think you can solve the problem given the current information provide a list of requested tables. -{% endif %} - {% if unmet_dependencies %} Note that a previous plan was unsuccessful because it did not satisfy the following required pieces of information: {unmet_dependencies!r} {% endif %} From e0229f64e1e2357d8ad6c7b449d64c9bb7d75e0b Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Wed, 23 Oct 2024 12:22:22 +0200 Subject: [PATCH 2/5] Minor agent tweaks --- lumen/ai/agents.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/lumen/ai/agents.py b/lumen/ai/agents.py index 80443bfd..89ee6cc5 100644 --- a/lumen/ai/agents.py +++ b/lumen/ai/agents.py @@ -198,9 +198,9 @@ async def invoke(self, messages: list | str): class SourceAgent(Agent): """ - The SourceAgent allows a user to upload datasets. + The SourceAgent allows a user to upload new datasets. - Use this if the user is requesting to add a dataset or you think + Only use this if the user is requesting to add a dataset or you think additional information is required to solve the user query. """ @@ -352,7 +352,6 @@ def _render_lumen(self, component: Component, message: pn.chat.ChatMessage = Non self.interface.stream(message=message, **message_kwargs, replace=True, max_width=self._max_width) - class TableAgent(LumenBaseAgent): """ Displays a single table / dataset. Does not discuss. @@ -479,8 +478,9 @@ def _use_table(self, event): self.interface.send(f"Show the table: {table!r}") async def answer(self, messages: list | str): - source = memory["current_source"] - tables = source.get_tables() + tables = [] + for source in memory['available_sources']: + tables += source.get_tables() if not tables: return @@ -720,7 +720,17 @@ async def answer(self, messages: list | str): table_schema = schema else: table_schema = await get_schema(source, source_table, include_min_max=False) - table_schemas[source_table] = { + + # Look up underlying table name + table_name = source_table + if ( + 'tables' in source.param and + isinstance(source.tables, dict) and + 'select ' not in source.tables[table_name].lower() + ): + table_name = source.tables[table_name] + + table_schemas[table_name] = { "schema": yaml.dump(table_schema), "sql": source.get_sql_expr(source_table) } From 3af0314d855f16d9bdea25c11ffe3d8d840f149d Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Wed, 23 Oct 2024 12:22:44 +0200 Subject: [PATCH 3/5] Tweaks for schema generation --- lumen/ai/utils.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/lumen/ai/utils.py b/lumen/ai/utils.py index 1322eef2..5c9c53c5 100644 --- a/lumen/ai/utils.py +++ b/lumen/ai/utils.py @@ -2,6 +2,7 @@ import asyncio import inspect +import math import time from functools import wraps @@ -140,10 +141,13 @@ async def get_schema( if "max" in spec: spec.pop("max") - if not include_enum: - for field, spec in schema.items(): - if "enum" in spec: - spec.pop("enum") + for field, spec in schema.items(): + if "enum" not in spec: + continue + elif not include_enum: + spec.pop("enum") + elif "limit" in get_kwargs: + spec["enum"].append("...") if count and include_count: spec["count"] = count @@ -174,7 +178,7 @@ def get_data_sync(): async def describe_data(df: pd.DataFrame) -> str: def format_float(num): - if pd.isna(num): + if pd.isna(num) or math.isinf(num): return num # if is integer, round to 0 decimals if num == int(num): @@ -209,7 +213,10 @@ def describe_data_sync(df): for col in df.select_dtypes(include=["object"]).columns: if col not in df_describe_dict: df_describe_dict[col] = {} - df_describe_dict[col]["nunique"] = df[col].nunique() + try: + df_describe_dict[col]["nunique"] = df[col].nunique() + except Exception: + df_describe_dict[col]["nunique"] = 'unknown' try: df_describe_dict[col]["lengths"] = { "max": df[col].str.len().max(), From a4815e01a5c1248a4b6d7f6fa03b5d58f12427c1 Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Wed, 23 Oct 2024 12:24:32 +0200 Subject: [PATCH 4/5] Remove print --- lumen/ai/assistant.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lumen/ai/assistant.py b/lumen/ai/assistant.py index 7cbeb70e..e1d152b5 100644 --- a/lumen/ai/assistant.py +++ b/lumen/ai/assistant.py @@ -535,7 +535,6 @@ async def _make_plan( requested.append(memory['current_table']) elif len(tables) == 1: requested.append(next(iter(tables))) - print(requested) while reasoning is None or requested: info += await self._lookup_schemas(tables, requested, provided, cache=schemas) available = [t for t in tables if t not in provided] From f8120ea772ceda74572ccd0ba56a6e9b15b4265c Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Wed, 23 Oct 2024 14:21:59 +0200 Subject: [PATCH 5/5] Update assistant.py --- lumen/ai/assistant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lumen/ai/assistant.py b/lumen/ai/assistant.py index e1d152b5..d43266a2 100644 --- a/lumen/ai/assistant.py +++ b/lumen/ai/assistant.py @@ -4,7 +4,7 @@ import re from io import StringIO -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import param import yaml