Skip to content

Commit

Permalink
Merge pull request #32 from jlowin/defaults
Browse files Browse the repository at this point in the history
Improve handling of default values
  • Loading branch information
jlowin authored May 14, 2024
2 parents 8fd626e + 37f531c commit 6658173
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 47 deletions.
12 changes: 8 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "controlflow"
version = "0.1.0"
version = "0.3.0"
description = "AI Workflows"
authors = [
{ name = "Jeremiah Lowin", email = "153965+jlowin@users.noreply.github.com" },
Expand All @@ -17,16 +17,20 @@ keywords = [
"ai",
"chatbot",
"llm",
"NLP",
"natural language processing",
"ai orchestration",
"llm orchestration",
"agentic workflows",
"flow engineering",
"prefect",
"workflow",
"orchestration",
"python",
"GPT",
"openai",
"assistant",
"agent",
"agents",
"AI agents",
"natural language processing",
]

[project.urls]
Expand Down
6 changes: 5 additions & 1 deletion src/controlflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

# from .agent_old import task, Agent, run_ai
from .core.flow import Flow, reset_global_flow as _reset_global_flow, flow
from .core.agent import Agent
from .core.task import Task, task
from .core.agent import Agent
from .core.controller.controller import Controller
from .instructions import instructions
from .dx import run_ai

Flow.model_rebuild()
Task.model_rebuild()
Agent.model_rebuild()

_reset_global_flow()
10 changes: 10 additions & 0 deletions src/controlflow/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@
logger = logging.getLogger(__name__)


def default_agent():
return Agent(
name="Marvin",
instructions="""
You are a diligent AI assistant. You complete
your tasks efficiently and without error.
""",
)


class Agent(Assistant, ControlFlowModel, ExposeSyncMethodsMixin):
name: str
user_access: bool = Field(
Expand Down
36 changes: 28 additions & 8 deletions src/controlflow/core/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class Controller(BaseModel, ExposeSyncMethodsMixin):
"""

# the flow is tracked by the Controller, not the Task, so that tasks can be
# defined and even instantiated outside a flow. When a Controller is
# created, we know we're inside a flow context and ready to load defaults
# and run.
flow: Flow = Field(
default_factory=get_flow,
description="The flow that the controller is a part of.",
Expand All @@ -65,6 +69,12 @@ def _create_graph(cls, data: Any) -> Any:
data["graph"] = Graph.from_tasks(data.get("tasks", []))
return data

@model_validator(mode="after")
def _finalize(self):
for task in self.tasks:
self.flow.add_task(task)
return self

@field_validator("tasks", mode="before")
def _validate_tasks(cls, v):
if v is None:
Expand Down Expand Up @@ -92,12 +102,21 @@ async def _run_agent(
"""

@prefect_task(task_run_name=f'Run Agent: "{agent.name}"')
async def _run_agent(agent: Agent, tasks: list[Task], thread: Thread = None):
async def _run_agent(
controller: Controller,
agent: Agent,
tasks: list[Task],
thread: Thread = None,
):
from controlflow.core.controller.instruction_template import MainTemplate

tasks = tasks or self.tasks
tasks = tasks or controller.tasks

tools = self.flow.tools + agent.get_tools() + [self._create_end_run_tool()]
tools = (
controller.flow.tools
+ agent.get_tools()
+ [controller._create_end_run_tool()]
)

# add tools for any inactive tasks that the agent is assigned to
for task in tasks:
Expand All @@ -106,12 +125,11 @@ async def _run_agent(agent: Agent, tasks: list[Task], thread: Thread = None):

instructions_template = MainTemplate(
agent=agent,
controller=self,
controller=controller,
tasks=tasks,
context=self.context,
context=controller.context,
instructions=get_instructions(),
)

instructions = instructions_template.render()

# filter tools because duplicate names are not allowed
Expand All @@ -126,7 +144,7 @@ async def _run_agent(agent: Agent, tasks: list[Task], thread: Thread = None):

run = Run(
assistant=agent,
thread=thread or self.flow.thread,
thread=thread or controller.flow.thread,
instructions=instructions,
tools=final_tools,
event_handler_class=AgentHandler,
Expand All @@ -146,7 +164,9 @@ async def _run_agent(agent: Agent, tasks: list[Task], thread: Thread = None):
)
return run

return await _run_agent(agent=agent, tasks=tasks, thread=thread)
return await _run_agent(
controller=self, agent=agent, tasks=tasks, thread=thread
)

@expose_sync_method("run_once")
async def run_once_async(self):
Expand Down
4 changes: 1 addition & 3 deletions src/controlflow/core/controller/instruction_template.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import inspect

from pydantic import BaseModel

from controlflow.core.agent import Agent
from controlflow.core.task import Task
from controlflow.utilities.jinja import jinja_env
Expand Down Expand Up @@ -187,7 +185,7 @@ def should_render(self):
return bool(self.flow_context or self.controller_context)


class MainTemplate(BaseModel):
class MainTemplate(ControlFlowModel):
agent: Agent
controller: Controller
context: dict
Expand Down
42 changes: 25 additions & 17 deletions src/controlflow/core/flow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import inspect
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Literal
from typing import TYPE_CHECKING, Any, Callable, Literal

import prefect
from marvin.beta.assistants import Thread
Expand All @@ -16,31 +17,23 @@

if TYPE_CHECKING:
from controlflow.core.agent import Agent
from controlflow.core.task import Task
logger = get_logger(__name__)


def default_agent():
from controlflow.core.agent import Agent

return [
Agent(
name="Marvin",
description="I am Marvin, the default agent for Control Flow.",
)
]


class Flow(ControlFlowModel):
thread: Thread = Field(None, validate_default=True)
tools: list[AssistantTool | Callable] = Field(
[], description="Tools that will be available to every agent in the flow"
default_factory=list,
description="Tools that will be available to every agent in the flow",
)
agents: list["Agent"] = Field(
default_factory=default_agent,
description="The default agents for the flow. These agents will be used "
"for any task that does not specify agents.",
default_factory=list,
)
context: dict = {}
_tasks: dict[str, "Task"] = {}
context: dict[str, Any] = {}

@field_validator("thread", mode="before")
def _load_thread_from_ctx(cls, v):
Expand All @@ -53,6 +46,13 @@ def _load_thread_from_ctx(cls, v):

return v

def add_task(self, task: "Task"):
if self._tasks.get(task.id, task) is not task:
raise ValueError(
f"A different task with id '{task.id}' already exists in flow."
)
self._tasks[task.id] = task

def add_message(self, message: str, role: Literal["user", "assistant"] = None):
prefect_task(self.thread.add)(message, role=role)

Expand Down Expand Up @@ -107,6 +107,7 @@ def flow(
fn=None,
*,
thread: Thread = None,
instructions: str = None,
tools: list[AssistantTool | Callable] = None,
agents: list["Agent"] = None,
):
Expand All @@ -122,12 +123,18 @@ def flow(
agents=agents,
)

sig = inspect.signature(fn)

@functools.wraps(fn)
def wrapper(
*args,
flow_kwargs: dict = None,
**kwargs,
):
# first process callargs
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()

flow_kwargs = flow_kwargs or {}

if thread is not None:
Expand All @@ -139,13 +146,14 @@ def wrapper(

p_fn = prefect.flow(fn)

flow_obj = Flow(**flow_kwargs)
flow_obj = Flow(**flow_kwargs, context=bound.arguments)

logger.info(
f'Executing AI flow "{fn.__name__}" on thread "{flow_obj.thread.id}"'
)

with ctx(flow=flow_obj), patch_marvin():
return p_fn(*args, **kwargs)
with controlflow.instructions.instructions(instructions):
return p_fn(*args, **kwargs)

return wrapper
2 changes: 1 addition & 1 deletion src/controlflow/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Edge(BaseModel):
type: EdgeType

def __repr__(self):
return f"{self.type}: {self.upstream.id} -> {self.downstream.id}"
return f"{self.type}: {self.upstream.friendly_name()} -> {self.downstream.friendly_name()}"

def __hash__(self) -> int:
return id(self)
Expand Down
Loading

0 comments on commit 6658173

Please sign in to comment.