Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve handling of default values #32

Merged
merged 1 commit into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "controlflow"
version = "0.1.0"
version = "0.3.0"
description = "AI Workflows"
authors = [
{ name = "Jeremiah Lowin", email = "153965+jlowin@users.noreply.github.com" },
Expand All @@ -17,16 +17,20 @@ keywords = [
"ai",
"chatbot",
"llm",
"NLP",
"natural language processing",
"ai orchestration",
"llm orchestration",
"agentic workflows",
"flow engineering",
"prefect",
"workflow",
"orchestration",
"python",
"GPT",
"openai",
"assistant",
"agent",
"agents",
"AI agents",
"natural language processing",
]

[project.urls]
Expand Down
6 changes: 5 additions & 1 deletion src/controlflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

# from .agent_old import task, Agent, run_ai
from .core.flow import Flow, reset_global_flow as _reset_global_flow, flow
from .core.agent import Agent
from .core.task import Task, task
from .core.agent import Agent
from .core.controller.controller import Controller
from .instructions import instructions
from .dx import run_ai

Flow.model_rebuild()
Task.model_rebuild()
Agent.model_rebuild()

_reset_global_flow()
10 changes: 10 additions & 0 deletions src/controlflow/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@
logger = logging.getLogger(__name__)


def default_agent():
return Agent(
name="Marvin",
instructions="""
You are a diligent AI assistant. You complete
your tasks efficiently and without error.
""",
)


class Agent(Assistant, ControlFlowModel, ExposeSyncMethodsMixin):
name: str
user_access: bool = Field(
Expand Down
36 changes: 28 additions & 8 deletions src/controlflow/core/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class Controller(BaseModel, ExposeSyncMethodsMixin):

"""

# the flow is tracked by the Controller, not the Task, so that tasks can be
# defined and even instantiated outside a flow. When a Controller is
# created, we know we're inside a flow context and ready to load defaults
# and run.
flow: Flow = Field(
default_factory=get_flow,
description="The flow that the controller is a part of.",
Expand All @@ -65,6 +69,12 @@ def _create_graph(cls, data: Any) -> Any:
data["graph"] = Graph.from_tasks(data.get("tasks", []))
return data

@model_validator(mode="after")
def _finalize(self):
for task in self.tasks:
self.flow.add_task(task)
return self

@field_validator("tasks", mode="before")
def _validate_tasks(cls, v):
if v is None:
Expand Down Expand Up @@ -92,12 +102,21 @@ async def _run_agent(
"""

@prefect_task(task_run_name=f'Run Agent: "{agent.name}"')
async def _run_agent(agent: Agent, tasks: list[Task], thread: Thread = None):
async def _run_agent(
controller: Controller,
agent: Agent,
tasks: list[Task],
thread: Thread = None,
):
from controlflow.core.controller.instruction_template import MainTemplate

tasks = tasks or self.tasks
tasks = tasks or controller.tasks

tools = self.flow.tools + agent.get_tools() + [self._create_end_run_tool()]
tools = (
controller.flow.tools
+ agent.get_tools()
+ [controller._create_end_run_tool()]
)

# add tools for any inactive tasks that the agent is assigned to
for task in tasks:
Expand All @@ -106,12 +125,11 @@ async def _run_agent(agent: Agent, tasks: list[Task], thread: Thread = None):

instructions_template = MainTemplate(
agent=agent,
controller=self,
controller=controller,
tasks=tasks,
context=self.context,
context=controller.context,
instructions=get_instructions(),
)

instructions = instructions_template.render()

# filter tools because duplicate names are not allowed
Expand All @@ -126,7 +144,7 @@ async def _run_agent(agent: Agent, tasks: list[Task], thread: Thread = None):

run = Run(
assistant=agent,
thread=thread or self.flow.thread,
thread=thread or controller.flow.thread,
instructions=instructions,
tools=final_tools,
event_handler_class=AgentHandler,
Expand All @@ -146,7 +164,9 @@ async def _run_agent(agent: Agent, tasks: list[Task], thread: Thread = None):
)
return run

return await _run_agent(agent=agent, tasks=tasks, thread=thread)
return await _run_agent(
controller=self, agent=agent, tasks=tasks, thread=thread
)

@expose_sync_method("run_once")
async def run_once_async(self):
Expand Down
4 changes: 1 addition & 3 deletions src/controlflow/core/controller/instruction_template.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import inspect

from pydantic import BaseModel

from controlflow.core.agent import Agent
from controlflow.core.task import Task
from controlflow.utilities.jinja import jinja_env
Expand Down Expand Up @@ -187,7 +185,7 @@ def should_render(self):
return bool(self.flow_context or self.controller_context)


class MainTemplate(BaseModel):
class MainTemplate(ControlFlowModel):
agent: Agent
controller: Controller
context: dict
Expand Down
42 changes: 25 additions & 17 deletions src/controlflow/core/flow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import inspect
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Literal
from typing import TYPE_CHECKING, Any, Callable, Literal

import prefect
from marvin.beta.assistants import Thread
Expand All @@ -16,31 +17,23 @@

if TYPE_CHECKING:
from controlflow.core.agent import Agent
from controlflow.core.task import Task
logger = get_logger(__name__)


def default_agent():
from controlflow.core.agent import Agent

return [
Agent(
name="Marvin",
description="I am Marvin, the default agent for Control Flow.",
)
]


class Flow(ControlFlowModel):
thread: Thread = Field(None, validate_default=True)
tools: list[AssistantTool | Callable] = Field(
[], description="Tools that will be available to every agent in the flow"
default_factory=list,
description="Tools that will be available to every agent in the flow",
)
agents: list["Agent"] = Field(
default_factory=default_agent,
description="The default agents for the flow. These agents will be used "
"for any task that does not specify agents.",
default_factory=list,
)
context: dict = {}
_tasks: dict[str, "Task"] = {}
context: dict[str, Any] = {}

@field_validator("thread", mode="before")
def _load_thread_from_ctx(cls, v):
Expand All @@ -53,6 +46,13 @@ def _load_thread_from_ctx(cls, v):

return v

def add_task(self, task: "Task"):
if self._tasks.get(task.id, task) is not task:
raise ValueError(
f"A different task with id '{task.id}' already exists in flow."
)
self._tasks[task.id] = task

def add_message(self, message: str, role: Literal["user", "assistant"] = None):
prefect_task(self.thread.add)(message, role=role)

Expand Down Expand Up @@ -107,6 +107,7 @@ def flow(
fn=None,
*,
thread: Thread = None,
instructions: str = None,
tools: list[AssistantTool | Callable] = None,
agents: list["Agent"] = None,
):
Expand All @@ -122,12 +123,18 @@ def flow(
agents=agents,
)

sig = inspect.signature(fn)

@functools.wraps(fn)
def wrapper(
*args,
flow_kwargs: dict = None,
**kwargs,
):
# first process callargs
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()

flow_kwargs = flow_kwargs or {}

if thread is not None:
Expand All @@ -139,13 +146,14 @@ def wrapper(

p_fn = prefect.flow(fn)

flow_obj = Flow(**flow_kwargs)
flow_obj = Flow(**flow_kwargs, context=bound.arguments)

logger.info(
f'Executing AI flow "{fn.__name__}" on thread "{flow_obj.thread.id}"'
)

with ctx(flow=flow_obj), patch_marvin():
return p_fn(*args, **kwargs)
with controlflow.instructions.instructions(instructions):
return p_fn(*args, **kwargs)

return wrapper
2 changes: 1 addition & 1 deletion src/controlflow/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Edge(BaseModel):
type: EdgeType

def __repr__(self):
return f"{self.type}: {self.upstream.id} -> {self.downstream.id}"
return f"{self.type}: {self.upstream.friendly_name()} -> {self.downstream.friendly_name()}"

def __hash__(self) -> int:
return id(self)
Expand Down
Loading
Loading