Skip to content

Commit c81c632

Browse files
committed
adding warping.py
1 parent 7e70faa commit c81c632

File tree

3 files changed

+119
-0
lines changed

3 files changed

+119
-0
lines changed

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,5 +111,6 @@ Analysis modules
111111
spectrum
112112
tuning_curves
113113
wavelets
114+
warping
114115

115116

pynapple/process/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,4 @@
3636
compute_discrete_tuning_curves,
3737
)
3838
from .wavelets import compute_wavelet_transform, generate_morlet_filterbank
39+
from .warping import build_tensor

pynapple/process/warping.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""
2+
Functions to create trial-based tensors and warp times
3+
"""
4+
5+
import inspect
6+
from functools import wraps
7+
from numbers import Number
8+
9+
import numpy as np
10+
11+
from .. import core as nap
12+
13+
14+
def _validate_warping_inputs(func):
15+
@wraps(func)
16+
def wrapper(*args, **kwargs):
17+
# Validate each positional argument
18+
sig = inspect.signature(func)
19+
kwargs = sig.bind_partial(*args, **kwargs).arguments
20+
21+
parameters_type = {
22+
"input": (nap.Ts, nap.Tsd, nap.TsdFrame, nap.TsdTensor, nap.TsGroup),
23+
"ep": (nap.IntervalSet,),
24+
"binsize": (Number,),
25+
"time_unit": (str,),
26+
"align": (str,),
27+
"padding_value": (Number,),
28+
}
29+
for param, param_type in parameters_type.items():
30+
if param in kwargs:
31+
if not isinstance(kwargs[param], param_type):
32+
raise TypeError(
33+
f"Invalid type. Parameter {param} must be of type {[p.__name__ for p in param_type]}."
34+
)
35+
36+
# Call the original function with validated inputs
37+
return func(**kwargs)
38+
39+
return wrapper
40+
41+
42+
def _build_tensor_from_tsgroup(input, ep, binsize, align, padding_value):
43+
# Determine size of tensor
44+
n_t = int(np.max(np.ceil((ep.end + binsize - ep.start) / binsize)))
45+
46+
output = np.ones(shape=(len(input), len(ep), n_t)) * padding_value
47+
48+
count = input.count(bin_size=binsize, ep=ep)
49+
50+
for i in range(len(ep)):
51+
tmp = count.get(ep.start[i], ep.end[i]).values # Time by neuron
52+
output[:, i, 0 : tmp.shape[0]] = np.transpose(tmp)
53+
54+
return output
55+
56+
57+
def _build_tensor_from_tsd(input, ep, binsize, align, padding_value):
58+
pass
59+
60+
61+
@_validate_warping_inputs
62+
def build_tensor(
63+
input, ep, binsize=None, align="start", padding_value=np.nan, time_unit="s"
64+
):
65+
"""
66+
Return trial-based tensor from an IntervalSet object.
67+
68+
- if `input` is a `TsGroup`, returns a numpy array of shape (number of trial, number of group element, number of time bins).
69+
The `binsize` parameter determines the number of time bins.
70+
71+
- if `input` is `Tsd`, `TsdFrame` or `TsdTensor`, returns a numpy array of shape
72+
(number of trial, shape of time series, number of time points).
73+
If the parameter `binsize` is used, the data are "bin-averaged".
74+
75+
76+
Parameters
77+
----------
78+
input : Tsd, TsdFrame, TsdTensor or TsGroup
79+
Returns a numpy array.
80+
ep : IntervalSet
81+
Epochs holding the trials. Each interval can be of unequal size.
82+
binsize : Number, optional
83+
align: str, optional
84+
How to align the time series ('start' [default], 'end', 'both')
85+
padding_value: Number, optional
86+
How to pad the array if unequal intervals. Default is np.nan.
87+
time_unit : str, optional
88+
Time units of the binsize parameter ('s' [default], 'ms', 'us').
89+
90+
Returns
91+
-------
92+
numpy.ndarray
93+
94+
Raises
95+
------
96+
RuntimeError
97+
If `time_unit` not in ["s", "ms", "us"]
98+
99+
100+
Examples
101+
--------
102+
103+
104+
105+
"""
106+
if time_unit not in ["s", "ms", "us"]:
107+
raise RuntimeError("time_unit should be 's', 'ms' or 'us'")
108+
if align not in ["start", "end", "both"]:
109+
raise RuntimeError("align should be 'start', 'end' or 'both'")
110+
111+
binsize = np.abs(nap.TsIndex.format_timestamps(np.array([binsize]), time_unit))[0]
112+
113+
if isinstance(input, nap.TsGroup):
114+
return _build_tensor_from_tsgroup(input, ep, binsize, align, padding_value)
115+
116+
if isinstance(input, (nap.Tsd, nap.TsdFrame, nap.TsdTensor)):
117+
return _build_tensor_from_tsd(input, ep, binsize, align, padding_value)

0 commit comments

Comments
 (0)