diff --git a/src/controlflow/core/agent.py b/src/controlflow/core/agent.py index 42ab9e46..ad33705b 100644 --- a/src/controlflow/core/agent.py +++ b/src/controlflow/core/agent.py @@ -1,11 +1,13 @@ import logging -from typing import Union +from typing import Callable, Optional, Union +from litellm import Message from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method from pydantic import Field from controlflow.core.flow import Flow, get_flow from controlflow.core.task import Task +from controlflow.llm.completions import Response, completion, completion_async from controlflow.tools.talk_to_human import talk_to_human from controlflow.utilities.prefect import ( wrap_prefect_tool, @@ -62,3 +64,51 @@ async def run_async( def __hash__(self): return id(self) + + +class LiteAgent(ControlFlowModel, ExposeSyncMethodsMixin): + name: str = Field( + ..., + description="The name of the agent. This is used to identify the agent in the system and should be unique per assigned task.", + ) + description: Optional[str] = Field( + None, description="A description of the agent, visible to other agents." + ) + instructions: Optional[str] = Field( + None, description="Instructions for the agent, private to this agent." + ) + tools: list[Callable] = Field( + [], description="List of tools availble to the agent." + ) + user_access: bool = Field( + False, + description="If True, the agent is given tools for interacting with a human user.", + ) + model: Optional[str] = Field( + None, + description="The model used by the agent. If not provided, the default model will be used.", + ) + + async def say_async(self, messages: Union[str, dict]) -> Response: + if not isinstance(messages, list): + raise ValueError("Messages must be provided as a list.") + + messages = [ + Message(role="user", content=m) if isinstance(m, str) else m + for m in messages + ] + + return await completion_async( + messages=messages, model=self.model, tools=self.tools + ) + + async def say(self, messages: Union[str, dict]) -> Response: + if not isinstance(messages, list): + raise ValueError("Messages must be provided as a list.") + + messages = [ + Message(role="user", content=m) if isinstance(m, str) else m + for m in messages + ] + + return completion(messages=messages, model=self.model, tools=self.tools) diff --git a/src/controlflow/llm/completions.py b/src/controlflow/llm/completions.py index bc22ed08..3e548800 100644 --- a/src/controlflow/llm/completions.py +++ b/src/controlflow/llm/completions.py @@ -39,15 +39,15 @@ def completion( Returns: A litellm.ModelResponse object representing the completion response. """ + intermediate_messages = [] intermediate_responses = [] if model is None: model = controlflow.settings.model - if tools is not None: - tool_dicts = [function_to_tool_dict(tool) for tool in tools] - else: - tool_dicts = None + + tool_dicts = [function_to_tool_dict(tool) for tool in tools or []] or None + response = litellm.completion( model=model, messages=messages, @@ -101,10 +101,7 @@ def stream_completion( if model is None: model = controlflow.settings.model - if tools is not None: - tool_dicts = [function_to_tool_dict(tool) for tool in tools] - else: - tool_dicts = None + tool_dicts = [function_to_tool_dict(tool) for tool in tools or []] or None chunks = [] for chunk in litellm.completion( @@ -163,10 +160,7 @@ async def completion_async( if model is None: model = controlflow.settings.model - if tools is not None: - tool_dicts = [function_to_tool_dict(tool) for tool in tools] - else: - tool_dicts = None + tool_dicts = [function_to_tool_dict(tool) for tool in tools or []] or None response = await litellm.acompletion( model=model, @@ -221,10 +215,7 @@ async def stream_completion_async( if model is None: model = controlflow.settings.model - if tools is not None: - tool_dicts = [function_to_tool_dict(tool) for tool in tools] - else: - tool_dicts = None + tool_dicts = [function_to_tool_dict(tool) for tool in tools or []] or None chunks = [] async for chunk in litellm.acompletion(