diff --git a/dacite/core.py b/dacite/core.py index 4d59c66..9ba0de4 100644 --- a/dacite/core.py +++ b/dacite/core.py @@ -1,7 +1,7 @@ import copy -from dataclasses import is_dataclass +from dataclasses import is_dataclass, dataclass from itertools import zip_longest -from typing import TypeVar, Type, Optional, get_type_hints, Mapping, Any, Collection, MutableMapping +from typing import TypeVar, Type, Optional, get_type_hints, Mapping, Any, Collection, MutableMapping, get_origin from dacite.cache import cache from dacite.config import Config @@ -33,6 +33,7 @@ is_init_var, extract_init_var, is_subclass, + is_generic_subclass, ) T = TypeVar("T") @@ -61,9 +62,10 @@ def from_dict(data_class: Type[T], data: Data, config: Optional[Config] = None) for field in data_class_fields: field = copy.copy(field) field.type = data_class_hints[field.name] - if field.name in data: + + if hasattr(data, field.name) or (isinstance(data, Mapping) and field.name in data): + field_data = getattr(data, field.name, None) or data[field.name] try: - field_data = data[field.name] value = _build_value(type_=field.type, data=field_data, config=config) except DaciteFieldError as error: error.update_path(field.name) @@ -98,6 +100,8 @@ def _build_value(type_: Type, data: Any, config: Config) -> Any: data = _build_value_for_collection(collection=type_, data=data, config=config) elif is_dataclass(type_) and isinstance(data, Mapping): data = from_dict(data_class=type_, data=data, config=config) + elif is_generic_subclass(type_) and is_dataclass(get_origin(type_)): + data = from_dict(data_class=get_origin(type_), data=data, config=config) for cast_type in config.cast: if is_subclass(type_, cast_type): if is_generic_collection(type_): diff --git a/dacite/types.py b/dacite/types.py index 4a2ba2f..9494f50 100644 --- a/dacite/types.py +++ b/dacite/types.py @@ -8,8 +8,13 @@ TypeVar, Mapping, Tuple, + get_origin, + get_type_hints, + get_args, cast as typing_cast, + _GenericAlias, # Remove import and check for Generic in a different way ) +from inspect import isclass from dacite.cache import cache @@ -43,6 +48,16 @@ def is_generic(type_: Type) -> bool: return hasattr(type_, "__origin__") +@cache +def is_generic_subclass(type_: Type) -> bool: + return is_generic(type_) and hasattr(type_, "__args__") + + +@cache +def is_generic_alias(type_: Type) -> bool: + return type(type_.__args__) == _GenericAlias + + @cache def is_union(type_: Type) -> bool: if is_generic(type_) and type_.__origin__ == Union: @@ -86,6 +101,22 @@ def is_init_var(type_: Type) -> bool: return isinstance(type_, InitVar) or type_ is InitVar +def is_valid_generic_class(value: Any, type_: Type) -> bool: + if not isinstance(value, get_origin(type_)): + return False + type_hints = get_type_hints(value) + for field_name, field_type in type_hints.items(): + if isinstance(field_type, TypeVar): + return ( + any([isinstance(getattr(value, field_name), arg) for arg in get_args(type_)]) + if get_args(type_) + else True + ) + else: + return is_instance(value, type_) + return True + + @cache def extract_init_var(type_: Type) -> Union[Type, Any]: try: @@ -128,6 +159,17 @@ def is_instance(value: Any, type_: Type) -> bool: return value in extract_generic(type_) elif is_init_var(type_): return is_instance(value, extract_init_var(type_)) + elif isclass(type(type_)) and type(type_) == _GenericAlias: + return is_valid_generic_class(value, type_) + elif isinstance(type_, TypeVar): + if hasattr(type_, "__constraints__") and type_.__constraints__: + return any(is_instance(value, t) for t in type_.__constraints__) + if hasattr(type_, "__bound__") and type_.__bound__: + if isinstance(type_.__bound__, tuple): + return any(is_instance(value, t) for t in type_.__bound__) + if type_.__bound__ is not None and is_generic(type_.__bound__): + return isinstance(value, type_.__bound__) + return True elif is_type_generic(type_): return is_subclass(value, extract_generic(type_)[0]) else: diff --git a/tests/core/test_base.py b/tests/core/test_base.py index 8d82788..6bb8cd5 100644 --- a/tests/core/test_base.py +++ b/tests/core/test_base.py @@ -1,9 +1,9 @@ -from dataclasses import dataclass, field -from typing import Any, NewType, Optional +from dataclasses import dataclass, field, asdict +from typing import Any, NewType, Optional, TypeVar, Generic, List, Union import pytest -from dacite import from_dict, MissingValueError, WrongTypeError +from dacite import from_dict, MissingValueError, WrongTypeError, Config def test_from_dict_with_correct_data(): @@ -191,3 +191,67 @@ class X: result = from_dict(X, {"s": "test"}) assert result == X(s=MyStr("test")) + + +def test_from_dict_generic(): + T = TypeVar("T", bound=Union[str, int]) + + @dataclass + class A(Generic[T]): + a: T + + @dataclass + class B: + a_str: A[str] + a_int: A[int] + + assert from_dict(B, {"a_str": {"a": "test"}, "a_int": {"a": 1}}) == B(a_str=A[str](a="test"), a_int=A[int](a=1)) + + +def test_from_dict_generic_invalid(): + T = TypeVar("T") + + @dataclass + class A(Generic[T]): + a: T + + @dataclass + class B: + a_str: A[str] + a_int: A[int] + + with pytest.raises(WrongTypeError): + from_dict(B, {"a_str": {"a": "test"}, "a_int": {"a": "not int"}}) + + +def test_from_dict_generic_common_invalid(): + T = TypeVar("T", str, List[str]) + + @dataclass + class Common(Generic[T]): + foo: T + bar: T + + @dataclass + class A: + elements: List[Common[int]] + + with pytest.raises(WrongTypeError): + from_dict(A, {"elements": [{"foo": 1, "bar": 2}, {"foo": 3, "bar": 4}]}) + + +def test_from_dict_generic_common(): + T = TypeVar("T", bound=int) + + @dataclass + class Common(Generic[T]): + foo: T + bar: T + + @dataclass + class A: + elements: List[Common[int]] + + result = from_dict(A, {"elements": [{"foo": 1, "bar": 2}, {"foo": 3, "bar": 4}]}) + + assert result == A(elements=[Common[int](1, 2), Common[int](3, 4)]) diff --git a/tests/test_types.py b/tests/test_types.py index 7b074a8..a3b77ea 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -1,4 +1,4 @@ -from dataclasses import InitVar +from dataclasses import InitVar, dataclass from typing import Optional, Union, List, Any, Dict, NewType, TypeVar, Generic, Collection, Tuple, Type from unittest.mock import patch, Mock @@ -268,13 +268,13 @@ def test_is_instance_with_with_type_and_not_matching_value_type(): assert not is_instance(1, Type[str]) -def test_is_instance_with_not_supported_generic_types(): +def test_is_instance_with_generic_types(): T = TypeVar("T") class X(Generic[T]): pass - assert not is_instance(X[str](), X[str]) + assert is_instance(X[str](), X[str]) def test_is_instance_with_generic_mapping_and_matching_value_type(): @@ -364,6 +364,10 @@ def test_is_instance_with_empty_tuple_and_not_matching_type(): assert not is_instance((1, 2), Tuple[()]) +def test_is_instance_list_type(): + assert is_instance([{}], List) + + def test_extract_generic(): assert extract_generic(List[int]) == (int,)