Skip to content

Commit

Permalink
replaced all __class__(...) calls to enforce that the metadata is kept
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 10, 2024
1 parent 766a469 commit 8339e40
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 65 deletions.
30 changes: 14 additions & 16 deletions pynapple/core/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ def __init__(self, t, time_units="s", time_support=None):
self.rate = np.nan
self.time_support = IntervalSet(start=[], end=[])

@abc.abstractmethod
def _define_instance(self, time, iset, data=None, **kwargs):
"""Return a new class instance.
Grab "columns", "metadata" and other and other
"""
pass

@property
def t(self):
"""The time index of the time series"""
Expand Down Expand Up @@ -368,25 +376,15 @@ def restrict(self, iset):
ends = iset.end

idx = _restrict(time_array, starts, ends)

kwargs = {}
if hasattr(self, "columns"):
kwargs["columns"] = self.columns

if hasattr(self, "_metadata"):
kwargs["metadata"] = self._metadata

if hasattr(self, "values"):
data_array = self.values
return self.__class__(
t=time_array[idx], d=data_array[idx], time_support=iset, **kwargs
)
else:
return self.__class__(t=time_array[idx], time_support=iset)
data = None if not hasattr(self, "values") else self.values[idx]
return self._define_instance(time_array[idx] , iset, data=data)

def copy(self):
"""Copy the data, index and time support"""
return self.__class__(t=self.index.copy(), time_support=self.time_support)
data = getattr(self, "values", None)
if data is not None:
data = data.copy() if hasattr(data, "copy") else data[:].copy()
return self._define_instance(self.index.copy(), self.time_support, data=data)

def find_support(self, min_gap, time_units="s"):
"""
Expand Down
67 changes: 29 additions & 38 deletions pynapple/core/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,26 @@ def __init__(self, t, d, time_units="s", time_support=None, load_array=True):
self.dtype = self.values.dtype
self._load_array = load_array


def _define_instance(self, time, iset, data=None, **kwargs):
"""
Define a new class instance.
Optional parameters for initialization are either passed to the function or are grabbed from self.
"""
for key in ["columns", "metadata", "load_array"]:
if hasattr(self, key):
kwargs[key] = kwargs.get(key, getattr(self, key))
return self.__class__(
t=time, d=data, time_support=iset, **kwargs
)


@property
def load_array(self):
"""Read-only property load-array."""
return self._load_array

def __setitem__(self, key, value):
"""setter for time series"""
if isinstance(key, _BaseTsd):
Expand Down Expand Up @@ -265,12 +285,6 @@ def to_numpy(self):
"""
return np.asarray(self.values)

def copy(self):
"""Copy the data, index and time support"""
return self.__class__(
t=self.index.copy(), d=self.values[:].copy(), time_support=self.time_support
)

def value_from(self, data, ep=None):
"""
Replace the value with the closest value from Tsd/TsdFrame/TsdTensor argument
Expand Down Expand Up @@ -314,7 +328,7 @@ def value_from(self, data, ep=None):
), "First argument should be an instance of Tsd, TsdFrame or TsdTensor"

t, d, time_support, kwargs = super().value_from(data, ep)
return data.__class__(t=t, d=d, time_support=time_support, **kwargs)
return data._define_instance(t, time_support, data=d, **kwargs)

def count(self, *args, dtype=None, **kwargs):
"""
Expand Down Expand Up @@ -428,13 +442,7 @@ def bin_average(self, bin_size, ep=None, time_units="s"):

t, d = _bin_average(time_array, data_array, starts, ends, bin_size)

kwargs = {}
if hasattr(self, "columns"):
kwargs["columns"] = self.columns
if hasattr(self, "_metadata"):
kwargs["metadata"] = self._metadata

return self.__class__(t=t, d=d, time_support=ep, **kwargs)
return self._define_instance(t, ep, data=d)

def dropna(self, update_time_support=True):
"""Drop every rows containing NaNs. By default, the time support is updated to start and end around the time points that are non NaNs.
Expand Down Expand Up @@ -468,13 +476,7 @@ def dropna(self, update_time_support=True):
else:
ep = self.time_support

kwargs = {}
if hasattr(self, "columns"):
kwargs["columns"] = self.columns
if hasattr(self, "_metadata"):
kwargs["metadata"] = self._metadata

return self.__class__(t=t, d=d, time_support=ep, **kwargs)
return self._define_instance(t, ep, data=d)

def convolve(self, array, ep=None, trim="both"):
"""Return the discrete linear convolution of the time series with a one dimensional sequence.
Expand Down Expand Up @@ -698,12 +700,8 @@ def interpolate(self, ts, ep=None, left=None, right=None):
new_d[start : start + len(t), ...] = interpolated_values

start += len(t)
kwargs_dict = dict(time_support=ep)
if hasattr(self, "columns"):
kwargs_dict["columns"] = self.columns
if hasattr(self, "_metadata"):
kwargs_dict["metadata"] = self._metadata
return self.__class__(t=new_t, d=new_d, **kwargs_dict)

return self._define_instance(new_t, ep, data=new_d)


class TsdTensor(_BaseTsd):
Expand Down Expand Up @@ -1351,16 +1349,6 @@ def as_units(self, units="s"):
df.columns = self.columns.copy()
return df

def copy(self):
"""Copy the data, index, time support, columns and metadata of the TsdFrame object."""
return self.__class__(
t=self.index.copy(),
d=self.values[:].copy(),
time_support=self.time_support,
columns=self.columns.copy(),
metadata=self._metadata,
)

def save(self, filename):
"""
Save TsdFrame object in npz format. The file will contain the timestamps, the
Expand Down Expand Up @@ -2025,6 +2013,9 @@ def __init__(self, t, time_units="s", time_support=None):
self.nap_class = self.__class__.__name__
self._initialized = True

def _define_instance(self, time, iset, data=None, **kwargs):
return self.__class__(t=time, time_support=iset)

def __repr__(self):
upper = "Time (s)"
rows = _get_terminal_size()[1]
Expand Down Expand Up @@ -2130,7 +2121,7 @@ def value_from(self, data, ep=None):

t, d, time_support, kwargs = super().value_from(data, ep)

return data.__class__(t, d, time_support=time_support, **kwargs)
return data._define_instance(t, time_support, data=d, **kwargs)

def count(self, *args, dtype=None, **kwargs):
"""
Expand Down
6 changes: 2 additions & 4 deletions pynapple/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,10 @@ def _split_tsd(func, tsd, indices_or_sections, axis=0):
if func in [np.split, np.array_split, np.vsplit] and axis == 0:
out = func._implementation(tsd.values, indices_or_sections)
index_list = np.split(tsd.index.values, indices_or_sections)
kwargs = {"columns": tsd.columns.values} if hasattr(tsd, "columns") else {}
return [tsd.__class__(t=t, d=d, **kwargs) for t, d in zip(index_list, out)]
return [tsd._define_instance(t, None, data=d) for t, d in zip(index_list, out)]
elif func in [np.dsplit, np.hsplit]:
out = func._implementation(tsd.values, indices_or_sections)
kwargs = {"columns": tsd.columns.values} if hasattr(tsd, "columns") else {}
return [tsd.__class__(t=tsd.index, d=d, **kwargs) for d in out]
return [tsd._define_instance(tsd.index, None, data=d) for d in out]
else:
return func._implementation(tsd.values, indices_or_sections, axis)

Expand Down
5 changes: 1 addition & 4 deletions pynapple/process/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,7 @@ def _compute_butterworth_filter(
slc = data.get_slice(start=ep.start[0], end=ep.end[0])
out[slc] = sosfiltfilt(sos, data.d[slc], axis=0)

kwargs = dict(t=data.t, d=out, time_support=data.time_support)
if isinstance(data, nap.TsdFrame):
kwargs["columns"] = data.columns
return data.__class__(**kwargs)
return data._define_instance(data.t, data.time_support, data=out)


def _compute_spectral_inversion(kernel):
Expand Down
3 changes: 3 additions & 0 deletions tests/test_abstract_tsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def __str__(self):
def __repr__(self):
return "In repr"

def _define_instance(self, time, iset, data=None, **kwargs):
pass


def test_create_atsd():
a = MyClass(t=np.arange(10), d=np.arange(10))
Expand Down
6 changes: 3 additions & 3 deletions tests/test_lazy_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def test_dask_lazy_loading_tsd(dask_array_tsd):
t=np.arange(dask_array_tsd.shape[0]), d=dask_array_tsd, load_array=False
)
assert isinstance(tsd.d, da.Array)
assert isinstance(tsd.restrict(nap.IntervalSet(0, 10)).d, np.ndarray)
assert isinstance(tsd.restrict(nap.IntervalSet(0, 10)).d, da.Array)
repr(tsd)
assert isinstance(tsd.d, da.Array)
assert isinstance(tsd[1:10].d, np.ndarray)
Expand Down Expand Up @@ -407,7 +407,7 @@ def test_dask_lazy_loading_tsdframe(dask_array_tsdframe):
load_array=False,
)
assert isinstance(tsdframe.d, da.Array)
assert isinstance(tsdframe.restrict(nap.IntervalSet(0, 10)).d, np.ndarray)
assert isinstance(tsdframe.restrict(nap.IntervalSet(0, 10)).d, da.Array)
repr(tsdframe)
assert isinstance(tsdframe.d, da.Array)
assert isinstance(tsdframe[1:10].d, np.ndarray)
Expand Down Expand Up @@ -449,7 +449,7 @@ def test_dask_lazy_loading_tsdtensor(dask_array_tsdtensor):
load_array=False,
)
assert isinstance(tsdtensor.d, da.Array)
assert isinstance(tsdtensor.restrict(nap.IntervalSet(0, 10)).d, np.ndarray)
assert isinstance(tsdtensor.restrict(nap.IntervalSet(0, 10)).d, da.Array)
repr(tsdtensor)
assert isinstance(tsdtensor.d, da.Array)
assert isinstance(tsdtensor[1:10].d, np.ndarray)
Expand Down

0 comments on commit 8339e40

Please sign in to comment.