Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support numpy.typing.NDArray dataclass members #156

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion dacite/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def is_instance(value: Any, type_: Type) -> bool:
return is_instance(value, extract_init_var(type_))
elif is_type_generic(type_):
return is_subclass(value, extract_generic(type_)[0])
elif is_generic(type_):
origin = extract_origin_collection(type_)
return isinstance(value, origin)
else:
try:
# As described in PEP 484 - section: "The numeric tower"
Expand All @@ -142,11 +145,15 @@ def is_generic_collection(type_: Type) -> bool:
return False
origin = extract_origin_collection(type_)
try:
return bool(origin and issubclass(origin, Collection))
return bool(origin and issubclass(origin, Collection) and not skip_generic_conversion(origin))
except (TypeError, AttributeError):
return False


def skip_generic_conversion(origin: Type) -> bool:
return origin.__module__ == "numpy" and origin.__qualname__ == "ndarray"


def extract_generic(type_: Type, defaults: Tuple = ()) -> tuple:
try:
if hasattr(type_, "_special") and type_._special:
Expand Down
12 changes: 11 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,15 @@
packages=["dacite"],
package_data={"dacite": ["py.typed"]},
install_requires=['dataclasses;python_version<"3.7"'],
extras_require={"dev": ["pytest>=5", "pytest-cov", "coveralls", "black", "mypy", "pylint"]},
extras_require={
"dev": [
"pytest>=5",
"pytest-cov",
"coveralls",
"black",
"mypy",
"pylint",
'numpy>=1.21.0;python_version>="3.7"',
]
},
)
1 change: 1 addition & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
import pytest

literal_support = init_var_type_support = pytest.mark.skipif(sys.version_info < (3, 8), reason="requires Python 3.8")
ndarray_support = pytest.mark.skipif(sys.version_info < (3, 7), reason="requires Python 3.7")
47 changes: 47 additions & 0 deletions tests/core/test_ndarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from dataclasses import dataclass
from typing import Sequence, TypeVar

import numpy
import numpy.typing
import numpy.testing

from dacite import from_dict, Config
from tests.common import ndarray_support


@ndarray_support
def test_from_dict_with_ndarray():
@dataclass
class X:
a: numpy.ndarray

result = from_dict(X, {"a": numpy.array([1, 2, 3])})

numpy.testing.assert_allclose(result.a, numpy.array([1, 2, 3]))


@ndarray_support
def test_from_dict_with_nptndarray():
@dataclass
class X:
a: numpy.typing.NDArray[numpy.float64]

result = from_dict(X, {"a": numpy.array([1, 2, 3])})

numpy.testing.assert_allclose(result.a, numpy.array([1, 2, 3]))


@ndarray_support
def test_from_dict_with_nptndarray_and_converter():
@dataclass
class X:
a: numpy.typing.NDArray[numpy.float64]

D = TypeVar("D", bound=numpy.generic)

def coerce_to_array(s: Sequence[D]) -> numpy.typing.NDArray[D]:
return numpy.array(s)

result = from_dict(X, {"a": [1, 2, 3]}, Config(type_hooks={numpy.typing.NDArray[numpy.float64]: coerce_to_array}))

numpy.testing.assert_allclose(result.a, numpy.array([1, 2, 3]))
2 changes: 1 addition & 1 deletion tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def test_is_instance_with_not_supported_generic_types():
class X(Generic[T]):
pass

assert not is_instance(X[str](), X[str])
assert is_instance(X[str](), X[str])


def test_is_instance_with_generic_mapping_and_matching_value_type():
Expand Down