Skip to content

Commit

Permalink
Revert "Stream output from LLM (#608)"
Browse files Browse the repository at this point in the history
This reverts commit 46798bd.
  • Loading branch information
philippjfr committed Jul 24, 2024
1 parent 84858f9 commit 35d8531
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 91 deletions.
113 changes: 54 additions & 59 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,11 @@ async def invoke(self, messages: list | str):
system_prompt = await self._system_prompt_with_context(messages)

message = None
async for output in self.llm.stream(
messages, system=system_prompt, response_model=self.response_model, field="output"
async for chunk in self.llm.stream(
messages, system=system_prompt, response_model=self.response_model
):
message = self.interface.stream(
output, replace=True, message=message, user=self.user
chunk, replace=True, message=message, user=self.user
)


Expand Down Expand Up @@ -285,20 +285,18 @@ async def requirements(self, messages: list | str, errors=None):
if 'current_data' in memory:
return self.requires

with self.interface.add_step(title="Checking if data is required") as step:
response = self.llm.stream(
messages,
system=(
"The user may or may not want to chat about a particular dataset. "
"Determine whether the provided user prompt requires access to "
"actual data. If they're only searching for one, it's not required."
),
response_model=DataRequired,
)
async for output in response:
step.stream(output.chain_of_thought, replace=True)
if output.data_required:
return self.requires + ['current_table']
result = await self.llm.invoke(
messages,
system=(
"The user may or may not want to chat about a particular dataset. "
"Determine whether the provided user prompt requires access to "
"actual data. If they're only searching for one, it's not required."
),
response_model=DataRequired,
allow_partial=False,
)
if result.data_required:
return self.requires + ['current_table']
return self.requires

async def _system_prompt_with_context(
Expand Down Expand Up @@ -510,7 +508,6 @@ def _render_sql(self, query):
pipeline = memory['current_pipeline']
out = SQLOutput(component=pipeline, spec=query)
self.interface.stream(out, user="SQL", replace=True)
return out

@retry_llm_output()
async def _create_valid_sql(self, messages, system, source, tables, errors=None):
Expand All @@ -525,24 +522,33 @@ async def _create_valid_sql(self, messages, system, source, tables, errors=None)
f"expertly revise, and please do not repeat these issues:\n{errors}")
}
]

with self.interface.add_step(title="Creating SQL query...", success_title="SQL Query") as step:
response = self.llm.stream(messages, system=system, response_model=Sql)
sql_query = None
async for output in response:
step_message = output.chain_of_thought
if output.query:
sql_query = output.query.replace("```sql", "").replace("```", "").strip()
step_message += f"\n```sql\n{sql_query}\n```"
step.stream(step_message, replace=True)

message = ""
with self.interface.add_step(title="Conjuring SQL query...", success_title="SQL Query") as step:
async for model in self.llm.stream(
messages,
system=system,
response_model=Sql,
field=None
):
chunk = model.query
if chunk is None:
continue
message = chunk
step.stream(
f"```sql\n{message}\n```",
replace=True,
)
if not message:
return
sql_query = message.replace("```sql", "").replace("```", "").strip()
if not sql_query:
raise ValueError("No SQL query was generated.")

# check whether the SQL query is valid
expr_name = output.expr_name
sql_expr_source = source.create_sql_expr_source({expr_name: sql_query})
pipeline = Pipeline(source=sql_expr_source, table=expr_name)
sql_expr_source = source.create_sql_expr_source({model.expr_name: sql_query})
pipeline = Pipeline(
source=sql_expr_source, table=model.expr_name
)
df = pipeline.data
if len(df) > 0:
memory["current_data"] = describe_data(df)
Expand All @@ -557,26 +563,25 @@ async def answer(self, messages: list | str):
return None

with self.interface.add_step(title="Checking if join is required") as step:
response = self.llm.stream(
join = (await self.llm.invoke(
messages,
system="Determine whether a table join is required to answer the user's query.",
response_model=JoinRequired,
)
async for output in response:
step.stream(output.chain_of_thought, replace=True)
join_required = output.join_required
step.success_title = 'Query requires join' if join_required else 'No join required'
allow_partial=False,
))
step.stream(join.chain_of_thought)
step.success_title = 'Query requires join' if join.join_required else 'No join required'

if join_required:
available_tables = " ".join(str(table) for table in source.get_tables())
if join.join_required:
available_tables = source.get_tables()
with self.interface.add_step(title="Determining tables required for join") as step:
output = await self.llm.invoke(
tables = (await self.llm.invoke(
messages,
system=f"List the tables that need to be joined: {available_tables}.",
response_model=TableJoins,
)
tables = output.tables
step.stream(f'\nJoin requires following tables: {tables}', replace=True)
allow_partial=False,
)).tables
step.stream(f'\nJoin requires following tables: {tables}')
step.success_title = 'Found tables required for join'
else:
tables = [table]
Expand All @@ -599,8 +604,8 @@ async def answer(self, messages: list | str):
return sql_query

async def invoke(self, messages: list | str):
sql_query = await self.answer(messages)
self._render_sql(sql_query)
sql = await self.answer(messages)
self._render_sql(sql)


class PipelineAgent(LumenBaseAgent):
Expand Down Expand Up @@ -812,23 +817,13 @@ async def answer(self, messages: list | str) -> hvPlotUIView:
print(f"{self.name} is being instructed that {view_prompt}.")

# Query
response = self.llm.stream(
spec_model = await self.llm.invoke(
messages,
system=system_prompt + view_prompt,
response_model=self._get_model(schema),
allow_partial=False,
)

json_pane = None
with self.interface.add_step(title="Generating view...") as step:
async for output in response:
if json_pane is None:
json_pane = pn.pane.JSON()
step.append(json_pane)
try:
spec = self._extract_spec(output)
json_pane.object = spec
except Exception:
pass
spec = self._extract_spec(spec_model)
print(f"{self.name} settled on {spec=!r}.")
memory["current_view"] = dict(spec, type=self.view_type)
return self.view_type(pipeline=pipeline, **spec)
Expand Down
54 changes: 27 additions & 27 deletions lumen/ai/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,39 +236,41 @@ async def _invalidate_memory(self, messages):
sql = memory.get("current_sql")
system = render_template("check_validity.jinja2", table=table, spec=spec, sql=sql, analyses=self._analyses)
with self.interface.add_step(title="Checking memory...", user="Assistant") as step:
response = self.llm.stream(
validity = await self.llm.invoke(
messages=messages,
system=system,
response_model=Validity,
allow_partial=False,
)
async for output in response:
step.stream(output.correct_assessment, replace=True)
step.success_title = f"{output.is_invalid.title()} needs refresh" if output.is_invalid else "Memory still valid"
if validity.correct_assessment:
step.stream(validity.correct_assessment)
step.success_title = f"{validity.is_invalid.title()} needs refresh" if validity.is_invalid else "Memory still valid"

if output and output.is_invalid:
if output.is_invalid == "table":
if validity and validity.is_invalid:
if validity.is_invalid == "table":
memory.pop("current_table", None)
memory.pop("current_data", None)
memory.pop("current_sql", None)
memory.pop("current_pipeline", None)
memory.pop("closest_tables", None)
print("\033[91mInvalidated from memory.\033[0m")
elif output.is_invalid == "sql":
elif validity.is_invalid == "sql":
memory.pop("current_sql", None)
memory.pop("current_data", None)
memory.pop("current_pipeline", None)
print("\033[91mInvalidated SQL from memory.\033[0m")

async def _create_suggestion(self, instance, event):
messages = self.interface.serialize(custom_serializer=self._serialize)[-3:-1]
response = self.llm.stream(
string = self.llm.stream(
messages,
system="Generate a follow-up question that a user might ask; ask from the user POV",
allow_partial=True,
)
try:
self.interface.disabled = True
async for output in response:
self.interface.active_widget.value_input = output
async for chunk in string:
self.interface.active_widget.value_input = chunk
finally:
self.interface.disabled = False

Expand All @@ -294,19 +296,24 @@ def _create_agent_model(agent_names):
return agent_model

@retry_llm_output()
async def _create_valid_agent(self, messages, system, agent_model, errors=None):
async def _create_valid_agent(self, messages, system, agent_model, return_reasoning, errors=None):
if errors:
errors = '\n'.join(errors)
messages += [{"role": "user", "content": f"\nExpertly resolve these issues:\n{errors}"}]

out = self.llm.stream(
out = await self.llm.invoke(
messages=messages,
system=system,
response_model=agent_model
response_model=agent_model,
allow_partial=False
)
return out
if not (out and out.agent):
raise ValueError("No agent selected.")
elif return_reasoning:
return out.agent, out.chain_of_thought
return out.agent

async def _choose_agent(self, messages: list | str, agents: list[Agent]):
async def _choose_agent(self, messages: list | str, agents: list[Agent], return_reasoning: bool = False):
agent_names = tuple(sagent.name[:-5] for sagent in agents)
if len(agent_names) == 0:
raise ValueError("No agents available to choose from.")
Expand All @@ -323,7 +330,7 @@ async def _choose_agent(self, messages: list | str, agents: list[Agent]):
system = render_template(
"pick_agent.jinja2", agents=agents, current_agent=self._current_agent.object
)
return await self._create_valid_agent(messages, system, agent_model)
return await self._create_valid_agent(messages, system, agent_model, return_reasoning)

async def _get_agent(self, messages: list | str):
if len(self.agents) == 1:
Expand All @@ -335,12 +342,10 @@ async def _get_agent(self, messages: list | str):
agent = agent_types[0]
else:
with self.interface.add_step(title="Selecting relevant agent...", user="Assistant") as step:
response = await self._choose_agent(messages, self.agents)
async for output in response:
step.stream(output.chain_of_thought, replace=True)
agent = output.agent
agent, reasoning = await self._choose_agent(messages, self.agents, return_reasoning=True)
step.stream(reasoning)
step.success_title = f"Selected {agent}"
messages.append({"role": "assistant", "content": output.chain_of_thought})
messages.append({"role": "assistant", "content": reasoning})

if agent is None:
return None
Expand All @@ -361,12 +366,7 @@ async def _get_agent(self, messages: list | str):
for agent in self.agents
if any(ur in agent.provides for ur in unmet_dependencies)
]
response = await self._choose_agent(messages, subagents)
if isinstance(response, str):
subagent_name = response
else:
async for output in response:
subagent_name = output.agent
subagent_name = await self._choose_agent(messages, subagents)
if subagent_name is None:
continue
subagent = agents[subagent_name]
Expand Down
6 changes: 2 additions & 4 deletions lumen/ai/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,14 @@ async def invoke(
messages: list | str,
system: str = "",
response_model: BaseModel | None = None,
allow_partial: bool = False,
allow_partial: bool = True,
model_key: str = "default",
**input_kwargs,
) -> BaseModel:
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
if system:
messages = [{"role": "system", "content": system}] + messages
print(messages)

kwargs = dict(self._client_kwargs)
kwargs.update(input_kwargs)
Expand All @@ -79,7 +78,7 @@ async def stream(
messages: list | str,
system: str = "",
response_model: BaseModel | None = None,
field: str | None = None,
field: str = "output",
model_key: str = "default",
**kwargs,
):
Expand All @@ -89,7 +88,6 @@ async def stream(
system=system,
response_model=response_model,
stream=True,
allow_partial=True,
model_key=model_key,
**kwargs,
)
Expand Down
1 change: 0 additions & 1 deletion lumen/ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ async def async_wrapper(*args, **kwargs):
errors = []
for i in range(retries):
if errors:
print(errors)
kwargs["errors"] = errors

try:
Expand Down

0 comments on commit 35d8531

Please sign in to comment.