Skip to content

Commit ba2338c

Browse files
authored
Merge pull request #249 from PrefectHQ/options
Better support for complex options
2 parents d3320c5 + 79c8517 commit ba2338c

File tree

5 files changed

+159
-93
lines changed

5 files changed

+159
-93
lines changed

docs/patterns/result-types.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ assert result is False
7171

7272
## Constrained Choices
7373

74-
Sometimes you want to limit the possible results to a specific set of values. You can do this by specifying a list of allowed values for the result type:
74+
Sometimes you want to limit the possible results to a specific set of values, in order to label or classify a response. You can do this by specifying a list of allowed values for the result type:
7575

7676
```python
7777
import controlflow as cf

src/controlflow/orchestration/orchestrator.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,11 @@
77

88
import controlflow
99
from controlflow.agents.agent import BaseAgent
10-
from controlflow.agents.teams import Team
1110
from controlflow.events.base import Event
1211
from controlflow.flows import Flow
1312
from controlflow.orchestration.agent_context import AgentContext
1413
from controlflow.orchestration.handler import Handler
1514
from controlflow.tasks.task import Task
16-
from controlflow.tools.orchestration import (
17-
create_task_fail_tool,
18-
create_task_success_tool,
19-
)
2015
from controlflow.tools.tools import Tool
2116
from controlflow.utilities.general import ControlFlowModel
2217
from controlflow.utilities.prefect import prefect_task as prefect_task
@@ -64,6 +59,8 @@ def _handlers(cls, v):
6459

6560
@field_validator("agents", mode="before")
6661
def _agents(cls, v):
62+
from controlflow.agents.teams import Team
63+
6764
if v is None:
6865
v = {}
6966

@@ -205,6 +202,6 @@ def get_tools(self, tasks: list[Task]) -> list[Tool]:
205202
tools.extend(self.flow.tools)
206203
for task in tasks:
207204
tools.extend(task.get_tools())
208-
tools.append(create_task_success_tool(task=task))
209-
tools.append(create_task_fail_tool(task=task))
205+
tools.append(task.create_success_tool())
206+
tools.append(task.create_fail_tool())
210207
return tools

src/controlflow/tasks/task.py

Lines changed: 110 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
Any,
77
Callable,
88
GenericAlias,
9-
Literal,
109
Optional,
1110
TypeVar,
1211
Union,
@@ -25,7 +24,7 @@
2524
import controlflow
2625
from controlflow.agents import BaseAgent
2726
from controlflow.instructions import get_instructions
28-
from controlflow.tools import Tool
27+
from controlflow.tools import Tool, tool
2928
from controlflow.tools.talk_to_user import talk_to_user
3029
from controlflow.utilities.context import ctx
3130
from controlflow.utilities.general import (
@@ -100,10 +99,10 @@ class Task(ControlFlowModel):
10099
)
101100
status: TaskStatus = TaskStatus.PENDING
102101
result: T = None
103-
result_type: Union[type[T], GenericAlias, _LiteralGenericAlias, None] = Field(
102+
result_type: Union[type[T], GenericAlias, tuple, None] = Field(
104103
str,
105104
description="The expected type of the result. This should be a type"
106-
", generic alias, BaseModel subclass, pd.DataFrame, or pd.Series. "
105+
", generic alias, BaseModel subclass, or list of choices. "
107106
"Can be None if no result is expected or the agent should communicate internally.",
108107
)
109108
error: Union[str, None] = None
@@ -264,9 +263,11 @@ def _default_parent(cls, v):
264263
return v
265264

266265
@field_validator("result_type", mode="before")
267-
def _turn_list_into_literal_result_type(cls, v):
266+
def _ensure_result_type_is_list_if_literal(cls, v):
267+
if isinstance(v, _LiteralGenericAlias):
268+
v = v.__args__
268269
if isinstance(v, (list, tuple, set)):
269-
return Literal[tuple(v)] # type: ignore
270+
v = tuple(v)
270271
return v
271272

272273
@field_serializer("parent")
@@ -560,6 +561,85 @@ def generate_subtasks(self, instructions: str = None, agent: BaseAgent = None):
560561
context=self.context,
561562
)
562563

564+
def create_success_tool(self) -> Tool:
565+
"""
566+
Create an agent-compatible tool for marking this task as successful.
567+
"""
568+
options = {}
569+
instructions = None
570+
result_schema = None
571+
572+
# if the result_type is a tuple of options, then we want the LLM to provide
573+
# a single integer index instead of writing out the entire option
574+
if isinstance(self.result_type, tuple):
575+
result_schema = int
576+
for i, option in enumerate(self.result_type):
577+
try:
578+
serialized = TypeAdapter(type(option)).dump_python(option)
579+
except PydanticSchemaGenerationError:
580+
serialized = repr(option)
581+
options[i] = serialized
582+
options_str = "\n\n".join(
583+
f"Option {i}: {option}" for i, option in options.items()
584+
)
585+
instructions = f"""
586+
Provide a single integer as the result, corresponding to the index
587+
of your chosen option. You options are: {options_str}
588+
"""
589+
590+
# otherwise try to load the schema for the result type
591+
elif self.result_type is not None:
592+
try:
593+
TypeAdapter(self.result_type)
594+
result_schema = self.result_type
595+
except PydanticSchemaGenerationError:
596+
pass
597+
if result_schema is None:
598+
raise ValueError(
599+
f"Could not load or infer schema for result type {self.result_type}. "
600+
"Please use a custom type or add compatibility."
601+
)
602+
603+
@tool(
604+
name=f"mark_task_{self.id}_successful",
605+
description=f"Mark task {self.id} as successful.",
606+
instructions=instructions,
607+
private=True,
608+
include_return_description=False,
609+
)
610+
def succeed(result: result_schema) -> str: # type: ignore
611+
if self.is_successful():
612+
raise ValueError(
613+
f"{self.friendly_name()} is already marked successful."
614+
)
615+
if options:
616+
if result not in options:
617+
raise ValueError(f"Invalid option. Please choose one of {options}")
618+
result = options[result]
619+
self.mark_successful(result=result)
620+
return f"{self.friendly_name()} marked successful."
621+
622+
return succeed
623+
624+
def create_fail_tool(self) -> Tool:
625+
"""
626+
Create an agent-compatible tool for failing this task.
627+
"""
628+
629+
@tool(
630+
name=f"mark_task_{self.id}_failed",
631+
description=(
632+
f"Mark task {self.id} as failed. Only use when technical errors prevent success. Provide a detailed reason for the failure."
633+
),
634+
private=True,
635+
include_return_description=False,
636+
)
637+
def fail(reason: str) -> str:
638+
self.mark_failed(reason=reason)
639+
return f"{self.friendly_name()} marked failed."
640+
641+
return fail
642+
563643
# Deprecated ---------------------------
564644

565645
@deprecated("Use Task.run(steps=1) instead.", version="0.9")
@@ -574,6 +654,11 @@ async def run_once_async(self, *args, **kwargs):
574654
def validate_result(result: Any, result_type: type[T]) -> T:
575655
if result_type is None and result is not None:
576656
raise ValueError("Task has result_type=None, but a result was provided.")
657+
elif isinstance(result_type, tuple):
658+
if result not in result_type:
659+
raise ValueError(
660+
f"Result {result} is not in the list of valid result types: {result_type}"
661+
)
577662
elif result_type is not None:
578663
try:
579664
result = TypeAdapter(result_type).validate_python(result)
@@ -594,3 +679,22 @@ def validate_result(result: Any, result_type: type[T]) -> T:
594679
# result = pd.Series(**result)
595680

596681
return result
682+
683+
684+
def _generate_result_schema(result_type: type[T]) -> type[T]:
685+
if result_type is None:
686+
return None
687+
688+
result_schema = None
689+
# try loading pydantic-compatible schemas
690+
try:
691+
TypeAdapter(result_type)
692+
result_schema = result_type
693+
except PydanticSchemaGenerationError:
694+
pass
695+
if result_schema is None:
696+
raise ValueError(
697+
f"Could not load or infer schema for result type {result_type}. "
698+
"Please use a custom type or add compatibility."
699+
)
700+
return result_schema

src/controlflow/tools/orchestration.py

Lines changed: 0 additions & 79 deletions
This file was deleted.

tests/tasks/test_tasks.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,47 @@ def test_custom_templated_prompt(self, agent_context):
312312
task = SimpleTask(prompt="{{ task.objective }}", objective="abc")
313313
prompt = task.get_prompt(context=agent_context)
314314
assert prompt == "abc"
315+
316+
317+
class TestResultType:
318+
def test_int_result(self):
319+
task = Task("choose 5", result_type=int)
320+
task.mark_successful(result=5)
321+
assert task.result == 5
322+
323+
def test_str_result(self):
324+
task = Task("choose 5", result_type=str)
325+
task.mark_successful(result="5")
326+
assert task.result == "5"
327+
328+
def test_tuple_of_ints_result(self):
329+
task = Task("choose 5", result_type=(4, 5, 6))
330+
task.mark_successful(result=5)
331+
assert task.result == 5
332+
333+
def test_tuple_of_ints_validates(self):
334+
task = Task("choose 5", result_type=(4, 5, 6))
335+
with pytest.raises(ValueError):
336+
task.mark_successful(result=7)
337+
338+
339+
class TestSuccessTool:
340+
def test_success_tool(self):
341+
task = Task("choose 5", result_type=int)
342+
tool = task.create_success_tool()
343+
tool.run(input=dict(result=5))
344+
assert task.is_successful()
345+
assert task.result == 5
346+
347+
def test_success_tool_with_list_of_options(self):
348+
task = Task('choose "good"', result_type=["bad", "good", "medium"])
349+
tool = task.create_success_tool()
350+
tool.run(input=dict(result=1))
351+
assert task.is_successful()
352+
assert task.result == "good"
353+
354+
def test_success_tool_with_list_of_options_requires_int(self):
355+
task = Task('choose "good"', result_type=["bad", "good", "medium"])
356+
tool = task.create_success_tool()
357+
with pytest.raises(ValueError):
358+
tool.run(input=dict(result="good"))

0 commit comments

Comments
 (0)