diff --git a/dacite/core.py b/dacite/core.py index d7553e5..c58f5b2 100644 --- a/dacite/core.py +++ b/dacite/core.py @@ -1,6 +1,11 @@ from dataclasses import is_dataclass from itertools import zip_longest -from typing import TypeVar, Type, Optional, get_type_hints, Mapping, Any, Collection, MutableMapping, get_origin +from typing import TypeVar, Type, Optional, get_type_hints, Mapping, Any, Collection, MutableMapping + +try: + from typing import get_origin # type: ignore +except ImportError: + from typing_extensions import get_origin from dacite.cache import cache from dacite.config import Config @@ -99,7 +104,9 @@ def _build_value(type_: Type, data: Any, config: Config) -> Any: elif cache(is_dataclass)(type_) and isinstance(data, Mapping): data = from_dict(data_class=type_, data=data, config=config) elif is_generic_subclass(type_) and is_dataclass(get_origin(type_)): - data = from_dict(data_class=get_origin(type_), data=data, config=config) + origin = get_origin(type_) + assert origin is not None + data = from_dict(data_class=origin, data=data, config=config) for cast_type in config.cast: if is_subclass(type_, cast_type): if is_generic_collection(type_): diff --git a/dacite/types.py b/dacite/types.py index 5fa9545..1bdd209 100644 --- a/dacite/types.py +++ b/dacite/types.py @@ -1,19 +1,10 @@ from dataclasses import InitVar -from typing import ( - Type, - Any, - Optional, - Union, - Collection, - TypeVar, - Mapping, - Tuple, - get_origin, - get_type_hints, - get_args, - cast as typing_cast, - _GenericAlias, # Remove import and check for Generic in a different way -) +from typing import Type, Any, Optional, Union, Collection, TypeVar, Mapping, Tuple, get_type_hints, cast as typing_cast + +try: + from typing import get_origin, get_args # type: ignore +except ImportError: + from typing_extensions import get_origin, get_args from inspect import isclass from dacite.cache import cache @@ -53,14 +44,9 @@ def is_generic_subclass(type_: Type) -> bool: return is_generic(type_) and hasattr(type_, "__args__") -@cache -def is_generic_alias(type_: Type) -> bool: - return type(type_.__args__) == _GenericAlias - - @cache def is_union(type_: Type) -> bool: - if is_generic(type_) and type_.__origin__ == Union: + if is_generic(type_) and get_origin(type_) == Union: return True try: @@ -81,7 +67,7 @@ def is_literal(type_: Type) -> bool: try: from typing import Literal # type: ignore - return is_generic(type_) and type_.__origin__ == Literal + return is_generic(type_) and get_origin(type_) == Literal except ImportError: return False @@ -101,16 +87,26 @@ def is_init_var(type_: Type) -> bool: return isinstance(type_, InitVar) or type_ is InitVar +@cache +def is_generic_alias(type_: Type) -> bool: + """Since `typing._GenericAlias` is not explicitly exported, we instead rely on this check.""" + return str(type_) == "" + + +@cache +def has_generic_alias_in_args(type_: Type) -> bool: + return is_generic_alias(type(get_args(type_))) + + def is_valid_generic_class(value: Any, type_: Type) -> bool: - if not isinstance(value, get_origin(type_)): + origin = get_origin(type_) + if not (origin and isinstance(value, origin)): return False - type_hints = get_type_hints(value) + type_hints = get_type_hints(type(value)) for field_name, field_type in type_hints.items(): if isinstance(field_type, TypeVar): return ( - any([isinstance(getattr(value, field_name), arg) for arg in get_args(type_)]) - if get_args(type_) - else True + any(isinstance(getattr(value, field_name), arg) for arg in get_args(type_)) if get_args(type_) else True ) else: return is_instance(value, type_) @@ -165,7 +161,7 @@ def is_instance(value: Any, type_: Type) -> bool: return value in extract_generic(type_) elif is_init_var(type_): return is_instance(value, extract_init_var(type_)) - elif isclass(type(type_)) and type(type_) == _GenericAlias: + elif isclass(type(type_)) and is_generic_alias(type(type_)): return is_valid_generic_class(value, type_) elif isinstance(type_, TypeVar): if hasattr(type_, "__constraints__") and type_.__constraints__: @@ -174,7 +170,7 @@ def is_instance(value: Any, type_: Type) -> bool: if isinstance(type_.__bound__, tuple): return any(is_instance(value, t) for t in type_.__bound__) if type_.__bound__ is not None and is_generic(type_.__bound__): - return isinstance(value, type_.__bound__) + return isinstance(value, extract_generic(type_.__bound__)) return True elif is_type_generic(type_): return is_subclass(value, extract_generic(type_)[0]) @@ -218,6 +214,6 @@ def is_subclass(sub_type: Type, base_type: Type) -> bool: @cache def is_type_generic(type_: Type) -> bool: try: - return type_.__origin__ in (type, Type) + return get_origin(type_) in (type, Type) except AttributeError: return False diff --git a/tests/core/test_base.py b/tests/core/test_base.py index 6bb8cd5..093b2d9 100644 --- a/tests/core/test_base.py +++ b/tests/core/test_base.py @@ -193,7 +193,7 @@ class X: assert result == X(s=MyStr("test")) -def test_from_dict_generic(): +def test_from_dict_generic_valid(): T = TypeVar("T", bound=Union[str, int]) @dataclass