Skip to content

Commit

Permalink
Merge pull request #6 from jlowin/empty-tasks
Browse files Browse the repository at this point in the history
Improve handling of empty tasks
  • Loading branch information
jlowin authored Apr 7, 2024
2 parents 1c75b4e + 737fffc commit d2955ec
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 24 deletions.
66 changes: 42 additions & 24 deletions src/control_flow/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
T = TypeVar("T")
logger = logging.getLogger(__name__)

NOT_PROVIDED = object()
TEMP_THREADS = {}


Expand Down Expand Up @@ -116,7 +117,9 @@
{% endfor %}
{% else %}
You have no explicit tasks to complete. Follow your instructions as best as you can.
You have no explicit tasks to complete. Follow your instructions as best as you
can. If it is not possible to comply with the instructions in any way, use the
`end_run` tool to manually stop the run.
{% endif %}
## Communication
Expand Down Expand Up @@ -321,15 +324,15 @@ def _get_tools(self) -> list[AssistantTool]:
# if there is only one task, and the agent can't send a response to the
# system, then we can quit as soon as it is marked finished
if not self.system_access and len(self.tasks) == 1:
end_run = True
early_end_run = True
else:
end_run = False
early_end_run = False

for i, task in self.numbered_tasks():
tools.extend(
[
task._create_complete_tool(task_id=i, end_run=end_run),
task._create_fail_tool(task_id=i, end_run=end_run),
task._create_complete_tool(task_id=i, end_run=early_end_run),
task._create_fail_tool(task_id=i, end_run=early_end_run),
]
)

Expand Down Expand Up @@ -386,13 +389,14 @@ def _get_openai_run_task(self):
This needs to be regenerated each time in case the instructions change.
"""

@prefect_task(task_run_name="Run OpenAI assistant")
@prefect_task(task_run_name=f"Run OpenAI assistant ({self.assistant.name})")
async def execute_openai_run(
context: dict = None, run_kwargs: dict = None
) -> Run:
run_kwargs = run_kwargs or {}
model = run_kwargs.pop(
"model", self.assistant.model or settings.assistant_model
"model",
self.assistant.model or self.flow.model or settings.assistant_model,
)
thread = run_kwargs.pop("thread", self.flow.thread)

Expand All @@ -408,12 +412,14 @@ async def execute_openai_run(
await run.run_async()
create_json_artifact(
key="messages",
data=run.messages,
# dump explicilty because of odd OAI serialization issue
data=[m.model_dump() for m in run.messages],
description="All messages sent and received during the run.",
)
create_json_artifact(
key="actions",
data=run.steps,
# dump explicilty because of odd OAI serialization issue
data=[s.model_dump() for s in run.steps],
description="All actions taken by the assistant during the run.",
)
return run
Expand All @@ -426,8 +432,8 @@ async def run_async(self, context: dict = None, **run_kwargs) -> list[AITask]:

openai_run(context=context, run_kwargs=run_kwargs)

# if this is not an interactive run, continue to run the AI
# until all tasks are no longer pending
# if this AI can't post messages to the system, then continue to invoke
# it until all tasks are finished
if not self.system_access:
counter = 0
while (
Expand Down Expand Up @@ -513,16 +519,19 @@ def _name_from_objective():
"""Helper function for naming task runs"""
from prefect.runtime import task_run

objective = task_run.parameters["task"]
if len(objective) > 50:
return f"Task: {objective[:50]}..."
objective = task_run.parameters.get("task")

if not objective:
objective = "Follow general instructions"
if len(objective) > 75:
return f"Task: {objective[:75]}..."
return f"Task: {objective}"


@prefect_task(task_run_name=_name_from_objective)
def run_agent(
task: str,
cast: T = str,
task: str = None,
cast: T = NOT_PROVIDED,
context: dict = None,
user_access: bool = None,
model: str = None,
Expand All @@ -533,18 +542,27 @@ def run_agent(
response will be cast to the given result type.
"""

if cast is NOT_PROVIDED:
if not task:
cast = None
else:
cast = str

# load flow
flow = ctx.get("flow", None)

# create task
ai_task = AITask[cast](objective=task, context=context)
# create tasks
if task:
ai_tasks = [AITask[cast](objective=task, context=context)]
else:
ai_tasks = []

# run agent
agent = Agent(tasks=[ai_task], flow=flow, user_access=user_access, **agent_kwargs)
agent = Agent(tasks=ai_tasks, flow=flow, user_access=user_access, **agent_kwargs)
agent.run(model=model)

# return
if ai_task.status == TaskStatus.COMPLETED:
return ai_task.result
elif ai_task.status == TaskStatus.FAILED:
raise ValueError(ai_task.error)
if ai_tasks:
if ai_tasks[0].status == TaskStatus.COMPLETED:
return ai_tasks[0].result
elif ai_tasks[0].status == TaskStatus.FAILED:
raise ValueError(ai_tasks[0].error)
19 changes: 19 additions & 0 deletions src/control_flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from marvin.beta.assistants import Assistant, Thread
from marvin.beta.assistants.assistants import AssistantTool
from marvin.utilities.logging import get_logger
from openai.types.beta.threads import Message
from prefect import flow as prefect_flow
from prefect import task as prefect_task
from pydantic import BaseModel, Field, field_validator
Expand All @@ -19,6 +20,7 @@ class AIFlow(BaseModel):
assistant: Optional[Assistant] = Field(None, validate_default=True)
tools: list[Union[AssistantTool, Callable]] = Field(None, validate_default=True)
instructions: Optional[str] = None
model: Optional[str] = None

model_config: dict = dict(validate_assignment=True, extra="forbid")

Expand Down Expand Up @@ -56,6 +58,7 @@ def ai_flow(
thread: Thread = None,
tools: list[Union[AssistantTool, Callable]] = None,
instructions: str = None,
model: str = None,
):
"""
Prepare a function to be executed as a Control Flow flow.
Expand All @@ -68,6 +71,7 @@ def ai_flow(
thread=thread,
tools=tools,
instructions=instructions,
model=model,
)

@functools.wraps(fn)
Expand All @@ -77,6 +81,7 @@ def wrapper(
_thread: Thread = None,
_tools: list[Union[AssistantTool, Callable]] = None,
_instructions: str = None,
_model: str = None,
**kwargs,
):
p_fn = prefect_flow(fn)
Expand All @@ -89,11 +94,13 @@ def wrapper(
)
flow_instructions = _instructions or instructions
flow_tools = _tools or tools
flow_model = _model or model
flow_obj = AIFlow(
thread=flow_thread,
assistant=flow_assistant,
tools=flow_tools,
instructions=flow_instructions,
model=flow_model,
)

logger.info(
Expand All @@ -104,3 +111,15 @@ def wrapper(
return p_fn(*args, **kwargs)

return wrapper


def get_messages(limit: int = None) -> list[Message]:
"""
Loads messages from the flow's thread.
Will error if no flow is found in the context.
"""
flow: Optional[AIFlow] = ctx.get("flow")
if not flow:
raise ValueError("No flow found in context")
return flow.thread.get_messages(limit=limit)

0 comments on commit d2955ec

Please sign in to comment.