-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add pint arrays support #26
Open
kadykov
wants to merge
30
commits into
xarray-contrib:main
Choose a base branch
from
kadykov:kadykov-pint
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 29 commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
8ba696a
Fix broken dask_array_type import
kadykov 99ccd71
Merge branch 'xarray-contrib:main' into main
kadykov 485d9c1
Fix is_cupy for pint arrays
kadykov 45aac6c
Fix units drop during as_cupy conversion from pint xarray
kadykov 528b258
Construct pint array without pint xarray accessor
kadykov a864e6a
Refactor with _as_dataarray wrapper
kadykov 1c4b1a7
Add basic tests for pint arrays
kadykov fa059bf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] afd258e
Merge branch 'xarray-contrib:main' into kadykov-pint
kadykov c7580b3
Fix as_cupy() for arrays created by pint_xarray
kadykov 9ad2357
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ca9b51c
Add pint nested array support
kadykov efac145
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 12e476a
Add basic pint(dask) array tests
kadykov 44bd5ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 4ac8ac6
Merge branch 'xarray-contrib:main' into kadykov-pint
kadykov 50637bb
Merge branch 'main' into kadykov-pint
dcherian 0d841d1
Move helper functions outside of the class
kadykov 5f10799
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] fa21b8a
Merge branch 'xarray-contrib:main' into kadykov-pint
kadykov 34136e3
Fix Dataset.cupy.as_numpy() error
kadykov 5eeaa4c
Refactor is_cupy, as_cupy() and as_numpy() tests
kadykov 04a830d
Ignore flake8 unused import warning
kadykov 4be620c
Merge branch 'xarray-contrib:main' into kadykov-pint
kadykov 41c0d8d
cleanup
dcherian c076239
Clean up
dcherian 53e5a86
Fixes.
dcherian cff0456
more fix
dcherian 0c36ef0
Failing test
dcherian d9ca81b
Merge branch 'xarray-contrib:main' into kadykov-pint
kadykov File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,66 +1,63 @@ | ||
"""Tests for cupy-xarray accessors""" | ||
import cupy as cp | ||
import numpy as np | ||
import pytest | ||
import xarray as xr | ||
from xarray.core.pycompat import dask_array_type | ||
from xarray.tests import requires_dask, requires_pint | ||
|
||
import cupy_xarray # noqa: F401 | ||
import cupy_xarray # noqa: F401 pylint:disable=unused-import | ||
|
||
da = xr.DataArray(np.random.rand(2, 3), attrs={"units": "candle"}) | ||
ds = xr.Dataset({"a": da}) | ||
|
||
@pytest.fixture | ||
def tutorial_ds_air(): | ||
return xr.tutorial.load_dataset("air_temperature") | ||
|
||
@pytest.mark.parametrize("obj", [da, ds]) | ||
def test_numpy(obj): | ||
"""Test is_cupy property in cupy xarray accessor""" | ||
|
||
@pytest.fixture | ||
def tutorial_da_air(tutorial_ds_air): | ||
return tutorial_ds_air.air | ||
|
||
|
||
@pytest.fixture | ||
def tutorial_ds_air_dask(): | ||
return xr.tutorial.open_dataset("air_temperature", chunks={"lat": 25, "lon": 25, "time": -1}) | ||
|
||
|
||
@pytest.fixture | ||
def tutorial_da_air_dask(tutorial_ds_air_dask): | ||
return tutorial_ds_air_dask.air | ||
|
||
|
||
def test_data_set_accessor(tutorial_ds_air): | ||
ds = tutorial_ds_air | ||
assert hasattr(ds, "cupy") | ||
assert not ds.cupy.is_cupy | ||
|
||
ds = ds.as_cupy() | ||
assert ds.cupy.is_cupy | ||
|
||
ds = ds.cupy.as_numpy() | ||
assert not ds.cupy.is_cupy | ||
|
||
|
||
def test_data_array_accessor(tutorial_da_air): | ||
da = tutorial_da_air | ||
assert hasattr(da, "cupy") | ||
assert not da.cupy.is_cupy | ||
|
||
da = da.as_cupy() | ||
assert da.cupy.is_cupy | ||
|
||
garr = da.cupy.get() | ||
assert isinstance(garr, np.ndarray) | ||
|
||
da = da.cupy.as_numpy() | ||
assert not da.cupy.is_cupy | ||
|
||
|
||
def test_data_array_accessor_dask(tutorial_da_air_dask): | ||
da = tutorial_da_air_dask | ||
assert hasattr(da, "cupy") | ||
assert not da.cupy.is_cupy | ||
|
||
da = da.as_cupy() | ||
assert da.cupy.is_cupy | ||
assert isinstance(da.data, dask_array_type) | ||
|
||
da = da.cupy.as_numpy() | ||
assert not da.cupy.is_cupy | ||
cpda = da.cupy.as_cupy() | ||
assert cpda.cupy.is_cupy | ||
|
||
as_numpy = cpda.as_numpy() | ||
assert not as_numpy.cupy.is_cupy | ||
if isinstance(as_numpy, xr.DataArray): | ||
assert isinstance(as_numpy.data, np.ndarray) | ||
|
||
|
||
@requires_dask | ||
@pytest.mark.parametrize("obj", [da, ds]) | ||
def test_dask(obj): | ||
"""Test is_cupy property in cupy xarray accessor""" | ||
as_dask = obj.chunk() | ||
assert not as_dask.cupy.is_cupy | ||
cpda = as_dask.cupy.as_cupy() | ||
assert cpda.cupy.is_cupy | ||
|
||
if isinstance(cpda, xr.DataArray): | ||
assert isinstance(cpda.data._meta, cp.ndarray) | ||
|
||
|
||
@requires_pint | ||
@pytest.mark.parametrize("obj", [da, ds]) | ||
def test_pint(obj): | ||
import pint | ||
import pint_xarray # noqa | ||
|
||
as_pint = obj.pint.quantify() | ||
|
||
assert not as_pint.cupy.is_cupy | ||
cpda = as_pint.cupy.as_cupy() | ||
if isinstance(cpda, xr.DataArray): | ||
assert isinstance(cpda.data, pint.Quantity) | ||
assert cpda.cupy.is_cupy | ||
|
||
as_dask = as_pint.chunk() | ||
if isinstance(as_dask, xr.DataArray): | ||
assert isinstance(as_dask.data, pint.Quantity) | ||
assert isinstance(as_dask.data.magnitude._meta, np.ndarray) | ||
assert not as_dask.cupy.is_cupy | ||
cpda = as_dask.cupy.as_cupy() | ||
assert cpda.cupy.is_cupy | ||
if isinstance(cpda, xr.DataArray): | ||
assert isinstance(cpda.data._meta, cp.ndarray) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@keewis is this expectation right for
DataArray(pint(np.ndarray)).chunk()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
be aware that
DataArray.chunk()
simply passes the wrapped duck array todask.array.from_array
if it was not already a dask collection, which means that it would produce adask(pint(np.ndarray))
. This is not what we want, so there'sDataArray.pint.chunk
to work around this.