diff --git a/pyproject.toml b/pyproject.toml index 186cd47c..f07c6993 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ tests = [ "pytest-timeout", "pytest-xdist", "pre-commit>=3.7.0", + "pandas", ] dev = ["controlflow[tests]", "ipython", "pdbpp", "ruff>=0.3.4"] diff --git a/src/controlflow/core/controller/instruction_template.py b/src/controlflow/core/controller/instruction_template.py index 1bac18a5..549a0aa6 100644 --- a/src/controlflow/core/controller/instruction_template.py +++ b/src/controlflow/core/controller/instruction_template.py @@ -109,20 +109,24 @@ class TasksTemplate(Template): satisfy the task objective, accounting for any other instructions. If a task does not require a result (`result_type=None`), you must still complete its stated objective by posting messages or using other tools - before marking the task as complete. + before marking the task as complete. Your result must be compatible with + the result constructor. For most results, the tool schema will indicate + the correct types. For some, like a DataFrame, provide an appropriate + kwargs dict. #### Using messages as results - You can reuse the contents of any message as a task's result by - providing a special `ThreadMessage` object when marking a task - successful. Only do this if the thread message is exactly compatible - with task's result_type (e.g. a string of JSON representation). Indicate - the number of messages ago that the message was posted (defaults to 1). - Also provide any characters to strip from the start or end of the - message, to make sure that the result doesn't reveal any internal - details (for example, always remove your name prefix and irrelevant - comments from the beginning or end of the response such as - "I'll mark the task complete now."). + If you posted a message whose contents could be reused as the result of + a task, you can quickly load the contents by providing a special + `ThreadMessage` object when marking a task successful. Indicate the + number of messages ago that the message was posted (defaults to 1), as + well as any characters to strip from the start or end of the message + (for example, always remove your name prefix and irrelevant comments + from the beginning or end of the response such as "I'll mark the task + complete now."). This will only work if the literal message contents (as + a string or JSON reprseentation) are exactly compatible with the result + type and will error otherwise. This is not magic; the string or JSON are + passed to the result constructor as-is. """ tasks: list[Task] diff --git a/src/controlflow/core/task.py b/src/controlflow/core/task.py index 74a17a52..f483848d 100644 --- a/src/controlflow/core/task.py +++ b/src/controlflow/core/task.py @@ -4,6 +4,7 @@ from enum import Enum from typing import ( TYPE_CHECKING, + Any, GenericAlias, Literal, Optional, @@ -18,6 +19,7 @@ from marvin.utilities.tools import FunctionTool from pydantic import ( Field, + PydanticSchemaGenerationError, TypeAdapter, field_serializer, field_validator, @@ -38,6 +40,8 @@ NOTSET, AssistantTool, ControlFlowModel, + PandasDataFrame, + PandasSeries, ToolType, ) from controlflow.utilities.user_access import talk_to_human @@ -113,22 +117,29 @@ class Task(ControlFlowModel): ) agents: list["Agent"] = Field( None, - description="The agents assigned to the task. If not provided, agents will be inferred from the parent task, flow, or global default.", + description="The agents assigned to the task. If not provided, agents " + "will be inferred from the parent task, flow, or global default.", ) context: dict = Field( default_factory=dict, - description="Additional context for the task. If tasks are provided as context, they are automatically added as `depends_on`", + description="Additional context for the task. If tasks are provided as " + "context, they are automatically added as `depends_on`", ) subtasks: list["Task"] = Field( default_factory=list, - description="A list of subtasks that are part of this task. Subtasks are considered dependencies, though they may be skipped.", + description="A list of subtasks that are part of this task. Subtasks are " + "considered dependencies, though they may be skipped.", ) depends_on: list["Task"] = Field( default_factory=list, description="Tasks that this task depends on explicitly." ) status: TaskStatus = TaskStatus.INCOMPLETE result: T = None - result_type: Union[type[T], GenericAlias, _LiteralGenericAlias, None] = None + result_type: Union[type[T], GenericAlias, _LiteralGenericAlias, None] = Field( + None, + description="The expected type of the result. This should be a type" + ", generic alias, BaseModel subclass, pd.DataFrame, or pd.Series.", + ) error: Union[str, None] = None tools: list[ToolType] = [] user_access: bool = False @@ -398,8 +409,9 @@ def succeed() -> str: # generate tool for other result types else: + result_schema = generate_result_schema(self.result_type) - def succeed(result: Union[ThreadMessage, self.result_type]) -> str: + def succeed(result: Union[ThreadMessage, result_schema]) -> str: # type: ignore # a shortcut for loading results from recent messages if isinstance(result, dict) and result.get("type") == "ThreadMessage": result = ThreadMessage(**result) @@ -455,8 +467,8 @@ def get_tools(self) -> list[ToolType]: def dependencies(self): return self.depends_on + self.subtasks - def mark_successful(self, result: T = None, validate: bool = True): - if validate: + def mark_successful(self, result: T = None, validate_upstreams: bool = True): + if validate_upstreams: if any(t.is_incomplete() for t in self.depends_on): raise ValueError( f"Task {self.objective} cannot be marked successful until all of its " @@ -470,14 +482,7 @@ def mark_successful(self, result: T = None, validate: bool = True): f"are: {', '.join(t.friendly_name() for t in self.subtasks if t.is_incomplete())}" ) - if self.result_type is None and result is not None: - raise ValueError( - f"Task {self.objective} has result_type=None, but a result was provided." - ) - elif self.result_type is not None: - result = TypeAdapter(self.result_type).validate_python(result) - - self.result = result + self.result = validate_result(result, self.result_type) self.status = TaskStatus.SUCCESSFUL return f"{self.friendly_name()} marked successful. Updated task definition: {self.model_dump()}" @@ -489,3 +494,54 @@ def mark_failed(self, message: Union[str, None] = None): def mark_skipped(self): self.status = TaskStatus.SKIPPED return f"{self.friendly_name()} marked skipped. Updated task definition: {self.model_dump()}" + + +def generate_result_schema(result_type: type[T]) -> type[T]: + result_schema = None + # try loading pydantic-compatible schemas + try: + TypeAdapter(result_type) + result_schema = result_type + except PydanticSchemaGenerationError: + pass + # try loading as dataframe + try: + import pandas as pd + + if result_type is pd.DataFrame: + result_schema = PandasDataFrame + elif result_type is pd.Series: + result_schema = PandasSeries + except ImportError: + pass + if result_schema is None: + raise ValueError( + f"Could not load or infer schema for result type {result_type}. " + "Please use a custom type or add compatibility." + ) + return result_schema + + +def validate_result(result: Any, result_type: type[T]) -> T: + if result_type is None and result is not None: + raise ValueError("Task has result_type=None, but a result was provided.") + elif result_type is not None: + try: + result = TypeAdapter(result_type).validate_python(result) + except PydanticSchemaGenerationError: + if isinstance(result, dict): + result = result_type(**result) + else: + result = result_type(result) + + # Convert DataFrame schema back into pd.DataFrame object + if result_type == PandasDataFrame: + import pandas as pd + + result = pd.DataFrame(**result) + elif result_type == PandasSeries: + import pandas as pd + + result = pd.Series(**result) + + return result diff --git a/src/controlflow/utilities/types.py b/src/controlflow/utilities/types.py index baf32057..569fe236 100644 --- a/src/controlflow/utilities/types.py +++ b/src/controlflow/utilities/types.py @@ -9,8 +9,30 @@ # flag for unset defaults NOTSET = "__NOTSET__" + ToolType = Union[FunctionTool, AssistantTool, Callable] class ControlFlowModel(BaseModel): model_config = dict(validate_assignment=True, extra="forbid") + + +class PandasDataFrame(ControlFlowModel): + """Schema for a pandas dataframe""" + + data: Union[ + list[list[Union[str, int, float, bool]]], + dict[str, list[Union[str, int, float, bool]]], + ] + columns: list[str] = None + index: list[str] = None + dtype: dict[str, str] = None + + +class PandasSeries(ControlFlowModel): + """Schema for a pandas series""" + + data: list[Union[str, int, float]] + index: list[str] = None + name: str = None + dtype: str = None diff --git a/tests/ai_tests/__init__.py b/tests/ai_tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/ai_tests/test_tasks.py b/tests/ai_tests/test_tasks.py new file mode 100644 index 00000000..e37e1aa7 --- /dev/null +++ b/tests/ai_tests/test_tasks.py @@ -0,0 +1,31 @@ +import pandas as pd +import pytest +from controlflow import Task +from pydantic import BaseModel + + +class Name(BaseModel): + first: str + last: str + + +@pytest.mark.usefixtures("unit_test_instructions") +class TestTaskResults: + def test_task_int_result(self): + task = Task("return 3", result_type=int) + assert task.run() == 3 + + def test_task_pydantic_result(self): + task = Task("the name is John Doe", result_type=Name) + result = task.run() + assert isinstance(result, Name) + assert result == Name(first="John", last="Doe") + + def test_task_dataframe_result(self): + task = Task( + 'return a dataframe with column "x" that has values 1 and 2 and column "y" that has values 3 and 4', + result_type=pd.DataFrame, + ) + result = task.run() + assert isinstance(result, pd.DataFrame) + assert result == pd.DataFrame(data={"x": [1, 2], "y": [3, 4]}) diff --git a/tests/fixtures/instructions.py b/tests/fixtures/instructions.py new file mode 100644 index 00000000..0240d220 --- /dev/null +++ b/tests/fixtures/instructions.py @@ -0,0 +1,10 @@ +import pytest +from controlflow import instructions + + +@pytest.fixture +def unit_test_instructions(): + with instructions( + "You are being unit tested. Be as fast and concise as possible. Do not post unecessary messages." + ): + yield diff --git a/tests/flows/__init__.py b/tests/flows/__init__.py new file mode 100644 index 00000000..e69de29b