Skip to content

Commit

Permalink
Fix issue with classifying complex labels
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed Sep 4, 2024
1 parent d886202 commit 613f099
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,21 +535,26 @@ def create_success_tool(self) -> Tool:
result_schema = None

# if the result_type is a tuple of options, then we want the LLM to provide
# a single integer index instead of writing out the entire option
# a single integer index instead of writing out the entire option. Therefore
# we create a tool that describes a series of options and accepts the index
# as a result.
if isinstance(self.result_type, tuple):
result_schema = int
options = {}
serialized_options = {}
for i, option in enumerate(self.result_type):
options[i] = option
try:
serialized = TypeAdapter(type(option)).dump_python(option)
except PydanticSchemaGenerationError:
serialized = repr(option)
options[i] = serialized
serialized_options[i] = serialized
options_str = "\n\n".join(
f"Option {i}: {option}" for i, option in options.items()
f"Option {i}: {option}" for i, option in serialized_options.items()
)
instructions = f"""
Provide a single integer as the result, corresponding to the index
of your chosen option. You options are: {options_str}
of your chosen option. Your options are: {options_str}
"""

# otherwise try to load the schema for the result type
Expand Down
20 changes: 20 additions & 0 deletions tests/tasks/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,26 @@ def test_success_tool_with_list_of_options_requires_int(self):
with pytest.raises(ValueError):
tool.run(input=dict(result="good"))

def test_tuple_of_ints_result(self):
task = Task("choose 5", result_type=(4, 5, 6))
tool = task.create_success_tool()
tool.run(input=dict(result=1))
assert task.result == 5

def test_tuple_of_pydantic_models_result(self):
class Person(BaseModel):
name: str
age: int

task = Task(
"Who is the oldest?",
result_type=(Person(name="Alice", age=30), Person(name="Bob", age=35)),
)
tool = task.create_success_tool()
tool.run(input=dict(result=1))
assert task.result == Person(name="Bob", age=35)
assert isinstance(task.result, Person)


class TestRun:
@pytest.mark.parametrize(
Expand Down

0 comments on commit 613f099

Please sign in to comment.