diff --git a/lumen/ai/agents.py b/lumen/ai/agents.py index 3869290e..70914dc1 100644 --- a/lumen/ai/agents.py +++ b/lumen/ai/agents.py @@ -172,11 +172,11 @@ async def invoke(self, messages: list | str): system_prompt = await self._system_prompt_with_context(messages) message = None - async for output in self.llm.stream( - messages, system=system_prompt, response_model=self.response_model, field="output" + async for chunk in self.llm.stream( + messages, system=system_prompt, response_model=self.response_model ): message = self.interface.stream( - output, replace=True, message=message, user=self.user + chunk, replace=True, message=message, user=self.user ) @@ -285,20 +285,18 @@ async def requirements(self, messages: list | str, errors=None): if 'current_data' in memory: return self.requires - with self.interface.add_step(title="Checking if data is required") as step: - response = self.llm.stream( - messages, - system=( - "The user may or may not want to chat about a particular dataset. " - "Determine whether the provided user prompt requires access to " - "actual data. If they're only searching for one, it's not required." - ), - response_model=DataRequired, - ) - async for output in response: - step.stream(output.chain_of_thought, replace=True) - if output.data_required: - return self.requires + ['current_table'] + result = await self.llm.invoke( + messages, + system=( + "The user may or may not want to chat about a particular dataset. " + "Determine whether the provided user prompt requires access to " + "actual data. If they're only searching for one, it's not required." + ), + response_model=DataRequired, + allow_partial=False, + ) + if result.data_required: + return self.requires + ['current_table'] return self.requires async def _system_prompt_with_context( @@ -510,7 +508,6 @@ def _render_sql(self, query): pipeline = memory['current_pipeline'] out = SQLOutput(component=pipeline, spec=query) self.interface.stream(out, user="SQL", replace=True) - return out @retry_llm_output() async def _create_valid_sql(self, messages, system, source, tables, errors=None): @@ -525,24 +522,33 @@ async def _create_valid_sql(self, messages, system, source, tables, errors=None) f"expertly revise, and please do not repeat these issues:\n{errors}") } ] - - with self.interface.add_step(title="Creating SQL query...", success_title="SQL Query") as step: - response = self.llm.stream(messages, system=system, response_model=Sql) - sql_query = None - async for output in response: - step_message = output.chain_of_thought - if output.query: - sql_query = output.query.replace("```sql", "").replace("```", "").strip() - step_message += f"\n```sql\n{sql_query}\n```" - step.stream(step_message, replace=True) - + message = "" + with self.interface.add_step(title="Conjuring SQL query...", success_title="SQL Query") as step: + async for model in self.llm.stream( + messages, + system=system, + response_model=Sql, + field=None + ): + chunk = model.query + if chunk is None: + continue + message = chunk + step.stream( + f"```sql\n{message}\n```", + replace=True, + ) + if not message: + return + sql_query = message.replace("```sql", "").replace("```", "").strip() if not sql_query: raise ValueError("No SQL query was generated.") # check whether the SQL query is valid - expr_name = output.expr_name - sql_expr_source = source.create_sql_expr_source({expr_name: sql_query}) - pipeline = Pipeline(source=sql_expr_source, table=expr_name) + sql_expr_source = source.create_sql_expr_source({model.expr_name: sql_query}) + pipeline = Pipeline( + source=sql_expr_source, table=model.expr_name + ) df = pipeline.data if len(df) > 0: memory["current_data"] = describe_data(df) @@ -557,26 +563,25 @@ async def answer(self, messages: list | str): return None with self.interface.add_step(title="Checking if join is required") as step: - response = self.llm.stream( + join = (await self.llm.invoke( messages, system="Determine whether a table join is required to answer the user's query.", response_model=JoinRequired, - ) - async for output in response: - step.stream(output.chain_of_thought, replace=True) - join_required = output.join_required - step.success_title = 'Query requires join' if join_required else 'No join required' + allow_partial=False, + )) + step.stream(join.chain_of_thought) + step.success_title = 'Query requires join' if join.join_required else 'No join required' - if join_required: - available_tables = " ".join(str(table) for table in source.get_tables()) + if join.join_required: + available_tables = source.get_tables() with self.interface.add_step(title="Determining tables required for join") as step: - output = await self.llm.invoke( + tables = (await self.llm.invoke( messages, system=f"List the tables that need to be joined: {available_tables}.", response_model=TableJoins, - ) - tables = output.tables - step.stream(f'\nJoin requires following tables: {tables}', replace=True) + allow_partial=False, + )).tables + step.stream(f'\nJoin requires following tables: {tables}') step.success_title = 'Found tables required for join' else: tables = [table] @@ -599,8 +604,8 @@ async def answer(self, messages: list | str): return sql_query async def invoke(self, messages: list | str): - sql_query = await self.answer(messages) - self._render_sql(sql_query) + sql = await self.answer(messages) + self._render_sql(sql) class PipelineAgent(LumenBaseAgent): @@ -812,23 +817,13 @@ async def answer(self, messages: list | str) -> hvPlotUIView: print(f"{self.name} is being instructed that {view_prompt}.") # Query - response = self.llm.stream( + spec_model = await self.llm.invoke( messages, system=system_prompt + view_prompt, response_model=self._get_model(schema), + allow_partial=False, ) - - json_pane = None - with self.interface.add_step(title="Generating view...") as step: - async for output in response: - if json_pane is None: - json_pane = pn.pane.JSON() - step.append(json_pane) - try: - spec = self._extract_spec(output) - json_pane.object = spec - except Exception: - pass + spec = self._extract_spec(spec_model) print(f"{self.name} settled on {spec=!r}.") memory["current_view"] = dict(spec, type=self.view_type) return self.view_type(pipeline=pipeline, **spec) diff --git a/lumen/ai/assistant.py b/lumen/ai/assistant.py index 8b0092db..c528b054 100644 --- a/lumen/ai/assistant.py +++ b/lumen/ai/assistant.py @@ -236,24 +236,25 @@ async def _invalidate_memory(self, messages): sql = memory.get("current_sql") system = render_template("check_validity.jinja2", table=table, spec=spec, sql=sql, analyses=self._analyses) with self.interface.add_step(title="Checking memory...", user="Assistant") as step: - response = self.llm.stream( + validity = await self.llm.invoke( messages=messages, system=system, response_model=Validity, + allow_partial=False, ) - async for output in response: - step.stream(output.correct_assessment, replace=True) - step.success_title = f"{output.is_invalid.title()} needs refresh" if output.is_invalid else "Memory still valid" + if validity.correct_assessment: + step.stream(validity.correct_assessment) + step.success_title = f"{validity.is_invalid.title()} needs refresh" if validity.is_invalid else "Memory still valid" - if output and output.is_invalid: - if output.is_invalid == "table": + if validity and validity.is_invalid: + if validity.is_invalid == "table": memory.pop("current_table", None) memory.pop("current_data", None) memory.pop("current_sql", None) memory.pop("current_pipeline", None) memory.pop("closest_tables", None) print("\033[91mInvalidated from memory.\033[0m") - elif output.is_invalid == "sql": + elif validity.is_invalid == "sql": memory.pop("current_sql", None) memory.pop("current_data", None) memory.pop("current_pipeline", None) @@ -261,14 +262,15 @@ async def _invalidate_memory(self, messages): async def _create_suggestion(self, instance, event): messages = self.interface.serialize(custom_serializer=self._serialize)[-3:-1] - response = self.llm.stream( + string = self.llm.stream( messages, system="Generate a follow-up question that a user might ask; ask from the user POV", + allow_partial=True, ) try: self.interface.disabled = True - async for output in response: - self.interface.active_widget.value_input = output + async for chunk in string: + self.interface.active_widget.value_input = chunk finally: self.interface.disabled = False @@ -294,19 +296,24 @@ def _create_agent_model(agent_names): return agent_model @retry_llm_output() - async def _create_valid_agent(self, messages, system, agent_model, errors=None): + async def _create_valid_agent(self, messages, system, agent_model, return_reasoning, errors=None): if errors: errors = '\n'.join(errors) messages += [{"role": "user", "content": f"\nExpertly resolve these issues:\n{errors}"}] - out = self.llm.stream( + out = await self.llm.invoke( messages=messages, system=system, - response_model=agent_model + response_model=agent_model, + allow_partial=False ) - return out + if not (out and out.agent): + raise ValueError("No agent selected.") + elif return_reasoning: + return out.agent, out.chain_of_thought + return out.agent - async def _choose_agent(self, messages: list | str, agents: list[Agent]): + async def _choose_agent(self, messages: list | str, agents: list[Agent], return_reasoning: bool = False): agent_names = tuple(sagent.name[:-5] for sagent in agents) if len(agent_names) == 0: raise ValueError("No agents available to choose from.") @@ -323,7 +330,7 @@ async def _choose_agent(self, messages: list | str, agents: list[Agent]): system = render_template( "pick_agent.jinja2", agents=agents, current_agent=self._current_agent.object ) - return await self._create_valid_agent(messages, system, agent_model) + return await self._create_valid_agent(messages, system, agent_model, return_reasoning) async def _get_agent(self, messages: list | str): if len(self.agents) == 1: @@ -335,12 +342,10 @@ async def _get_agent(self, messages: list | str): agent = agent_types[0] else: with self.interface.add_step(title="Selecting relevant agent...", user="Assistant") as step: - response = await self._choose_agent(messages, self.agents) - async for output in response: - step.stream(output.chain_of_thought, replace=True) - agent = output.agent + agent, reasoning = await self._choose_agent(messages, self.agents, return_reasoning=True) + step.stream(reasoning) step.success_title = f"Selected {agent}" - messages.append({"role": "assistant", "content": output.chain_of_thought}) + messages.append({"role": "assistant", "content": reasoning}) if agent is None: return None @@ -361,12 +366,7 @@ async def _get_agent(self, messages: list | str): for agent in self.agents if any(ur in agent.provides for ur in unmet_dependencies) ] - response = await self._choose_agent(messages, subagents) - if isinstance(response, str): - subagent_name = response - else: - async for output in response: - subagent_name = output.agent + subagent_name = await self._choose_agent(messages, subagents) if subagent_name is None: continue subagent = agents[subagent_name] diff --git a/lumen/ai/llm.py b/lumen/ai/llm.py index 84649071..40a17f51 100644 --- a/lumen/ai/llm.py +++ b/lumen/ai/llm.py @@ -45,7 +45,7 @@ async def invoke( messages: list | str, system: str = "", response_model: BaseModel | None = None, - allow_partial: bool = False, + allow_partial: bool = True, model_key: str = "default", **input_kwargs, ) -> BaseModel: @@ -53,7 +53,6 @@ async def invoke( messages = [{"role": "user", "content": messages}] if system: messages = [{"role": "system", "content": system}] + messages - print(messages) kwargs = dict(self._client_kwargs) kwargs.update(input_kwargs) @@ -79,7 +78,7 @@ async def stream( messages: list | str, system: str = "", response_model: BaseModel | None = None, - field: str | None = None, + field: str = "output", model_key: str = "default", **kwargs, ): @@ -89,7 +88,6 @@ async def stream( system=system, response_model=response_model, stream=True, - allow_partial=True, model_key=model_key, **kwargs, ) diff --git a/lumen/ai/utils.py b/lumen/ai/utils.py index 51ef23a8..8b010542 100644 --- a/lumen/ai/utils.py +++ b/lumen/ai/utils.py @@ -56,7 +56,6 @@ async def async_wrapper(*args, **kwargs): errors = [] for i in range(retries): if errors: - print(errors) kwargs["errors"] = errors try: