Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pint arrays support #26

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
8ba696a
Fix broken dask_array_type import
kadykov Jan 4, 2023
99ccd71
Merge branch 'xarray-contrib:main' into main
kadykov Jan 12, 2023
485d9c1
Fix is_cupy for pint arrays
kadykov Jan 13, 2023
45aac6c
Fix units drop during as_cupy conversion from pint xarray
kadykov Jan 13, 2023
528b258
Construct pint array without pint xarray accessor
kadykov Jan 16, 2023
a864e6a
Refactor with _as_dataarray wrapper
kadykov Jan 16, 2023
1c4b1a7
Add basic tests for pint arrays
kadykov Jan 16, 2023
fa059bf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2023
afd258e
Merge branch 'xarray-contrib:main' into kadykov-pint
kadykov Jan 16, 2023
c7580b3
Fix as_cupy() for arrays created by pint_xarray
kadykov Jan 23, 2023
9ad2357
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 23, 2023
ca9b51c
Add pint nested array support
kadykov Jan 24, 2023
efac145
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2023
12e476a
Add basic pint(dask) array tests
kadykov Jan 24, 2023
44bd5ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2023
4ac8ac6
Merge branch 'xarray-contrib:main' into kadykov-pint
kadykov Feb 9, 2023
50637bb
Merge branch 'main' into kadykov-pint
dcherian Feb 14, 2023
0d841d1
Move helper functions outside of the class
kadykov Mar 9, 2023
5f10799
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 9, 2023
fa21b8a
Merge branch 'xarray-contrib:main' into kadykov-pint
kadykov Apr 4, 2023
34136e3
Fix Dataset.cupy.as_numpy() error
kadykov Apr 5, 2023
5eeaa4c
Refactor is_cupy, as_cupy() and as_numpy() tests
kadykov Apr 5, 2023
04a830d
Ignore flake8 unused import warning
kadykov May 25, 2023
4be620c
Merge branch 'xarray-contrib:main' into kadykov-pint
kadykov Oct 26, 2023
41c0d8d
cleanup
dcherian Oct 27, 2023
c076239
Clean up
dcherian Oct 27, 2023
53e5a86
Fixes.
dcherian Oct 27, 2023
cff0456
more fix
dcherian Oct 27, 2023
0c36ef0
Failing test
dcherian Oct 27, 2023
d9ca81b
Merge branch 'xarray-contrib:main' into kadykov-pint
kadykov Oct 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 21 additions & 26 deletions cupy_xarray/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
)
from xarray.core.pycompat import DuckArrayModule

dsk = DuckArrayModule("dask")
dask_array_type = dsk.type
dask_array_type = DuckArrayModule("dask").type
pint_array_type = DuckArrayModule("pint").type


@register_dataarray_accessor("cupy")
Expand All @@ -26,6 +26,8 @@ def is_cupy(self):
"""bool: The underlying data is a cupy array."""
if isinstance(self.da.data, dask_array_type):
return isinstance(self.da.data._meta, cp.ndarray)
if isinstance(self.da.data, pint_array_type):
return isinstance(self.da.data.magnitude, cp.ndarray)
return isinstance(self.da.data, cp.ndarray)

def as_cupy(self):
Expand All @@ -52,20 +54,14 @@ def as_cupy(self):

"""
if isinstance(self.da.data, dask_array_type):
return DataArray(
return self._as_dataarray(
data=self.da.data.map_blocks(cp.asarray),
coords=self.da.coords,
dims=self.da.dims,
name=self.da.name,
attrs=self.da.attrs,
)
return DataArray(
data=cp.asarray(self.da.data),
coords=self.da.coords,
dims=self.da.dims,
name=self.da.name,
attrs=self.da.attrs,
)
if isinstance(self.da.data, pint_array_type):
dcherian marked this conversation as resolved.
Show resolved Hide resolved
return self._as_dataarray(
data=(self.da.data.units * cp.asarray(self.da.data.magnitude)),
dcherian marked this conversation as resolved.
Show resolved Hide resolved
)
return self._as_dataarray(data=cp.asarray(self.da.data))

def as_numpy(self):
"""
Expand All @@ -79,27 +75,26 @@ def as_numpy(self):
"""
if self.is_cupy:
if isinstance(self.da.data, dask_array_type):
return DataArray(
return self._as_dataarray(
data=self.da.data.map_blocks(
lambda block: block.get(), dtype=self.da.data._meta.dtype
),
coords=self.da.coords,
dims=self.da.dims,
name=self.da.name,
attrs=self.da.attrs,
)
return DataArray(
data=self.da.data.get(),
coords=self.da.coords,
dims=self.da.dims,
name=self.da.name,
attrs=self.da.attrs,
)
return self._as_dataarray(data=self.da.data.get())
return self.da.as_numpy()

def get(self):
return self.da.data.get()

def _as_dataarray(self, data):
return DataArray(
data=data,
coords=self.da.coords,
dims=self.da.dims,
name=self.da.name,
attrs=self.da.attrs,
)


@register_dataset_accessor("cupy")
class CupyDatasetAccessor:
Expand Down
34 changes: 33 additions & 1 deletion cupy_xarray/tests/test_accessors.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import numpy as np
import pytest
import xarray as xr
from xarray.core.pycompat import dask_array_type
from xarray.core.pycompat import DuckArrayModule
from xarray.tests import requires_pint
Copy link
Contributor

@dcherian dcherian Feb 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's copy this over to tests/__init__.py from xarray and similarly from DuckArrayModule


dask_array_type = DuckArrayModule("dask").type
pint_array_type = DuckArrayModule("pint").type

from pint import UnitRegistry

import cupy_xarray # noqa: F401

ureg = UnitRegistry()


@pytest.fixture
def tutorial_ds_air():
Expand All @@ -26,6 +34,16 @@ def tutorial_da_air_dask(tutorial_ds_air_dask):
return tutorial_ds_air_dask.air


@pytest.fixture
def tutorial_ds_air_pint():
return xr.tutorial.load_dataset("air_temperature") * ureg.Quantity("degree_Kelvin")


@pytest.fixture
def tutorial_da_air_pint(tutorial_ds_air_pint):
return tutorial_ds_air_pint.air


def test_data_set_accessor(tutorial_ds_air):
ds = tutorial_ds_air
assert hasattr(ds, "cupy")
Expand Down Expand Up @@ -64,3 +82,17 @@ def test_data_array_accessor_dask(tutorial_da_air_dask):

da = da.cupy.as_numpy()
assert not da.cupy.is_cupy


@requires_pint
def test_data_array_accessor_pint(tutorial_da_air_pint):
Copy link
Contributor

@negin513 negin513 Mar 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test checks this case:

pint + xr.DataArray --> cupy --> numpy

da = tutorial_da_air_pint
assert hasattr(da, "cupy")
assert not da.cupy.is_cupy

da = da.as_cupy()
assert da.cupy.is_cupy
assert isinstance(da.data, pint_array_type)

da = da.cupy.as_numpy()
assert not da.cupy.is_cupy