Skip to content

Commit

Permalink
Improve type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed Oct 2, 2024
1 parent 0c16e9c commit 54b54c6
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
5 changes: 3 additions & 2 deletions src/controlflow/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from langchain_core.language_models import BaseChatModel
from pydantic import Field, field_serializer, field_validator
from typing_extensions import Self

import controlflow
from controlflow.agents.names import AGENT_NAMES
Expand Down Expand Up @@ -183,11 +184,11 @@ def get_prompt(self) -> str:
return template.render()

@contextmanager
def create_context(self):
def create_context(self) -> Generator[Self, None, None]:
with ctx(agent=self):
yield self

def __enter__(self):
def __enter__(self) -> Self:
self._cm_stack.append(self.create_context())
return self._cm_stack[-1].__enter__()

Expand Down
7 changes: 4 additions & 3 deletions src/controlflow/flows/flow.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import uuid
from contextlib import contextmanager, nullcontext
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Generator, Optional, Union

from prefect.context import FlowRunContext
from pydantic import Field
from typing_extensions import Self

import controlflow
from controlflow.agents import Agent
Expand Down Expand Up @@ -54,7 +55,7 @@ class Flow(ControlFlowModel):
context: dict[str, Any] = {}
_cm_stack: list[contextmanager] = []

def __enter__(self):
def __enter__(self) -> Self:
# use stack so we can enter the context multiple times
cm = self.create_context()
self._cm_stack.append(cm)
Expand Down Expand Up @@ -111,7 +112,7 @@ def add_events(self, events: list[Event]):
self.history.add_events(thread_id=self.thread_id, events=events)

@contextmanager
def create_context(self, **prefect_kwargs):
def create_context(self, **prefect_kwargs) -> Generator[Self, None, None]:
# create a new Prefect flow if we're not already in a flow run
if FlowRunContext.get() is None:
prefect_context = prefect_flow_context(**prefect_kwargs)
Expand Down
6 changes: 4 additions & 2 deletions src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
TYPE_CHECKING,
Any,
Callable,
Generator,
GenericAlias,
Literal,
Optional,
Expand All @@ -27,6 +28,7 @@
field_serializer,
field_validator,
)
from typing_extensions import Self

import controlflow
from controlflow.agents import Agent
Expand Down Expand Up @@ -426,12 +428,12 @@ async def run_async(
raise ValueError(f"{self.friendly_name()} failed: {self.result}")

@contextmanager
def create_context(self):
def create_context(self) -> Generator[Self, None, None]:
stack = ctx.get("tasks") or []
with ctx(tasks=stack + [self]):
yield self

def __enter__(self):
def __enter__(self) -> Self:
# use stack so we can enter the context multiple times
self._cm_stack.append(ExitStack())
return self._cm_stack[-1].enter_context(self.create_context())
Expand Down

0 comments on commit 54b54c6

Please sign in to comment.