From 613f0994ca486feddfff9fcaf1bf3f0066babd86 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Tue, 3 Sep 2024 21:53:02 -0400 Subject: [PATCH] Fix issue with classifying complex labels --- src/controlflow/tasks/task.py | 13 +++++++++---- tests/tasks/test_tasks.py | 20 ++++++++++++++++++++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index 5df4991b..039b940e 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -535,21 +535,26 @@ def create_success_tool(self) -> Tool: result_schema = None # if the result_type is a tuple of options, then we want the LLM to provide - # a single integer index instead of writing out the entire option + # a single integer index instead of writing out the entire option. Therefore + # we create a tool that describes a series of options and accepts the index + # as a result. if isinstance(self.result_type, tuple): result_schema = int + options = {} + serialized_options = {} for i, option in enumerate(self.result_type): + options[i] = option try: serialized = TypeAdapter(type(option)).dump_python(option) except PydanticSchemaGenerationError: serialized = repr(option) - options[i] = serialized + serialized_options[i] = serialized options_str = "\n\n".join( - f"Option {i}: {option}" for i, option in options.items() + f"Option {i}: {option}" for i, option in serialized_options.items() ) instructions = f""" Provide a single integer as the result, corresponding to the index - of your chosen option. You options are: {options_str} + of your chosen option. Your options are: {options_str} """ # otherwise try to load the schema for the result type diff --git a/tests/tasks/test_tasks.py b/tests/tasks/test_tasks.py index 40faea60..e3fbaf17 100644 --- a/tests/tasks/test_tasks.py +++ b/tests/tasks/test_tasks.py @@ -421,6 +421,26 @@ def test_success_tool_with_list_of_options_requires_int(self): with pytest.raises(ValueError): tool.run(input=dict(result="good")) + def test_tuple_of_ints_result(self): + task = Task("choose 5", result_type=(4, 5, 6)) + tool = task.create_success_tool() + tool.run(input=dict(result=1)) + assert task.result == 5 + + def test_tuple_of_pydantic_models_result(self): + class Person(BaseModel): + name: str + age: int + + task = Task( + "Who is the oldest?", + result_type=(Person(name="Alice", age=30), Person(name="Bob", age=35)), + ) + tool = task.create_success_tool() + tool.run(input=dict(result=1)) + assert task.result == Person(name="Bob", age=35) + assert isinstance(task.result, Person) + class TestRun: @pytest.mark.parametrize(