From 3313ae32d7e90d9f4d1d35f388278a0386829802 Mon Sep 17 00:00:00 2001 From: Tobias Pfeiffer Date: Tue, 10 Aug 2021 16:35:29 +0900 Subject: [PATCH 1/3] do not treat generic numpy.ndarray as a generic collection --- dacite/types.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dacite/types.py b/dacite/types.py index 1d4dfea..37c1089 100644 --- a/dacite/types.py +++ b/dacite/types.py @@ -142,11 +142,15 @@ def is_generic_collection(type_: Type) -> bool: return False origin = extract_origin_collection(type_) try: - return bool(origin and issubclass(origin, Collection)) + return bool(origin and issubclass(origin, Collection) and not skip_generic_conversion(origin)) except (TypeError, AttributeError): return False +def skip_generic_conversion(origin: Type) -> bool: + return origin.__module__ == "numpy" and origin.__qualname__ == "ndarray" + + def extract_generic(type_: Type, defaults: Tuple = ()) -> tuple: try: if hasattr(type_, "_special") and type_._special: From bc54426a0fc19c85c5497fc4fcb4ab3b6cc0d640 Mon Sep 17 00:00:00 2001 From: Tobias Pfeiffer Date: Tue, 10 Aug 2021 16:59:11 +0900 Subject: [PATCH 2/3] allow is_instance checks for arbitrary generic types --- dacite/types.py | 3 +++ tests/test_types.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/dacite/types.py b/dacite/types.py index 37c1089..bcdfd5b 100644 --- a/dacite/types.py +++ b/dacite/types.py @@ -127,6 +127,9 @@ def is_instance(value: Any, type_: Type) -> bool: return is_instance(value, extract_init_var(type_)) elif is_type_generic(type_): return is_subclass(value, extract_generic(type_)[0]) + elif is_generic(type_): + origin = extract_origin_collection(type_) + return isinstance(value, origin) else: try: # As described in PEP 484 - section: "The numeric tower" diff --git a/tests/test_types.py b/tests/test_types.py index 948ff16..5e90cfc 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -207,7 +207,7 @@ def test_is_instance_with_not_supported_generic_types(): class X(Generic[T]): pass - assert not is_instance(X[str](), X[str]) + assert is_instance(X[str](), X[str]) def test_is_instance_with_generic_mapping_and_matching_value_type(): From 0e378974cfe425d76cde4b285563825e64a47880 Mon Sep 17 00:00:00 2001 From: Tobias Pfeiffer Date: Tue, 10 Aug 2021 17:03:34 +0900 Subject: [PATCH 3/3] add test for npt.NDArray handling --- setup.py | 12 +++++++++- tests/common.py | 1 + tests/core/test_ndarray.py | 47 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 tests/core/test_ndarray.py diff --git a/setup.py b/setup.py index 89c86c5..13e8a8e 100644 --- a/setup.py +++ b/setup.py @@ -27,5 +27,15 @@ packages=["dacite"], package_data={"dacite": ["py.typed"]}, install_requires=['dataclasses;python_version<"3.7"'], - extras_require={"dev": ["pytest>=5", "pytest-cov", "coveralls", "black", "mypy", "pylint"]}, + extras_require={ + "dev": [ + "pytest>=5", + "pytest-cov", + "coveralls", + "black", + "mypy", + "pylint", + 'numpy>=1.21.0;python_version>="3.7"', + ] + }, ) diff --git a/tests/common.py b/tests/common.py index 71a557a..be8380f 100644 --- a/tests/common.py +++ b/tests/common.py @@ -3,3 +3,4 @@ import pytest literal_support = init_var_type_support = pytest.mark.skipif(sys.version_info < (3, 8), reason="requires Python 3.8") +ndarray_support = pytest.mark.skipif(sys.version_info < (3, 7), reason="requires Python 3.7") diff --git a/tests/core/test_ndarray.py b/tests/core/test_ndarray.py new file mode 100644 index 0000000..df04813 --- /dev/null +++ b/tests/core/test_ndarray.py @@ -0,0 +1,47 @@ +from dataclasses import dataclass +from typing import Sequence, TypeVar + +import numpy +import numpy.typing +import numpy.testing + +from dacite import from_dict, Config +from tests.common import ndarray_support + + +@ndarray_support +def test_from_dict_with_ndarray(): + @dataclass + class X: + a: numpy.ndarray + + result = from_dict(X, {"a": numpy.array([1, 2, 3])}) + + numpy.testing.assert_allclose(result.a, numpy.array([1, 2, 3])) + + +@ndarray_support +def test_from_dict_with_nptndarray(): + @dataclass + class X: + a: numpy.typing.NDArray[numpy.float64] + + result = from_dict(X, {"a": numpy.array([1, 2, 3])}) + + numpy.testing.assert_allclose(result.a, numpy.array([1, 2, 3])) + + +@ndarray_support +def test_from_dict_with_nptndarray_and_converter(): + @dataclass + class X: + a: numpy.typing.NDArray[numpy.float64] + + D = TypeVar("D", bound=numpy.generic) + + def coerce_to_array(s: Sequence[D]) -> numpy.typing.NDArray[D]: + return numpy.array(s) + + result = from_dict(X, {"a": [1, 2, 3]}, Config(type_hooks={numpy.typing.NDArray[numpy.float64]: coerce_to_array})) + + numpy.testing.assert_allclose(result.a, numpy.array([1, 2, 3]))