Skip to content

Commit

Permalink
Adding perievent validators
Browse files Browse the repository at this point in the history
  • Loading branch information
gviejo committed Jan 9, 2025
1 parent 0688345 commit fef038a
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 155 deletions.
227 changes: 129 additions & 98 deletions pynapple/process/perievent.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,47 @@
"""Functions to realign time series relative to a reference time.
"""
Functions to realign time series relative to a reference time.
"""

import numpy as np

from .. import core as nap
from ._process_functions import _perievent_continuous, _perievent_trigger_average

import inspect
from functools import wraps
from numbers import Number


def _validate_perievent_inputs(func):
@wraps(func)
def wrapper(*args, **kwargs):
# Validate each positional argument
sig = inspect.signature(func)
kwargs = sig.bind_partial(*args, **kwargs).arguments

parameters_type = {
"timestamps": (nap.Ts, nap.Tsd, nap.TsdFrame, nap.TsdTensor, nap.TsGroup),
"timeseries": (nap.Tsd, nap.TsdFrame, nap.TsdTensor),
"tref": (nap.Ts, nap.Tsd, nap.TsdFrame, nap.TsdTensor),
"group": nap.TsGroup,
"ep": (nap.IntervalSet, None),
"feature": (nap.Tsd, nap.TsdFrame, nap.TsdTensor),
"binsize": Number,
"windowsize": (tuple, Number),
"time_units": str,
}
for param, param_type in parameters_type.items():
if param in kwargs:
if not isinstance(kwargs[param], param_type):
raise TypeError(
f"Invalid type. Parameter {param} must be of type {param_type}."
)

# Call the original function with validated inputs
return func(**kwargs)

return wrapper


def _align_tsd(tsd, tref, window, time_support):
"""
Expand All @@ -21,7 +56,7 @@ def _align_tsd(tsd, tref, window, time_support):
The data to align
tref : numpy.ndarray
The reference times
windowsize : tuple
window : tuple
Start and end of the window size around tref
Returns
Expand Down Expand Up @@ -50,89 +85,94 @@ def _align_tsd(tsd, tref, window, time_support):
return group


def compute_perievent(data, tref, minmax, time_unit="s"):
@_validate_perievent_inputs
def compute_perievent(timestamps, tref, windowsize, ep=None, time_unit="s"):
"""
Center the timestamps of a time series object or a time series group around the timestamps given by the `tref` argument.
`minmax` indicates the start and end of the window. If `minmax=(-5, 10)`, the window will be from -5 second to 10 second.
If `minmax=10`, the window will be from -10 second to 10 second.
`windowsize` indicates the start and end of the window. If `windowsize=(-5, 10)`, the window will be from -5 second to 10 second.
If `windowsize=10`, the window will be from -10 second to 10 second.
To center continuous time series around a set of timestamps, you can use `compute_perievent_continuous`.
To center the values of a time series around a set of timestamps, you can use `compute_perievent_continuous`.
Parameters
----------
data : Ts, Tsd or TsGroup
The data to align to tref.
If Ts/Tsd, returns a TsGroup.
timestamps : Ts, Tsd, TsdFrame, TsdTensor or TsGroup
The timestamps to align to tref.
If Ts/Tsd/TsdFrame/TsdTensor, returns a TsGroup.
If TsGroup, returns a dictionary of TsGroup
tref : Ts or Tsd
The timestamps of the event to align to
minmax : tuple, int or float
tref : Ts, Tsd, TsdFrame or TsdTensor.
The time reference of the event to align to
windowsize : tuple of int/float or int or float
The window size. Can be unequal on each side i.e. (-500, 1000).
ep : IntervalSet, optional
The epochs to perform the operation. If None, the default is the time support of the `timestamps` object.
time_unit : str, optional
Time units of the minmax ('s' [default], 'ms', 'us').
Time units of the windowsize ('s' [default], 'ms', 'us').
Returns
-------
dict
A TsGroup if data is a Ts/Tsd or
a dictionary of TsGroup if data is a TsGroup.
A TsGroup if timestamps is a Ts/Tsd/TsdFrame/TsdTensor or
a dictionary of TsGroup if timestamps is a TsGroup.
Raises
------
RuntimeError
if tref is not a Ts/Tsd object or if data is not a Ts/Tsd or TsGroup
If `time_unit` not in ["s", "ms", "us"]
If `windowsize` is wrongly defined
"""
assert isinstance(tref, (nap.Ts, nap.Tsd)), "tref should be a Ts or Tsd object."
assert isinstance(
data, (nap.Ts, nap.Tsd, nap.TsGroup)
), "data should be a Ts, Tsd or TsGroup."
assert isinstance(
minmax, (float, int, tuple)
), "minmax should be a tuple or int or float."
assert isinstance(time_unit, str), "time_unit should be a str."
assert time_unit in ["s", "ms", "us"], "time_unit should be 's', 'ms' or 'us'"
if time_unit not in ["s", "ms", "us"]:
raise RuntimeError("time_unit should be 's', 'ms' or 'us'")

if isinstance(minmax, float) or isinstance(minmax, int):
minmax = np.array([minmax, minmax], dtype=np.float64)
if isinstance(windowsize, Number):
windowsize = np.array([windowsize, windowsize], dtype=np.float64)

window = np.abs(nap.TsIndex.format_timestamps(np.array(minmax), time_unit))
if len(windowsize) != 2:
raise RuntimeError(
"windowsize should be a tuple of 2 numbers or a single number."
)

time_support = nap.IntervalSet(start=-window[0], end=window[1])
if not all([isinstance(x, Number) for x in windowsize]):
raise RuntimeError(
"windowsize should be a tuple of 2 numbers or a single number."
)

if isinstance(data, nap.TsGroup):
toreturn = {}
if ep is None:
ep = timestamps.time_support

for n in data.index:
toreturn[n] = _align_tsd(data[n], tref, window, time_support)
window = np.abs(nap.TsIndex.format_timestamps(np.array(windowsize), time_unit))

if isinstance(timestamps, nap.TsGroup):
toreturn = {}
for n in timestamps.index:
toreturn[n] = _align_tsd(timestamps[n], tref, window, ep)
return toreturn

else:
return _align_tsd(data, tref, window, time_support)
return _align_tsd(timestamps, tref, window, ep)


def compute_perievent_continuous(data, tref, minmax, ep=None, time_unit="s"):
def compute_perievent_continuous(timeseries, tref, windowsize, ep=None, time_unit="s"):
"""
Center continuous time series around the timestamps given by the 'tref' argument.
`minmax` indicates the start and end of the window. If `minmax=(-5, 10)`, the window will be from -5 second to 10 second.
If `minmax=10`, the window will be from -10 second to 10 second.
`windowsize` indicates the start and end of the window. If `windowsize=(-5, 10)`, the window will be from -5 second to 10 second.
If `windowsize=10`, the window will be from -10 second to 10 second.
To realign timestamps around a set of timestamps, you can use `compute_perievent_continuous`.
To realign timestamps around a set of timestamps, you can use `compute_perievent`.
This function assumes a constant sampling rate of the time series.
Parameters
----------
data : Tsd, TsdFrame or TsdTensor
The data to align to tref.
tref : Ts or Tsd
The timestamps of the event to align to
minmax : tuple or int or float
timeseries : Tsd, TsdFrame or TsdTensor
The time series to align to tref.
tref : Ts, Tsd, TsdFrame or TsdTensor
The time reference of the event to align to
windowsize : tuple of int/float or int or float
The window size. Can be unequal on each side i.e. (-500, 1000).
ep : IntervalSet, optional
The epochs to perform the operation. If None, the default is the time support of the data.
time_unit : str, optional
Time units of the minmax ('s' [default], 'ms', 'us').
Time units of the windowsize ('s' [default], 'ms', 'us').
Returns
-------
Expand All @@ -143,31 +183,31 @@ def compute_perievent_continuous(data, tref, minmax, ep=None, time_unit="s"):
Raises
------
RuntimeError
if tref is not a Ts/Tsd object or if data is not a Tsd/TsdFrame/TsdTensor object.
If `time_unit` not in ["s", "ms", "us"]
"""
if time_unit not in ["s", "ms", "us"]:
raise RuntimeError("time_unit should be 's', 'ms' or 'us'")

assert isinstance(tref, (nap.Ts, nap.Tsd)), "tref should be a Ts or Tsd object."
assert isinstance(
data, (nap.Tsd, nap.TsdFrame, nap.TsdTensor)
), "data should be a Tsd, TsdFrame or TsdTensor."
assert isinstance(
minmax, (float, int, tuple)
), "minmax should be a tuple or int or float."
assert isinstance(time_unit, str), "time_unit should be a str."
assert time_unit in ["s", "ms", "us"], "time_unit should be 's', 'ms' or 'us'"
if isinstance(windowsize, Number):
windowsize = np.array([windowsize, windowsize], dtype=np.float64)

if ep is None:
ep = data.time_support
else:
assert isinstance(ep, (nap.IntervalSet)), "ep should be an IntervalSet object."
if len(windowsize) != 2:
raise RuntimeError(
"windowsize should be a tuple of 2 numbers or a single number."
)

if not all([isinstance(x, Number) for x in windowsize]):
raise RuntimeError(
"windowsize should be a tuple of 2 numbers or a single number."
)

if isinstance(minmax, float) or isinstance(minmax, int):
minmax = np.array([minmax, minmax], dtype=np.float64)
if ep is None:
ep = timeseries.time_support

window = np.abs(nap.TsIndex.format_timestamps(np.array(minmax), time_unit))
window = np.abs(nap.TsIndex.format_timestamps(np.array(windowsize), time_unit))

time_array = data.index.values
data_array = data.values
time_array = timeseries.index.values
data_array = timeseries.values
time_target_array = tref.index.values
starts = ep.start
ends = ep.end
Expand All @@ -194,18 +234,18 @@ def compute_event_trigger_average(
group,
feature,
binsize,
windowsize=None,
windowsize=0,
ep=None,
time_unit="s",
):
"""
Bin the event timestamps within binsize and compute the Event Trigger Average (ETA) within windowsize.
Bin the event timestamps within binsize and compute the Event Trigger Average (ETA) within `windowsize`.
If C is the event count matrix and `feature` is a Tsd array, the function computes
the Hankel matrix H from windowsize=(-t1,+t2) by offseting the Tsd array.
The ETA is then defined as the dot product between H and C divided by the number of events.
The object feature can be any dimensions.
The object `feature` can be any dimensions.
Parameters
----------
Expand All @@ -216,42 +256,33 @@ def compute_event_trigger_average(
binsize : float or int
The bin size. Default is second.
If different, specify with the parameter time_unit ('s' [default], 'ms', 'us').
windowsize : tuple of float/int or float/int
windowsize : tuple of float/int or float/int, optional
The window size. Default is second. For example windowsize = (-1, 1) is equivalent to windowsize = 1
If different, specify with the parameter time_unit ('s' [default], 'ms', 'us').
ep : IntervalSet
The epochs on which the average is computed
Default is (0, 0)
ep : IntervalSet, optional
The epochs on which the average is computed. If None, the time support of the feature is used.
time_unit : str, optional
The time unit of the parameters. They have to be consistent for binsize and windowsize.
('s' [default], 'ms', 'us').
"""
assert isinstance(group, nap.TsGroup), "group should be a TsGroup."
assert isinstance(
feature, (nap.Tsd, nap.TsdFrame, nap.TsdTensor)
), "Feature should be a Tsd, TsdFrame or TsdTensor"
assert isinstance(binsize, (float, int)), "binsize should be int or float."
assert isinstance(time_unit, str), "time_unit should be a str."
assert time_unit in ["s", "ms", "us"], "time_unit should be 's', 'ms' or 'us'"

if windowsize is not None:
if isinstance(windowsize, tuple):
assert (
len(windowsize) == 2
), "windowsize should be a tuple of 2 elements (-t, +t)"
assert all(
[isinstance(t, (float, int)) for t in windowsize]
), "windowsize should be a tuple of int/float"
else:
assert isinstance(
windowsize, (float, int)
), "windowsize should be a tuple of int/float or int/float."
windowsize = (windowsize, windowsize)
else:
windowsize = (0.0, 0.0)
if time_unit not in ["s", "ms", "us"]:
raise RuntimeError("time_unit should be 's', 'ms' or 'us'")

if ep is not None:
assert isinstance(ep, (nap.IntervalSet)), "ep should be an IntervalSet object."
else:
if isinstance(windowsize, Number):
windowsize = np.array([windowsize, windowsize], dtype=np.float64)

if len(windowsize) != 2:
raise RuntimeError(
"windowsize should be a tuple of 2 numbers or a single number."
)

if not all([isinstance(x, Number) for x in windowsize]):
raise RuntimeError(
"windowsize should be a tuple of 2 numbers or a single number."
)

if ep is None:
ep = feature.time_support

binsize = nap.TsIndex.format_timestamps(
Expand All @@ -272,7 +303,7 @@ def compute_event_trigger_average(
idx2 = np.arange(0, end + binsize, binsize)[1:]
time_idx = np.hstack((idx1, np.zeros(1), idx2))

eta = np.zeros((time_idx.shape[0], len(group), *feature.shape[1:]))
# eta = np.zeros((time_idx.shape[0], len(group), *feature.shape[1:]))

windows = np.array([len(idx1), len(idx2)])

Expand Down
Loading

0 comments on commit fef038a

Please sign in to comment.