Skip to content

Commit

Permalink
added dask array compat
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 10, 2024
1 parent adec14f commit 212f83d
Show file tree
Hide file tree
Showing 3 changed files with 261 additions and 39 deletions.
75 changes: 37 additions & 38 deletions pynapple/core/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import importlib
import warnings
from numbers import Number
from typing import Callable

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -63,6 +64,27 @@ def _get_class(data):
return TsdTensor


def _initialize_tsd_output(inp, out):

if isinstance(out, np.ndarray) or is_array_like(out):
# # if dims increased in any case, we can't return safely a time series
# if out.ndim > self.ndim:
# return out
if out.shape[0] == inp.index.shape[0]:
kwargs = {"load_array": inp._load_array}
if (inp.ndim == 2) and (out.ndim == 2) and (out.shape[1] == inp.shape[1]):
# only pass columns and metadata if number of columns is preserved
if hasattr(inp, "columns"):
kwargs["columns"] = inp.columns
if hasattr(inp, "_metadata"):
kwargs["metadata"] = inp._metadata
return _get_class(out)(
t=inp.index, d=out, time_support=inp.time_support, **kwargs
)

return out


class _BaseTsd(_Base, NDArrayOperatorsMixin, abc.ABC):
"""
Abstract base class for time series objects.
Expand Down Expand Up @@ -125,6 +147,19 @@ def method(*args, **kwargs):
return np_func(self, *args, **kwargs)

return method
if name not in ("__getstate__", "__setstate__", "__reduce__", "__reduce_ex__"):
# apply array specific methods
attr = getattr(self.d, name, None)

if isinstance(attr, Callable):

def method(*args, **kwargs):
out = attr(*args, **kwargs)
return _initialize_tsd_output(self, out)

return method
elif attr:
return attr

raise AttributeError(
"Time series object does not have the attribute {}".format(name)
Expand Down Expand Up @@ -167,20 +202,7 @@ def __array_ufunc__(self, ufunc, method, *args, **kwargs):
else:
out = ufunc(*new_args, **kwargs)

if isinstance(out, np.ndarray) or is_array_like(out):
if out.shape[0] == self.index.shape[0]:
kwargs = {"load_array": self._load_array}
if hasattr(self, "columns"):
kwargs["columns"] = self.columns
if hasattr(self, "_metadata"):
kwargs["metadata"] = self._metadata
return _get_class(out)(
t=self.index, d=out, time_support=self.time_support, **kwargs
)
else:
return out
else:
return out
return _initialize_tsd_output(self, out)
else:
return NotImplemented

Expand Down Expand Up @@ -211,30 +233,7 @@ def __array_function__(self, func, types, args, kwargs):
new_args.append(a)

out = func._implementation(*new_args, **kwargs)

if isinstance(out, np.ndarray) or is_array_like(out):
# # if dims increased in any case, we can't return safely a time series
# if out.ndim > self.ndim:
# return out
if out.shape[0] == self.index.shape[0]:
kwargs = {"load_array": self._load_array}
if (
(self.ndim == 2)
and (out.ndim == 2)
and (out.shape[1] == self.shape[1])
):
# only pass columns and metadata if number of columns is preserved
if hasattr(self, "columns"):
kwargs["columns"] = self.columns
if hasattr(self, "_metadata"):
kwargs["metadata"] = self._metadata
return _get_class(out)(
t=self.index, d=out, time_support=self.time_support, **kwargs
)
else:
return out
else:
return out
return _initialize_tsd_output(self, out)

def as_array(self):
"""
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ docs = [
"matplotlib",
"seaborn",
"zarr",
"dandi"
"dandi",
"dask",
]
dandi = [
"dandi", # Dandi package
Expand Down
222 changes: 222 additions & 0 deletions tests/test_lazy_loading.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,78 @@
import warnings
from contextlib import nullcontext as does_not_raise
from pathlib import Path
from tempfile import TemporaryDirectory

import dask.array as da
import h5py
import numpy as np
import pandas as pd
import pytest
import zarr
from pynwb.testing.mock.base import mock_TimeSeries
from pynwb.testing.mock.file import mock_NWBFile

import pynapple as nap


@pytest.fixture
def dask_array_tsdframe():
"""Fixture for a Dask array."""
array = da.random.random((100, 100), chunks=(10, 10))
return array


@pytest.fixture
def dask_array_tsd():
"""Fixture for a Dask array."""
array = da.random.random((100,), chunks=(10,))
return array


@pytest.fixture
def dask_array_tsdtensor():
"""Fixture for a Dask array."""
array = da.random.random((100, 10, 5), chunks=(10, 1, 2))
return array


@pytest.fixture
def zarr_tsd():
"""Fixture for a Zarr array."""
with TemporaryDirectory() as tmpdir:
store = zarr.DirectoryStore(tmpdir)
root = zarr.open(store, mode="w")
array = root.create_dataset("data", shape=(100,), chunks=(10,), dtype="f8")
array[:] = np.random.random((100,))
yield array


@pytest.fixture
def zarr_tsdframe():
"""Fixture for a Zarr array."""
with TemporaryDirectory() as tmpdir:
store = zarr.DirectoryStore(tmpdir)
root = zarr.open(store, mode="w")
array = root.create_dataset(
"data", shape=(100, 100), chunks=(10, 10), dtype="f8"
)
array[:] = np.random.random((100, 100))
yield array


@pytest.fixture
def zarr_tsdtensor():
"""Fixture for a Zarr array."""
with TemporaryDirectory() as tmpdir:
store = zarr.DirectoryStore(tmpdir)
root = zarr.open(store, mode="w")
array = root.create_dataset(
"data", shape=(100, 10, 2), chunks=(10, 9, 1), dtype="f8"
)
array[:] = np.random.random((100, 10, 2))
yield array


@pytest.mark.parametrize(
"time, data, expectation",
[
Expand Down Expand Up @@ -301,3 +362,164 @@ def test_tsgroup_no_warnings(tmp_path): # default fixture
# file_path = Path(f'data_{k}.h5')
# if file_path.exists():
# file_path.unlink()


def test_dask_lazy_loading_tsd(dask_array_tsd):
tsd = nap.Tsd(
t=np.arange(dask_array_tsd.shape[0]), d=dask_array_tsd, load_array=False
)
assert isinstance(tsd.d, da.Array)
assert isinstance(tsd.restrict(nap.IntervalSet(0, 10)).d, np.ndarray)
repr(tsd)
assert isinstance(tsd.d, da.Array)
assert isinstance(tsd[1:10].d, np.ndarray)
tsd = nap.Tsd(
t=np.arange(dask_array_tsd.shape[0]), d=dask_array_tsd, load_array=True
)
assert isinstance(tsd.d, np.ndarray)


def test_dask_lazy_compute_tsd(dask_array_tsd):
tsd = nap.Tsd(
t=np.arange(dask_array_tsd.shape[0]), d=dask_array_tsd, load_array=False
)
tsd = tsd + 1
assert isinstance(tsd.d, da.Array)
assert isinstance(tsd[:10].d, np.ndarray)
assert tsd[:10]._load_array is True

out = tsd.compute()
assert isinstance(out.d, np.ndarray)
assert isinstance(tsd.chunks, tuple)
assert tsd._load_array is False

out2 = tsd.map_blocks(np.exp)
assert isinstance(out2.d, da.Array)
assert out2._load_array is False

assert isinstance(np.exp(tsd).d, da.Array)


def test_dask_lazy_loading_tsdframe(dask_array_tsdframe):
tsdframe = nap.TsdFrame(
t=np.arange(dask_array_tsdframe.shape[0]),
d=dask_array_tsdframe,
load_array=False,
)
assert isinstance(tsdframe.d, da.Array)
assert isinstance(tsdframe.restrict(nap.IntervalSet(0, 10)).d, np.ndarray)
repr(tsdframe)
assert isinstance(tsdframe.d, da.Array)
assert isinstance(tsdframe[1:10].d, np.ndarray)
assert isinstance(tsdframe.loc[1].d, np.ndarray)
tsdframe = nap.TsdFrame(
t=np.arange(dask_array_tsdframe.shape[0]),
d=dask_array_tsdframe,
load_array=True,
)
assert isinstance(tsdframe.d, np.ndarray)


def test_dask_lazy_compute_tsdframe(dask_array_tsdframe):
tsdframe = nap.TsdFrame(
t=np.arange(dask_array_tsdframe.shape[0]),
d=dask_array_tsdframe,
load_array=False,
)
tsdframe = tsdframe**2
assert isinstance(tsdframe.d, da.Array)
assert isinstance(tsdframe[:10].d, np.ndarray)
assert tsdframe[:10]._load_array is True

out = tsdframe.compute()
assert isinstance(out.d, np.ndarray)
assert isinstance(tsdframe.chunks, tuple)
assert tsdframe._load_array is False

out2 = tsdframe.map_blocks(np.exp)
assert isinstance(out2.d, da.Array)
assert out2._load_array is False
assert isinstance(np.exp(tsdframe).d, da.Array)


def test_dask_lazy_loading_tsdtensor(dask_array_tsdtensor):
tsdtensor = nap.TsdTensor(
t=np.arange(dask_array_tsdtensor.shape[0]),
d=dask_array_tsdtensor,
load_array=False,
)
assert isinstance(tsdtensor.d, da.Array)
assert isinstance(tsdtensor.restrict(nap.IntervalSet(0, 10)).d, np.ndarray)
repr(tsdtensor)
assert isinstance(tsdtensor.d, da.Array)
assert isinstance(tsdtensor[1:10].d, np.ndarray)
tsdtensor = nap.TsdTensor(
t=np.arange(dask_array_tsdtensor.shape[0]),
d=dask_array_tsdtensor,
load_array=True,
)
assert isinstance(tsdtensor.d, np.ndarray)


def test_dask_lazy_compute_tsdtensor(dask_array_tsdtensor):
tsdtensor = nap.TsdTensor(
t=np.arange(dask_array_tsdtensor.shape[0]),
d=dask_array_tsdtensor,
load_array=False,
)
tsdtensor = tsdtensor + 1
assert isinstance(tsdtensor.d, da.Array)
assert isinstance(tsdtensor[:10].d, np.ndarray)
assert tsdtensor[:10]._load_array is True

out = tsdtensor.compute()
assert isinstance(out.d, np.ndarray)
assert isinstance(tsdtensor.chunks, tuple)
assert tsdtensor._load_array is False

out2 = tsdtensor.map_blocks(np.exp)
assert isinstance(out2.d, da.Array)
assert out2._load_array is False

assert isinstance(np.exp(tsdtensor).d, da.Array)


def test_lazy_load_zarr_tsd(zarr_tsd):
tsd = nap.Tsd(t=np.arange(zarr_tsd.shape[0]), d=zarr_tsd, load_array=False)
assert isinstance(tsd.d, zarr.Array)
assert isinstance(tsd.restrict(nap.IntervalSet(0, 10)).d, np.ndarray)
repr(tsd)
assert isinstance(tsd.d, zarr.Array)
assert isinstance(tsd[1:10].d, np.ndarray)
tsd = nap.TsdFrame(t=np.arange(zarr_tsd.shape[0]), d=zarr_tsd, load_array=True)
assert isinstance(tsd.d, np.ndarray)


def test_lazy_load_zarr_tsdframe(zarr_tsdframe):
tsdframe = nap.TsdFrame(
t=np.arange(zarr_tsdframe.shape[0]), d=zarr_tsdframe, load_array=False
)
assert isinstance(tsdframe.d, zarr.Array)
assert isinstance(tsdframe.restrict(nap.IntervalSet(0, 10)).d, np.ndarray)
repr(tsdframe)
assert isinstance(tsdframe.d, zarr.Array)
assert isinstance(tsdframe[1:10].d, np.ndarray)
tsdframe = nap.TsdFrame(
t=np.arange(zarr_tsdframe.shape[0]), d=zarr_tsdframe, load_array=True
)
assert isinstance(tsdframe.d, np.ndarray)


def test_lazy_load_zarr_tsdtensor(zarr_tsdtensor):
tsdtensor = nap.TsdTensor(
t=np.arange(zarr_tsdtensor.shape[0]), d=zarr_tsdtensor, load_array=False
)
assert isinstance(tsdtensor.d, zarr.Array)
assert isinstance(tsdtensor.restrict(nap.IntervalSet(0, 10)).d, np.ndarray)
repr(tsdtensor)
assert isinstance(tsdtensor.d, zarr.Array)
assert isinstance(tsdtensor[1:10].d, np.ndarray)
tsdtensor = nap.TsdTensor(
t=np.arange(zarr_tsdtensor.shape[0]), d=zarr_tsdtensor, load_array=True
)
assert isinstance(tsdtensor.d, np.ndarray)

0 comments on commit 212f83d

Please sign in to comment.