Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve handling of empty tasks #6

Merged
merged 2 commits into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 42 additions & 24 deletions src/control_flow/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
T = TypeVar("T")
logger = logging.getLogger(__name__)

NOT_PROVIDED = object()
TEMP_THREADS = {}


Expand Down Expand Up @@ -116,7 +117,9 @@

{% endfor %}
{% else %}
You have no explicit tasks to complete. Follow your instructions as best as you can.
You have no explicit tasks to complete. Follow your instructions as best as you
can. If it is not possible to comply with the instructions in any way, use the
`end_run` tool to manually stop the run.
{% endif %}

## Communication
Expand Down Expand Up @@ -321,15 +324,15 @@ def _get_tools(self) -> list[AssistantTool]:
# if there is only one task, and the agent can't send a response to the
# system, then we can quit as soon as it is marked finished
if not self.system_access and len(self.tasks) == 1:
end_run = True
early_end_run = True
else:
end_run = False
early_end_run = False

for i, task in self.numbered_tasks():
tools.extend(
[
task._create_complete_tool(task_id=i, end_run=end_run),
task._create_fail_tool(task_id=i, end_run=end_run),
task._create_complete_tool(task_id=i, end_run=early_end_run),
task._create_fail_tool(task_id=i, end_run=early_end_run),
]
)

Expand Down Expand Up @@ -386,13 +389,14 @@ def _get_openai_run_task(self):
This needs to be regenerated each time in case the instructions change.
"""

@prefect_task(task_run_name="Run OpenAI assistant")
@prefect_task(task_run_name=f"Run OpenAI assistant ({self.assistant.name})")
async def execute_openai_run(
context: dict = None, run_kwargs: dict = None
) -> Run:
run_kwargs = run_kwargs or {}
model = run_kwargs.pop(
"model", self.assistant.model or settings.assistant_model
"model",
self.assistant.model or self.flow.model or settings.assistant_model,
)
thread = run_kwargs.pop("thread", self.flow.thread)

Expand All @@ -408,12 +412,14 @@ async def execute_openai_run(
await run.run_async()
create_json_artifact(
key="messages",
data=run.messages,
# dump explicilty because of odd OAI serialization issue
data=[m.model_dump() for m in run.messages],
description="All messages sent and received during the run.",
)
create_json_artifact(
key="actions",
data=run.steps,
# dump explicilty because of odd OAI serialization issue
data=[s.model_dump() for s in run.steps],
description="All actions taken by the assistant during the run.",
)
return run
Expand All @@ -426,8 +432,8 @@ async def run_async(self, context: dict = None, **run_kwargs) -> list[AITask]:

openai_run(context=context, run_kwargs=run_kwargs)

# if this is not an interactive run, continue to run the AI
# until all tasks are no longer pending
# if this AI can't post messages to the system, then continue to invoke
# it until all tasks are finished
if not self.system_access:
counter = 0
while (
Expand Down Expand Up @@ -513,16 +519,19 @@ def _name_from_objective():
"""Helper function for naming task runs"""
from prefect.runtime import task_run

objective = task_run.parameters["task"]
if len(objective) > 50:
return f"Task: {objective[:50]}..."
objective = task_run.parameters.get("task")

if not objective:
objective = "Follow general instructions"
if len(objective) > 75:
return f"Task: {objective[:75]}..."
return f"Task: {objective}"


@prefect_task(task_run_name=_name_from_objective)
def run_agent(
task: str,
cast: T = str,
task: str = None,
cast: T = NOT_PROVIDED,
context: dict = None,
user_access: bool = None,
model: str = None,
Expand All @@ -533,18 +542,27 @@ def run_agent(
response will be cast to the given result type.
"""

if cast is NOT_PROVIDED:
if not task:
cast = None
else:
cast = str

# load flow
flow = ctx.get("flow", None)

# create task
ai_task = AITask[cast](objective=task, context=context)
# create tasks
if task:
ai_tasks = [AITask[cast](objective=task, context=context)]
else:
ai_tasks = []

# run agent
agent = Agent(tasks=[ai_task], flow=flow, user_access=user_access, **agent_kwargs)
agent = Agent(tasks=ai_tasks, flow=flow, user_access=user_access, **agent_kwargs)
agent.run(model=model)

# return
if ai_task.status == TaskStatus.COMPLETED:
return ai_task.result
elif ai_task.status == TaskStatus.FAILED:
raise ValueError(ai_task.error)
if ai_tasks:
if ai_tasks[0].status == TaskStatus.COMPLETED:
return ai_tasks[0].result
elif ai_tasks[0].status == TaskStatus.FAILED:
raise ValueError(ai_tasks[0].error)
19 changes: 19 additions & 0 deletions src/control_flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from marvin.beta.assistants import Assistant, Thread
from marvin.beta.assistants.assistants import AssistantTool
from marvin.utilities.logging import get_logger
from openai.types.beta.threads import Message
from prefect import flow as prefect_flow
from prefect import task as prefect_task
from pydantic import BaseModel, Field, field_validator
Expand All @@ -19,6 +20,7 @@ class AIFlow(BaseModel):
assistant: Optional[Assistant] = Field(None, validate_default=True)
tools: list[Union[AssistantTool, Callable]] = Field(None, validate_default=True)
instructions: Optional[str] = None
model: Optional[str] = None

model_config: dict = dict(validate_assignment=True, extra="forbid")

Expand Down Expand Up @@ -56,6 +58,7 @@ def ai_flow(
thread: Thread = None,
tools: list[Union[AssistantTool, Callable]] = None,
instructions: str = None,
model: str = None,
):
"""
Prepare a function to be executed as a Control Flow flow.
Expand All @@ -68,6 +71,7 @@ def ai_flow(
thread=thread,
tools=tools,
instructions=instructions,
model=model,
)

@functools.wraps(fn)
Expand All @@ -77,6 +81,7 @@ def wrapper(
_thread: Thread = None,
_tools: list[Union[AssistantTool, Callable]] = None,
_instructions: str = None,
_model: str = None,
**kwargs,
):
p_fn = prefect_flow(fn)
Expand All @@ -89,11 +94,13 @@ def wrapper(
)
flow_instructions = _instructions or instructions
flow_tools = _tools or tools
flow_model = _model or model
flow_obj = AIFlow(
thread=flow_thread,
assistant=flow_assistant,
tools=flow_tools,
instructions=flow_instructions,
model=flow_model,
)

logger.info(
Expand All @@ -104,3 +111,15 @@ def wrapper(
return p_fn(*args, **kwargs)

return wrapper


def get_messages(limit: int = None) -> list[Message]:
"""
Loads messages from the flow's thread.

Will error if no flow is found in the context.
"""
flow: Optional[AIFlow] = ctx.get("flow")
if not flow:
raise ValueError("No flow found in context")
return flow.thread.get_messages(limit=limit)