Skip to content

Commit

Permalink
Add LiteAgent
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed May 21, 2024
1 parent f95c040 commit 586dceb
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 17 deletions.
52 changes: 51 additions & 1 deletion src/controlflow/core/agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
from typing import Union
from typing import Callable, Optional, Union

from litellm import Message
from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method
from pydantic import Field

from controlflow.core.flow import Flow, get_flow
from controlflow.core.task import Task
from controlflow.llm.completions import Response, completion, completion_async
from controlflow.tools.talk_to_human import talk_to_human
from controlflow.utilities.prefect import (
wrap_prefect_tool,
Expand Down Expand Up @@ -62,3 +64,51 @@ async def run_async(

def __hash__(self):
return id(self)


class LiteAgent(ControlFlowModel, ExposeSyncMethodsMixin):
name: str = Field(
...,
description="The name of the agent. This is used to identify the agent in the system and should be unique per assigned task.",
)
description: Optional[str] = Field(
None, description="A description of the agent, visible to other agents."
)
instructions: Optional[str] = Field(
None, description="Instructions for the agent, private to this agent."
)
tools: list[Callable] = Field(
[], description="List of tools availble to the agent."
)
user_access: bool = Field(
False,
description="If True, the agent is given tools for interacting with a human user.",
)
model: Optional[str] = Field(
None,
description="The model used by the agent. If not provided, the default model will be used.",
)

async def say_async(self, messages: Union[str, dict]) -> Response:
if not isinstance(messages, list):
raise ValueError("Messages must be provided as a list.")

messages = [
Message(role="user", content=m) if isinstance(m, str) else m
for m in messages
]

return await completion_async(
messages=messages, model=self.model, tools=self.tools
)

async def say(self, messages: Union[str, dict]) -> Response:
if not isinstance(messages, list):
raise ValueError("Messages must be provided as a list.")

messages = [
Message(role="user", content=m) if isinstance(m, str) else m
for m in messages
]

return completion(messages=messages, model=self.model, tools=self.tools)
23 changes: 7 additions & 16 deletions src/controlflow/llm/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def completion(
Returns:
A litellm.ModelResponse object representing the completion response.
"""

intermediate_messages = []
intermediate_responses = []

if model is None:
model = controlflow.settings.model
if tools is not None:
tool_dicts = [function_to_tool_dict(tool) for tool in tools]
else:
tool_dicts = None

tool_dicts = [function_to_tool_dict(tool) for tool in tools or []] or None

response = litellm.completion(
model=model,
messages=messages,
Expand Down Expand Up @@ -101,10 +101,7 @@ def stream_completion(
if model is None:
model = controlflow.settings.model

if tools is not None:
tool_dicts = [function_to_tool_dict(tool) for tool in tools]
else:
tool_dicts = None
tool_dicts = [function_to_tool_dict(tool) for tool in tools or []] or None

chunks = []
for chunk in litellm.completion(
Expand Down Expand Up @@ -163,10 +160,7 @@ async def completion_async(
if model is None:
model = controlflow.settings.model

if tools is not None:
tool_dicts = [function_to_tool_dict(tool) for tool in tools]
else:
tool_dicts = None
tool_dicts = [function_to_tool_dict(tool) for tool in tools or []] or None

response = await litellm.acompletion(
model=model,
Expand Down Expand Up @@ -221,10 +215,7 @@ async def stream_completion_async(
if model is None:
model = controlflow.settings.model

if tools is not None:
tool_dicts = [function_to_tool_dict(tool) for tool in tools]
else:
tool_dicts = None
tool_dicts = [function_to_tool_dict(tool) for tool in tools or []] or None

chunks = []
async for chunk in litellm.acompletion(
Expand Down

0 comments on commit 586dceb

Please sign in to comment.