Skip to content

Commit

Permalink
Move history to the flow
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed Jun 17, 2024
1 parent b89b920 commit f1bd06e
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 23 deletions.
4 changes: 2 additions & 2 deletions src/controlflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

# --- Default history ---
# assign to controlflow.default_history to change the default history
from .llm.history import DEFAULT_HISTORY as default_history
from .llm.history import DEFAULT_HISTORY as default_history, get_default_history

# --- Default agent ---
# assign to controlflow.default_agent to change the default agent
from .core.agent.agent import DEFAULT_AGENT as default_agent
from .core.agent.agent import DEFAULT_AGENT as default_agent, get_default_agent

# --- Version ---
try:
Expand Down
12 changes: 3 additions & 9 deletions src/controlflow/core/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from controlflow.instructions import get_instructions
from controlflow.llm.completions import completion, completion_async
from controlflow.llm.handlers import PrintHandler, ResponseHandler, TUIHandler
from controlflow.llm.history import History
from controlflow.llm.messages import AIMessage, MessageType, SystemMessage
from controlflow.llm.tools import as_tools
from controlflow.tui.app import TUIApp as TUI
Expand Down Expand Up @@ -91,9 +90,6 @@ class Controller(ControlFlowModel):
description="Tasks that the controller will complete.",
)
agents: Union[list[Agent], None] = None
history: History = Field(
default_factory=controlflow.llm.history.get_default_history
)
context: dict = {}
model_config: dict = dict(extra="forbid")
enable_tui: bool = Field(default_factory=lambda: controlflow.settings.enable_tui)
Expand Down Expand Up @@ -203,7 +199,7 @@ def _setup_run(self):

# prepare messages
system_message = SystemMessage(content=instructions)
messages = self.history.load_messages(thread_id=self.flow.thread_id)
messages = self.flow.get_messages()

# setup handlers
handlers = []
Expand Down Expand Up @@ -244,8 +240,7 @@ async def run_once_async(self) -> list[MessageType]:
pass

# save history
self.history.save_messages(
thread_id=self.flow.thread_id,
self.flow.add_messages(
messages=response_handler.response_messages,
)
self._iteration += 1
Expand Down Expand Up @@ -281,8 +276,7 @@ def run_once(self) -> list[MessageType]:
pass

# save history
self.history.save_messages(
thread_id=self.flow.thread_id,
self.flow.add_messages(
messages=response_handler.response_messages,
)
self._iteration += 1
Expand Down
51 changes: 39 additions & 12 deletions src/controlflow/core/flow.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import datetime
import uuid
from contextlib import contextmanager, nullcontext
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

from pydantic import Field

import controlflow
from controlflow.llm.history import get_default_history
import controlflow.llm
from controlflow.llm.history import History, get_default_history
from controlflow.llm.messages import MessageType
from controlflow.utilities.context import ctx
from controlflow.utilities.logging import get_logger
Expand All @@ -22,6 +24,9 @@ class Flow(ControlFlowModel):
name: Optional[str] = None
description: Optional[str] = None
thread_id: str = Field(default_factory=lambda: uuid.uuid4().hex)
history: History = Field(
default_factory=controlflow.llm.history.get_default_history
)
tools: list[Callable] = Field(
default_factory=list,
description="Tools that will be available to every agent in the flow",
Expand All @@ -35,7 +40,39 @@ class Flow(ControlFlowModel):
_tasks: dict[str, "Task"] = {}
_cm_stack: list[contextmanager] = []

# --- Prefect kwargs ---
def __init__(self, *, copy_parent_history: bool = True, **kwargs):
"""
By default, the flow will copy the history from the parent flow if one
exists. Because each flow is a new thread, new messages will not be shared
between the parent and child flow.
"""
super().__init__(**kwargs)
parent = get_flow()
if parent and copy_parent_history:
self.add_messages(parent.get_messages())

def __enter__(self):
# use stack so we can enter the context multiple times
cm = self.create_context()
self._cm_stack.append(cm)
return cm.__enter__()

def __exit__(self, *exc_info):
# exit the context manager
return self._cm_stack.pop().__exit__(*exc_info)

def get_messages(
self,
limit: int = None,
before: datetime.datetime = None,
after: datetime.datetime = None,
) -> list[MessageType]:
return self.history.load_messages(
thread_id=self.thread_id, limit=limit, before=before, after=after
)

def add_messages(self, messages: list[MessageType]):
self.history.save_messages(thread_id=self.thread_id, messages=messages)

def add_task(self, task: "Task"):
if self._tasks.get(task.id, task) is not task:
Expand All @@ -53,16 +90,6 @@ def create_context(self, create_prefect_flow_context: bool = True):
with ctx(flow=self), prefect_ctx:
yield self

def __enter__(self):
# use stack so we can enter the context multiple times
cm = self.create_context()
self._cm_stack.append(cm)
return cm.__enter__()

def __exit__(self, *exc_info):
# exit the context manager
return self._cm_stack.pop().__exit__(*exc_info)

async def run_async(self):
"""
Runs the flow asynchronously.
Expand Down

0 comments on commit f1bd06e

Please sign in to comment.