Skip to content

Commit f29e0cf

Browse files
authored
Merge pull request #257 from pynapple-org/tuning_from_tsd
Tuning from tsd
2 parents 135a3f4 + abb4c7e commit f29e0cf

File tree

3 files changed

+114
-39
lines changed

3 files changed

+114
-39
lines changed

pynapple/io/interface_nwb.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@
1919
import numpy as np
2020
import pynwb
2121
from pynwb import NWBHDF5IO
22-
23-
# from rich.console import Console
24-
# from rich.table import Table
2522
from tabulate import tabulate
2623

2724
from .. import core as nap

pynapple/process/tuning_curves.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,8 @@ def compute_1d_tuning_curves_continuous(
426426
"""
427427
if not isinstance(tsdframe, (nap.Tsd, nap.TsdFrame)):
428428
raise RuntimeError("Unknown format for tsdframe.")
429+
elif isinstance(tsdframe, nap.Tsd):
430+
tsdframe = tsdframe[:, np.newaxis]
429431

430432
assert isinstance(
431433
feature, (nap.Tsd, nap.TsdFrame)
@@ -496,6 +498,8 @@ def compute_2d_tuning_curves_continuous(
496498
"""
497499
if not isinstance(tsdframe, (nap.Tsd, nap.TsdFrame)):
498500
raise RuntimeError("Unknown format for tsdframe.")
501+
elif isinstance(tsdframe, nap.Tsd):
502+
tsdframe = tsdframe[:, np.newaxis]
499503

500504
assert isinstance(
501505
features, nap.TsdFrame

tests/test_tuning_curves.py

Lines changed: 110 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# @Last Modified time: 2024-01-29 11:05:11
66

77
"""Tests of tuning curves for `pynapple` package."""
8-
8+
from contextlib import nullcontext as does_not_raise
99
import pynapple as nap
1010
import numpy as np
1111
import pandas as pd
@@ -214,22 +214,69 @@ def test_compute_2d_mutual_info():
214214
np.testing.assert_approx_equal(si.loc[0, "SI"], 2.0)
215215

216216

217-
def test_compute_1d_tuning_curves_continuous():
218-
tsdframe = nap.TsdFrame(t=np.arange(0, 100), d=np.ones((100, 1)))
217+
@pytest.mark.parametrize(
218+
"tsd, expected_columns",
219+
[
220+
(nap.TsdFrame(t=np.arange(0, 100), d=np.ones((100, 1))), [0]),
221+
(nap.TsdFrame(t=np.arange(0, 100), d=np.ones((100, 2))), [0, 1]),
222+
(nap.Tsd(t=np.arange(0, 100), d=np.ones((100, ))), [0])
223+
]
224+
)
225+
def test_compute_1d_tuning_curves_continuous(tsd, expected_columns):
219226
feature = nap.Tsd(t=np.arange(0, 100, 0.1), d=np.arange(0, 100, 0.1) % 1.0)
220-
tc = nap.compute_1d_tuning_curves_continuous(tsdframe, feature, nb_bins=10)
227+
tc = nap.compute_1d_tuning_curves_continuous(tsd, feature, nb_bins=10)
221228

222229
assert len(tc) == 10
223-
assert list(tc.columns) == list(tsdframe.columns)
230+
assert list(tc.columns) == expected_columns
224231
np.testing.assert_array_almost_equal(tc[0].values[1:], np.zeros(9))
225232
assert int(tc[0].values[0]) == 1.0
226233

227-
def test_compute_1d_tuning_curves_continuous_error():
228-
tsdframe = nap.TsdFrame(t=np.arange(0, 100), d=np.ones((100, 1)))
229-
feature = nap.Tsd(t=np.arange(0, 100, 0.1), d=np.arange(0, 100, 0.1) % 1.0)
230-
with pytest.raises(RuntimeError) as e_info:
231-
nap.compute_1d_tuning_curves_continuous([1,2,3], feature, nb_bins=10)
232-
assert str(e_info.value) == "Unknown format for tsdframe."
234+
235+
@pytest.mark.parametrize(
236+
"tsdframe, expectation",
237+
[
238+
(nap.TsdFrame(t=np.arange(0, 100), d=np.ones((100, 1))), does_not_raise()),
239+
([1, 2, 3], pytest.raises(RuntimeError, match="Unknown format for tsdframe.")),
240+
(
241+
nap.TsdTensor(t=np.arange(0, 100), d=np.ones((100, 1, 1))),
242+
pytest.raises(RuntimeError, match="Unknown format for tsdframe.")
243+
),
244+
]
245+
)
246+
@pytest.mark.parametrize(
247+
"feature",
248+
[
249+
nap.Tsd(t=np.arange(0, 100, 0.1), d=np.arange(0, 100, 0.1) % 1.0),
250+
nap.TsdFrame(t=np.arange(0, 100, 0.1), d=np.arange(0, 100, 0.1)[:, None] % 1.0)
251+
]
252+
)
253+
def test_compute_1d_tuning_curves_continuous_error_tsdframe(tsdframe, expectation, feature):
254+
with expectation:
255+
nap.compute_1d_tuning_curves_continuous(tsdframe, feature, nb_bins=10)
256+
257+
258+
@pytest.mark.parametrize(
259+
"feature, expectation",
260+
[
261+
(nap.Tsd(t=np.arange(0, 100, 0.1), d=np.arange(0, 100, 0.1) % 1.0), does_not_raise()),
262+
(nap.TsdFrame(t=np.arange(0, 100, 0.1), d=np.arange(0, 100, 0.1)[:, None] % 1.0), does_not_raise()),
263+
(
264+
nap.TsdFrame(t=np.arange(0, 100, 0.1), d=np.arange(0, 200, 0.1).reshape(1000, 2) % 1.0),
265+
pytest.raises(AssertionError, match=r"feature should be a Tsd \(or TsdFrame with 1 column only\)"))
266+
]
267+
)
268+
@pytest.mark.parametrize(
269+
"tsd",
270+
[
271+
nap.TsdFrame(t=np.arange(0, 100), d=np.ones((100, 1))),
272+
nap.TsdFrame(t=np.arange(0, 100), d=np.ones((100, 2))),
273+
nap.Tsd(t=np.arange(0, 100), d=np.ones((100, )))
274+
]
275+
)
276+
def test_compute_1d_tuning_curves_continuous_error_featues(tsd, feature, expectation):
277+
with expectation:
278+
nap.compute_1d_tuning_curves_continuous(tsd, feature, nb_bins=10)
279+
233280

234281
def test_compute_1d_tuning_curves_continuous_with_ep():
235282
tsdframe = nap.TsdFrame(t=np.arange(0, 100), d=np.ones((100, 1)))
@@ -250,31 +297,51 @@ def test_compute_1d_tuning_curves_continuous_with_min_max():
250297
np.testing.assert_array_almost_equal(tc[0].values[1:], np.zeros(9))
251298
assert tc[0].values[0] == 1.0
252299

300+
@pytest.mark.parametrize(
301+
"tsdframe, expected_columns",
302+
[
303+
(nap.TsdFrame(t=np.arange(0, 100), d=np.hstack((np.ones((100, 1)), np.ones((100, 1)) * 2))), [0, 1]),
304+
(
305+
nap.TsdFrame(t=np.arange(0, 100), d=np.hstack((np.ones((100, 1)), np.ones((100, 1)) * 2)),
306+
columns=["x", "y"]),
307+
["x", "y"]
308+
),
309+
(nap.Tsd(t=np.arange(0, 100), d=np.hstack((np.ones((100, )) * 2))), [0])
310+
311+
]
312+
)
313+
@pytest.mark.parametrize("nb_bins", [1, 2, 3])
314+
def test_compute_2d_tuning_curves_continuous(nb_bins, tsdframe, expected_columns):
253315

254-
def test_compute_2d_tuning_curves_continuous():
255-
tsdframe = nap.TsdFrame(
256-
t=np.arange(0, 100), d=np.hstack((np.ones((100, 1)), np.ones((100, 1)) * 2))
257-
)
258316
features = nap.TsdFrame(
259317
t=np.arange(100), d=np.tile(np.array([[0, 0, 1, 1], [0, 1, 0, 1]]), 25).T
260318
)
261-
tc, xy = nap.compute_2d_tuning_curves_continuous(tsdframe, features, 2)
319+
tc, xy = nap.compute_2d_tuning_curves_continuous(tsdframe, features, nb_bins)
262320

263321
assert isinstance(tc, dict)
264-
assert list(tc.keys()) == list(tsdframe.columns)
322+
assert list(tc.keys()) == expected_columns
265323
for i in tc.keys():
266-
assert tc[i].shape == (2, 2)
267-
tmp = np.zeros((2, 2, 2))
268-
tmp[:, 0, 0] = [1, 2]
269-
for i in range(2):
270-
np.testing.assert_array_almost_equal(tc[i], tmp[i])
324+
assert tc[i].shape == (nb_bins, nb_bins)
325+
271326
assert isinstance(xy, list)
272327
assert len(xy) == 2
273328
for i in range(2):
274329
assert np.min(xy) > 0
275330
assert np.max(xy) < 1
276331

277332

333+
def test_compute_2d_tuning_curves_continuous_output_value():
334+
tsdframe = nap.TsdFrame(t=np.arange(0, 100), d=np.hstack((np.ones((100, 1)), np.ones((100, 1)) * 2)))
335+
features = nap.TsdFrame(
336+
t=np.arange(100), d=np.tile(np.array([[0, 0, 1, 1], [0, 1, 0, 1]]), 25).T
337+
)
338+
tc, xy = nap.compute_2d_tuning_curves_continuous(tsdframe, features, 2)
339+
tmp = np.zeros((2, 2, 2))
340+
tmp[:, 0, 0] = [1, 2]
341+
for i in range(2):
342+
np.testing.assert_array_almost_equal(tc[i], tmp[i])
343+
344+
278345
def test_compute_2d_tuning_curves_continuous_with_ep():
279346
tsdframe = nap.TsdFrame(
280347
t=np.arange(0, 100), d=np.hstack((np.ones((100, 1)), np.ones((100, 1)) * 2))
@@ -289,28 +356,35 @@ def test_compute_2d_tuning_curves_continuous_with_ep():
289356
for i in tc1.keys():
290357
np.testing.assert_array_almost_equal(tc1[i], tc2[i])
291358

292-
def test_compute_2d_tuning_curves_continuous_error():
293-
tsdframe = nap.TsdFrame(
294-
t=np.arange(0, 100), d=np.hstack((np.ones((100, 1)), np.ones((100, 1)) * 2))
295-
)
359+
360+
def test_compute_2d_tuning_curves_continuous_error_tsdframe():
296361
features = nap.TsdFrame(
297362
t=np.arange(100), d=np.tile(np.array([[0, 0, 1, 1], [0, 1, 0, 1]]), 25).T
298363
)
299364
with pytest.raises(RuntimeError) as e_info:
300365
nap.compute_2d_tuning_curves_continuous([1,2,3], features, 2)
301366
assert str(e_info.value) == "Unknown format for tsdframe."
302367

303-
with pytest.raises(AssertionError) as e_info:
304-
nap.compute_2d_tuning_curves_continuous(tsdframe, [1,2,3], 2)
305-
assert str(e_info.value) == "features should be a TsdFrame with 2 columns"
306-
307-
features = nap.TsdFrame(
308-
t=np.arange(100), d=np.tile(np.array([[0, 0, 1, 1], [0, 1, 0, 1], [0,0,0,0]]), 25).T
309-
)
310-
with pytest.raises(AssertionError) as e_info:
368+
@pytest.mark.parametrize(
369+
"features, expectation",
370+
[
371+
([1, 2, 3], pytest.raises(AssertionError, match="features should be a TsdFrame with 2 columns")),
372+
(nap.TsdFrame(t=np.arange(100), d=np.tile(np.array([[0, 0, 1, 1], [0, 1, 0, 1], [0,0,0,0]]), 25).T),
373+
pytest.raises(AssertionError, match="features should have 2 columns only")),
374+
375+
]
376+
377+
)
378+
@pytest.mark.parametrize(
379+
"tsdframe",
380+
[
381+
nap.TsdFrame(t=np.arange(0, 100), d=np.hstack((np.ones((100, 1)), np.ones((100, 1)) * 2))),
382+
nap.Tsd(t=np.arange(0, 100), d=np.ones((100, )))
383+
]
384+
)
385+
def test_compute_2d_tuning_curves_continuous_error_feature(tsdframe, features, expectation):
386+
with expectation:
311387
nap.compute_2d_tuning_curves_continuous(tsdframe, features, 2)
312-
assert str(e_info.value) == "features should have 2 columns only."
313-
314388

315389
@pytest.mark.filterwarnings("ignore")
316390
def test_compute_2d_tuning_curves_with_minmax():

0 commit comments

Comments
 (0)