Skip to content

Commit

Permalink
Merge pull request #44 from jlowin/df
Browse files Browse the repository at this point in the history
Add support for pandas results
  • Loading branch information
jlowin authored May 16, 2024
2 parents ae390e7 + 8244a14 commit 8992cad
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 26 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ tests = [
"pytest-timeout",
"pytest-xdist",
"pre-commit>=3.7.0",
"pandas",
]
dev = ["controlflow[tests]", "ipython", "pdbpp", "ruff>=0.3.4"]

Expand Down
26 changes: 15 additions & 11 deletions src/controlflow/core/controller/instruction_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,20 +109,24 @@ class TasksTemplate(Template):
satisfy the task objective, accounting for any other instructions. If a
task does not require a result (`result_type=None`), you must still
complete its stated objective by posting messages or using other tools
before marking the task as complete.
before marking the task as complete. Your result must be compatible with
the result constructor. For most results, the tool schema will indicate
the correct types. For some, like a DataFrame, provide an appropriate
kwargs dict.
#### Using messages as results
You can reuse the contents of any message as a task's result by
providing a special `ThreadMessage` object when marking a task
successful. Only do this if the thread message is exactly compatible
with task's result_type (e.g. a string of JSON representation). Indicate
the number of messages ago that the message was posted (defaults to 1).
Also provide any characters to strip from the start or end of the
message, to make sure that the result doesn't reveal any internal
details (for example, always remove your name prefix and irrelevant
comments from the beginning or end of the response such as
"I'll mark the task complete now.").
If you posted a message whose contents could be reused as the result of
a task, you can quickly load the contents by providing a special
`ThreadMessage` object when marking a task successful. Indicate the
number of messages ago that the message was posted (defaults to 1), as
well as any characters to strip from the start or end of the message
(for example, always remove your name prefix and irrelevant comments
from the beginning or end of the response such as "I'll mark the task
complete now."). This will only work if the literal message contents (as
a string or JSON reprseentation) are exactly compatible with the result
type and will error otherwise. This is not magic; the string or JSON are
passed to the result constructor as-is.
"""
tasks: list[Task]
Expand Down
86 changes: 71 additions & 15 deletions src/controlflow/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
GenericAlias,
Literal,
Optional,
Expand All @@ -18,6 +19,7 @@
from marvin.utilities.tools import FunctionTool
from pydantic import (
Field,
PydanticSchemaGenerationError,
TypeAdapter,
field_serializer,
field_validator,
Expand All @@ -38,6 +40,8 @@
NOTSET,
AssistantTool,
ControlFlowModel,
PandasDataFrame,
PandasSeries,
ToolType,
)
from controlflow.utilities.user_access import talk_to_human
Expand Down Expand Up @@ -113,22 +117,29 @@ class Task(ControlFlowModel):
)
agents: list["Agent"] = Field(
None,
description="The agents assigned to the task. If not provided, agents will be inferred from the parent task, flow, or global default.",
description="The agents assigned to the task. If not provided, agents "
"will be inferred from the parent task, flow, or global default.",
)
context: dict = Field(
default_factory=dict,
description="Additional context for the task. If tasks are provided as context, they are automatically added as `depends_on`",
description="Additional context for the task. If tasks are provided as "
"context, they are automatically added as `depends_on`",
)
subtasks: list["Task"] = Field(
default_factory=list,
description="A list of subtasks that are part of this task. Subtasks are considered dependencies, though they may be skipped.",
description="A list of subtasks that are part of this task. Subtasks are "
"considered dependencies, though they may be skipped.",
)
depends_on: list["Task"] = Field(
default_factory=list, description="Tasks that this task depends on explicitly."
)
status: TaskStatus = TaskStatus.INCOMPLETE
result: T = None
result_type: Union[type[T], GenericAlias, _LiteralGenericAlias, None] = None
result_type: Union[type[T], GenericAlias, _LiteralGenericAlias, None] = Field(
None,
description="The expected type of the result. This should be a type"
", generic alias, BaseModel subclass, pd.DataFrame, or pd.Series.",
)
error: Union[str, None] = None
tools: list[ToolType] = []
user_access: bool = False
Expand Down Expand Up @@ -398,8 +409,9 @@ def succeed() -> str:

# generate tool for other result types
else:
result_schema = generate_result_schema(self.result_type)

def succeed(result: Union[ThreadMessage, self.result_type]) -> str:
def succeed(result: Union[ThreadMessage, result_schema]) -> str: # type: ignore
# a shortcut for loading results from recent messages
if isinstance(result, dict) and result.get("type") == "ThreadMessage":
result = ThreadMessage(**result)
Expand Down Expand Up @@ -455,8 +467,8 @@ def get_tools(self) -> list[ToolType]:
def dependencies(self):
return self.depends_on + self.subtasks

def mark_successful(self, result: T = None, validate: bool = True):
if validate:
def mark_successful(self, result: T = None, validate_upstreams: bool = True):
if validate_upstreams:
if any(t.is_incomplete() for t in self.depends_on):
raise ValueError(
f"Task {self.objective} cannot be marked successful until all of its "
Expand All @@ -470,14 +482,7 @@ def mark_successful(self, result: T = None, validate: bool = True):
f"are: {', '.join(t.friendly_name() for t in self.subtasks if t.is_incomplete())}"
)

if self.result_type is None and result is not None:
raise ValueError(
f"Task {self.objective} has result_type=None, but a result was provided."
)
elif self.result_type is not None:
result = TypeAdapter(self.result_type).validate_python(result)

self.result = result
self.result = validate_result(result, self.result_type)
self.status = TaskStatus.SUCCESSFUL
return f"{self.friendly_name()} marked successful. Updated task definition: {self.model_dump()}"

Expand All @@ -489,3 +494,54 @@ def mark_failed(self, message: Union[str, None] = None):
def mark_skipped(self):
self.status = TaskStatus.SKIPPED
return f"{self.friendly_name()} marked skipped. Updated task definition: {self.model_dump()}"


def generate_result_schema(result_type: type[T]) -> type[T]:
result_schema = None
# try loading pydantic-compatible schemas
try:
TypeAdapter(result_type)
result_schema = result_type
except PydanticSchemaGenerationError:
pass
# try loading as dataframe
try:
import pandas as pd

if result_type is pd.DataFrame:
result_schema = PandasDataFrame
elif result_type is pd.Series:
result_schema = PandasSeries
except ImportError:
pass
if result_schema is None:
raise ValueError(
f"Could not load or infer schema for result type {result_type}. "
"Please use a custom type or add compatibility."
)
return result_schema


def validate_result(result: Any, result_type: type[T]) -> T:
if result_type is None and result is not None:
raise ValueError("Task has result_type=None, but a result was provided.")
elif result_type is not None:
try:
result = TypeAdapter(result_type).validate_python(result)
except PydanticSchemaGenerationError:
if isinstance(result, dict):
result = result_type(**result)
else:
result = result_type(result)

# Convert DataFrame schema back into pd.DataFrame object
if result_type == PandasDataFrame:
import pandas as pd

result = pd.DataFrame(**result)
elif result_type == PandasSeries:
import pandas as pd

result = pd.Series(**result)

return result
22 changes: 22 additions & 0 deletions src/controlflow/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,30 @@
# flag for unset defaults
NOTSET = "__NOTSET__"


ToolType = Union[FunctionTool, AssistantTool, Callable]


class ControlFlowModel(BaseModel):
model_config = dict(validate_assignment=True, extra="forbid")


class PandasDataFrame(ControlFlowModel):
"""Schema for a pandas dataframe"""

data: Union[
list[list[Union[str, int, float, bool]]],
dict[str, list[Union[str, int, float, bool]]],
]
columns: list[str] = None
index: list[str] = None
dtype: dict[str, str] = None


class PandasSeries(ControlFlowModel):
"""Schema for a pandas series"""

data: list[Union[str, int, float]]
index: list[str] = None
name: str = None
dtype: str = None
Empty file added tests/ai_tests/__init__.py
Empty file.
31 changes: 31 additions & 0 deletions tests/ai_tests/test_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pandas as pd
import pytest
from controlflow import Task
from pydantic import BaseModel


class Name(BaseModel):
first: str
last: str


@pytest.mark.usefixtures("unit_test_instructions")
class TestTaskResults:
def test_task_int_result(self):
task = Task("return 3", result_type=int)
assert task.run() == 3

def test_task_pydantic_result(self):
task = Task("the name is John Doe", result_type=Name)
result = task.run()
assert isinstance(result, Name)
assert result == Name(first="John", last="Doe")

def test_task_dataframe_result(self):
task = Task(
'return a dataframe with column "x" that has values 1 and 2 and column "y" that has values 3 and 4',
result_type=pd.DataFrame,
)
result = task.run()
assert isinstance(result, pd.DataFrame)
assert result == pd.DataFrame(data={"x": [1, 2], "y": [3, 4]})
10 changes: 10 additions & 0 deletions tests/fixtures/instructions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import pytest
from controlflow import instructions


@pytest.fixture
def unit_test_instructions():
with instructions(
"You are being unit tested. Be as fast and concise as possible. Do not post unecessary messages."
):
yield
Empty file added tests/flows/__init__.py
Empty file.

0 comments on commit 8992cad

Please sign in to comment.