diff --git a/pynapple/process/decoding.py b/pynapple/process/decoding.py index 13900d3e..e730138b 100644 --- a/pynapple/process/decoding.py +++ b/pynapple/process/decoding.py @@ -22,8 +22,9 @@ def decode_1d(tuning_curves, group, ep, bin_size, time_units="s", feature=None): tuning_curves : pandas.DataFrame Each column is the tuning curve of one neuron relative to the feature. Index should be the center of the bin. - group : TsGroup or dict of Ts/Tsd object. + group : TsGroup, TsdFrame or dict of Ts/Tsd object. A group of neurons with the same index as tuning curves column names. + You may also pass a TsdFrame with smoothed rates (recommended). ep : IntervalSet The epoch on which decoding is computed bin_size : float @@ -48,21 +49,35 @@ def decode_1d(tuning_curves, group, ep, bin_size, time_units="s", feature=None): If different size of neurons for tuning_curves and group. If indexes don't match between tuning_curves and group. """ - if isinstance(group, dict): - newgroup = nap.TsGroup(group, time_support=ep) + if isinstance(group, nap.TsdFrame): + newgroup = group.restrict(ep) + + if tuning_curves.shape[1] != newgroup.shape[1]: + raise RuntimeError("Different shapes for tuning_curves and group") + + if not np.all(tuning_curves.columns.values == np.array(newgroup.columns)): + raise RuntimeError("Different indices for tuning curves and group keys") + + count = group + elif isinstance(group, nap.TsGroup): newgroup = group.restrict(ep) - else: - raise RuntimeError("Unknown format for group") - if tuning_curves.shape[1] != len(newgroup): - raise RuntimeError("Different shapes for tuning_curves and group") + if tuning_curves.shape[1] != len(newgroup): + raise RuntimeError("Different shapes for tuning_curves and group") - if not np.all(tuning_curves.columns.values == np.array(newgroup.keys())): - raise RuntimeError("Difference indexes for tuning curves and group keys") + if not np.all(tuning_curves.columns.values == np.array(newgroup.keys())): + raise RuntimeError("Different indices for tuning curves and group keys") - # Bin spikes - count = newgroup.count(bin_size, ep, time_units) + # Bin spikes + count = newgroup.count(bin_size, ep, time_units) + + elif isinstance(group, dict): + newgroup = nap.TsGroup(group, time_support=ep) + count = newgroup.count(bin_size, ep, time_units) + + else: + raise RuntimeError("Unknown format for group") # Occupancy if feature is None: @@ -122,9 +137,10 @@ def decode_2d(tuning_curves, group, ep, bin_size, xy, time_units="s", features=N Parameters ---------- tuning_curves : dict - Dictionnay of 2d tuning curves (one for each neuron). - group : TsGroup or dict of Ts/Tsd object. + Dictionary of 2d tuning curves (one for each neuron). + group : TsGroup, TsdFrame or dict of Ts/Tsd object. A group of neurons with the same keys as tuning_curves dictionary. + You may also pass a TsdFrame with smoothed rates (recommended). ep : IntervalSet The epoch on which decoding is computed bin_size : float @@ -153,27 +169,40 @@ def decode_2d(tuning_curves, group, ep, bin_size, xy, time_units="s", features=N """ - if type(group) is dict: - newgroup = nap.TsGroup(group, time_support=ep) - numcells = len(newgroup) + if type(group) is nap.TsdFrame: + newgroup = group.restrict(ep) + numcells = newgroup.shape[1] + + if len(tuning_curves) != numcells: + raise RuntimeError("Different shapes for tuning_curves and group") + + if not np.all( + np.array(list(tuning_curves.keys())) == np.array(newgroup.columns) + ): + raise RuntimeError("Different indices for tuning curves and group keys") + + count = group + elif type(group) is nap.TsGroup: newgroup = group.restrict(ep) numcells = len(newgroup) - else: - raise RuntimeError("Unknown format for group") - if len(tuning_curves) != numcells: - raise RuntimeError("Different shapes for tuning_curves and group") + if len(tuning_curves) != numcells: + raise RuntimeError("Different shapes for tuning_curves and group") - if not np.all(np.array(list(tuning_curves.keys())) == np.array(newgroup.keys())): - raise RuntimeError("Difference indexes for tuning curves and group keys") + if not np.all( + np.array(list(tuning_curves.keys())) == np.array(newgroup.keys()) + ): + raise RuntimeError("Different indices for tuning curves and group keys") - # Bin spikes - # if type(newgroup) is not nap.TsdFrame: - count = newgroup.count(bin_size, ep, time_units) - # else: - # #Spikes already "binned" with continuous TsdFrame input - # count = newgroup + count = newgroup.count(bin_size, ep, time_units) + + elif type(group) is dict: + newgroup = nap.TsGroup(group, time_support=ep) + count = newgroup.count(bin_size, ep, time_units) + + else: + raise RuntimeError("Unknown format for group") indexes = list(tuning_curves.keys()) @@ -199,9 +228,7 @@ def decode_2d(tuning_curves, group, ep, bin_size, xy, time_units="s", features=N tc = np.array([tuning_curves[i] for i in tuning_curves.keys()]) tc = tc.reshape(tc.shape[0], np.prod(tc.shape[1:])) tc = tc.T - ct = count.values - bin_size_s = nap.TsIndex.format_timestamps( np.array([bin_size], dtype=np.float64), time_units )[0] diff --git a/tests/test_decoding.py b/tests/test_decoding.py index f181badf..240c7f4d 100644 --- a/tests/test_decoding.py +++ b/tests/test_decoding.py @@ -26,7 +26,6 @@ def get_testing_set_1d(): def test_decode_1d(): feature, group, tc, ep = get_testing_set_1d() decoded, proba = nap.decode_1d(tc, group, ep, bin_size=1) - assert isinstance(decoded, nap.Tsd) assert isinstance(proba, nap.TsdFrame) np.testing.assert_array_almost_equal(feature.values, decoded.values) @@ -37,11 +36,10 @@ def test_decode_1d(): tmp[0:50, 1] = 0.0 np.testing.assert_array_almost_equal(proba.values, tmp) -def test_decode_1d_with_dict(): +def test_decode_1d_with_TsdFrame(): feature, group, tc, ep = get_testing_set_1d() - group = dict(group) - decoded, proba = nap.decode_1d(tc, group, ep, bin_size=1) - + count = group.count(bin_size=1, ep = ep) + decoded, proba = nap.decode_1d(tc, count, ep, bin_size=1) assert isinstance(decoded, nap.Tsd) assert isinstance(proba, nap.TsdFrame) np.testing.assert_array_almost_equal(feature.values, decoded.values) @@ -65,6 +63,21 @@ def test_decode_1d_with_feature(): tmp[50:, 0] = 0.0 tmp[0:50, 1] = 0.0 np.testing.assert_array_almost_equal(proba.values, tmp) + +def test_decode_1d_with_dict(): + feature, group, tc, ep = get_testing_set_1d() + group = dict(group) + decoded, proba = nap.decode_1d(tc, group, ep, bin_size=1, feature=feature) + np.testing.assert_array_almost_equal(feature.values, decoded.values) + assert isinstance(decoded, nap.Tsd) + assert isinstance(proba, nap.TsdFrame) + np.testing.assert_array_almost_equal(feature.values, decoded.values) + assert len(decoded) == 100 + assert len(proba) == 100 + tmp = np.ones((100, 2)) + tmp[50:, 0] = 0.0 + tmp[0:50, 1] = 0.0 + np.testing.assert_array_almost_equal(proba.values, tmp) def test_decode_1d_with_wrong_feature(): feature, group, tc, ep = get_testing_set_1d() @@ -94,7 +107,7 @@ def test_decoded_1d_raise_errors(): tc.columns = [0, 2] with pytest.raises(Exception) as e_info: nap.decode_1d(tc, group, ep, 1) - assert str(e_info.value) == "Difference indexes for tuning curves and group keys" + assert str(e_info.value) == "Different indices for tuning curves and group keys" def get_testing_set_2d(): @@ -137,6 +150,26 @@ def test_decode_2d(): tmp[51:100:2, 1] = 1 np.testing.assert_array_almost_equal(proba[:, :, 1], tmp) +def test_decode_2d_with_TsdFrame(): + features, group, tc, ep, xy = get_testing_set_2d() + count = group.count(bin_size=1, ep = ep) + decoded, proba = nap.decode_2d(tc, count, ep, 1, xy) + + assert isinstance(decoded, nap.TsdFrame) + assert isinstance(proba, np.ndarray) + np.testing.assert_array_almost_equal(features.values, decoded.values) + assert len(decoded) == 100 + assert len(proba) == 100 + tmp = np.zeros((100, 2)) + tmp[0:50:2, 0] = 1 + tmp[50:100:2, 1] = 1 + np.testing.assert_array_almost_equal(proba[:, :, 0], tmp) + + tmp = np.zeros((100, 2)) + tmp[1:50:2, 0] = 1 + tmp[51:100:2, 1] = 1 + np.testing.assert_array_almost_equal(proba[:, :, 1], tmp) + def test_decode_2d_with_dict(): features, group, tc, ep, xy = get_testing_set_2d() group = dict(group) @@ -186,4 +219,4 @@ def test_decoded_2d_raise_errors(): tc = {k: tc[i] for k, i in zip(np.arange(0, 40, 10), tc.keys())} with pytest.raises(Exception) as e_info: nap.decode_2d(tc, group, ep, 1, xy) - assert str(e_info.value) == "Difference indexes for tuning curves and group keys" + assert str(e_info.value) == "Different indices for tuning curves and group keys"