6
6
Any ,
7
7
Callable ,
8
8
GenericAlias ,
9
- Literal ,
10
9
Optional ,
11
10
TypeVar ,
12
11
Union ,
25
24
import controlflow
26
25
from controlflow .agents import BaseAgent
27
26
from controlflow .instructions import get_instructions
28
- from controlflow .tools import Tool
27
+ from controlflow .tools import Tool , tool
29
28
from controlflow .tools .talk_to_user import talk_to_user
30
29
from controlflow .utilities .context import ctx
31
30
from controlflow .utilities .general import (
@@ -100,10 +99,10 @@ class Task(ControlFlowModel):
100
99
)
101
100
status : TaskStatus = TaskStatus .PENDING
102
101
result : T = None
103
- result_type : Union [type [T ], GenericAlias , _LiteralGenericAlias , None ] = Field (
102
+ result_type : Union [type [T ], GenericAlias , tuple , None ] = Field (
104
103
str ,
105
104
description = "The expected type of the result. This should be a type"
106
- ", generic alias, BaseModel subclass, pd.DataFrame, or pd.Series . "
105
+ ", generic alias, BaseModel subclass, or list of choices . "
107
106
"Can be None if no result is expected or the agent should communicate internally." ,
108
107
)
109
108
error : Union [str , None ] = None
@@ -264,9 +263,11 @@ def _default_parent(cls, v):
264
263
return v
265
264
266
265
@field_validator ("result_type" , mode = "before" )
267
- def _turn_list_into_literal_result_type (cls , v ):
266
+ def _ensure_result_type_is_list_if_literal (cls , v ):
267
+ if isinstance (v , _LiteralGenericAlias ):
268
+ v = v .__args__
268
269
if isinstance (v , (list , tuple , set )):
269
- return Literal [ tuple (v )] # type: ignore
270
+ v = tuple (v )
270
271
return v
271
272
272
273
@field_serializer ("parent" )
@@ -560,6 +561,85 @@ def generate_subtasks(self, instructions: str = None, agent: BaseAgent = None):
560
561
context = self .context ,
561
562
)
562
563
564
+ def create_success_tool (self ) -> Tool :
565
+ """
566
+ Create an agent-compatible tool for marking this task as successful.
567
+ """
568
+ options = {}
569
+ instructions = None
570
+ result_schema = None
571
+
572
+ # if the result_type is a tuple of options, then we want the LLM to provide
573
+ # a single integer index instead of writing out the entire option
574
+ if isinstance (self .result_type , tuple ):
575
+ result_schema = int
576
+ for i , option in enumerate (self .result_type ):
577
+ try :
578
+ serialized = TypeAdapter (type (option )).dump_python (option )
579
+ except PydanticSchemaGenerationError :
580
+ serialized = repr (option )
581
+ options [i ] = serialized
582
+ options_str = "\n \n " .join (
583
+ f"Option { i } : { option } " for i , option in options .items ()
584
+ )
585
+ instructions = f"""
586
+ Provide a single integer as the result, corresponding to the index
587
+ of your chosen option. You options are: { options_str }
588
+ """
589
+
590
+ # otherwise try to load the schema for the result type
591
+ elif self .result_type is not None :
592
+ try :
593
+ TypeAdapter (self .result_type )
594
+ result_schema = self .result_type
595
+ except PydanticSchemaGenerationError :
596
+ pass
597
+ if result_schema is None :
598
+ raise ValueError (
599
+ f"Could not load or infer schema for result type { self .result_type } . "
600
+ "Please use a custom type or add compatibility."
601
+ )
602
+
603
+ @tool (
604
+ name = f"mark_task_{ self .id } _successful" ,
605
+ description = f"Mark task { self .id } as successful." ,
606
+ instructions = instructions ,
607
+ private = True ,
608
+ include_return_description = False ,
609
+ )
610
+ def succeed (result : result_schema ) -> str : # type: ignore
611
+ if self .is_successful ():
612
+ raise ValueError (
613
+ f"{ self .friendly_name ()} is already marked successful."
614
+ )
615
+ if options :
616
+ if result not in options :
617
+ raise ValueError (f"Invalid option. Please choose one of { options } " )
618
+ result = options [result ]
619
+ self .mark_successful (result = result )
620
+ return f"{ self .friendly_name ()} marked successful."
621
+
622
+ return succeed
623
+
624
+ def create_fail_tool (self ) -> Tool :
625
+ """
626
+ Create an agent-compatible tool for failing this task.
627
+ """
628
+
629
+ @tool (
630
+ name = f"mark_task_{ self .id } _failed" ,
631
+ description = (
632
+ f"Mark task { self .id } as failed. Only use when technical errors prevent success. Provide a detailed reason for the failure."
633
+ ),
634
+ private = True ,
635
+ include_return_description = False ,
636
+ )
637
+ def fail (reason : str ) -> str :
638
+ self .mark_failed (reason = reason )
639
+ return f"{ self .friendly_name ()} marked failed."
640
+
641
+ return fail
642
+
563
643
# Deprecated ---------------------------
564
644
565
645
@deprecated ("Use Task.run(steps=1) instead." , version = "0.9" )
@@ -574,6 +654,11 @@ async def run_once_async(self, *args, **kwargs):
574
654
def validate_result (result : Any , result_type : type [T ]) -> T :
575
655
if result_type is None and result is not None :
576
656
raise ValueError ("Task has result_type=None, but a result was provided." )
657
+ elif isinstance (result_type , tuple ):
658
+ if result not in result_type :
659
+ raise ValueError (
660
+ f"Result { result } is not in the list of valid result types: { result_type } "
661
+ )
577
662
elif result_type is not None :
578
663
try :
579
664
result = TypeAdapter (result_type ).validate_python (result )
@@ -594,3 +679,22 @@ def validate_result(result: Any, result_type: type[T]) -> T:
594
679
# result = pd.Series(**result)
595
680
596
681
return result
682
+
683
+
684
+ def _generate_result_schema (result_type : type [T ]) -> type [T ]:
685
+ if result_type is None :
686
+ return None
687
+
688
+ result_schema = None
689
+ # try loading pydantic-compatible schemas
690
+ try :
691
+ TypeAdapter (result_type )
692
+ result_schema = result_type
693
+ except PydanticSchemaGenerationError :
694
+ pass
695
+ if result_schema is None :
696
+ raise ValueError (
697
+ f"Could not load or infer schema for result type { result_type } . "
698
+ "Please use a custom type or add compatibility."
699
+ )
700
+ return result_schema
0 commit comments