diff --git a/cupy_xarray/accessors.py b/cupy_xarray/accessors.py index 06a5631..1bcb7f8 100644 --- a/cupy_xarray/accessors.py +++ b/cupy_xarray/accessors.py @@ -1,14 +1,17 @@ import cupy as cp -from xarray import ( - DataArray, - Dataset, - register_dataarray_accessor, - register_dataset_accessor, -) +from xarray import Dataset, register_dataarray_accessor, register_dataset_accessor from xarray.core.pycompat import DuckArrayModule -dsk = DuckArrayModule("dask") -dask_array_type = dsk.type +dask_array_type = DuckArrayModule("dask").type +pint_array_type = DuckArrayModule("pint").type + + +def _get_datatype(data): + if isinstance(data, dask_array_type): + return isinstance(data._meta, cp.ndarray) + elif isinstance(data, pint_array_type): + return _get_datatype(data.magnitude) + return isinstance(data, cp.ndarray) @register_dataarray_accessor("cupy") @@ -22,11 +25,9 @@ def __init__(self, da): self.da = da @property - def is_cupy(self): - """bool: The underlying data is a cupy array.""" - if isinstance(self.da.data, dask_array_type): - return isinstance(self.da.data._meta, cp.ndarray) - return isinstance(self.da.data, cp.ndarray) + def is_cupy(self) -> bool: + """True if the underlying data is a cupy array.""" + return _get_datatype(self.da.data) def as_cupy(self): """ @@ -51,21 +52,7 @@ def as_cupy(self): """ - if isinstance(self.da.data, dask_array_type): - return DataArray( - data=self.da.data.map_blocks(cp.asarray), - coords=self.da.coords, - dims=self.da.dims, - name=self.da.name, - attrs=self.da.attrs, - ) - return DataArray( - data=cp.asarray(self.da.data), - coords=self.da.coords, - dims=self.da.dims, - name=self.da.name, - attrs=self.da.attrs, - ) + return self.da.copy(data=_as_cupy_data(self.da.data)) def as_numpy(self): """ @@ -77,30 +64,38 @@ def as_numpy(self): DataArray with underlying data cast to numpy. """ - if self.is_cupy: - if isinstance(self.da.data, dask_array_type): - return DataArray( - data=self.da.data.map_blocks( - lambda block: block.get(), dtype=self.da.data._meta.dtype - ), - coords=self.da.coords, - dims=self.da.dims, - name=self.da.name, - attrs=self.da.attrs, - ) - return DataArray( - data=self.da.data.get(), - coords=self.da.coords, - dims=self.da.dims, - name=self.da.name, - attrs=self.da.attrs, - ) - return self.da.as_numpy() + raise NotImplementedError("Please use .as_numpy DataArray method directly.") def get(self): return self.da.data.get() +def _as_cupy_data(data): + if isinstance(data, dask_array_type): + return data.map_blocks(cp.asarray) + if isinstance(data, pint_array_type): + from pint import Quantity # pylint: disable=import-outside-toplevel + + return Quantity( + _as_cupy_data(data.magnitude), + units=data.units, + ) + return cp.asarray(data) + + +def _as_numpy_data(data): + if isinstance(data, dask_array_type): + return data.map_blocks(lambda block: block.get(), dtype=data._meta.dtype) + if isinstance(data, pint_array_type): + from pint import Quantity # pylint: disable=import-outside-toplevel + + return Quantity( + _as_numpy_data(data.magnitude), + units=data.units, + ) + return data.get() if isinstance(data, cp.ndarray) else data + + @register_dataset_accessor("cupy") class CupyDatasetAccessor: """ @@ -112,12 +107,24 @@ def __init__(self, ds): self.ds = ds @property - def is_cupy(self): + def has_cupy(self) -> bool: + """True if any data variable contains a cupy array.""" + return any([da.cupy.is_cupy for da in self.ds.data_vars.values()]) + + @property + def is_cupy(self) -> bool: + """True if all data variables contain cupy arrays.""" return all([da.cupy.is_cupy for da in self.ds.data_vars.values()]) def as_cupy(self): - data_vars = {var: da.as_cupy() for var, da in self.ds.data_vars.items()} - return Dataset(data_vars=data_vars, coords=self.ds.coords, attrs=self.ds.attrs) + if not self.is_cupy: + data_vars = {var: da.as_cupy() for var, da in self.ds.data_vars.items()} + return Dataset( + data_vars=data_vars, + coords=self.ds.coords, + attrs=self.ds.attrs, + ) + return self.ds def as_numpy(self): if self.is_cupy: @@ -127,8 +134,7 @@ def as_numpy(self): coords=self.ds.coords, attrs=self.ds.attrs, ) - else: - return self.ds.as_numpy() + return self.ds # Attach the `as_cupy` methods to the top level `Dataset` and `Dataarray` objects. diff --git a/cupy_xarray/tests/test_accessors.py b/cupy_xarray/tests/test_accessors.py index cad6955..18c8f02 100644 --- a/cupy_xarray/tests/test_accessors.py +++ b/cupy_xarray/tests/test_accessors.py @@ -1,66 +1,63 @@ +"""Tests for cupy-xarray accessors""" +import cupy as cp import numpy as np import pytest import xarray as xr -from xarray.core.pycompat import dask_array_type +from xarray.tests import requires_dask, requires_pint -import cupy_xarray # noqa: F401 +import cupy_xarray # noqa: F401 pylint:disable=unused-import +da = xr.DataArray(np.random.rand(2, 3), attrs={"units": "candle"}) +ds = xr.Dataset({"a": da}) -@pytest.fixture -def tutorial_ds_air(): - return xr.tutorial.load_dataset("air_temperature") +@pytest.mark.parametrize("obj", [da, ds]) +def test_numpy(obj): + """Test is_cupy property in cupy xarray accessor""" -@pytest.fixture -def tutorial_da_air(tutorial_ds_air): - return tutorial_ds_air.air - - -@pytest.fixture -def tutorial_ds_air_dask(): - return xr.tutorial.open_dataset("air_temperature", chunks={"lat": 25, "lon": 25, "time": -1}) - - -@pytest.fixture -def tutorial_da_air_dask(tutorial_ds_air_dask): - return tutorial_ds_air_dask.air - - -def test_data_set_accessor(tutorial_ds_air): - ds = tutorial_ds_air - assert hasattr(ds, "cupy") - assert not ds.cupy.is_cupy - - ds = ds.as_cupy() - assert ds.cupy.is_cupy - - ds = ds.cupy.as_numpy() - assert not ds.cupy.is_cupy - - -def test_data_array_accessor(tutorial_da_air): - da = tutorial_da_air - assert hasattr(da, "cupy") - assert not da.cupy.is_cupy - - da = da.as_cupy() - assert da.cupy.is_cupy - - garr = da.cupy.get() - assert isinstance(garr, np.ndarray) - - da = da.cupy.as_numpy() - assert not da.cupy.is_cupy - - -def test_data_array_accessor_dask(tutorial_da_air_dask): - da = tutorial_da_air_dask - assert hasattr(da, "cupy") - assert not da.cupy.is_cupy - - da = da.as_cupy() - assert da.cupy.is_cupy - assert isinstance(da.data, dask_array_type) - - da = da.cupy.as_numpy() assert not da.cupy.is_cupy + cpda = da.cupy.as_cupy() + assert cpda.cupy.is_cupy + + as_numpy = cpda.as_numpy() + assert not as_numpy.cupy.is_cupy + if isinstance(as_numpy, xr.DataArray): + assert isinstance(as_numpy.data, np.ndarray) + + +@requires_dask +@pytest.mark.parametrize("obj", [da, ds]) +def test_dask(obj): + """Test is_cupy property in cupy xarray accessor""" + as_dask = obj.chunk() + assert not as_dask.cupy.is_cupy + cpda = as_dask.cupy.as_cupy() + assert cpda.cupy.is_cupy + + if isinstance(cpda, xr.DataArray): + assert isinstance(cpda.data._meta, cp.ndarray) + + +@requires_pint +@pytest.mark.parametrize("obj", [da, ds]) +def test_pint(obj): + import pint + import pint_xarray # noqa + + as_pint = obj.pint.quantify() + + assert not as_pint.cupy.is_cupy + cpda = as_pint.cupy.as_cupy() + if isinstance(cpda, xr.DataArray): + assert isinstance(cpda.data, pint.Quantity) + assert cpda.cupy.is_cupy + + as_dask = as_pint.chunk() + if isinstance(as_dask, xr.DataArray): + assert isinstance(as_dask.data, pint.Quantity) + assert isinstance(as_dask.data.magnitude._meta, np.ndarray) + assert not as_dask.cupy.is_cupy + cpda = as_dask.cupy.as_cupy() + assert cpda.cupy.is_cupy + if isinstance(cpda, xr.DataArray): + assert isinstance(cpda.data._meta, cp.ndarray)