Skip to content

Commit

Permalink
added check for default_factory (#1149)
Browse files Browse the repository at this point in the history
* added check for default_factory

* moved the check up to AdapterCommon

* removed comment

---------

Co-authored-by: zilto <tjean@DESKTOP-V6JDCS2>
  • Loading branch information
zilto and zilto authored Sep 23, 2024
1 parent 0bf4c45 commit c6dcd00
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 2 deletions.
4 changes: 2 additions & 2 deletions hamilton/io/data_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_required_arguments(cls) -> Dict[str, Type[Type]]:
return {
field.name: type_hints.get(field.name)
for field in dataclasses.fields(cls)
if field.default == dataclasses.MISSING
if field.default == dataclasses.MISSING and field.default_factory == dataclasses.MISSING
}

@classmethod
Expand All @@ -87,7 +87,7 @@ def get_optional_arguments(cls) -> Dict[str, Type[Type]]:
return {
field.name: type_hints.get(field.name)
for field in dataclasses.fields(cls)
if field.default != dataclasses.MISSING
if field.default != dataclasses.MISSING or field.default_factory != dataclasses.MISSING
}

@classmethod
Expand Down
64 changes: 64 additions & 0 deletions tests/function_modifiers/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,70 @@ def fn() -> dict:
assert list(saver_node.input_types) == ["fn"]


@dataclasses.dataclass
class DefaultFactoryLoader(DataLoader):
field_with_factory: int = dataclasses.field(default_factory=int)

def __post_init__(self):
self.param2 = self.field_with_factory + 1

def load_data(self, type_: Type[int]) -> Tuple[int, Dict[str, Any]]:
return self.param2, {}

@classmethod
def applicable_types(cls) -> Collection[Type]:
return [int]

@classmethod
def name(cls) -> str:
return "factory"


def test_loader_default_factory_field():
@LoadFromDecorator([DefaultFactoryLoader])
def foo(param: int) -> int:
return param

fg = graph.create_function_graph(
ad_hoc_utils.create_temporary_module(foo),
config={},
)
assert len(fg) == 3
assert "foo" in fg


@dataclasses.dataclass
class DefaultFactorySaver(DataSaver):
field_with_factory: int = dataclasses.field(default_factory=int)

def __post_init__(self):
self.param2 = self.field_with_factory + 1

def save_data(self, data: int) -> Dict[str, Any]:
return {}

@classmethod
def applicable_types(cls) -> Collection[Type]:
return [int]

@classmethod
def name(cls) -> str:
return "factory"


def test_saver_default_factory_field():
@SaveToDecorator([DefaultFactorySaver])
def foo(param: int) -> int:
return param

fg = graph.create_function_graph(
ad_hoc_utils.create_temporary_module(foo),
config={},
)
assert len(fg) == 3
assert "foo" in fg


@dataclasses.dataclass
class OptionalParamDataLoader(DataLoader):
param: int = 1
Expand Down

0 comments on commit c6dcd00

Please sign in to comment.