Skip to content

Version 2024.8.16: Functional XarrayActive #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Aug 20, 2024
Merged
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8c1d172
Added initial testing directory
dwest77a Aug 9, 2024
794b006
Added active_options kwarg
dwest77a Aug 9, 2024
aad6985
Added chunk space function, could be used in CFAPyX
dwest77a Aug 9, 2024
1833fe8
Removed old imports
dwest77a Aug 9, 2024
58d0bdf
New name for script with wrapper classes
dwest77a Aug 9, 2024
c98e6e2
Added handling of variable-only chunking, kwargs for ActiveArrayWrapper
dwest77a Aug 9, 2024
0f0549f
Updated testing
dwest77a Aug 9, 2024
d98c37c
Fixed some bugs with product/np.prod
dwest77a Aug 14, 2024
ca97bf1
Fixed issue with dims as dict not tuple
dwest77a Aug 14, 2024
79c0912
Removed NotImplemented init
dwest77a Aug 14, 2024
fd0e296
Added initial basic test suite
dwest77a Aug 16, 2024
9bb5629
Updated to 1.2.1, adds identical extents handler
dwest77a Aug 16, 2024
f80e4c9
Standardised copy function for active partition
dwest77a Aug 16, 2024
7e03564
Documentation changes, functional active_mean operation
dwest77a Aug 16, 2024
191e5fd
Fixed dask reduction to use combine functions
dwest77a Aug 16, 2024
26ac08a
Added bypass for when no active chunks are specified
dwest77a Aug 16, 2024
8db6935
Made dim a kwarg for ActiveDataArray mean
dwest77a Aug 16, 2024
d4763a1
Functional version of recursive active mean
dwest77a Aug 16, 2024
3bc1314
Updated with recursive test separately
dwest77a Aug 16, 2024
20c9ed2
Minor syntax changes
dwest77a Aug 16, 2024
006e9cf
Commit of dask/xarray issue to new branch
dwest77a Aug 20, 2024
544d7b4
Updated all scripts with minor fix to get_extent
dwest77a Aug 20, 2024
21860b6
Merge pull request #2 from dwest77a/xdIssue
dwest77a Aug 20, 2024
e142156
Updated to use arraypartition-1.0
dwest77a Aug 20, 2024
ccdb558
Updated with arraypartition requirement
dwest77a Aug 20, 2024
3b5fb94
Testing with pypi installed package for ArrayPartition
dwest77a Aug 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: Automatic Test
# Specify which GitHub events will trigger a CI build

on: push
# Define a single job, build

jobs:
build:
# Specify an OS for the runner
runs-on: ubuntu-latest

#Define steps
steps:

# Firstly, checkout repo
- name: Checkout repository
uses: actions/checkout@v2
# Set up Python env
- name: Setup Python
uses: actions/setup-python@v2
with:
python-version: 3.11
# Install dependencies
- name: Install Python dependencies
run: |
python3 -m pip install --upgrade pip
pip3 install -r requirements.txt
pip3 install -e .
# Test with pytest
- name: Run pytest
run: |
pytest
149 changes: 106 additions & 43 deletions XarrayActive/active_chunk.py
Original file line number Diff line number Diff line change
@@ -1,95 +1,158 @@
import numpy as np
from itertools import product


# Holds all CFA-specific Active routines.
class ActiveOptionsContainer:
"""
Container for ActiveOptions properties.
"""
@property
def active_options(self):
"""
Property of the datastore that relates private option variables to the standard
``active_options`` parameter.
"""
return {
'chunks': self._active_chunks,
'chunk_limits': self._chunk_limits,
}

@active_options.setter
def active_options(self, value):
self._set_active_options(**value)

def _set_active_options(self, chunks={}, chunk_limits=True):

if chunks == {}:
raise NotImplementedError(
'Default chunking is not implemented, please provide a chunk scheme '
' - active_options = {"chunks": {}}'
)

self._active_chunks = chunks
self._chunk_limits = chunk_limits

# Holds all Active routines.
class ActiveChunk:

description = "Container class for Active routines performed on each chunk. All active-per-chunk content can be found here."

def __init__(self, *args, **kwargs):
raise NotImplementedError

def _post_process_data(self, data):
# Perform any post-processing steps on the data here
return data

def _standard_mean(self, axis=None, skipna=None, **kwargs):
def _standard_sum(self, axes=None, skipna=None, **kwargs):
"""
Standard Mean routine matches the normal routine for dask, required at this
stage if Active mean not available.
"""
size = 1
for i in axis:
size *= self.shape[i]

arr = np.array(self)
if skipna:
total = np.nanmean(arr, axis=axis, **kwargs) *size
total = np.nansum(arr, axis=axes, **kwargs)
else:
total = np.mean(arr, axis=axis, **kwargs) *size
return {'n': self._numel(arr, axis=axis), 'total': total}
total = np.sum(arr, axis=axes, **kwargs)
return total

def _standard_max(self, axes=None, skipna=None, **kwargs):
return np.max(self, axis=axes)

def _standard_min(self, axes=None, skipna=None, **kwargs):
return np.min(self, axis=axes)

def _numel(self, axis=None):
if not axis:
def _numel(self, method, axes=None):
if not axes:
return self.size

size = 1
for i in axis:
for i in axes:
size *= self.shape[i]
newshape = list(self.shape)
newshape[axis] = 1
for ax in axes:
newshape[ax] = 1

return np.full(newshape, size)

def active_mean(self, axis=None, skipna=None, **kwargs):
def active_method(self, method, axis=None, skipna=None, **kwargs):
"""
Use PyActiveStorage package functionality to perform mean of this Fragment.
:param axis: (int) The axis over which to perform the active_mean operation.
:param axis: (int) The axes over which to perform the active_mean operation.
:param skipna: (bool) Skip NaN values when calculating the mean.
:returns: A ``duck array`` (numpy-like) with the reduced array or scalar value,
as specified by the axis parameter.
as specified by the axes parameter.
"""

standard_methods = {
'mean': self._standard_sum,
'sum' : self._standard_sum,
'max' : self._standard_max,
'min' : self._standard_min
}
ret = None
n = self._numel(method, axes=axis)

try:
from activestorage.active import Active
except ImportError:
# Unable to import Active package. Default to using normal mean.
print("ActiveWarning: Unable to import active module - defaulting to standard method.")
return self._standard_mean(axis=axis, skipna=skipna, **kwargs)

active = Active(self.filename, self.address)
active.method = "mean"
extent = self.get_extent()

if not axis is None:
return {
'n': self._numel(axis=axis),
'total': self._post_process_data(active[extent])
ret = {
'n': n,
'total': standard_methods[method](axes=axis, skipna=skipna, **kwargs)
}

# Experimental Recursive requesting to get each 1D column along the axis being requested.
range_recursives = []
for dim in range(self.ndim):
if dim != axis:
range_recursives.append(range(extent[dim].start, extent[dim].stop+1))
else:
range_recursives.append(extent[dim])
results = np.array(self._get_elements(active, range_recursives, hyperslab=[]))
if not ret:

active = Active(self.filename, self.address)
active.method = method
extent = tuple(self.get_extent())

if axis == None:
axis = tuple([i for i in range(self.ndim)])

n = self._numel(method, axes=axis)

if len(axis) == self.ndim:
data = active[extent]
t = self._post_process_data(data) * n

ret = {
'n': n,
'total': t
}

if not ret:
# Experimental Recursive requesting to get each 1D column along the axes being requested.
range_recursives = []
for dim in range(self.ndim):
if dim not in axis:
range_recursives.append(range(extent[dim].start, extent[dim].stop))
else:
range_recursives.append(extent[dim])
results = np.array(self._get_elements(active, range_recursives, hyperslab=[]))

t = self._post_process_data(results) * n
ret = {
'n': n,
'total': t
}

return {
'n': self._numel(axis=axis),
'total': self._post_process_data(results)
}
if method == 'mean':
return ret
else:
return ret['total']/ret['n']

def _get_elements(self, active, recursives, hyperslab=[]):
dimarray = []
current = recursives[0]
if not len(recursives) > 1:
if not len(recursives) > 0:

# Perform active slicing and meaning here.
return active[hyperslab]
return active[tuple(hyperslab)].flatten()[0]

current = recursives[0]

if type(current) == slice:
newslab = hyperslab + [current]
138 changes: 123 additions & 15 deletions XarrayActive/active_dask.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,44 @@
import dask.array as da
from dask.array.reductions import mean_agg
from dask.array.reductions import mean_agg, mean_combine, nanmax, nanmin
from dask.utils import deepmap
from dask.array.core import _concatenate2
import numpy as np


def block_active_mean(arr, *args, **kwargs):
if hasattr(arr,'active_mean'):
return arr.active_mean(*args, **kwargs)
def partition_mean(arr, *args, **kwargs):
return partition_method(arr, 'mean', *args, **kwargs)

def partition_max(arr, *args, **kwargs):
return partition_method(arr, 'max', *args, **kwargs)

def partition_min(arr, *args, **kwargs):
return partition_method(arr, 'min', *args, **kwargs)

def partition_sum(arr, *args, **kwargs):
return partition_method(arr, 'sum', *args, **kwargs)

def partition_method(arr, method, *args, **kwargs):
if hasattr(arr,'active_method'):
return arr.active_method(method,*args, **kwargs)
else:
# Here's where barebones Xarray might fall over - may need a non-CFA custom class.
raise NotImplementedError
# Additional handling for 'meta' calculations in dask.
# Not currently implemented, bypassed using None
if arr.size == 0:
return None
return None

def general_combine(pairs, axis=None):
if not isinstance(pairs, list):
pairs = [pairs]
return _concatenate2(pairs, axes=axis)

def max_agg(pairs, axis=None, **kwargs):
return general_combine(pairs, axis=axis).max(axis=axis, **kwargs)

def min_agg(pairs, axis=None, **kwargs):
return general_combine(pairs, axis=axis).min(axis=axis, **kwargs)

def sum_agg(pairs, axis=None, **kwargs):
return general_combine(pairs, axis=axis).sum(axis=axis, **kwargs)

class DaskActiveArray(da.Array):

@@ -17,11 +48,12 @@ class DaskActiveArray(da.Array):
def is_active(self):
return True

def copy(self):
"""
Create a new DaskActiveArray instance with all the same parameters as the current instance.
"""
return DaskActiveArray(self.dask, self.name, self.chunks, meta=self)
#def copy(self):
# """
# Create a new DaskActiveArray instance with all the same parameters as the current instance.
# """
# copy_arr = DaskActiveArray(self.dask, self.name, self.chunks, meta=self)
# return copy_arr

def __getitem__(self, index):
"""
@@ -49,10 +81,86 @@ def active_mean(self, axis=None, skipna=None):

newarr = da.reduction(
self,
block_active_mean,
partition_mean,
mean_agg,
combine=mean_combine,
axis=axis,
dtype=self.dtype,
)

return newarr

def active_max(self, axis=None, skipna=None):
"""
Perform ``dask delayed`` active mean for each ``dask block`` which corresponds to a single ``chunk``.
Combines the results of the dask delayed ``active_max`` operations on each block into a single dask Array,
which is then mapped to a new DaskActiveArray object.
:param axis: (int) The index of the axis on which to perform the active max.
:param skipna: (bool) Skip NaN values when calculating the max.
:returns: A new ``DaskActiveArray`` object which has been reduced along the specified axes using
the concatenations of active_means from each chunk.
"""

newarr = da.reduction(
self,
partition_max,
max_agg,
combine=max_agg,
axis=axis,
dtype=self.dtype,
)

return newarr

def active_min(self, axis=None, skipna=None):
"""
Perform ``dask delayed`` active mean for each ``dask block`` which corresponds to a single ``chunk``.
Combines the results of the dask delayed ``active_min`` operations on each block into a single dask Array,
which is then mapped to a new DaskActiveArray object.
:param axis: (int) The index of the axis on which to perform the active min.
:param skipna: (bool) Skip NaN values when calculating the min.
:returns: A new ``DaskActiveArray`` object which has been reduced along the specified axes using
the concatenations of active_means from each chunk.
"""

newarr = da.reduction(
self,
partition_min,
min_agg,
combine=min_agg,
axis=axis,
dtype=self.dtype,
)

return newarr

def active_sum(self, axis=None, skipna=None):
"""
Perform ``dask delayed`` active mean for each ``dask block`` which corresponds to a single ``chunk``.
Combines the results of the dask delayed ``active_sum`` operations on each block into a single dask Array,
which is then mapped to a new DaskActiveArray object.
:param axis: (int) The index of the axis on which to perform the active sum.
:param skipna: (bool) Skip NaN values when calculating the sum.
:returns: A new ``DaskActiveArray`` object which has been reduced along the specified axes using
the concatenations of active_means from each chunk.
"""

newarr = da.reduction(
self,
partition_sum,
sum_agg,
combine=sum_agg,
axis=axis,
dtype=self.dtype
dtype=self.dtype,
)

return DaskActiveArray(newarr.dask, newarr.name, newarr.chunks, meta=newarr)
return newarr
107 changes: 94 additions & 13 deletions XarrayActive/active_xarray.py
Original file line number Diff line number Diff line change
@@ -8,26 +8,76 @@
from xarray.core.dataarray import DataArray

from .active_dask import DaskActiveArray
from xarray.core import duck_array_ops

class ActiveDataArray(DataArray):
# No additional properties
__slots__ = ()

def mean(
self,
dim,
*args,
**kwargs,
):

return self._active_op(
dataarray_active_mean,
*args,
**kwargs,
)

def max(
self,
*args,
**kwargs,
):

return self._active_op(
dataarray_active_max,#duck_array_ops.max,
*args,
**kwargs,
)

def min(
self,
*args,
**kwargs,
):

return self._active_op(
dataarray_active_min,
*args,
**kwargs,
)

def sum(
self,
*args,
**kwargs,
):

return self._active_op(
dataarray_active_sum,
*args,
**kwargs,
)

def _active_op(
self,
op = None,
dim = None,
*,
skipna = None,
keep_attrs = None,
skipna: bool | None = None,
keep_attrs: bool | None = None,
**kwargs,
):
"""
Reduce this DataArray's data by applying ``mean`` along some dimension(s).
Reduce this DataArray's data by applying an operation along some dimension(s).
Parameters
----------
dim : str, Iterable of Hashable, "..." or None, default: None
Name of dimension[s] along which to apply ``mean``. For e.g. ``dim="x"``
Name of dimension[s] along which to apply the operation`. For e.g. ``dim="x"``
or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions.
skipna : bool or None, optional
If True, skip missing values (as marked by NaN). By default, only
@@ -40,24 +90,28 @@ def mean(
returned without attributes.
**kwargs : Any
Additional keyword arguments passed on to the appropriate array
function for calculating ``mean`` on this object's data.
function for calculating the operation on this object's data.
These could include dask-specific kwargs like ``split_every``.
Returns
-------
reduced : DataArray
New DataArray with ``mean`` applied to its data and the
New DataArray with ``max`` applied to its data and the
indicated dimension(s) removed
See Also
--------
numpy.max
dask.array.max
"""
return self.reduce(
dataarray_active_mean, # from duck_array_ops.mean
op,
dim=dim,
skipna=skipna,
keep_attrs=keep_attrs,
**kwargs,
)

class ActiveDataset(Dataset):

# No additional properties
@@ -68,12 +122,12 @@ def _construct_dataarray(self, name):

darr = super()._construct_dataarray(name)


is_active_variable = True

# Convert variable to DaskActiveArray if not already defined as that type.
# CFAPyX - FragmentArrayWrapper returns a DaskActiveArray upon indexing.
variable = darr.variable
# If the active parts have been lost at this point.
if not isinstance(variable.data, DaskActiveArray) and is_active_variable:
variable.data = DaskActiveArray(
variable.data.dask,
@@ -96,7 +150,19 @@ def _construct_dataarray(self, name):
fastpath=True
)

def dataarray_active_mean(array: DaskActiveArray, axis=None, skipna=None, **kwargs):
def dataarray_active_mean(array, *args, **kwargs):
return dataarray_active_method(array, 'mean', *args, **kwargs)

def dataarray_active_max(array, *args, **kwargs):
return dataarray_active_method(array, 'max', *args, **kwargs)

def dataarray_active_min(array, *args, **kwargs):
return dataarray_active_method(array, 'min', *args, **kwargs)

def dataarray_active_sum(array, *args, **kwargs):
return dataarray_active_method(array, 'sum', *args, **kwargs)

def dataarray_active_method(array: DaskActiveArray, method: str, axis=None, skipna=None, **kwargs):
"""
Function provided to dask reduction, activates the ``active_mean`` method of the ``DaskActiveArray``.
@@ -109,10 +175,25 @@ def dataarray_active_mean(array: DaskActiveArray, axis=None, skipna=None, **kwar
:returns: The result from performing the ``DaskActiveArray.active_mean`` method, which gives a new
``DaskActiveArray`` object.
"""
from xarray.core import duck_array_ops
arr_methods = {
'mean': array.active_mean,
'max': array.active_max,
'min': array.active_min,
'sum': array.active_sum
}

duck_methods = {
'mean': duck_array_ops.mean,
'max': duck_array_ops.max,
'min': duck_array_ops.min,
'sum': duck_array_ops.sum
}

from xarray.core import duck_array_ops
try:
return array.active_mean(axis, skipna=skipna, **kwargs)
return arr_methods[method](axis, skipna=skipna, **kwargs)
except AttributeError:
print("ActiveWarning: Unable to compute active mean - array has already been loaded.")
print("NetCDF file size may prohibit lazy loading and thus Active methods.")
return duck_array_ops.mean(array, axis=axis, skipna=skipna, **kwargs)
return duck_methods[method](array, axis=axis, skipna=skipna, **kwargs)
8 changes: 7 additions & 1 deletion XarrayActive/backend.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@
)

from .active_xarray import ActiveDataset
from .datastore import ActiveDataStore

def open_active_dataset(
filename_or_obj,
@@ -18,6 +19,7 @@ def open_active_dataset(
decode_coords=None,
use_cftime=None,
decode_timedelta=None,
active_options={},
group=None,
):
"""
@@ -31,7 +33,9 @@ def open_active_dataset(
"""

# Load the normal datastore from the provided file (object not supported).
store = NetCDF4DataStore.open(filename_or_obj, group=group)
store = ActiveDataStore.open(filename_or_obj, group=group)

store.active_options = active_options

# Xarray makes use of StoreBackendEntrypoints to provide the Dataset 'ds'
store_entrypoint = ActiveStoreBackendEntrypoint()
@@ -64,6 +68,7 @@ def open_dataset(
decode_coords=None,
use_cftime=None,
decode_timedelta=None,
active_options={},
group=None,
# backend specific keyword arguments
# do not use 'chunks' or 'cache' here
@@ -82,6 +87,7 @@ def open_dataset(
decode_coords=decode_coords,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
active_options=active_options,
group=group)

class ActiveStoreBackendEntrypoint(StoreBackendEntrypoint):
151 changes: 39 additions & 112 deletions XarrayActive/datastore.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,57 @@
from xarray.backends import NetCDF4DataStore
from xarray.backends.common import (
BackendArray,
robust_getitem
)
from xarray.core.utils import FrozenDict
from xarray.coding.variables import pop_to
from xarray.coding.strings import create_vlen_dtype

from xarray.core import indexing
from xarray.core.variable import Variable

from dask.utils import SerializableLock
from dask.array.core import getter
from dask.base import tokenize

from contextlib import suppress
import functools
import operator

import numpy as np

from .active_dask import DaskActiveArray
from .active_chunk import (
ActiveOptionsContainer,
)

from .wrappers import ActiveArrayWrapper

class ActiveDataStore(NetCDF4DataStore, ActiveOptionsContainer):

class ActiveDataStore(NetCDF4DataStore):
def open_store_variable(self, name: str, var):
def get_variables(self):
"""
"""
return FrozenDict(
(k, self.open_variable(k, v)) for k, v in self.ds.variables.items()
)

def open_variable(self, name: str, var):
if name in self.ds.dimensions or not self._active_chunks:
return self.open_store_variable(name, var)
else:
return self.open_active_variable(name, var)

def open_active_variable(self, name: str, var):
import netCDF4

dimensions = var.dimensions

units = ''
if hasattr(var, 'units'):
units = getattr(var, 'units')

attributes = {k: var.getncattr(k) for k in var.ncattrs()}
data = indexing.LazilyIndexedArray(ActiveSubarrayWrapper(name, self))
data = indexing.LazilyIndexedArray(
ActiveArrayWrapper(
self._filename,
var,
var.shape,
units,
var.dtype,
named_dims=dimensions,
active_options=self.active_options
)
)

encoding = {}

if isinstance(var.datatype, netCDF4.EnumType):
@@ -66,100 +90,3 @@ def open_store_variable(self, name: str, var):
encoding["original_shape"] = data.shape

return Variable(dimensions, data, attributes, encoding)

class ActiveSubarrayWrapper(BackendArray, SuperLazyArrayLike):

def __init__(self, variable_name, datastore, chunks=None, extent=None):
self.datastore = datastore
self.variable_name = variable_name

self._chunks = chunks
self._extent = extent
self._lock = SerializableLock()

self._variable = self._get_variable()
self.shape = self._variable.shape
self.ndim = len(self.shape)

dtype = self._variable.dtype
if dtype is str:
# use object dtype (with additional vlen string metadata) because that's
# the only way in numpy to represent variable length strings and to
# check vlen string dtype in further steps
# it also prevents automatic string concatenation via
# conventions.decode_cf_variable
dtype = create_vlen_dtype(str)
self.dtype = dtype

self.__array_function__ = self.get_array

def _get_variable(self, needs_lock=True):
ds = self.datastore._acquire(needs_lock)
variable = ds.variables[self.variable_name]
variable.set_auto_maskandscale(False)
# only added in netCDF4-python v1.2.8
with suppress(AttributeError):
variable.set_auto_chartostring(False)
return variable

def __array__(self):

if not self._chunks:
# get_array should just get the whole array if that's what we're trying to do.
# indexing should just be added to the instance of this class, and then the
# built-in mean from _ActiveFragment should take care of things.
return self._variable


# for every dask chunk return a smaller object with the right extent.
# Create a chunk_shape tuple from chunks and _variable (figure out which chunk and which axis, divide etc.)
# Define a subarray for each chunk, with appropriate index.

chunks = None # Need to find out what this needs to be.

name = (f"{self.__class__.__name__}-{tokenize(self)}",)
dsk = {}
for pos in positions:

subarray = ArrayPartition(
filename,
address,
dtype=,
shape=,
position=pos,
)

key = f"{subarray.__class__.__name__}-{tokenize(subarray)}"
dsk[key] = subarray
dsk[name + f_index] = (
getter, # Dask default should be enough with the new indexing routine.
key,
extent,
False,
getattr(subarray,"_lock",False)
)

return DaskActiveArray(dsk, name, chunks=chunks, dtype=self.dtype)


def _getitem(self, key):
if self.datastore.is_remote: # pragma: no cover
getitem = functools.partial(robust_getitem, catch=RuntimeError)
else:
getitem = operator.getitem

try:
with self.datastore.lock:
original_array = self.get_array(needs_lock=False)
array = getitem(original_array, key)
except IndexError:
# Catch IndexError in netCDF4 and return a more informative
# error message. This is most often called when an unsorted
# indexer is used before the data is loaded from disk.
msg = (
"The indexing operation you are attempting to perform "
"is not valid on netCDF4.Variable object. Try loading "
"your data into memory first by calling .load()."
)
raise IndexError(msg)
return array
414 changes: 0 additions & 414 deletions XarrayActive/partition.py

This file was deleted.

117 changes: 0 additions & 117 deletions XarrayActive/subarray.py

This file was deleted.

144 changes: 144 additions & 0 deletions XarrayActive/wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from arraypartition import (
ArrayPartition,
ArrayLike,
get_chunk_space,
get_chunk_shape,
get_chunk_positions,
get_chunk_extent,
get_dask_chunks,
combine_slices
)
from .active_chunk import (
ActiveChunk,
ActiveOptionsContainer
)

from .active_dask import DaskActiveArray

from dask.array.core import getter
from dask.base import tokenize

from itertools import product

class ActivePartition(ArrayPartition):
"""
Combines ActiveChunk - active methods, and ArrayPartition - array methods
into a single ChunkWrapper class.
"""
def copy(self, extent=None):

kwargs = self.get_kwargs()
if extent:
kwargs['extent'] = combine_slices(self.shape, list(self.get_extent()), extent)
ap = ActivePartition(
self.filename,
self.address,
**kwargs
)
return ap

class ActiveArrayWrapper(ArrayLike, ActiveOptionsContainer):
"""
ActiveArrayWrapper behaves like an Array that can be indexed or referenced to
return a Dask-like array object. This class is essentially a constructor for the
partitions that feed into the returned Dask-like array into Xarray.
"""
def __init__(
self,
filename,
var,
shape,
units=None,
dtype=None,
named_dims=None,
active_options={},
):

self._variable = var

self.filename = filename
self.name = var.name
self.active_options = active_options

self.named_dims = named_dims

super().__init__(shape, units=units, dtype=dtype)

self.chunk_shape = get_chunk_shape(
self._active_chunks,
self.shape,
self.named_dims,
chunk_limits=self._chunk_limits
)

self.chunk_space = get_chunk_space(
self.chunk_shape,
self.shape
)

self.__array_function__ = self.__array__

def __getitem__(self, selection):
"""
Non-lazy retrieval of the dask array when this object is indexed.
"""
arr = self.__array__()
return arr[selection]

def __array__(self, *args, **kwargs):

if not self._active_chunks:
# get_array should just get the whole array if that's what we're trying to do.
# indexing should just be added to the instance of this class, and then the
# built-in mean from _ActiveFragment should take care of things.
return self._variable
else:

# for every dask chunk return a smaller object with the right extent.
# Create a chunk_shape tuple from chunks and _variable (figure out which chunk and which axis, divide etc.)
# Define a subarray for each chunk, with appropriate index.

array_name = (f"{self.__class__.__name__}-{tokenize(self)}",)
dsk = {}
positions = get_chunk_positions(self.chunk_space)
request = get_chunk_extent(positions[0], self.shape, self.chunk_space)

global_extent = {}

for position in positions:
position = tuple(position)

extent = get_chunk_extent(position, self.shape, self.chunk_space)
cformat = None
global_extent[position] = extent

chunk = ActivePartition(
self.filename,
self.name,
dtype=self.dtype,
units=self.units,
shape=self.chunk_shape,
position=position,
extent=extent,
format=cformat
)

c_identifier = f"{chunk.__class__.__name__}-{tokenize(chunk)}"
dsk[c_identifier] = chunk
dsk[array_name + position] = (
getter, # Dask default should be enough with the new indexing routine.
c_identifier,
request,
False,
getattr(chunk,"_lock",False)
)

dask_chunks = get_dask_chunks(
self.shape,
self.chunk_space,
extent=global_extent,
dtype=self.dtype,
explicit_shapes=None
)

return DaskActiveArray(dsk, array_name[0], chunks=dask_chunks, dtype=self.dtype)
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -7,4 +7,5 @@ h5py==3.11.0
dask==2024.7.0
cftime==1.6.4
cfunits==3.3.7
pytest==7.2.0
pytest==7.2.0
ArrayPartition==1.0
Empty file added tests/__init__.py
Empty file.
Binary file added tests/rain_test.nc
Binary file not shown.
102 changes: 102 additions & 0 deletions tests/test_active.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# All routines for testing CFA general methods.
import xarray as xr
import numpy as np

def test_active():

path_to_active = f'tests/rain_test.nc'

try:
ds = xr.open_dataset(
path_to_active,
engine='Active',
active_options={})#{'chunks':{'time':2}})
except Exception as err:
assert isinstance(err, NotImplementedError)

ds = xr.open_dataset(
path_to_active,
engine='Active',
active_options={'chunks':{'time':2}})

assert 'p' in ds
assert ds['p'].shape == (20, 180, 360)

p_sel = ds['p'].isel(time=slice(0,3),latitude=slice(140,145), longitude=slice(90,100))

assert p_sel.shape == (3, 5, 10)

p_value = p_sel.mean()

assert p_value.shape == ()
assert (p_value.to_numpy() - 0.53279) < 0.01

def test_active_recursive():

path_to_active = f'tests/rain_test.nc'

try:
ds = xr.open_dataset(
path_to_active,
engine='Active',
active_options={})#{'chunks':{'time':2}})
except Exception as err:
assert isinstance(err, NotImplementedError)

ds = xr.open_dataset(
path_to_active,
engine='Active',
active_options={'chunks':{'time':2}})

assert 'p' in ds
assert ds['p'].shape == (20, 180, 360)

p_sel = ds['p'].isel(time=slice(0,3),latitude=slice(140,145), longitude=slice(90,100))

assert p_sel.shape == (3, 5, 10)

p_mean = p_sel.mean(dim='time')

assert p_mean.shape == (5, 10)
assert (p_mean[0][0].to_numpy() - 0.683402) < 0.01

def test_active_methods():

path_to_active = f'tests/rain_test.nc'

try:
ds = xr.open_dataset(
path_to_active,
engine='Active',
active_options={})#{'chunks':{'time':2}})
except Exception as err:
assert isinstance(err, NotImplementedError)

ds = xr.open_dataset(
path_to_active,
engine='Active',
active_options={'chunks':{'time':2}})

assert 'p' in ds
assert ds['p'].shape == (20, 180, 360)

p_sel = ds['p'].isel(latitude=slice(140,145), longitude=slice(90,100))

assert p_sel.shape == (20, 5, 10)

p_value = p_sel.isel(time=slice(0,3)).max()
assert p_value.shape == ()
assert (p_value.to_numpy() - 0.9978273) < 0.01

p_value = p_sel.isel(time=slice(0,3)).min()
assert p_value.shape == ()
assert (p_value.to_numpy() - 0.0014456) < 0.01

p_value = p_sel.isel(time=slice(0,3)).sum()
assert p_value.shape == ()
assert (p_value.to_numpy() - 76.7931739) < 0.01

if __name__ == '__main__':
test_active()
test_active_recursive()
test_active_methods()