diff --git a/docs/concepts/agents.mdx b/docs/concepts/agents.mdx index 63c609b4..702df75f 100644 --- a/docs/concepts/agents.mdx +++ b/docs/concepts/agents.mdx @@ -25,7 +25,7 @@ An agent has the following key properties: - `name` (str): The name of the agent, which serves as an identifier and is visible to other agents in the workflow. - `description` (str, optional): A brief description of the agent's role or specialization, which is also visible to other agents. - `instructions` (str, optional): Specific instructions or guidelines for the agent to follow during task execution. These instructions are private to the agent and not visible to other agents. -- `tools` (List[ToolType], optional): A list of tools available to the agent. Tools are Python functions that the agent can call to perform specific actions or computations. +- `tools` (List[Callable], optional): A list of tools available to the agent. Tools are Python functions that the agent can call to perform specific actions or computations. - `user_access` (bool, default=False): Indicates whether the agent has access to user interactions. If set to `True`, the agent will be provided with the `talk_to_human` tool to communicate with users. These properties help define the agent's characteristics, behavior, and capabilities within the flow. diff --git a/docs/guides/deployment.mdx b/docs/guides/deployment.mdx index 00183869..abb982f3 100644 --- a/docs/guides/deployment.mdx +++ b/docs/guides/deployment.mdx @@ -9,9 +9,8 @@ Here are recommendations for production settings: ```bash -# set log levels to INFO for CF and Marvin +# set log levels to INFO CONTROLFLOW_LOGGING_LEVEL=INFO -MARVIN_LOGGING_LEVEL=INFO # disable local (terminal) inputs for flows that take user inputs CONTROLFLOW_ENABLE_LOCAL_INPUT=0 diff --git a/examples/documentation.py b/examples/documentation.py deleted file mode 100644 index 35394a7c..00000000 --- a/examples/documentation.py +++ /dev/null @@ -1,64 +0,0 @@ -import glob as glob_module -from pathlib import Path - -import controlflow -from controlflow import flow, task -from marvin.beta.assistants import Assistant, Thread -from marvin.tools.filesystem import read, write - -ROOT = Path(controlflow.__file__).parents[2] - - -def glob(pattern: str) -> list[str]: - """ - Returns a list of paths matching a valid glob pattern. - The pattern can include ** for recursive matching, such as - '~/path/to/root/**/*.py' - """ - return glob_module.glob(pattern, recursive=True) - - -assistant = Assistant( - instructions=""" - You are an expert technical writer who writes wonderful documentation for - open-source tools and believes that documentation is a product unto itself. - """, - tools=[read, write, glob], -) - - -@task(model="gpt-3.5-turbo") -def examine_source_code(source_dir: Path, extensions: list[str]): - """ - Load all matching files in the root dir and all subdirectories and - read them carefully. - """ - - -@task(model="gpt-3.5-turbo") -def read_docs(docs_dir: Path): - """ - Read all documentation in the docs dir and subdirectories, if any. - """ - - -@task -def write_docs(docs_dir: Path, instructions: str = None): - """ - Write new documentation based on the provided instructions. - """ - - -@flow(assistant=assistant) -def docs_flow(instructions: str): - examine_source_code(ROOT / "src", extensions=[".py"]) - # read_docs(ROOT / "docs") - write_docs(ROOT / "docs", instructions=instructions) - - -if __name__ == "__main__": - thread = Thread() - docs_flow( - _thread=thread, - instructions="Write documentation for the AI Flow class and save it in docs/flow.md", - ) diff --git a/pyproject.toml b/pyproject.toml index 2d2f8904..a2f583db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,12 +6,13 @@ authors = [ { name = "Jeremiah Lowin", email = "153965+jlowin@users.noreply.github.com" }, ] dependencies = [ - "marvin @ git+https://github.com/prefecthq/marvin@main", "prefect[dev] @ git+https://github.com/prefecthq/prefect@main", # can remove when prefect fully migrates to pydantic 2 "pydantic>=2", "textual>=0.61.1", "litellm>=1.37.17", + "jinja2>=3.1.4", + "pydantic-settings>=2.2.1", ] readme = "README.md" requires-python = ">= 3.9" @@ -79,7 +80,7 @@ skip-magic-trailing-comma = false "__init__.py" = ['I', 'F401', 'E402'] "conftest.py" = ["F401", "F403"] 'tests/fixtures/*.py' = ['F401', 'F403'] -"src/controlflow/utilities/types.py" = ['F401'] +# "src/controlflow/utilities/types.py" = ['F401'] [tool.pytest.ini_options] timeout = 120 diff --git a/requirements-dev.lock b/requirements-dev.lock index 860fae07..093abe1a 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -22,7 +22,6 @@ anyio==3.7.1 # via httpx # via openai # via prefect - # via starlette # via watchfiles apprise==1.7.5 # via prefect @@ -50,7 +49,6 @@ bytecode==0.15.1 # via ddtrace cachetools==5.3.3 # via google-auth - # via marvin # via prefect cairocffi==1.7.0 # via cairosvg @@ -121,8 +119,6 @@ execnet==2.1.1 # via pytest-xdist executing==2.0.1 # via stack-data -fastapi==0.110.0 - # via marvin filelock==3.13.3 # via huggingface-hub # via virtualenv @@ -154,7 +150,6 @@ httpcore==1.0.5 # via httpx # via prefect httpx==0.27.0 - # via marvin # via openai # via prefect # via respx @@ -187,10 +182,10 @@ itsdangerous==2.1.2 # via prefect jedi==0.19.1 # via ipython -jinja2==3.1.3 +jinja2==3.1.4 + # via controlflow # via jinja2-humanize-extension # via litellm - # via marvin # via mike # via mkdocs # via mkdocs-material @@ -203,7 +198,6 @@ jmespath==1.0.1 # via boto3 # via botocore jsonpatch==1.33 - # via marvin # via prefect jsonpointer==2.4 # via jsonpatch @@ -237,8 +231,6 @@ markupsafe==2.1.5 # via mkdocs-autorefs # via mkdocstrings # via werkzeug -marvin @ git+https://github.com/prefecthq/marvin@9ff559318af3dd1185ab7d6a1e85b39967915b81 - # via controlflow matplotlib-inline==0.1.6 # via ipython mdit-py-plugins==0.4.1 @@ -289,7 +281,6 @@ oauthlib==3.2.2 # via requests-oauthlib openai==1.28.1 # via litellm - # via marvin opentelemetry-api==1.24.0 # via ddtrace orjson==3.10.0 @@ -304,8 +295,6 @@ paginate==0.5.6 # via mkdocs-material parso==0.8.3 # via jedi -partialjson==0.0.7 - # via marvin pathspec==0.12.1 # via mkdocs # via prefect @@ -325,11 +314,10 @@ pluggy==1.4.0 # via pytest pre-commit==3.7.0 # via prefect -prefect @ git+https://github.com/prefecthq/prefect@aad7f63ccfc7767729f65fe20b0defb208b2c451 +prefect @ git+https://github.com/prefecthq/prefect@fa0ffc96b9740b6ab00a81b14b50b17285abdefb # via controlflow prompt-toolkit==3.0.43 # via ipython - # via marvin protobuf==5.26.1 # via ddtrace ptyprocess==0.7.0 @@ -347,8 +335,6 @@ pycparser==2.22 # via cffi pydantic==2.6.4 # via controlflow - # via fastapi - # via marvin # via openai # via prefect # via pydantic-settings @@ -356,7 +342,7 @@ pydantic-core==2.16.3 # via prefect # via pydantic pydantic-settings==2.2.1 - # via marvin + # via controlflow pygments==2.17.2 # via ipython # via mkdocs-material @@ -458,7 +444,6 @@ respx==0.21.1 rfc3339-validator==0.1.4 # via prefect rich==13.7.1 - # via marvin # via prefect # via textual # via typer-slim @@ -502,15 +487,12 @@ sqlparse==0.5.0 # via ddtrace stack-data==0.6.3 # via ipython -starlette==0.36.3 - # via fastapi text-unidecode==1.3 # via python-slugify textual==0.61.1 # via controlflow tiktoken==0.6.0 # via litellm - # via marvin time-machine==2.14.1 # via pendulum tinycss2==1.3.0 @@ -527,7 +509,6 @@ traitlets==5.14.2 # via ipython # via matplotlib-inline typer==0.12.0 - # via marvin # via prefect typer-cli==0.12.0 # via typer @@ -542,9 +523,7 @@ typing-extensions==4.10.0 # via aiosqlite # via alembic # via ddtrace - # via fastapi # via huggingface-hub - # via marvin # via mypy # via openai # via prefect @@ -554,7 +533,6 @@ typing-extensions==4.10.0 # via textual # via typer-slim tzdata==2024.1 - # via marvin # via pendulum tzlocal==5.2 # via dateparser @@ -569,7 +547,6 @@ urllib3==2.2.1 # via requests # via responses uvicorn==0.28.1 - # via marvin # via prefect vermin==1.6.0 # via prefect diff --git a/requirements.lock b/requirements.lock index 2557fb37..eef45d9f 100644 --- a/requirements.lock +++ b/requirements.lock @@ -22,7 +22,6 @@ anyio==3.7.1 # via httpx # via openai # via prefect - # via starlette # via watchfiles apprise==1.7.5 # via prefect @@ -50,7 +49,6 @@ bytecode==0.15.1 # via ddtrace cachetools==5.3.3 # via google-auth - # via marvin # via prefect cairocffi==1.7.0 # via cairosvg @@ -121,8 +119,6 @@ execnet==2.1.1 # via pytest-xdist executing==2.0.1 # via stack-data -fastapi==0.110.0 - # via marvin filelock==3.14.0 # via huggingface-hub # via virtualenv @@ -154,7 +150,6 @@ httpcore==1.0.5 # via httpx # via prefect httpx==0.27.0 - # via marvin # via openai # via prefect # via respx @@ -187,10 +182,10 @@ itsdangerous==2.1.2 # via prefect jedi==0.19.1 # via ipython -jinja2==3.1.3 +jinja2==3.1.4 + # via controlflow # via jinja2-humanize-extension # via litellm - # via marvin # via mike # via mkdocs # via mkdocs-material @@ -203,7 +198,6 @@ jmespath==1.0.1 # via boto3 # via botocore jsonpatch==1.33 - # via marvin # via prefect jsonpointer==2.4 # via jsonpatch @@ -237,8 +231,6 @@ markupsafe==2.1.5 # via mkdocs-autorefs # via mkdocstrings # via werkzeug -marvin @ git+https://github.com/prefecthq/marvin@9ff559318af3dd1185ab7d6a1e85b39967915b81 - # via controlflow matplotlib-inline==0.1.7 # via ipython mdit-py-plugins==0.4.1 @@ -289,7 +281,6 @@ oauthlib==3.2.2 # via requests-oauthlib openai==1.28.1 # via litellm - # via marvin opentelemetry-api==1.24.0 # via ddtrace orjson==3.10.0 @@ -304,8 +295,6 @@ paginate==0.5.6 # via mkdocs-material parso==0.8.4 # via jedi -partialjson==0.0.7 - # via marvin pathspec==0.12.1 # via mkdocs # via prefect @@ -325,11 +314,10 @@ pluggy==1.5.0 # via pytest pre-commit==3.7.1 # via prefect -prefect @ git+https://github.com/prefecthq/prefect@aad7f63ccfc7767729f65fe20b0defb208b2c451 +prefect @ git+https://github.com/prefecthq/prefect@fa0ffc96b9740b6ab00a81b14b50b17285abdefb # via controlflow prompt-toolkit==3.0.43 # via ipython - # via marvin protobuf==5.26.1 # via ddtrace ptyprocess==0.7.0 @@ -347,8 +335,6 @@ pycparser==2.22 # via cffi pydantic==2.6.4 # via controlflow - # via fastapi - # via marvin # via openai # via prefect # via pydantic-settings @@ -356,7 +342,7 @@ pydantic-core==2.16.3 # via prefect # via pydantic pydantic-settings==2.2.1 - # via marvin + # via controlflow pygments==2.17.2 # via ipython # via mkdocs-material @@ -458,7 +444,6 @@ respx==0.21.1 rfc3339-validator==0.1.4 # via prefect rich==13.7.1 - # via marvin # via prefect # via textual # via typer-slim @@ -502,15 +487,12 @@ sqlparse==0.5.0 # via ddtrace stack-data==0.6.3 # via ipython -starlette==0.36.3 - # via fastapi text-unidecode==1.3 # via python-slugify textual==0.61.1 # via controlflow tiktoken==0.6.0 # via litellm - # via marvin time-machine==2.14.1 # via pendulum tinycss2==1.3.0 @@ -527,7 +509,6 @@ traitlets==5.14.3 # via ipython # via matplotlib-inline typer==0.12.0 - # via marvin # via prefect typer-cli==0.12.0 # via typer @@ -542,9 +523,7 @@ typing-extensions==4.10.0 # via aiosqlite # via alembic # via ddtrace - # via fastapi # via huggingface-hub - # via marvin # via mypy # via openai # via prefect @@ -554,7 +533,6 @@ typing-extensions==4.10.0 # via textual # via typer-slim tzdata==2024.1 - # via marvin # via pendulum tzlocal==5.2 # via dateparser @@ -569,7 +547,6 @@ urllib3==2.2.1 # via requests # via responses uvicorn==0.28.1 - # via marvin # via prefect vermin==1.6.0 # via prefect diff --git a/src/controlflow/__init__.py b/src/controlflow/__init__.py index ef7533c1..a9fbaff7 100644 --- a/src/controlflow/__init__.py +++ b/src/controlflow/__init__.py @@ -1,5 +1,7 @@ from .settings import settings +from . import llm + from .core.flow import Flow from .core.task import Task from .core.agent import Agent diff --git a/src/controlflow/agents/agents.py b/src/controlflow/agents/agents.py deleted file mode 100644 index 20c297f3..00000000 --- a/src/controlflow/agents/agents.py +++ /dev/null @@ -1,42 +0,0 @@ -import marvin - -from controlflow.core.agent import Agent -from controlflow.instructions import get_instructions -from controlflow.utilities.context import ctx -from controlflow.utilities.threads import get_history - - -def choose_agent( - agents: list[Agent], - instructions: str = None, - context: dict = None, - model: str = None, -) -> Agent: - """ - Given a list of potential agents, choose the most qualified assistant to complete the tasks. - """ - - instructions = get_instructions() - history = [] - if (flow := ctx.get("flow")) and flow.thread.id: - history = get_history(thread_id=flow.thread.id) - - info = dict( - history=history, - global_instructions=instructions, - context=context, - ) - - agent = marvin.classify( - info, - agents, - instructions=""" - Given the conversation context, choose the AI agent most - qualified to take the next turn at completing the tasks. Take into - account the instructions, each agent's own instructions, and the - tools they have available. - """, - model_kwargs=dict(model=model), - ) - - return agent diff --git a/src/controlflow/core/agent.py b/src/controlflow/core/agent.py index 57ab3364..b26663c0 100644 --- a/src/controlflow/core/agent.py +++ b/src/controlflow/core/agent.py @@ -1,17 +1,12 @@ import logging -from typing import Callable, Optional, Union +from typing import Callable, Optional -from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method from pydantic import Field import controlflow -from controlflow.core.flow import Flow, get_flow -from controlflow.core.task import Task from controlflow.tools.talk_to_human import talk_to_human -from controlflow.utilities.prefect import ( - wrap_prefect_tool, -) -from controlflow.utilities.types import Assistant, ControlFlowModel, ToolType +from controlflow.utilities.asyncio import ExposeSyncMethodsMixin +from controlflow.utilities.types import ControlFlowModel logger = logging.getLogger(__name__) @@ -20,42 +15,6 @@ def get_default_agent() -> "Agent": return controlflow.default_agent -class AgentOLD(Assistant, ControlFlowModel, ExposeSyncMethodsMixin): - name: str - user_access: bool = Field( - False, - description="If True, the agent is given tools for interacting with a human user.", - ) - - def get_tools(self) -> list[ToolType]: - tools = super().get_tools() - if self.user_access: - tools.append(talk_to_human) - - return [wrap_prefect_tool(tool) for tool in tools] - - @expose_sync_method("run") - async def run_async( - self, - tasks: Union[list[Task], Task, None] = None, - flow: Flow = None, - ): - from controlflow.core.controller import Controller - - if isinstance(tasks, Task): - tasks = [tasks] - - flow = flow or get_flow() - - if not flow: - raise ValueError( - "Agents must be run within a flow context or with a flow argument." - ) - - controller = Controller(agents=[self], tasks=tasks or [], flow=flow) - return await controller.run_agent_async(agent=self) - - class Agent(ControlFlowModel, ExposeSyncMethodsMixin): name: str = Field( ..., @@ -88,34 +47,6 @@ def get_tools(self) -> list[Callable]: tools.append(talk_to_human) return tools - # def say( - # self, messages: Union[str, dict], thread_id: str = None, history: History = None - # ) -> Response: - - # if thread_id is None: - # thread_id = self.default_thread_id - # if history is None: - # history = get_default_history() - # if not isinstance(messages, list): - # raise ValueError("Messages must be provided as a list.") - - # messages = [ - # Message(role="user", content=m) if isinstance(m, str) else m - # for m in messages - # ] - # history_messages = history.load_messages(thread_id=thread_id, limit=50) - - # response = completion( - # messages=history_messages + messages, - # model=self.model, - # tools=self.tools, - # ) - # history.save_messages( - # thread_id=thread_id, - # messages=messages + history_messages + response.messages, - # ) - # return response - DEFAULT_AGENT = Agent( name="Marvin", diff --git a/src/controlflow/core/controller/controller.py b/src/controlflow/core/controller/controller.py index d5755070..ccb5b4ad 100644 --- a/src/controlflow/core/controller/controller.py +++ b/src/controlflow/core/controller/controller.py @@ -3,14 +3,13 @@ from collections import defaultdict from contextlib import asynccontextmanager from functools import cached_property -from typing import Union +from typing import Callable, Union -from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method from pydantic import BaseModel, Field, PrivateAttr, computed_field, model_validator import controlflow from controlflow.core.agent import Agent -from controlflow.core.controller.moderators import marvin_moderator +from controlflow.core.controller.moderators import classify_moderator from controlflow.core.flow import Flow, get_flow from controlflow.core.graph import Graph from controlflow.core.task import Task @@ -20,9 +19,9 @@ from controlflow.llm.history import History from controlflow.llm.messages import AssistantMessage, ControlFlowMessage, SystemMessage from controlflow.tui.app import TUIApp as TUI +from controlflow.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method from controlflow.utilities.context import ctx from controlflow.utilities.tasks import all_complete, any_incomplete -from controlflow.utilities.types import FunctionTool logger = logging.getLogger(__name__) @@ -87,7 +86,7 @@ def _finalize(self): self.flow.add_task(task) return self - def _create_end_turn_tool(self) -> FunctionTool: + def _create_end_turn_tool(self) -> Callable: def end_turn(): """ Call this tool to skip your turn and let another agent go next. This @@ -96,13 +95,17 @@ def end_turn(): automatically, so only use it if you are truly stuck and unable to proceed. """ - self._end_run_counts[ctx.get("controller_agent")] += 1 - if self._end_run_counts[ctx.get("controller_agent")] >= 3: + + # the agent's name is used as the key to track the number of times + key = getattr(ctx.get("controller_agent", None), "name", None) + + self._end_run_counts[key] += 1 + if self._end_run_counts[key] >= 3: self._should_abort = True - self._end_run_counts[ctx.get("controller_agent")] = 0 + self._end_run_counts[key] = 0 return ( - f"Ending turn. {3 - self._end_run_counts[ctx.get('controller_agent')]}" + f"Ending turn. {3 - self._end_run_counts[key]}" " more uses will abort the workflow." ) @@ -151,6 +154,7 @@ async def _run_agent(self, agent: Agent, tasks: list[Task]): tools=tools, handlers=handlers, max_iterations=1, + assistant_name=agent.name, stream=True, message_preprocessor=add_agent_name_to_message, ): @@ -173,7 +177,7 @@ async def _run_agent(self, agent: Agent, tasks: list[Task]): # ) def choose_agent(self, agents: list[Agent], tasks: list[Task]) -> Agent: - return marvin_moderator( + return classify_moderator( agents=agents, tasks=tasks, iteration=self._iteration, @@ -183,11 +187,13 @@ def choose_agent(self, agents: list[Agent], tasks: list[Task]) -> Agent: async def tui(self): if tui := ctx.get("tui"): yield tui - else: + elif controlflow.settings.enable_tui: tui = TUI(flow=self.flow) with ctx(tui=tui): - async with tui.run_context(run=controlflow.settings.enable_tui): + async with tui.run_context(): yield tui + else: + yield @expose_sync_method("run_once") async def run_once_async(self): @@ -228,7 +234,6 @@ async def run_once_async(self): elif len(agents) == 1: agent = agents[0] else: - raise NotImplementedError("Need to reimplement multi-agent") agent = self.choose_agent(agents=agents, tasks=tasks) with ctx(controller_agent=agent): diff --git a/src/controlflow/core/controller/instruction_template.py b/src/controlflow/core/controller/instruction_template.py index 0ced8893..b35d847b 100644 --- a/src/controlflow/core/controller/instruction_template.py +++ b/src/controlflow/core/controller/instruction_template.py @@ -19,7 +19,9 @@ def render(self) -> str: if self.should_render(): render_kwargs = dict(self) render_kwargs.pop("template") - return jinja_env.render(inspect.cleandoc(self.template), **render_kwargs) + return jinja_env.from_string(inspect.cleandoc(self.template)).render( + **render_kwargs + ) class AgentTemplate(Template): @@ -131,10 +133,14 @@ class CommunicationTemplate(Template): messages unless you need to record information in addition to what you provide as a task's result, or for the following reasons: - - You need to post a message or otherwise communicate to complete a task. + - You need to post a message or otherwise communicate to complete a + task. For example, the task instructs you to write, discuss, or + otherwise produce content (and does not accept a result, or the result + that meets the objective is different than the instructed actions). - You need to communicate with other agents to complete a task. - You want to write your thought process for future reference. + Note that You may see other agents post messages; they may have different instructions than you do, so do not follow their example automatically. diff --git a/src/controlflow/core/controller/moderators.py b/src/controlflow/core/controller/moderators.py index 195b0b0f..ae264e50 100644 --- a/src/controlflow/core/controller/moderators.py +++ b/src/controlflow/core/controller/moderators.py @@ -1,9 +1,8 @@ -import marvin - from controlflow.core.agent import Agent from controlflow.core.flow import get_flow_messages from controlflow.core.task import Task from controlflow.instructions import get_instructions +from controlflow.llm.classify import classify def round_robin( @@ -15,7 +14,7 @@ def round_robin( return agents[iteration % len(agents)] -def marvin_moderator( +def classify_moderator( agents: list[Agent], tasks: list[Task], context: dict = None, @@ -26,9 +25,9 @@ def marvin_moderator( instructions = get_instructions() context = context or {} context.update(tasks=tasks, history=history, instructions=instructions) - agent = marvin.classify( + agent = classify( context, - agents, + labels=agents, instructions=""" Given the context, choose the AI agent best suited to take the next turn at completing the tasks in the task graph. Take into account @@ -37,6 +36,6 @@ def marvin_moderator( completed before their downstream/parents can be completed. An agent can only work on a task that it is assigned to. """, - model_kwargs=dict(model=model) if model else None, + model=model, ) return agent diff --git a/src/controlflow/core/flow.py b/src/controlflow/core/flow.py index c81d120c..c8897d51 100644 --- a/src/controlflow/core/flow.py +++ b/src/controlflow/core/flow.py @@ -4,6 +4,7 @@ from pydantic import Field +from controlflow.llm.history import get_default_history from controlflow.llm.messages import MessageType from controlflow.utilities.context import ctx from controlflow.utilities.logging import get_logger @@ -82,6 +83,6 @@ def get_flow_messages(limit: int = None) -> list[MessageType]: """ flow = get_flow() if flow: - return flow.thread.get_messages(limit=limit) + get_default_history().load_messages(flow.thread_id, limit=limit) else: return [] diff --git a/src/controlflow/core/task.py b/src/controlflow/core/task.py index 65a6dd1d..1beac2e0 100644 --- a/src/controlflow/core/task.py +++ b/src/controlflow/core/task.py @@ -39,7 +39,6 @@ ControlFlowModel, PandasDataFrame, PandasSeries, - ToolType, ) if TYPE_CHECKING: @@ -219,7 +218,7 @@ def _serialize_agents(self, agents: Optional[list["Agent"]]): ] @field_serializer("tools") - def _serialize_tools(self, tools: list[ToolType]): + def _serialize_tools(self, tools: list[Callable]): tools = controlflow.llm.tools.as_tools(tools) return [t.model_dump({"name", "description"}) for t in tools] diff --git a/src/controlflow/decorators.py b/src/controlflow/decorators.py index b8aa62fa..faa659a6 100644 --- a/src/controlflow/decorators.py +++ b/src/controlflow/decorators.py @@ -1,17 +1,17 @@ import functools import inspect +from typing import Callable import prefect -from marvin.beta.assistants import Thread import controlflow from controlflow.core.agent import Agent from controlflow.core.flow import Flow from controlflow.core.task import Task from controlflow.utilities.logging import get_logger -from controlflow.utilities.marvin import patch_marvin + +# from controlflow.utilities.marvin import patch_marvin from controlflow.utilities.tasks import resolve_tasks -from controlflow.utilities.types import ToolType logger = get_logger(__name__) @@ -19,9 +19,9 @@ def flow( fn=None, *, - thread: Thread = None, + thread: str = None, instructions: str = None, - tools: list[ToolType] = None, + tools: list[Callable] = None, agents: list["Agent"] = None, lazy: bool = None, ): @@ -38,9 +38,9 @@ def flow( Args: fn (callable, optional): The function to be wrapped as a flow. If not provided, the decorator will act as a partial function and return a new flow decorator. - thread (Thread, optional): The thread to execute the flow on. Defaults to None. + thread (str, optional): The thread to execute the flow on. Defaults to None. instructions (str, optional): Instructions for the flow. Defaults to None. - tools (list[ToolType], optional): List of tools to be used in the flow. Defaults to None. + tools (list[Callable], optional): List of tools to be used in the flow. Defaults to None. agents (list[Agent], optional): List of agents to be used in the flow. Defaults to None. lazy (bool, optional): Whether the flow should be run lazily. If not set, behavior is determined by the global `eager_mode` setting. Lazy execution means @@ -92,7 +92,7 @@ def wrapper( # create a function to wrap as a Prefect flow @prefect.flow def wrapped_flow(*args, lazy_=None, **kwargs): - with flow_obj, patch_marvin(): + with flow_obj: with controlflow.instructions(instructions): result = fn(*args, **kwargs) @@ -126,7 +126,7 @@ def task( objective: str = None, instructions: str = None, agents: list["Agent"] = None, - tools: list[ToolType] = None, + tools: list[Callable] = None, user_access: bool = None, lazy: bool = None, ): @@ -147,7 +147,7 @@ def task( instructions (str, optional): Instructions for the task. Defaults to None, in which case the function docstring is used as the instructions. agents (list[Agent], optional): List of agents to be used in the task. Defaults to None. - tools (list[ToolType], optional): List of tools to be used in the task. Defaults to None. + tools (list[Callable], optional): List of tools to be used in the task. Defaults to None. user_access (bool, optional): Whether the task requires user access. Defaults to None, in which case it is set to False. lazy (bool, optional): Whether the task should be run lazily. If not diff --git a/src/controlflow/llm/classify.py b/src/controlflow/llm/classify.py new file mode 100644 index 00000000..a92ba282 --- /dev/null +++ b/src/controlflow/llm/classify.py @@ -0,0 +1,90 @@ +from pydantic import TypeAdapter + +import controlflow +from controlflow.llm.messages import AssistantMessage, SystemMessage, UserMessage + + +def classify( + data: str, + labels: list, + instructions: str = None, + context: dict = None, + model: str = None, +): + try: + label_strings = [TypeAdapter(type(t)).dump_json(t).decode() for t in labels] + except Exception as exc: + raise ValueError(f"Unable to cast labels to strings: {exc}") + + messages = [ + SystemMessage( + """ + You are an expert classifier that always maintains as much semantic meaning + as possible when labeling information. You use inference or deduction whenever + necessary to understand missing or omitted data. Classify the provided data, + text, or information as one of the provided labels. For boolean labels, + consider "truthy" or affirmative inputs to be "true". + + ## Labels + + You must classify the data as one of the following labels, which are + numbered (starting from 0) and provide a brief description. Output + the label number only. + + {% for label in labels %} + - Label #{{ loop.index0 }}: {{ label }} + {% endfor %} + """ + ).render(labels=label_strings), + UserMessage( + """ + ## Information to classify + + {{ data }} + + {% if instructions -%} + ## Additional instructions + + {{ instructions }} + {% endif %} + + {% if context -%} + ## Additional context + + {% for key, value in context.items() -%} + - {{ key }}: {{ value }} + + {% endfor %} + {% endif %} + + """ + ).render(data=data, instructions=instructions, context=context), + AssistantMessage(""" + The best label for the data is Label # + """), + ] + + result = controlflow.llm.completions.completion( + messages=messages, + model=model, + max_tokens=1, + logit_bias={ + str(encoding): 100 + for i in range(len(labels)) + for encoding in _encoder(model)(str(i)) + }, + ) + + index = int(result[0].content) + return labels[index] + + +def _encoder(model: str): + import tiktoken + + try: + encoding = tiktoken.encoding_for_model(model) + except (KeyError, AttributeError): + encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") + + return encoding.encode diff --git a/src/controlflow/llm/completions.py b/src/controlflow/llm/completions.py index 48f206fc..c30b91ad 100644 --- a/src/controlflow/llm/completions.py +++ b/src/controlflow/llm/completions.py @@ -20,10 +20,10 @@ def completion( messages: list[Union[dict, ControlFlowMessage]], - model=None, + model: str = None, tools: list[Callable] = None, assistant_name: str = None, - max_iterations=None, + max_iterations: int = None, handlers: list[CompletionHandler] = None, message_preprocessor: Callable[[ControlFlowMessage], ControlFlowMessage] = None, stream: bool = False, @@ -124,7 +124,7 @@ def completion( def _completion_stream( messages: list[Union[dict, ControlFlowMessage]], - model=None, + model: str = None, tools: list[Callable] = None, assistant_name: str = None, max_iterations: int = None, @@ -243,10 +243,10 @@ def _completion_stream( async def completion_async( messages: list[Union[dict, ControlFlowMessage]], - model=None, + model: str = None, tools: list[Callable] = None, assistant_name: str = None, - max_iterations=None, + max_iterations: int = None, handlers: list[CompletionHandler] = None, message_preprocessor: Callable[[ControlFlowMessage], ControlFlowMessage] = None, stream: bool = False, @@ -344,30 +344,9 @@ async def completion_async( handler.on_end() -""" -Perform asynchronous streaming completion using the LLM model. - -Args: - messages: A list of messages to be used for completion. - model: The LLM model to be used for completion. If not provided, the default model from controlflow.settings will be used. - tools: A list of callable tools to be used during completion. - assistant_name: The name of the assistant, which will be set as the `name` attribute of any messages it generates. - max_iterations: The maximum number of iterations to perform completion. If not provided, it will continue until completion is done. - handlers: A list of CompletionHandler objects to handle completion events. - message_preprocessor: A callable function to preprocess each ControlFlowMessage before completion. - **kwargs: Additional keyword arguments to be passed to the litellm.acompletion function. - -Yields: - Each ControlFlowMessage generated during completion. - -Returns: - The final completion response as a list of ControlFlowMessage objects. -""" - - async def _completion_stream_async( messages: list[Union[dict, ControlFlowMessage]], - model=None, + model: str = None, tools: list[Callable] = None, assistant_name: str = None, max_iterations: int = None, @@ -383,10 +362,13 @@ async def _completion_stream_async( model: The LLM model to be used for completion. If not provided, the default model from controlflow.settings will be used. tools: A list of callable tools to be used during completion. assistant_name: The name of the assistant, which will be set as the `name` attribute of any messages it generates. + max_iterations: The maximum number of iterations to perform completion. If not provided, it will continue until completion is done. + handlers: A list of CompletionHandler objects to handle completion events. + message_preprocessor: A callable function to preprocess each ControlFlowMessage before completion. **kwargs: Additional keyword arguments to be passed to the litellm.acompletion function. Yields: - Each message + Each ControlFlowMessage generated during completion. Returns: The final completion response as a list of ControlFlowMessage objects. @@ -476,5 +458,6 @@ async def _completion_stream_async( break except Exception as exc: handler.on_exception(exc) + raise finally: handler.on_end() diff --git a/src/controlflow/llm/formatting.py b/src/controlflow/llm/formatting.py index e1843c98..dc6d7bea 100644 --- a/src/controlflow/llm/formatting.py +++ b/src/controlflow/llm/formatting.py @@ -41,9 +41,14 @@ def format_message( def format_text_message(message: MessageType) -> Panel: + if message.role == "assistant" and message.name: + title = f"Agent: {message.name}" + else: + title = ROLE_NAMES.get(message.role, "Agent") + return Panel( Markdown(message.content or ""), - title=f"[bold]{ROLE_NAMES.get(message.role, 'Agent')}[/]", + title=f"[bold]{title}[/]", subtitle=f"[italic]{format_timestamp(message.timestamp)}[/]", title_align="left", subtitle_align="right", diff --git a/src/controlflow/llm/messages.py b/src/controlflow/llm/messages.py index 50cd7c21..91e2928e 100644 --- a/src/controlflow/llm/messages.py +++ b/src/controlflow/llm/messages.py @@ -1,4 +1,5 @@ import datetime +import inspect import json import uuid from typing import Any, List, Literal, Optional, Union @@ -12,6 +13,7 @@ model_validator, ) +from controlflow.utilities.jinja import jinja_env from controlflow.utilities.types import _OpenAIBaseType # ----------------------------------------------- @@ -67,7 +69,7 @@ def _lowercase_role(cls, v): @field_validator("timestamp", mode="before") def _validate_timestamp(cls, v): if isinstance(v, int): - v = datetime.datetime.fromtimestamp(v) + v = datetime.datetime.fromtimestamp(v, tz=datetime.timezone.utc) return v @model_validator(mode="after") @@ -94,6 +96,20 @@ class SystemMessage(ControlFlowMessage): _openai_fields = {"role", "content", "name"} # ---- end openai fields + @field_validator("content", mode="before") + def _validate_content(cls, v): + if isinstance(v, str): + v = inspect.cleandoc(v) + return v + + def render(self, **kwargs) -> "SystemMessage": + """ + Renders the content as a jinja template with the given keyword arguments + and returns a new SystemMessage. + """ + content = jinja_env.from_string(self.content).render(**kwargs) + return self.model_copy(update=dict(content=content)) + class UserMessage(ControlFlowMessage): # ---- begin openai fields @@ -106,9 +122,24 @@ class UserMessage(ControlFlowMessage): @field_validator("content", mode="before") def _validate_content(cls, v): if isinstance(v, str): + v = inspect.cleandoc(v) v = [TextContent(text=v)] return v + def render(self, **kwargs) -> "UserMessage": + """ + Renders the content as a jinja template with the given keyword arguments + and returns a new SystemMessage. + """ + content = [] + for c in self.content: + if isinstance(c, TextContent): + text = jinja_env.from_string(c.text).render(**kwargs) + content.append(TextContent(text=text)) + else: + content.append(c) + return self.model_copy(update=dict(content=content)) + class AssistantMessage(ControlFlowMessage): """A message from the assistant.""" @@ -126,9 +157,26 @@ class AssistantMessage(ControlFlowMessage): description="If True, this message is a streamed delta, or chunk, of a full message.", ) + @field_validator("content", mode="before") + def _validate_content(cls, v): + if isinstance(v, str): + v = inspect.cleandoc(v) + return v + def has_tool_calls(self): return bool(self.tool_calls) + def render(self, **kwargs) -> "AssistantMessage": + """ + Renders the content as a jinja template with the given keyword arguments + and returns a new AssistantMessage. + """ + if self.content is None: + content = self.content + else: + content = jinja_env.from_string(self.content).render(**kwargs) + return self.model_copy(update=dict(content=content)) + class ToolMessage(ControlFlowMessage): """A message for reporting the result of a tool call.""" @@ -144,6 +192,12 @@ class ToolMessage(ControlFlowMessage): tool_result: Any = Field(None, exclude=True) tool_metadata: dict = Field(default_factory=dict) + @field_validator("content", mode="before") + def _validate_content(cls, v): + if isinstance(v, str): + v = inspect.cleandoc(v) + return v + MessageType = Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage] diff --git a/src/controlflow/settings.py b/src/controlflow/settings.py index cc823936..2ec6bc75 100644 --- a/src/controlflow/settings.py +++ b/src/controlflow/settings.py @@ -60,7 +60,6 @@ class Settings(ControlFlowSettings): "on a task.", ) prefect: PrefectSettings = Field(default_factory=PrefectSettings) - openai_api_key: Optional[str] = Field(None, validate_assignment=True) # ------------ home settings ------------ @@ -119,14 +118,6 @@ def __init__(self, **data): super().__init__(**data) self.prefect.apply() - @field_validator("openai_api_key", mode="after") - def _apply_api_key(cls, v): - if v is not None: - import marvin - - marvin.settings.openai.api_key = v - return v - @field_validator("home_path", mode="before") def _validate_home_path(cls, v): v = Path(v).expanduser() diff --git a/src/controlflow/tasks/auto_tasks.py b/src/controlflow/tasks/auto_tasks.py index 7157598b..5c57797a 100644 --- a/src/controlflow/tasks/auto_tasks.py +++ b/src/controlflow/tasks/auto_tasks.py @@ -5,7 +5,7 @@ from controlflow.core.agent import Agent from controlflow.core.task import Task -from controlflow.utilities.types import AssistantTool, ControlFlowModel, FunctionTool +from controlflow.utilities.types import AssistantTool, ControlFlowModel ToolLiteral = TypeVar("ToolLiteral", bound=str) @@ -176,12 +176,7 @@ def auto_tasks( ) -> list[Task]: tool_names = [] for tool in available_tools or []: - if isinstance(tool, FunctionTool): - tool_names.append(tool.function.name) - elif isinstance(tool, AssistantTool): - tool_names.append(tool.type) - else: - tool_names.append(tool.__name__) + tool_names.append(tool.__name__) if tool_names: literal_tool_names = Literal[*tool_names] # type: ignore diff --git a/src/controlflow/tools/talk_to_human.py b/src/controlflow/tools/talk_to_human.py index 4455137c..cd58a367 100644 --- a/src/controlflow/tools/talk_to_human.py +++ b/src/controlflow/tools/talk_to_human.py @@ -2,11 +2,11 @@ import contextlib from typing import TYPE_CHECKING -from marvin.utilities.tools import tool_from_function from prefect.context import FlowRunContext from prefect.input.run_input import receive_input import controlflow +from controlflow.llm.tools import tool from controlflow.utilities.context import ctx if TYPE_CHECKING: @@ -36,7 +36,7 @@ async def listen_for_response(): return response -@tool_from_function +@tool async def talk_to_human(message: str, get_response: bool = True) -> str: """ Send a message to the human user and optionally wait for a response. diff --git a/src/controlflow/tui/test2.py b/src/controlflow/tui/test2.py index f78ae9fa..378af6b8 100644 --- a/src/controlflow/tui/test2.py +++ b/src/controlflow/tui/test2.py @@ -1,12 +1,9 @@ import asyncio -from marvin.utilities.asyncio import run_sync - from controlflow import Task from controlflow.core.flow import Flow from controlflow.tui.app import TUIApp -run_sync asyncio with Flow() as flow: t = Task("get the user name", user_access=True) diff --git a/src/controlflow/utilities/asyncio.py b/src/controlflow/utilities/asyncio.py new file mode 100644 index 00000000..63d7e623 --- /dev/null +++ b/src/controlflow/utilities/asyncio.py @@ -0,0 +1,94 @@ +import asyncio +import functools +from typing import Any, Callable, Coroutine, TypeVar, cast + +from prefect.utilities.asyncutils import run_sync + +T = TypeVar("T") + +BACKGROUND_TASKS = set() + + +def create_task(coro): + """ + Creates async background tasks in a way that is safe from garbage + collection. + + See + https://textual.textualize.io/blog/2023/02/11/the-heisenbug-lurking-in-your-async-code/ + + Example: + + async def my_coro(x: int) -> int: + return x + 1 + + # safely submits my_coro for background execution + create_task(my_coro(1)) + """ # noqa: E501 + task = asyncio.create_task(coro) + BACKGROUND_TASKS.add(task) + task.add_done_callback(BACKGROUND_TASKS.discard) + return task + + +class ExposeSyncMethodsMixin: + """ + A mixin that can take functions decorated with `expose_sync_method` + and automatically create synchronous versions. + """ + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + for method in list(cls.__dict__.values()): + if callable(method) and hasattr(method, "_sync_name"): + sync_method_name = method._sync_name + setattr(cls, sync_method_name, method._sync_wrapper) + + +def expose_sync_method(name: str) -> Callable[..., Any]: + """ + Decorator that automatically exposes synchronous versions of async methods. + Note it doesn't work with classmethods. + + Args: + name: The name of the synchronous method. + + Returns: + The decorated function. + + Example: + Basic usage: + ```python + class MyClass(ExposeSyncMethodsMixin): + + @expose_sync_method("my_method") + async def my_method_async(self): + return 42 + + my_instance = MyClass() + await my_instance.my_method_async() # returns 42 + my_instance.my_method() # returns 42 + ``` + """ + + def decorator( + async_method: Callable[..., Coroutine[Any, Any, T]], + ) -> Callable[..., Coroutine[Any, Any, T]]: + @functools.wraps(async_method) + def sync_wrapper(*args: Any, **kwargs: Any) -> T: + coro = async_method(*args, **kwargs) + return run_sync(coro) + + # Cast the sync_wrapper to the same type as the async_method to give the + # type checker the needed information. + casted_sync_wrapper = cast(Callable[..., T], sync_wrapper) + + # Attach attributes to the async wrapper + setattr(async_method, "_sync_wrapper", casted_sync_wrapper) + setattr(async_method, "_sync_name", name) + + # return the original async method; the sync wrapper will be added to + # the class by the init hook + return async_method + + return decorator diff --git a/src/controlflow/utilities/context.py b/src/controlflow/utilities/context.py index 007f78b1..4640002b 100644 --- a/src/controlflow/utilities/context.py +++ b/src/controlflow/utilities/context.py @@ -1,4 +1,73 @@ -from marvin.utilities.context import ScopedContext +"""Module for defining context utilities.""" + +import contextvars +from contextlib import contextmanager +from typing import Any, Generator + + +class ScopedContext: + """ + `ScopedContext` provides a context management mechanism using `contextvars`. + + This class allows setting and retrieving key-value pairs in a scoped context, + which is preserved across asynchronous tasks and threads within the same context. + + Attributes: + _context_storage (ContextVar): A context variable to store the context data. + + Example: + Basic Usage of ScopedContext + ```python + context = ScopedContext() + with context(key="value"): + assert context.get("key") == "value" + # Outside the context, the value is no longer available. + assert context.get("key") is None + ``` + """ + + def __init__(self, initial_value: dict = None): + """Initializes the ScopedContext with an initial valuedictionary.""" + self._context_storage = contextvars.ContextVar( + "scoped_context_storage", default=initial_value or {} + ) + + def get(self, key: str, default: Any = None) -> Any: + return self._context_storage.get().get(key, default) + + def __getitem__(self, key: str) -> Any: + notfound = object() + result = self.get(key, default=notfound) + if result == notfound: + raise KeyError(key) + return result + + def set(self, **kwargs: Any) -> None: + ctx = self._context_storage.get() + updated_ctx = {**ctx, **kwargs} + token = self._context_storage.set(updated_ctx) + return token + + @contextmanager + def __call__(self, **kwargs: Any) -> Generator[None, None, Any]: + current_context_copy = self._context_storage.get().copy() + token = self.set(**kwargs) + try: + yield + finally: + try: + self._context_storage.reset(token) + except ValueError as exc: + if "was created in a different context" in str(exc).lower(): + # the only way we can reach this line is if the setup and + # teardown of this context are run in different frames or + # threads (which happens with pytest fixtures!), in which case + # the token is considered invalid. This catch serves as a + # "manual" reset of the context values + self._context_storage.set(current_context_copy) + else: + raise + ctx = ScopedContext( dict( diff --git a/src/controlflow/utilities/jinja.py b/src/controlflow/utilities/jinja.py index 92d4ca71..30649c59 100644 --- a/src/controlflow/utilities/jinja.py +++ b/src/controlflow/utilities/jinja.py @@ -1,13 +1,23 @@ import inspect +import os from datetime import datetime from zoneinfo import ZoneInfo -from marvin.utilities.jinja import BaseEnvironment +from jinja2 import Environment as JinjaEnvironment +from jinja2 import StrictUndefined, select_autoescape -jinja_env = BaseEnvironment( - globals={ +jinja_env = JinjaEnvironment( + autoescape=select_autoescape(default_for_string=False), + trim_blocks=True, + lstrip_blocks=True, + auto_reload=True, + undefined=StrictUndefined, +) + +jinja_env.globals.update( + { "now": lambda: datetime.now(ZoneInfo("UTC")), "inspect": inspect, - "id": id, + "getcwd": os.getcwd, } ) diff --git a/src/controlflow/utilities/logging.py b/src/controlflow/utilities/logging.py index c04b4c79..04909d31 100644 --- a/src/controlflow/utilities/logging.py +++ b/src/controlflow/utilities/logging.py @@ -2,8 +2,6 @@ from functools import lru_cache from typing import Optional -from marvin.utilities.logging import add_logging_methods - @lru_cache() def get_logger(name: Optional[str] = None) -> logging.Logger: @@ -40,5 +38,4 @@ def get_logger(name: Optional[str] = None) -> logging.Logger: else: logger = parent_logger - add_logging_methods(logger) return logger diff --git a/src/controlflow/utilities/marvin.py b/src/controlflow/utilities/marvin.py index 5cf60426..9de99a46 100644 --- a/src/controlflow/utilities/marvin.py +++ b/src/controlflow/utilities/marvin.py @@ -1,93 +1,93 @@ -import inspect -from contextlib import contextmanager -from typing import Any, Callable +# import inspect +# from contextlib import contextmanager +# from typing import Any, Callable -import marvin.ai.text -from marvin.client.openai import AsyncMarvinClient -from marvin.settings import temporary_settings as temporary_marvin_settings -from openai.types.chat import ChatCompletion -from prefect import task as prefect_task +# import marvin.ai.text +# from marvin.client.openai import AsyncMarvinClient +# from marvin.settings import temporary_settings as temporary_marvin_settings +# from openai.types.chat import ChatCompletion +# from prefect import task as prefect_task -from controlflow.utilities.prefect import ( - create_json_artifact, -) +# from controlflow.utilities.prefect import ( +# create_json_artifact, +# ) -original_classify_async = marvin.classify_async -original_cast_async = marvin.cast_async -original_extract_async = marvin.extract_async -original_generate_async = marvin.generate_async -original_paint_async = marvin.paint_async -original_speak_async = marvin.speak_async -original_transcribe_async = marvin.transcribe_async +# original_classify_async = marvin.classify_async +# original_cast_async = marvin.cast_async +# original_extract_async = marvin.extract_async +# original_generate_async = marvin.generate_async +# original_paint_async = marvin.paint_async +# original_speak_async = marvin.speak_async +# original_transcribe_async = marvin.transcribe_async -class AsyncControlFlowClient(AsyncMarvinClient): - async def generate_chat(self, **kwargs: Any) -> "ChatCompletion": - super_method = super().generate_chat +# class AsyncControlFlowClient(AsyncMarvinClient): +# async def generate_chat(self, **kwargs: Any) -> "ChatCompletion": +# super_method = super().generate_chat - @prefect_task(task_run_name="Generate OpenAI chat completion") - async def _generate_chat(**kwargs): - messages = kwargs.get("messages", []) - create_json_artifact(key="prompt", data=messages) - response = await super_method(**kwargs) - create_json_artifact(key="response", data=response) - return response +# @prefect_task(task_run_name="Generate OpenAI chat completion") +# async def _generate_chat(**kwargs): +# messages = kwargs.get("messages", []) +# create_json_artifact(key="prompt", data=messages) +# response = await super_method(**kwargs) +# create_json_artifact(key="response", data=response) +# return response - return await _generate_chat(**kwargs) +# return await _generate_chat(**kwargs) -def generate_task(name: str, original_fn: Callable): - if inspect.iscoroutinefunction(original_fn): +# def generate_task(name: str, original_fn: Callable): +# if inspect.iscoroutinefunction(original_fn): - @prefect_task(name=name) - async def wrapper(*args, **kwargs): - create_json_artifact(key="args", data=[args, kwargs]) - result = await original_fn(*args, **kwargs) - create_json_artifact(key="result", data=result) - return result - else: +# @prefect_task(name=name) +# async def wrapper(*args, **kwargs): +# create_json_artifact(key="args", data=[args, kwargs]) +# result = await original_fn(*args, **kwargs) +# create_json_artifact(key="result", data=result) +# return result +# else: - @prefect_task(name=name) - def wrapper(*args, **kwargs): - create_json_artifact(key="args", data=[args, kwargs]) - result = original_fn(*args, **kwargs) - create_json_artifact(key="result", data=result) - return result +# @prefect_task(name=name) +# def wrapper(*args, **kwargs): +# create_json_artifact(key="args", data=[args, kwargs]) +# result = original_fn(*args, **kwargs) +# create_json_artifact(key="result", data=result) +# return result - return wrapper +# return wrapper -@contextmanager -def patch_marvin(): - with temporary_marvin_settings(default_async_client_cls=AsyncControlFlowClient): - try: - marvin.ai.text.classify_async = generate_task( - "marvin.classify", original_classify_async - ) - marvin.ai.text.cast_async = generate_task( - "marvin.cast", original_cast_async - ) - marvin.ai.text.extract_async = generate_task( - "marvin.extract", original_extract_async - ) - marvin.ai.text.generate_async = generate_task( - "marvin.generate", original_generate_async - ) - marvin.ai.images.paint_async = generate_task( - "marvin.paint", original_paint_async - ) - marvin.ai.audio.speak_async = generate_task( - "marvin.speak", original_speak_async - ) - marvin.ai.audio.transcribe_async = generate_task( - "marvin.transcribe", original_transcribe_async - ) - yield - finally: - marvin.ai.text.classify_async = original_classify_async - marvin.ai.text.cast_async = original_cast_async - marvin.ai.text.extract_async = original_extract_async - marvin.ai.text.generate_async = original_generate_async - marvin.ai.images.paint_async = original_paint_async - marvin.ai.audio.speak_async = original_speak_async - marvin.ai.audio.transcribe_async = original_transcribe_async +# @contextmanager +# def patch_marvin(): +# with temporary_marvin_settings(default_async_client_cls=AsyncControlFlowClient): +# try: +# marvin.ai.text.classify_async = generate_task( +# "marvin.classify", original_classify_async +# ) +# marvin.ai.text.cast_async = generate_task( +# "marvin.cast", original_cast_async +# ) +# marvin.ai.text.extract_async = generate_task( +# "marvin.extract", original_extract_async +# ) +# marvin.ai.text.generate_async = generate_task( +# "marvin.generate", original_generate_async +# ) +# marvin.ai.images.paint_async = generate_task( +# "marvin.paint", original_paint_async +# ) +# marvin.ai.audio.speak_async = generate_task( +# "marvin.speak", original_speak_async +# ) +# marvin.ai.audio.transcribe_async = generate_task( +# "marvin.transcribe", original_transcribe_async +# ) +# yield +# finally: +# marvin.ai.text.classify_async = original_classify_async +# marvin.ai.text.cast_async = original_cast_async +# marvin.ai.text.extract_async = original_extract_async +# marvin.ai.text.generate_async = original_generate_async +# marvin.ai.images.paint_async = original_paint_async +# marvin.ai.audio.speak_async = original_speak_async +# marvin.ai.audio.transcribe_async = original_transcribe_async diff --git a/src/controlflow/utilities/prefect.py b/src/controlflow/utilities/prefect.py index 305d4b0a..493ad130 100644 --- a/src/controlflow/utilities/prefect.py +++ b/src/controlflow/utilities/prefect.py @@ -1,20 +1,13 @@ import inspect -import json -from typing import Any, Callable +from typing import Any from uuid import UUID -import prefect -from marvin.types import FunctionTool -from marvin.utilities.asyncio import run_sync -from marvin.utilities.tools import tool_from_function from prefect import get_client as get_prefect_client -from prefect import task as prefect_task from prefect.artifacts import ArtifactRequest from prefect.context import FlowRunContext, TaskRunContext +from prefect.utilities.asyncutils import run_sync from pydantic import TypeAdapter -from controlflow.utilities.types import AssistantTool, ToolType - def create_markdown_artifact( key: str, @@ -117,66 +110,49 @@ def create_python_artifact( ) -def safe_isinstance(obj, type_) -> bool: - # FunctionTool objects are typed generics, and - # Python 3.9 will raise an error if you try to isinstance a typed generic... - try: - return isinstance(obj, type_) - except TypeError: - try: - return issubclass(type(obj), type_) - except TypeError: - return False - - -def wrap_prefect_tool(tool: ToolType) -> AssistantTool: - """ - Wraps a Marvin tool in a prefect task - """ - if not ( - safe_isinstance(tool, AssistantTool) or safe_isinstance(tool, FunctionTool) - ): - tool = tool_from_function(tool) - - if safe_isinstance(tool, FunctionTool): - # for functions, we modify the function to become a Prefect task and - # publish an artifact that contains details about the function call - - if isinstance(tool.function._python_fn, prefect.tasks.Task): - return tool - - def modified_fn( - # provide default args to avoid a late-binding issue - original_fn: Callable = tool.function._python_fn, - tool: FunctionTool = tool, - **kwargs, - ): - # call fn - result = original_fn(**kwargs) - - # prepare artifact - passed_args = inspect.signature(original_fn).bind(**kwargs).arguments - try: - passed_args = json.dumps(passed_args, indent=2) - except Exception: - pass - create_markdown_artifact( - markdown=TOOL_CALL_FUNCTION_RESULT_TEMPLATE.format( - name=tool.function.name, - description=tool.function.description or "(none provided)", - args=passed_args, - result=result, - ), - key="result", - ) - - # return result - return result - - # replace the function with the modified version - tool.function._python_fn = prefect_task( - modified_fn, - task_run_name=f"Tool call: {tool.function.name}", - ) - - return tool +# def wrap_prefect_tool(tool: ToolType) -> AssistantTool: +# if not (isinstance(tool, AssistantTool) or isinstance(tool, ToolFunction)): +# tool = tool(tool) + +# if isinstance(tool, ToolFunction): +# # for functions, we modify the function to become a Prefect task and +# # publish an artifact that contains details about the function call + +# if isinstance(tool.function._python_fn, prefect.tasks.Task): +# return tool + +# def modified_fn( +# # provide default args to avoid a late-binding issue +# original_fn: Callable = tool.function._python_fn, +# tool: ToolFunction = tool, +# **kwargs, +# ): +# # call fn +# result = original_fn(**kwargs) + +# # prepare artifact +# passed_args = inspect.signature(original_fn).bind(**kwargs).arguments +# try: +# passed_args = json.dumps(passed_args, indent=2) +# except Exception: +# pass +# create_markdown_artifact( +# markdown=TOOL_CALL_FUNCTION_RESULT_TEMPLATE.format( +# name=tool.function.name, +# description=tool.function.description or "(none provided)", +# args=passed_args, +# result=result, +# ), +# key="result", +# ) + +# # return result +# return result + +# # replace the function with the modified version +# tool.function._python_fn = prefect_task( +# modified_fn, +# task_run_name=f"Tool call: {tool.function.name}", +# ) + +# return tool diff --git a/src/controlflow/utilities/threads.py b/src/controlflow/utilities/threads.py deleted file mode 100644 index c707a9f8..00000000 --- a/src/controlflow/utilities/threads.py +++ /dev/null @@ -1,27 +0,0 @@ -from marvin.beta.assistants.threads import Message, Thread - -THREAD_REGISTRY = {} - - -def save_thread(name: str, thread: Thread): - """ - Save an OpenAI thread to the thread registry under a known name - """ - THREAD_REGISTRY[name] = thread - - -def load_thread(name: str): - """ - Load an OpenAI thread from the thread registry by name - """ - if name not in THREAD_REGISTRY: - thread = Thread() - save_thread(name, thread) - return THREAD_REGISTRY[name] - - -def get_history(thread_id: str, limit: int = None) -> list[Message]: - """ - Get the history of a thread - """ - return Thread(id=thread_id).get_messages(limit=limit) diff --git a/src/controlflow/utilities/types.py b/src/controlflow/utilities/types.py index 722cd7cb..43d878ab 100644 --- a/src/controlflow/utilities/types.py +++ b/src/controlflow/utilities/types.py @@ -1,23 +1,13 @@ -from enum import Enum -from functools import partial, update_wrapper -from typing import Callable, Optional, Union - -from marvin.beta.assistants import Assistant, Thread -from marvin.beta.assistants.assistants import AssistantTool -from marvin.types import FunctionTool -from marvin.utilities.asyncio import ExposeSyncMethodsMixin +from typing import Union + from pydantic import ( BaseModel, - computed_field, ) # flag for unset defaults NOTSET = "__NOTSET__" -ToolType = Union[FunctionTool, AssistantTool, Callable] - - class ControlFlowModel(BaseModel): model_config = dict(validate_assignment=True, extra="forbid") diff --git a/tests/core/test_controller.py b/tests/core/test_controller.py index a9953828..6422dc6c 100644 --- a/tests/core/test_controller.py +++ b/tests/core/test_controller.py @@ -67,9 +67,9 @@ def test_controller_agent_selection(self, flow, monkeypatch): agent2 = Agent(name="Agent 2") task = Task(objective="Test Task", agents=[agent1, agent2]) controller = Controller(flow=flow, tasks=[task], agents=[agent1, agent2]) - mocked_marvin_moderator = AsyncMock(return_value=agent1) + mocked_classify_moderator = AsyncMock(return_value=agent1) monkeypatch.setattr( - "controlflow.core.controller.moderators.marvin_moderator", - mocked_marvin_moderator, + "controlflow.core.controller.moderators.classify_moderator", + mocked_classify_moderator, ) assert controller.agents == [agent1, agent2] diff --git a/tests/core/test_flows.py b/tests/core/test_flows.py index f159b278..180a08f9 100644 --- a/tests/core/test_flows.py +++ b/tests/core/test_flows.py @@ -6,7 +6,7 @@ class TestFlow: def test_flow_initialization(self): flow = Flow() - assert flow.thread is not None + assert flow.thread_id is not None assert len(flow.tools) == 0 assert len(flow.agents) == 0 assert len(flow.context) == 0 diff --git a/tests/fixtures/mocks.py b/tests/fixtures/mocks.py index 4009476b..a4cbf656 100644 --- a/tests/fixtures/mocks.py +++ b/tests/fixtures/mocks.py @@ -6,7 +6,7 @@ from controlflow.core.agent import Agent from controlflow.core.task import Task, TaskStatus from controlflow.llm.completions import Response -from marvin.settings import temporary_settings as temporary_marvin_settings +from controlflow.settings import temporary_settings def new_chunk(): @@ -18,35 +18,10 @@ def new_chunk(): @pytest.fixture def prevent_openai_calls(): """Prevent any calls to the OpenAI API from being made.""" - with temporary_marvin_settings(openai__api_key="unset"): + with temporary_settings(llm_api_key="unset"): yield -@pytest.fixture -def mock_run(monkeypatch, prevent_openai_calls): - """ - This fixture mocks the calls to the OpenAI Assistants API. Use it in a test - and assign any desired side effects (like completing a task) to the mock - object's `.side_effect` attribute. - - For example: - - def test_example(mock_run): - task = Task(objective="Say hello") - - def side_effect(): - task.mark_complete() - - mock_run.side_effect = side_effect - - task.run() - - """ - MockRun = AsyncMock() - monkeypatch.setattr("controlflow.core.controller.controller.Run.run_async", MockRun) - yield MockRun - - @pytest.fixture def mock_controller_run_agent(monkeypatch, prevent_openai_calls): MockRunAgent = AsyncMock() @@ -68,9 +43,6 @@ def get_messages(*args, **kwargs): monkeypatch.setattr( "controlflow.core.controller.controller.Controller._run_agent", MockRunAgent ) - monkeypatch.setattr( - "marvin.beta.assistants.Thread.get_messages", MockThreadGetMessages - ) yield MockRunAgent