From 212f83de2dac8da45e3917259cd8e11b4e37b20d Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 10 Dec 2024 11:39:55 -0500 Subject: [PATCH] added dask array compat --- pynapple/core/time_series.py | 75 ++++++------ pyproject.toml | 3 +- tests/test_lazy_loading.py | 222 +++++++++++++++++++++++++++++++++++ 3 files changed, 261 insertions(+), 39 deletions(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index bdf53485..5e8df475 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -19,6 +19,7 @@ import importlib import warnings from numbers import Number +from typing import Callable import numpy as np import pandas as pd @@ -63,6 +64,27 @@ def _get_class(data): return TsdTensor +def _initialize_tsd_output(inp, out): + + if isinstance(out, np.ndarray) or is_array_like(out): + # # if dims increased in any case, we can't return safely a time series + # if out.ndim > self.ndim: + # return out + if out.shape[0] == inp.index.shape[0]: + kwargs = {"load_array": inp._load_array} + if (inp.ndim == 2) and (out.ndim == 2) and (out.shape[1] == inp.shape[1]): + # only pass columns and metadata if number of columns is preserved + if hasattr(inp, "columns"): + kwargs["columns"] = inp.columns + if hasattr(inp, "_metadata"): + kwargs["metadata"] = inp._metadata + return _get_class(out)( + t=inp.index, d=out, time_support=inp.time_support, **kwargs + ) + + return out + + class _BaseTsd(_Base, NDArrayOperatorsMixin, abc.ABC): """ Abstract base class for time series objects. @@ -125,6 +147,19 @@ def method(*args, **kwargs): return np_func(self, *args, **kwargs) return method + if name not in ("__getstate__", "__setstate__", "__reduce__", "__reduce_ex__"): + # apply array specific methods + attr = getattr(self.d, name, None) + + if isinstance(attr, Callable): + + def method(*args, **kwargs): + out = attr(*args, **kwargs) + return _initialize_tsd_output(self, out) + + return method + elif attr: + return attr raise AttributeError( "Time series object does not have the attribute {}".format(name) @@ -167,20 +202,7 @@ def __array_ufunc__(self, ufunc, method, *args, **kwargs): else: out = ufunc(*new_args, **kwargs) - if isinstance(out, np.ndarray) or is_array_like(out): - if out.shape[0] == self.index.shape[0]: - kwargs = {"load_array": self._load_array} - if hasattr(self, "columns"): - kwargs["columns"] = self.columns - if hasattr(self, "_metadata"): - kwargs["metadata"] = self._metadata - return _get_class(out)( - t=self.index, d=out, time_support=self.time_support, **kwargs - ) - else: - return out - else: - return out + return _initialize_tsd_output(self, out) else: return NotImplemented @@ -211,30 +233,7 @@ def __array_function__(self, func, types, args, kwargs): new_args.append(a) out = func._implementation(*new_args, **kwargs) - - if isinstance(out, np.ndarray) or is_array_like(out): - # # if dims increased in any case, we can't return safely a time series - # if out.ndim > self.ndim: - # return out - if out.shape[0] == self.index.shape[0]: - kwargs = {"load_array": self._load_array} - if ( - (self.ndim == 2) - and (out.ndim == 2) - and (out.shape[1] == self.shape[1]) - ): - # only pass columns and metadata if number of columns is preserved - if hasattr(self, "columns"): - kwargs["columns"] = self.columns - if hasattr(self, "_metadata"): - kwargs["metadata"] = self._metadata - return _get_class(out)( - t=self.index, d=out, time_support=self.time_support, **kwargs - ) - else: - return out - else: - return out + return _initialize_tsd_output(self, out) def as_array(self): """ diff --git a/pyproject.toml b/pyproject.toml index 00e0a8f0..f5ba5afd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,7 +86,8 @@ docs = [ "matplotlib", "seaborn", "zarr", - "dandi" + "dandi", + "dask", ] dandi = [ "dandi", # Dandi package diff --git a/tests/test_lazy_loading.py b/tests/test_lazy_loading.py index 9695497d..8078c738 100644 --- a/tests/test_lazy_loading.py +++ b/tests/test_lazy_loading.py @@ -1,17 +1,78 @@ import warnings from contextlib import nullcontext as does_not_raise from pathlib import Path +from tempfile import TemporaryDirectory +import dask.array as da import h5py import numpy as np import pandas as pd import pytest +import zarr from pynwb.testing.mock.base import mock_TimeSeries from pynwb.testing.mock.file import mock_NWBFile import pynapple as nap +@pytest.fixture +def dask_array_tsdframe(): + """Fixture for a Dask array.""" + array = da.random.random((100, 100), chunks=(10, 10)) + return array + + +@pytest.fixture +def dask_array_tsd(): + """Fixture for a Dask array.""" + array = da.random.random((100,), chunks=(10,)) + return array + + +@pytest.fixture +def dask_array_tsdtensor(): + """Fixture for a Dask array.""" + array = da.random.random((100, 10, 5), chunks=(10, 1, 2)) + return array + + +@pytest.fixture +def zarr_tsd(): + """Fixture for a Zarr array.""" + with TemporaryDirectory() as tmpdir: + store = zarr.DirectoryStore(tmpdir) + root = zarr.open(store, mode="w") + array = root.create_dataset("data", shape=(100,), chunks=(10,), dtype="f8") + array[:] = np.random.random((100,)) + yield array + + +@pytest.fixture +def zarr_tsdframe(): + """Fixture for a Zarr array.""" + with TemporaryDirectory() as tmpdir: + store = zarr.DirectoryStore(tmpdir) + root = zarr.open(store, mode="w") + array = root.create_dataset( + "data", shape=(100, 100), chunks=(10, 10), dtype="f8" + ) + array[:] = np.random.random((100, 100)) + yield array + + +@pytest.fixture +def zarr_tsdtensor(): + """Fixture for a Zarr array.""" + with TemporaryDirectory() as tmpdir: + store = zarr.DirectoryStore(tmpdir) + root = zarr.open(store, mode="w") + array = root.create_dataset( + "data", shape=(100, 10, 2), chunks=(10, 9, 1), dtype="f8" + ) + array[:] = np.random.random((100, 10, 2)) + yield array + + @pytest.mark.parametrize( "time, data, expectation", [ @@ -301,3 +362,164 @@ def test_tsgroup_no_warnings(tmp_path): # default fixture # file_path = Path(f'data_{k}.h5') # if file_path.exists(): # file_path.unlink() + + +def test_dask_lazy_loading_tsd(dask_array_tsd): + tsd = nap.Tsd( + t=np.arange(dask_array_tsd.shape[0]), d=dask_array_tsd, load_array=False + ) + assert isinstance(tsd.d, da.Array) + assert isinstance(tsd.restrict(nap.IntervalSet(0, 10)).d, np.ndarray) + repr(tsd) + assert isinstance(tsd.d, da.Array) + assert isinstance(tsd[1:10].d, np.ndarray) + tsd = nap.Tsd( + t=np.arange(dask_array_tsd.shape[0]), d=dask_array_tsd, load_array=True + ) + assert isinstance(tsd.d, np.ndarray) + + +def test_dask_lazy_compute_tsd(dask_array_tsd): + tsd = nap.Tsd( + t=np.arange(dask_array_tsd.shape[0]), d=dask_array_tsd, load_array=False + ) + tsd = tsd + 1 + assert isinstance(tsd.d, da.Array) + assert isinstance(tsd[:10].d, np.ndarray) + assert tsd[:10]._load_array is True + + out = tsd.compute() + assert isinstance(out.d, np.ndarray) + assert isinstance(tsd.chunks, tuple) + assert tsd._load_array is False + + out2 = tsd.map_blocks(np.exp) + assert isinstance(out2.d, da.Array) + assert out2._load_array is False + + assert isinstance(np.exp(tsd).d, da.Array) + + +def test_dask_lazy_loading_tsdframe(dask_array_tsdframe): + tsdframe = nap.TsdFrame( + t=np.arange(dask_array_tsdframe.shape[0]), + d=dask_array_tsdframe, + load_array=False, + ) + assert isinstance(tsdframe.d, da.Array) + assert isinstance(tsdframe.restrict(nap.IntervalSet(0, 10)).d, np.ndarray) + repr(tsdframe) + assert isinstance(tsdframe.d, da.Array) + assert isinstance(tsdframe[1:10].d, np.ndarray) + assert isinstance(tsdframe.loc[1].d, np.ndarray) + tsdframe = nap.TsdFrame( + t=np.arange(dask_array_tsdframe.shape[0]), + d=dask_array_tsdframe, + load_array=True, + ) + assert isinstance(tsdframe.d, np.ndarray) + + +def test_dask_lazy_compute_tsdframe(dask_array_tsdframe): + tsdframe = nap.TsdFrame( + t=np.arange(dask_array_tsdframe.shape[0]), + d=dask_array_tsdframe, + load_array=False, + ) + tsdframe = tsdframe**2 + assert isinstance(tsdframe.d, da.Array) + assert isinstance(tsdframe[:10].d, np.ndarray) + assert tsdframe[:10]._load_array is True + + out = tsdframe.compute() + assert isinstance(out.d, np.ndarray) + assert isinstance(tsdframe.chunks, tuple) + assert tsdframe._load_array is False + + out2 = tsdframe.map_blocks(np.exp) + assert isinstance(out2.d, da.Array) + assert out2._load_array is False + assert isinstance(np.exp(tsdframe).d, da.Array) + + +def test_dask_lazy_loading_tsdtensor(dask_array_tsdtensor): + tsdtensor = nap.TsdTensor( + t=np.arange(dask_array_tsdtensor.shape[0]), + d=dask_array_tsdtensor, + load_array=False, + ) + assert isinstance(tsdtensor.d, da.Array) + assert isinstance(tsdtensor.restrict(nap.IntervalSet(0, 10)).d, np.ndarray) + repr(tsdtensor) + assert isinstance(tsdtensor.d, da.Array) + assert isinstance(tsdtensor[1:10].d, np.ndarray) + tsdtensor = nap.TsdTensor( + t=np.arange(dask_array_tsdtensor.shape[0]), + d=dask_array_tsdtensor, + load_array=True, + ) + assert isinstance(tsdtensor.d, np.ndarray) + + +def test_dask_lazy_compute_tsdtensor(dask_array_tsdtensor): + tsdtensor = nap.TsdTensor( + t=np.arange(dask_array_tsdtensor.shape[0]), + d=dask_array_tsdtensor, + load_array=False, + ) + tsdtensor = tsdtensor + 1 + assert isinstance(tsdtensor.d, da.Array) + assert isinstance(tsdtensor[:10].d, np.ndarray) + assert tsdtensor[:10]._load_array is True + + out = tsdtensor.compute() + assert isinstance(out.d, np.ndarray) + assert isinstance(tsdtensor.chunks, tuple) + assert tsdtensor._load_array is False + + out2 = tsdtensor.map_blocks(np.exp) + assert isinstance(out2.d, da.Array) + assert out2._load_array is False + + assert isinstance(np.exp(tsdtensor).d, da.Array) + + +def test_lazy_load_zarr_tsd(zarr_tsd): + tsd = nap.Tsd(t=np.arange(zarr_tsd.shape[0]), d=zarr_tsd, load_array=False) + assert isinstance(tsd.d, zarr.Array) + assert isinstance(tsd.restrict(nap.IntervalSet(0, 10)).d, np.ndarray) + repr(tsd) + assert isinstance(tsd.d, zarr.Array) + assert isinstance(tsd[1:10].d, np.ndarray) + tsd = nap.TsdFrame(t=np.arange(zarr_tsd.shape[0]), d=zarr_tsd, load_array=True) + assert isinstance(tsd.d, np.ndarray) + + +def test_lazy_load_zarr_tsdframe(zarr_tsdframe): + tsdframe = nap.TsdFrame( + t=np.arange(zarr_tsdframe.shape[0]), d=zarr_tsdframe, load_array=False + ) + assert isinstance(tsdframe.d, zarr.Array) + assert isinstance(tsdframe.restrict(nap.IntervalSet(0, 10)).d, np.ndarray) + repr(tsdframe) + assert isinstance(tsdframe.d, zarr.Array) + assert isinstance(tsdframe[1:10].d, np.ndarray) + tsdframe = nap.TsdFrame( + t=np.arange(zarr_tsdframe.shape[0]), d=zarr_tsdframe, load_array=True + ) + assert isinstance(tsdframe.d, np.ndarray) + + +def test_lazy_load_zarr_tsdtensor(zarr_tsdtensor): + tsdtensor = nap.TsdTensor( + t=np.arange(zarr_tsdtensor.shape[0]), d=zarr_tsdtensor, load_array=False + ) + assert isinstance(tsdtensor.d, zarr.Array) + assert isinstance(tsdtensor.restrict(nap.IntervalSet(0, 10)).d, np.ndarray) + repr(tsdtensor) + assert isinstance(tsdtensor.d, zarr.Array) + assert isinstance(tsdtensor[1:10].d, np.ndarray) + tsdtensor = nap.TsdTensor( + t=np.arange(zarr_tsdtensor.shape[0]), d=zarr_tsdtensor, load_array=True + ) + assert isinstance(tsdtensor.d, np.ndarray)