Skip to content

Fixing issue with event trigger average #209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions pynapple/process/perievent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: gviejo
# @Date: 2022-01-30 22:59:00
# @Last Modified by: Guillaume Viejo
# @Last Modified time: 2023-11-19 19:13:24
# @Last Modified time: 2023-11-20 12:08:15

import numpy as np
from scipy.linalg import hankel
Expand Down Expand Up @@ -121,7 +121,7 @@ def compute_event_trigger_average(
group : TsGroup
The group of Ts/Tsd objects that hold the trigger time.
feature : Tsd
The 1-dimensional feature to average
The 1-dimensional feature to average. Can be a TsdFrame with one column only.
binsize : float
The bin size. Default is second.
If different, specify with the parameter time_units ('s' [default], 'ms', 'us').
Expand All @@ -147,6 +147,13 @@ def compute_event_trigger_average(
if type(group) is not nap.TsGroup:
raise RuntimeError("Unknown format for group")

if isinstance(feature, nap.TsdFrame):
if feature.shape[1] == 1:
feature = feature[:, 0]

if type(feature) is not nap.Tsd:
raise RuntimeError("Feature should be a Tsd or a TsdFrame with one column")

binsize = nap.TsIndex.format_timestamps(
np.array([binsize], dtype=np.float64), time_units
)[0]
Expand Down
20 changes: 19 additions & 1 deletion tests/test_spike_trigger_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: gviejo
# @Date: 2022-08-29 17:27:02
# @Last Modified by: Guillaume Viejo
# @Last Modified time: 2023-09-18 16:33:39
# @Last Modified time: 2023-11-20 12:07:53
#!/usr/bin/env python

"""Tests of spike trigger average for `pynapple` package."""
Expand Down Expand Up @@ -37,6 +37,12 @@ def test_compute_spike_trigger_average():
assert sta.shape == output.shape
np.testing.assert_array_almost_equal(sta, output)

feature = nap.TsdFrame(
t=feature.index.values, d=feature.values[:,None], time_support=ep
)
sta = nap.compute_event_trigger_average(spikes, feature, 0.2, (0.6, 0.6), ep)
np.testing.assert_array_almost_equal(sta, output)


def test_compute_spike_trigger_average_raise_error():
ep = nap.IntervalSet(0, 101)
Expand All @@ -51,6 +57,18 @@ def test_compute_spike_trigger_average_raise_error():
nap.compute_event_trigger_average(feature, feature, 0.1, (0.5, 0.5), ep)
assert str(e_info.value) == "Unknown format for group"

feature = nap.TsdFrame(
t=np.arange(0, 101, 0.01), d=np.random.rand(int(101 / 0.01), 3), time_support=ep
)
spikes = nap.TsGroup(
{0: nap.Ts(t1), 1: nap.Ts(t1 - 0.1), 2: nap.Ts(t1 + 0.2)}, time_support=ep
)
with pytest.raises(Exception) as e_info:
nap.compute_event_trigger_average(spikes, feature, 0.1, (0.5, 0.5), ep)
assert str(e_info.value) == "Feature should be a Tsd or a TsdFrame with one column"




def test_compute_spike_trigger_average_time_units():
ep = nap.IntervalSet(0, 100)
Expand Down