Skip to content

Commit

Permalink
Merge pull request #295 from PrefectHQ/alias
Browse files Browse the repository at this point in the history
Fix issue with special type aliases
  • Loading branch information
jlowin authored Sep 9, 2024
2 parents 0704f0f + 137982a commit fa0b415
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 5 deletions.
33 changes: 29 additions & 4 deletions src/controlflow/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
TypeVar,
Union,
_AnnotatedAlias,
_GenericAlias,
_LiteralGenericAlias,
_SpecialGenericAlias,
)

from pydantic import (
Field,
PydanticSchemaGenerationError,
RootModel,
TypeAdapter,
field_serializer,
field_validator,
Expand Down Expand Up @@ -44,6 +47,19 @@
logger = get_logger(__name__)


class Labels(RootModel):
root: tuple[Any, ...]

def __iter__(self):
return iter(self.root)

def __getitem__(self, item):
return self.root[item]

def __repr__(self) -> str:
return f'Labels: {", ".join(self.root)}'


class TaskStatus(Enum):
PENDING = "PENDING"
RUNNING = "RUNNING"
Expand Down Expand Up @@ -89,7 +105,15 @@ class Task(ControlFlowModel):
)
status: TaskStatus = TaskStatus.PENDING
result: Optional[Union[T, str]] = None
result_type: Union[type[T], GenericAlias, _AnnotatedAlias, tuple, None] = Field(
result_type: Union[
type[T],
GenericAlias,
_GenericAlias,
_SpecialGenericAlias,
_AnnotatedAlias,
Labels,
None,
] = Field(
NOTSET,
description="The expected type of the result. This should be a type"
", generic alias, BaseModel subclass, or list of choices. "
Expand Down Expand Up @@ -228,7 +252,7 @@ def _validate_result_type(cls, v):
if isinstance(v, _LiteralGenericAlias):
v = v.__args__
if isinstance(v, (list, tuple, set)):
v = tuple(v)
v = Labels(v)
return v

@field_serializer("parent")
Expand Down Expand Up @@ -489,7 +513,7 @@ def create_success_tool(self) -> Tool:
# a single integer index instead of writing out the entire option. Therefore
# we create a tool that describes a series of options and accepts the index
# as a result.
if isinstance(self.result_type, tuple):
if isinstance(self.result_type, Labels):
result_schema = int
options = {}
serialized_options = {}
Expand All @@ -511,6 +535,7 @@ def create_success_tool(self) -> Tool:
# otherwise try to load the schema for the result type
elif self.result_type is not None:
try:
# see if the result type is a valid pydantic type
TypeAdapter(self.result_type)
result_schema = self.result_type
except PydanticSchemaGenerationError:
Expand Down Expand Up @@ -562,7 +587,7 @@ def fail(reason: str) -> str:
def validate_result(self, raw_result: Any) -> T:
if self.result_type is None and raw_result is not None:
raise ValueError("Task has result_type=None, but a result was provided.")
elif isinstance(self.result_type, tuple):
elif isinstance(self.result_type, Labels):
if raw_result not in self.result_type:
raise ValueError(
f"Result {raw_result} is not in the list of valid result types: {self.result_type}"
Expand Down
17 changes: 16 additions & 1 deletion tests/tasks/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated, Any
from typing import Annotated, Any, Dict, List

import pytest
from pydantic import BaseModel
Expand Down Expand Up @@ -288,6 +288,21 @@ def test_tuple_of_ints_validates(self):
with pytest.raises(ValueError):
task.mark_successful(result=7)

def test_typed_dict_result(self):
task = Task("", result_type=dict[str, int])
task.mark_successful(result={"a": 5, "b": "6"})
assert task.result == {"a": 5, "b": 6}

def test_special_list_type_result(self):
task = Task("", result_type=List[int])
task.mark_successful(result=[5, 6])
assert task.result == [5, 6]

def test_special_dict_type_result(self):
task = Task("", result_type=Dict[str, int])
task.mark_successful(result={"a": 5, "b": "6"})
assert task.result == {"a": 5, "b": 6}

def test_pydantic_result(self):
class Name(BaseModel):
first: str
Expand Down

0 comments on commit fa0b415

Please sign in to comment.