Skip to content

Commit 7e70faa

Browse files
committed
Cleaning perievent
1 parent fef038a commit 7e70faa

File tree

3 files changed

+155
-64
lines changed

3 files changed

+155
-64
lines changed

pynapple/process/perievent.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
Functions to realign time series relative to a reference time.
33
"""
44

5+
import inspect
6+
from functools import wraps
7+
from numbers import Number
8+
59
import numpy as np
610

711
from .. import core as nap
812
from ._process_functions import _perievent_continuous, _perievent_trigger_average
913

10-
import inspect
11-
from functools import wraps
12-
from numbers import Number
13-
1414

1515
def _validate_perievent_inputs(func):
1616
@wraps(func)
@@ -23,18 +23,18 @@ def wrapper(*args, **kwargs):
2323
"timestamps": (nap.Ts, nap.Tsd, nap.TsdFrame, nap.TsdTensor, nap.TsGroup),
2424
"timeseries": (nap.Tsd, nap.TsdFrame, nap.TsdTensor),
2525
"tref": (nap.Ts, nap.Tsd, nap.TsdFrame, nap.TsdTensor),
26-
"group": nap.TsGroup,
27-
"ep": (nap.IntervalSet, None),
26+
"group": (nap.TsGroup,),
27+
"ep": (nap.IntervalSet,),
2828
"feature": (nap.Tsd, nap.TsdFrame, nap.TsdTensor),
29-
"binsize": Number,
29+
"binsize": (Number,),
3030
"windowsize": (tuple, Number),
31-
"time_units": str,
31+
"time_unit": (str,),
3232
}
3333
for param, param_type in parameters_type.items():
3434
if param in kwargs:
3535
if not isinstance(kwargs[param], param_type):
3636
raise TypeError(
37-
f"Invalid type. Parameter {param} must be of type {param_type}."
37+
f"Invalid type. Parameter {param} must be of type {[p.__name__ for p in param_type]}."
3838
)
3939

4040
# Call the original function with validated inputs
@@ -43,7 +43,7 @@ def wrapper(*args, **kwargs):
4343
return wrapper
4444

4545

46-
def _align_tsd(tsd, tref, window, time_support):
46+
def _align_tsd(tsd, tref, window, new_time_support):
4747
"""
4848
Helper function compiled with numba for aligning times.
4949
See compute_perievent for using this function
@@ -72,21 +72,21 @@ def _align_tsd(tsd, tref, window, time_support):
7272
if isinstance(tsd, nap.Ts):
7373
for i in range(len(tref)):
7474
tmp = tsd.index[lbounds[i] : rbounds[i]] - tref.index[i]
75-
group[i] = nap.Ts(t=tmp, time_support=time_support)
75+
group[i] = nap.Ts(t=tmp, time_support=new_time_support)
7676
else:
7777
for i in range(len(tref)):
7878
tmp = tsd.index[lbounds[i] : rbounds[i]] - tref.index[i]
7979
tmp2 = tsd.values[lbounds[i] : rbounds[i]]
80-
group[i] = nap.Tsd(t=tmp, d=tmp2, time_support=time_support)
80+
group[i] = nap.Tsd(t=tmp, d=tmp2, time_support=new_time_support)
8181

82-
group = nap.TsGroup(group, time_support=time_support, bypass_check=True)
82+
group = nap.TsGroup(group, time_support=new_time_support, bypass_check=True)
8383
group.set_info(ref_times=tref.index)
8484

8585
return group
8686

8787

8888
@_validate_perievent_inputs
89-
def compute_perievent(timestamps, tref, windowsize, ep=None, time_unit="s"):
89+
def compute_perievent(timestamps, tref, windowsize, time_unit="s"):
9090
"""
9191
Center the timestamps of a time series object or a time series group around the timestamps given by the `tref` argument.
9292
`windowsize` indicates the start and end of the window. If `windowsize=(-5, 10)`, the window will be from -5 second to 10 second.
@@ -104,8 +104,6 @@ def compute_perievent(timestamps, tref, windowsize, ep=None, time_unit="s"):
104104
The time reference of the event to align to
105105
windowsize : tuple of int/float or int or float
106106
The window size. Can be unequal on each side i.e. (-500, 1000).
107-
ep : IntervalSet, optional
108-
The epochs to perform the operation. If None, the default is the time support of the `timestamps` object.
109107
time_unit : str, optional
110108
Time units of the windowsize ('s' [default], 'ms', 'us').
111109
@@ -137,20 +135,20 @@ def compute_perievent(timestamps, tref, windowsize, ep=None, time_unit="s"):
137135
"windowsize should be a tuple of 2 numbers or a single number."
138136
)
139137

140-
if ep is None:
141-
ep = timestamps.time_support
142-
143138
window = np.abs(nap.TsIndex.format_timestamps(np.array(windowsize), time_unit))
144139

140+
new_time_support = nap.IntervalSet(start=-window[0], end=window[1])
141+
145142
if isinstance(timestamps, nap.TsGroup):
146143
toreturn = {}
147144
for n in timestamps.index:
148-
toreturn[n] = _align_tsd(timestamps[n], tref, window, ep)
145+
toreturn[n] = _align_tsd(timestamps[n], tref, window, new_time_support)
149146
return toreturn
150147
else:
151-
return _align_tsd(timestamps, tref, window, ep)
148+
return _align_tsd(timestamps, tref, window, new_time_support)
152149

153150

151+
@_validate_perievent_inputs
154152
def compute_perievent_continuous(timeseries, tref, windowsize, ep=None, time_unit="s"):
155153
"""
156154
Center continuous time series around the timestamps given by the 'tref' argument.
@@ -230,6 +228,7 @@ def compute_perievent_continuous(timeseries, tref, windowsize, ep=None, time_uni
230228
return nap.TsdTensor(t=time_idx, d=new_data_array, time_support=time_support)
231229

232230

231+
@_validate_perievent_inputs
233232
def compute_event_trigger_average(
234233
group,
235234
feature,

tests/test_perievent.py

Lines changed: 94 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -54,26 +54,52 @@ def test_compute_perievent_windowsize():
5454
def test_compute_perievent_raise_error():
5555
tsd = nap.Ts(t=np.arange(100))
5656
tref = nap.Ts(t=np.arange(10, 100, 10))
57-
with pytest.raises(AssertionError) as e_info:
58-
nap.compute_perievent(tsd, [0, 1, 2], windowsize=(-10, 10))
59-
assert str(e_info.value) == "tref should be a Ts or Tsd object."
57+
with pytest.raises(TypeError) as e_info:
58+
nap.compute_perievent(tsd, tref=[0, 1, 2], windowsize=(-10, 10))
59+
assert (
60+
str(e_info.value)
61+
== "Invalid type. Parameter tref must be of type ['Ts', 'Tsd', 'TsdFrame', 'TsdTensor']."
62+
)
6063

61-
with pytest.raises(AssertionError) as e_info:
62-
nap.compute_perievent([0, 1, 2], tref, windowsize=(-10, 10))
63-
assert str(e_info.value) == "data should be a Ts, Tsd or TsGroup."
64+
with pytest.raises(TypeError) as e_info:
65+
nap.compute_perievent(timestamps=[0, 1, 2], tref=tref, windowsize=(-10, 10))
66+
assert (
67+
str(e_info.value)
68+
== "Invalid type. Parameter timestamps must be of type ['Ts', 'Tsd', 'TsdFrame', 'TsdTensor', 'TsGroup']."
69+
)
6470

65-
with pytest.raises(AssertionError) as e_info:
71+
with pytest.raises(TypeError) as e_info:
6672
nap.compute_perievent(tsd, tref, windowsize={0: 1})
67-
assert str(e_info.value) == "windowsize should be a tuple or int or float."
73+
assert (
74+
str(e_info.value)
75+
== "Invalid type. Parameter windowsize must be of type ['tuple', 'Number']."
76+
)
6877

69-
with pytest.raises(AssertionError) as e_info:
78+
with pytest.raises(TypeError) as e_info:
7079
nap.compute_perievent(tsd, tref, windowsize=10, time_unit=1)
71-
assert str(e_info.value) == "time_unit should be a str."
80+
assert (
81+
str(e_info.value)
82+
== "Invalid type. Parameter time_unit must be of type ['str']."
83+
)
7284

73-
with pytest.raises(AssertionError) as e_info:
85+
with pytest.raises(RuntimeError) as e_info:
7486
nap.compute_perievent(tsd, tref, windowsize=10, time_unit="a")
7587
assert str(e_info.value) == "time_unit should be 's', 'ms' or 'us'"
7688

89+
with pytest.raises(RuntimeError) as e_info:
90+
nap.compute_perievent(tsd, tref, windowsize=(1, 2, 3))
91+
assert (
92+
str(e_info.value)
93+
== "windowsize should be a tuple of 2 numbers or a single number."
94+
)
95+
96+
with pytest.raises(RuntimeError) as e_info:
97+
nap.compute_perievent(tsd, tref, windowsize=(1, "2"))
98+
assert (
99+
str(e_info.value)
100+
== "windowsize should be a tuple of 2 numbers or a single number."
101+
)
102+
77103

78104
def test_compute_perievent_with_tsgroup():
79105
tsgroup = nap.TsGroup(
@@ -97,7 +123,9 @@ def test_compute_perievent_time_units():
97123
tsd = nap.Tsd(t=np.arange(100), d=np.arange(100))
98124
tref = nap.Ts(t=np.arange(10, 100, 10))
99125
for tu, fa in zip(["s", "ms", "us"], [1, 1e3, 1e6]):
100-
peth = nap.compute_perievent(tsd, tref, windowsize=(-10 * fa, 10 * fa), time_unit=tu)
126+
peth = nap.compute_perievent(
127+
tsd, tref, windowsize=(-10 * fa, 10 * fa), time_unit=tu
128+
)
101129
for i, j in zip(peth.keys(), np.arange(0, 100, 10)):
102130
np.testing.assert_array_almost_equal(peth[i].index, np.arange(-10, 10))
103131
np.testing.assert_array_almost_equal(peth[i].values, np.arange(j, j + 20))
@@ -127,7 +155,9 @@ def test_compute_perievent_continuous():
127155
np.testing.assert_array_almost_equal(
128156
pe.index.values, np.arange(windowsize[0], windowsize[-1] + 1)
129157
)
130-
tmp = np.array([np.arange(t + windowsize[0], t + windowsize[1] + 1) for t in tref.t]).T
158+
tmp = np.array(
159+
[np.arange(t + windowsize[0], t + windowsize[1] + 1) for t in tref.t]
160+
).T
131161
np.testing.assert_array_almost_equal(pe.values, tmp)
132162

133163
windowsize = (5, 10)
@@ -184,11 +214,15 @@ def test_compute_perievent_continuous_time_units():
184214
tref = nap.Ts(t=np.array([20, 60]))
185215
windowsize = (-5, 10)
186216
for tu, fa in zip(["s", "ms", "us"], [1, 1e3, 1e6]):
187-
pe = nap.compute_perievent_continuous(tsd, tref, windowsize=(windowsize[0] * fa, windowsize[1] * fa), time_unit=tu)
217+
pe = nap.compute_perievent_continuous(
218+
tsd, tref, windowsize=(windowsize[0] * fa, windowsize[1] * fa), time_unit=tu
219+
)
188220
np.testing.assert_array_almost_equal(
189221
pe.index.values, np.arange(windowsize[0], windowsize[1] + 1)
190222
)
191-
tmp = np.array([np.arange(t + windowsize[0], t + windowsize[1] + 1) for t in tref.t]).T
223+
tmp = np.array(
224+
[np.arange(t + windowsize[0], t + windowsize[1] + 1) for t in tref.t]
225+
).T
192226
np.testing.assert_array_almost_equal(pe.values, tmp)
193227

194228

@@ -201,7 +235,10 @@ def test_compute_perievent_continuous_with_ep():
201235

202236
assert pe.shape[1] == len(tref) - 1
203237
tmp = np.array(
204-
[np.arange(t + windowsize[0], t + windowsize[1] + 1) for t in tref.restrict(ep).t]
238+
[
239+
np.arange(t + windowsize[0], t + windowsize[1] + 1)
240+
for t in tref.restrict(ep).t
241+
]
205242
).T
206243
np.testing.assert_array_almost_equal(pe.values, tmp)
207244

@@ -237,26 +274,55 @@ def test_compute_perievent_continuous_with_ep():
237274
def test_compute_perievent_continuous_raise_error():
238275
tsd = nap.Tsd(t=np.arange(100), d=np.arange(100))
239276
tref = nap.Ts(t=np.arange(10, 100, 10))
240-
with pytest.raises(AssertionError) as e_info:
241-
nap.compute_perievent_continuous(tsd, [0, 1, 2], windowsize=(-10, 10))
242-
assert str(e_info.value) == "tref should be a Ts or Tsd object."
277+
with pytest.raises(TypeError) as e_info:
278+
nap.compute_perievent_continuous(tsd, tref=[0, 1, 2], windowsize=(-10, 10))
279+
assert (
280+
str(e_info.value)
281+
== "Invalid type. Parameter tref must be of type ['Ts', 'Tsd', 'TsdFrame', 'TsdTensor']."
282+
)
243283

244-
with pytest.raises(AssertionError) as e_info:
284+
with pytest.raises(TypeError) as e_info:
245285
nap.compute_perievent_continuous([0, 1, 2], tref, windowsize=(-10, 10))
246-
assert str(e_info.value) == "data should be a Tsd, TsdFrame or TsdTensor."
286+
assert (
287+
str(e_info.value)
288+
== "Invalid type. Parameter timeseries must be of type ['Tsd', 'TsdFrame', 'TsdTensor']."
289+
)
247290

248-
with pytest.raises(AssertionError) as e_info:
291+
with pytest.raises(TypeError) as e_info:
249292
nap.compute_perievent_continuous(tsd, tref, windowsize={0: 1})
250-
assert str(e_info.value) == "windowsize should be a tuple or int or float."
293+
assert (
294+
str(e_info.value)
295+
== "Invalid type. Parameter windowsize must be of type ['tuple', 'Number']."
296+
)
251297

252-
with pytest.raises(AssertionError) as e_info:
298+
with pytest.raises(TypeError) as e_info:
253299
nap.compute_perievent_continuous(tsd, tref, windowsize=10, time_unit=1)
254-
assert str(e_info.value) == "time_unit should be a str."
300+
assert (
301+
str(e_info.value)
302+
== "Invalid type. Parameter time_unit must be of type ['str']."
303+
)
255304

256-
with pytest.raises(AssertionError) as e_info:
305+
with pytest.raises(RuntimeError) as e_info:
257306
nap.compute_perievent_continuous(tsd, tref, windowsize=10, time_unit="a")
258307
assert str(e_info.value) == "time_unit should be 's', 'ms' or 'us'"
259308

260-
with pytest.raises(AssertionError) as e_info:
309+
with pytest.raises(TypeError) as e_info:
261310
nap.compute_perievent_continuous(tsd, tref, windowsize=10, ep="a")
262-
assert str(e_info.value) == "ep should be an IntervalSet object."
311+
assert (
312+
str(e_info.value)
313+
== "Invalid type. Parameter ep must be of type ['IntervalSet']."
314+
)
315+
316+
with pytest.raises(RuntimeError) as e_info:
317+
nap.compute_perievent_continuous(tsd, tref, windowsize=(1, 2, 3))
318+
assert (
319+
str(e_info.value)
320+
== "windowsize should be a tuple of 2 numbers or a single number."
321+
)
322+
323+
with pytest.raises(RuntimeError) as e_info:
324+
nap.compute_perievent_continuous(tsd, tref, windowsize=(1, "2"))
325+
assert (
326+
str(e_info.value)
327+
== "windowsize should be a tuple of 2 numbers or a single number."
328+
)

tests/test_spike_trigger_average.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ def test_compute_spike_trigger_average_tsd():
3333
assert sta.shape == output.shape
3434
np.testing.assert_array_almost_equal(sta.values, output)
3535

36+
sta = nap.compute_event_trigger_average(spikes, feature, 0.2, (0.6, 0.6))
37+
assert isinstance(sta, nap.TsdFrame)
38+
assert sta.shape == output.shape
39+
np.testing.assert_array_almost_equal(sta.values, output)
40+
3641

3742
def test_compute_spike_trigger_average_tsdframe():
3843
ep = nap.IntervalSet(0, 100)
@@ -149,41 +154,62 @@ def test_compute_spike_trigger_average_raise_error():
149154
{0: nap.Ts(t1), 1: nap.Ts(t1 - 0.1), 2: nap.Ts(t1 + 0.2)}, time_support=ep
150155
)
151156

152-
with pytest.raises(Exception) as e_info:
157+
with pytest.raises(TypeError) as e_info:
153158
nap.compute_event_trigger_average(feature, feature, 0.1, (0.5, 0.5), ep)
154-
assert str(e_info.value) == "group should be a TsGroup."
159+
assert (
160+
str(e_info.value)
161+
== "Invalid type. Parameter group must be of type ['TsGroup']."
162+
)
155163

156-
with pytest.raises(Exception) as e_info:
164+
with pytest.raises(TypeError) as e_info:
157165
nap.compute_event_trigger_average(spikes, np.array(10), 0.1, (0.5, 0.5), ep)
158-
assert str(e_info.value) == "Feature should be a Tsd, TsdFrame or TsdTensor"
166+
assert (
167+
str(e_info.value)
168+
== "Invalid type. Parameter feature must be of type ['Tsd', 'TsdFrame', 'TsdTensor']."
169+
)
159170

160-
with pytest.raises(Exception) as e_info:
171+
with pytest.raises(TypeError) as e_info:
161172
nap.compute_event_trigger_average(spikes, feature, "0.1", (0.5, 0.5), ep)
162-
assert str(e_info.value) == "binsize should be int or float."
173+
assert (
174+
str(e_info.value)
175+
== "Invalid type. Parameter binsize must be of type ['Number']."
176+
)
163177

164-
with pytest.raises(Exception) as e_info:
178+
with pytest.raises(TypeError) as e_info:
165179
nap.compute_event_trigger_average(
166180
spikes, feature, 0.1, (0.5, 0.5), ep, time_unit=1
167181
)
168-
assert str(e_info.value) == "time_unit should be a str."
182+
assert (
183+
str(e_info.value)
184+
== "Invalid type. Parameter time_unit must be of type ['str']."
185+
)
169186

170-
with pytest.raises(Exception) as e_info:
187+
with pytest.raises(RuntimeError) as e_info:
171188
nap.compute_event_trigger_average(
172189
spikes, feature, 0.1, (0.5, 0.5), ep, time_unit="a"
173190
)
174191
assert str(e_info.value) == "time_unit should be 's', 'ms' or 'us'"
175192

176-
with pytest.raises(Exception) as e_info:
193+
with pytest.raises(RuntimeError) as e_info:
177194
nap.compute_event_trigger_average(spikes, feature, 0.1, (0.5, 0.5, 0.5), ep)
178-
assert str(e_info.value) == "windowsize should be a tuple of 2 elements (-t, +t)"
195+
assert (
196+
str(e_info.value)
197+
== "windowsize should be a tuple of 2 numbers or a single number."
198+
)
179199

180-
with pytest.raises(Exception) as e_info:
200+
with pytest.raises(RuntimeError) as e_info:
181201
nap.compute_event_trigger_average(spikes, feature, 0.1, ("a", "b"), ep)
182-
assert str(e_info.value) == "windowsize should be a tuple of int/float"
202+
assert (
203+
str(e_info.value)
204+
== "windowsize should be a tuple of 2 numbers or a single number."
205+
)
183206

184-
with pytest.raises(Exception) as e_info:
207+
with pytest.raises(TypeError) as e_info:
185208
nap.compute_event_trigger_average(spikes, feature, 0.1, (0.5, 0.5), [1, 2, 3])
186-
assert str(e_info.value) == "ep should be an IntervalSet object."
209+
assert (
210+
str(e_info.value)
211+
== "Invalid type. Parameter ep must be of type ['IntervalSet']."
212+
)
187213

188214

189215
def test_compute_spike_trigger_average_time_unit():

0 commit comments

Comments
 (0)