Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pint arrays support #26

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
8ba696a
Fix broken dask_array_type import
kadykov Jan 4, 2023
99ccd71
Merge branch 'xarray-contrib:main' into main
kadykov Jan 12, 2023
485d9c1
Fix is_cupy for pint arrays
kadykov Jan 13, 2023
45aac6c
Fix units drop during as_cupy conversion from pint xarray
kadykov Jan 13, 2023
528b258
Construct pint array without pint xarray accessor
kadykov Jan 16, 2023
a864e6a
Refactor with _as_dataarray wrapper
kadykov Jan 16, 2023
1c4b1a7
Add basic tests for pint arrays
kadykov Jan 16, 2023
fa059bf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2023
afd258e
Merge branch 'xarray-contrib:main' into kadykov-pint
kadykov Jan 16, 2023
c7580b3
Fix as_cupy() for arrays created by pint_xarray
kadykov Jan 23, 2023
9ad2357
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 23, 2023
ca9b51c
Add pint nested array support
kadykov Jan 24, 2023
efac145
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2023
12e476a
Add basic pint(dask) array tests
kadykov Jan 24, 2023
44bd5ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2023
4ac8ac6
Merge branch 'xarray-contrib:main' into kadykov-pint
kadykov Feb 9, 2023
50637bb
Merge branch 'main' into kadykov-pint
dcherian Feb 14, 2023
0d841d1
Move helper functions outside of the class
kadykov Mar 9, 2023
5f10799
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 9, 2023
fa21b8a
Merge branch 'xarray-contrib:main' into kadykov-pint
kadykov Apr 4, 2023
34136e3
Fix Dataset.cupy.as_numpy() error
kadykov Apr 5, 2023
5eeaa4c
Refactor is_cupy, as_cupy() and as_numpy() tests
kadykov Apr 5, 2023
04a830d
Ignore flake8 unused import warning
kadykov May 25, 2023
4be620c
Merge branch 'xarray-contrib:main' into kadykov-pint
kadykov Oct 26, 2023
41c0d8d
cleanup
dcherian Oct 27, 2023
c076239
Clean up
dcherian Oct 27, 2023
53e5a86
Fixes.
dcherian Oct 27, 2023
cff0456
more fix
dcherian Oct 27, 2023
0c36ef0
Failing test
dcherian Oct 27, 2023
d9ca81b
Merge branch 'xarray-contrib:main' into kadykov-pint
kadykov Oct 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 58 additions & 52 deletions cupy_xarray/accessors.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import cupy as cp
from xarray import (
DataArray,
Dataset,
register_dataarray_accessor,
register_dataset_accessor,
)
from xarray import Dataset, register_dataarray_accessor, register_dataset_accessor
from xarray.core.pycompat import DuckArrayModule

dsk = DuckArrayModule("dask")
dask_array_type = dsk.type
dask_array_type = DuckArrayModule("dask").type
pint_array_type = DuckArrayModule("pint").type


def _get_datatype(data):
if isinstance(data, dask_array_type):
return isinstance(data._meta, cp.ndarray)
elif isinstance(data, pint_array_type):
return _get_datatype(data.magnitude)
return isinstance(data, cp.ndarray)


@register_dataarray_accessor("cupy")
Expand All @@ -22,11 +25,9 @@ def __init__(self, da):
self.da = da

@property
def is_cupy(self):
"""bool: The underlying data is a cupy array."""
if isinstance(self.da.data, dask_array_type):
return isinstance(self.da.data._meta, cp.ndarray)
return isinstance(self.da.data, cp.ndarray)
def is_cupy(self) -> bool:
"""True if the underlying data is a cupy array."""
return _get_datatype(self.da.data)

def as_cupy(self):
"""
Expand All @@ -51,21 +52,7 @@ def as_cupy(self):
<class 'cupy.core.core.ndarray'>

"""
if isinstance(self.da.data, dask_array_type):
return DataArray(
data=self.da.data.map_blocks(cp.asarray),
coords=self.da.coords,
dims=self.da.dims,
name=self.da.name,
attrs=self.da.attrs,
)
return DataArray(
data=cp.asarray(self.da.data),
coords=self.da.coords,
dims=self.da.dims,
name=self.da.name,
attrs=self.da.attrs,
)
return self.da.copy(data=_as_cupy_data(self.da.data))

def as_numpy(self):
"""
Expand All @@ -77,30 +64,38 @@ def as_numpy(self):
DataArray with underlying data cast to numpy.

"""
if self.is_cupy:
if isinstance(self.da.data, dask_array_type):
return DataArray(
data=self.da.data.map_blocks(
lambda block: block.get(), dtype=self.da.data._meta.dtype
),
coords=self.da.coords,
dims=self.da.dims,
name=self.da.name,
attrs=self.da.attrs,
)
return DataArray(
data=self.da.data.get(),
coords=self.da.coords,
dims=self.da.dims,
name=self.da.name,
attrs=self.da.attrs,
)
return self.da.as_numpy()
raise NotImplementedError("Please use .as_numpy DataArray method directly.")

def get(self):
return self.da.data.get()


def _as_cupy_data(data):
if isinstance(data, dask_array_type):
return data.map_blocks(cp.asarray)
if isinstance(data, pint_array_type):
from pint import Quantity # pylint: disable=import-outside-toplevel

return Quantity(
_as_cupy_data(data.magnitude),
units=data.units,
)
return cp.asarray(data)


def _as_numpy_data(data):
if isinstance(data, dask_array_type):
return data.map_blocks(lambda block: block.get(), dtype=data._meta.dtype)
if isinstance(data, pint_array_type):
from pint import Quantity # pylint: disable=import-outside-toplevel

return Quantity(
_as_numpy_data(data.magnitude),
units=data.units,
)
return data.get() if isinstance(data, cp.ndarray) else data


@register_dataset_accessor("cupy")
class CupyDatasetAccessor:
"""
Expand All @@ -112,12 +107,24 @@ def __init__(self, ds):
self.ds = ds

@property
def is_cupy(self):
def has_cupy(self) -> bool:
"""True if any data variable contains a cupy array."""
return any([da.cupy.is_cupy for da in self.ds.data_vars.values()])

@property
def is_cupy(self) -> bool:
"""True if all data variables contain cupy arrays."""
return all([da.cupy.is_cupy for da in self.ds.data_vars.values()])

def as_cupy(self):
data_vars = {var: da.as_cupy() for var, da in self.ds.data_vars.items()}
return Dataset(data_vars=data_vars, coords=self.ds.coords, attrs=self.ds.attrs)
if not self.is_cupy:
data_vars = {var: da.as_cupy() for var, da in self.ds.data_vars.items()}
return Dataset(
data_vars=data_vars,
coords=self.ds.coords,
attrs=self.ds.attrs,
)
return self.ds

def as_numpy(self):
if self.is_cupy:
Expand All @@ -127,8 +134,7 @@ def as_numpy(self):
coords=self.ds.coords,
attrs=self.ds.attrs,
)
else:
return self.ds.as_numpy()
return self.ds


# Attach the `as_cupy` methods to the top level `Dataset` and `Dataarray` objects.
Expand Down
111 changes: 54 additions & 57 deletions cupy_xarray/tests/test_accessors.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,63 @@
"""Tests for cupy-xarray accessors"""
import cupy as cp
import numpy as np
import pytest
import xarray as xr
from xarray.core.pycompat import dask_array_type
from xarray.tests import requires_dask, requires_pint

import cupy_xarray # noqa: F401
import cupy_xarray # noqa: F401 pylint:disable=unused-import

da = xr.DataArray(np.random.rand(2, 3), attrs={"units": "candle"})
ds = xr.Dataset({"a": da})

@pytest.fixture
def tutorial_ds_air():
return xr.tutorial.load_dataset("air_temperature")

@pytest.mark.parametrize("obj", [da, ds])
def test_numpy(obj):
"""Test is_cupy property in cupy xarray accessor"""

@pytest.fixture
def tutorial_da_air(tutorial_ds_air):
return tutorial_ds_air.air


@pytest.fixture
def tutorial_ds_air_dask():
return xr.tutorial.open_dataset("air_temperature", chunks={"lat": 25, "lon": 25, "time": -1})


@pytest.fixture
def tutorial_da_air_dask(tutorial_ds_air_dask):
return tutorial_ds_air_dask.air


def test_data_set_accessor(tutorial_ds_air):
ds = tutorial_ds_air
assert hasattr(ds, "cupy")
assert not ds.cupy.is_cupy

ds = ds.as_cupy()
assert ds.cupy.is_cupy

ds = ds.cupy.as_numpy()
assert not ds.cupy.is_cupy


def test_data_array_accessor(tutorial_da_air):
da = tutorial_da_air
assert hasattr(da, "cupy")
assert not da.cupy.is_cupy

da = da.as_cupy()
assert da.cupy.is_cupy

garr = da.cupy.get()
assert isinstance(garr, np.ndarray)

da = da.cupy.as_numpy()
assert not da.cupy.is_cupy


def test_data_array_accessor_dask(tutorial_da_air_dask):
da = tutorial_da_air_dask
assert hasattr(da, "cupy")
assert not da.cupy.is_cupy

da = da.as_cupy()
assert da.cupy.is_cupy
assert isinstance(da.data, dask_array_type)

da = da.cupy.as_numpy()
assert not da.cupy.is_cupy
cpda = da.cupy.as_cupy()
assert cpda.cupy.is_cupy

as_numpy = cpda.as_numpy()
assert not as_numpy.cupy.is_cupy
if isinstance(as_numpy, xr.DataArray):
assert isinstance(as_numpy.data, np.ndarray)


@requires_dask
@pytest.mark.parametrize("obj", [da, ds])
def test_dask(obj):
"""Test is_cupy property in cupy xarray accessor"""
as_dask = obj.chunk()
assert not as_dask.cupy.is_cupy
cpda = as_dask.cupy.as_cupy()
assert cpda.cupy.is_cupy

if isinstance(cpda, xr.DataArray):
assert isinstance(cpda.data._meta, cp.ndarray)


@requires_pint
@pytest.mark.parametrize("obj", [da, ds])
def test_pint(obj):
import pint
import pint_xarray # noqa

as_pint = obj.pint.quantify()

assert not as_pint.cupy.is_cupy
cpda = as_pint.cupy.as_cupy()
if isinstance(cpda, xr.DataArray):
assert isinstance(cpda.data, pint.Quantity)
assert cpda.cupy.is_cupy

as_dask = as_pint.chunk()
if isinstance(as_dask, xr.DataArray):
assert isinstance(as_dask.data, pint.Quantity)
assert isinstance(as_dask.data.magnitude._meta, np.ndarray)
Comment on lines +55 to +58
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@keewis is this expectation right for DataArray(pint(np.ndarray)).chunk()

Copy link
Contributor

@keewis keewis Oct 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

be aware that DataArray.chunk() simply passes the wrapped duck array to dask.array.from_array if it was not already a dask collection, which means that it would produce a dask(pint(np.ndarray)). This is not what we want, so there's DataArray.pint.chunk to work around this.

assert not as_dask.cupy.is_cupy
cpda = as_dask.cupy.as_cupy()
assert cpda.cupy.is_cupy
if isinstance(cpda, xr.DataArray):
assert isinstance(cpda.data._meta, cp.ndarray)