Skip to content

Commit

Permalink
Add basic handling of typing.Generic
Browse files Browse the repository at this point in the history
  • Loading branch information
mciszczon committed Jan 11, 2023
1 parent c831d57 commit 71a093d
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 9 deletions.
9 changes: 6 additions & 3 deletions dacite/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import is_dataclass
from itertools import zip_longest
from typing import TypeVar, Type, Optional, get_type_hints, Mapping, Any, Collection, MutableMapping
from typing import TypeVar, Type, Optional, get_type_hints, Mapping, Any, Collection, MutableMapping, get_origin

from dacite.cache import cache
from dacite.config import Config
Expand Down Expand Up @@ -31,6 +31,7 @@
is_init_var,
extract_init_var,
is_subclass,
is_generic_subclass,
)

T = TypeVar("T")
Expand Down Expand Up @@ -58,9 +59,9 @@ def from_dict(data_class: Type[T], data: Data, config: Optional[Config] = None)
raise UnexpectedDataError(keys=extra_fields)
for field in data_class_fields:
field_type = data_class_hints[field.name]
if field.name in data:
if hasattr(data, field.name) or (isinstance(data, Mapping) and field.name in data):
field_data = getattr(data, field.name, None) or data[field.name]
try:
field_data = data[field.name]
value = _build_value(type_=field_type, data=field_data, config=config)
except DaciteFieldError as error:
error.update_path(field.name)
Expand Down Expand Up @@ -97,6 +98,8 @@ def _build_value(type_: Type, data: Any, config: Config) -> Any:
data = _build_value_for_collection(collection=type_, data=data, config=config)
elif cache(is_dataclass)(type_) and isinstance(data, Mapping):
data = from_dict(data_class=type_, data=data, config=config)
elif is_generic_subclass(type_) and is_dataclass(get_origin(type_)):
data = from_dict(data_class=get_origin(type_), data=data, config=config)
for cast_type in config.cast:
if is_subclass(type_, cast_type):
if is_generic_collection(type_):
Expand Down
42 changes: 42 additions & 0 deletions dacite/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@
TypeVar,
Mapping,
Tuple,
get_origin,
get_type_hints,
get_args,
cast as typing_cast,
_GenericAlias, # Remove import and check for Generic in a different way
)
from inspect import isclass

from dacite.cache import cache

Expand Down Expand Up @@ -43,6 +48,16 @@ def is_generic(type_: Type) -> bool:
return hasattr(type_, "__origin__")


@cache
def is_generic_subclass(type_: Type) -> bool:
return is_generic(type_) and hasattr(type_, "__args__")


@cache
def is_generic_alias(type_: Type) -> bool:
return type(type_.__args__) == _GenericAlias


@cache
def is_union(type_: Type) -> bool:
if is_generic(type_) and type_.__origin__ == Union:
Expand Down Expand Up @@ -86,6 +101,22 @@ def is_init_var(type_: Type) -> bool:
return isinstance(type_, InitVar) or type_ is InitVar


def is_valid_generic_class(value: Any, type_: Type) -> bool:
if not isinstance(value, get_origin(type_)):
return False
type_hints = get_type_hints(value)
for field_name, field_type in type_hints.items():
if isinstance(field_type, TypeVar):
return (
any([isinstance(getattr(value, field_name), arg) for arg in get_args(type_)])
if get_args(type_)
else True
)
else:
return is_instance(value, type_)
return True


@cache
def extract_init_var(type_: Type) -> Union[Type, Any]:
try:
Expand Down Expand Up @@ -134,6 +165,17 @@ def is_instance(value: Any, type_: Type) -> bool:
return value in extract_generic(type_)
elif is_init_var(type_):
return is_instance(value, extract_init_var(type_))
elif isclass(type(type_)) and type(type_) == _GenericAlias:
return is_valid_generic_class(value, type_)
elif isinstance(type_, TypeVar):
if hasattr(type_, "__constraints__") and type_.__constraints__:
return any(is_instance(value, t) for t in type_.__constraints__)
if hasattr(type_, "__bound__") and type_.__bound__:
if isinstance(type_.__bound__, tuple):
return any(is_instance(value, t) for t in type_.__bound__)
if type_.__bound__ is not None and is_generic(type_.__bound__):
return isinstance(value, type_.__bound__)
return True
elif is_type_generic(type_):
return is_subclass(value, extract_generic(type_)[0])
else:
Expand Down
70 changes: 67 additions & 3 deletions tests/core/test_base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from dataclasses import dataclass, field
from typing import Any, NewType, Optional
from dataclasses import dataclass, field, asdict
from typing import Any, NewType, Optional, TypeVar, Generic, List, Union

import pytest

from dacite import from_dict, MissingValueError, WrongTypeError
from dacite import from_dict, MissingValueError, WrongTypeError, Config


def test_from_dict_with_correct_data():
Expand Down Expand Up @@ -191,3 +191,67 @@ class X:
result = from_dict(X, {"s": "test"})

assert result == X(s=MyStr("test"))


def test_from_dict_generic():
T = TypeVar("T", bound=Union[str, int])

@dataclass
class A(Generic[T]):
a: T

@dataclass
class B:
a_str: A[str]
a_int: A[int]

assert from_dict(B, {"a_str": {"a": "test"}, "a_int": {"a": 1}}) == B(a_str=A[str](a="test"), a_int=A[int](a=1))


def test_from_dict_generic_invalid():
T = TypeVar("T")

@dataclass
class A(Generic[T]):
a: T

@dataclass
class B:
a_str: A[str]
a_int: A[int]

with pytest.raises(WrongTypeError):
from_dict(B, {"a_str": {"a": "test"}, "a_int": {"a": "not int"}})


def test_from_dict_generic_common_invalid():
T = TypeVar("T", str, List[str])

@dataclass
class Common(Generic[T]):
foo: T
bar: T

@dataclass
class A:
elements: List[Common[int]]

with pytest.raises(WrongTypeError):
from_dict(A, {"elements": [{"foo": 1, "bar": 2}, {"foo": 3, "bar": 4}]})


def test_from_dict_generic_common():
T = TypeVar("T", bound=int)

@dataclass
class Common(Generic[T]):
foo: T
bar: T

@dataclass
class A:
elements: List[Common[int]]

result = from_dict(A, {"elements": [{"foo": 1, "bar": 2}, {"foo": 3, "bar": 4}]})

assert result == A(elements=[Common[int](1, 2), Common[int](3, 4)])
10 changes: 7 additions & 3 deletions tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import InitVar
from dataclasses import InitVar, dataclass
from typing import Optional, Union, List, Any, Dict, NewType, TypeVar, Generic, Collection, Tuple, Type
from unittest.mock import patch, Mock

Expand Down Expand Up @@ -268,13 +268,13 @@ def test_is_instance_with_with_type_and_not_matching_value_type():
assert not is_instance(1, Type[str])


def test_is_instance_with_not_supported_generic_types():
def test_is_instance_with_generic_types():
T = TypeVar("T")

class X(Generic[T]):
pass

assert not is_instance(X[str](), X[str])
assert is_instance(X[str](), X[str])


def test_is_instance_with_generic_mapping_and_matching_value_type():
Expand Down Expand Up @@ -364,6 +364,10 @@ def test_is_instance_with_empty_tuple_and_not_matching_type():
assert not is_instance((1, 2), Tuple[()])


def test_is_instance_list_type():
assert is_instance([{}], List)


def test_extract_generic():
assert extract_generic(List[int]) == (int,)

Expand Down

0 comments on commit 71a093d

Please sign in to comment.