Skip to content

Commit 1b2be85

Browse files
committed
Decoder with TsdFrame
1 parent 286e64b commit 1b2be85

File tree

2 files changed

+97
-41
lines changed

2 files changed

+97
-41
lines changed

pynapple/process/decoding.py

Lines changed: 57 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ def decode_1d(tuning_curves, group, ep, bin_size, time_units="s", feature=None):
2222
tuning_curves : pandas.DataFrame
2323
Each column is the tuning curve of one neuron relative to the feature.
2424
Index should be the center of the bin.
25-
group : TsGroup or dict of Ts/Tsd object.
25+
group : TsGroup, TsdFrame or dict of Ts/Tsd object.
2626
A group of neurons with the same index as tuning curves column names.
27+
You may also pass a TsdFrame with smoothed rates (recommended).
2728
ep : IntervalSet
2829
The epoch on which decoding is computed
2930
bin_size : float
@@ -48,22 +49,36 @@ def decode_1d(tuning_curves, group, ep, bin_size, time_units="s", feature=None):
4849
If different size of neurons for tuning_curves and group.
4950
If indexes don't match between tuning_curves and group.
5051
"""
51-
if isinstance(group, dict):
52-
newgroup = nap.TsGroup(group, time_support=ep)
52+
if isinstance(group, nap.TsdFrame):
53+
newgroup = group.restrict(ep)
54+
55+
if tuning_curves.shape[1] != newgroup.shape[1]:
56+
raise RuntimeError("Different shapes for tuning_curves and group")
57+
58+
if not np.all(tuning_curves.columns.values == np.array(newgroup.columns)):
59+
raise RuntimeError("Different indices for tuning curves and group keys")
60+
61+
count = group
62+
5363
elif isinstance(group, nap.TsGroup):
5464
newgroup = group.restrict(ep)
65+
66+
if tuning_curves.shape[1] != len(newgroup):
67+
raise RuntimeError("Different shapes for tuning_curves and group")
68+
69+
if not np.all(tuning_curves.columns.values == np.array(newgroup.keys())):
70+
raise RuntimeError("Different indices for tuning curves and group keys")
71+
72+
# Bin spikes
73+
count = newgroup.count(bin_size, ep, time_units)
74+
75+
elif isinstance(group, dict):
76+
newgroup = nap.TsGroup(group, time_support=ep)
77+
count = newgroup.count(bin_size, ep, time_units)
78+
5579
else:
5680
raise RuntimeError("Unknown format for group")
57-
58-
if tuning_curves.shape[1] != len(newgroup):
59-
raise RuntimeError("Different shapes for tuning_curves and group")
60-
61-
if not np.all(tuning_curves.columns.values == np.array(newgroup.keys())):
62-
raise RuntimeError("Difference indexes for tuning curves and group keys")
63-
64-
# Bin spikes
65-
count = newgroup.count(bin_size, ep, time_units)
66-
81+
6782
# Occupancy
6883
if feature is None:
6984
occupancy = np.ones(tuning_curves.shape[0])
@@ -122,9 +137,10 @@ def decode_2d(tuning_curves, group, ep, bin_size, xy, time_units="s", features=N
122137
Parameters
123138
----------
124139
tuning_curves : dict
125-
Dictionnay of 2d tuning curves (one for each neuron).
126-
group : TsGroup or dict of Ts/Tsd object.
140+
Dictionary of 2d tuning curves (one for each neuron).
141+
group : TsGroup, TsdFrame or dict of Ts/Tsd object.
127142
A group of neurons with the same keys as tuning_curves dictionary.
143+
You may also pass a TsdFrame with smoothed rates (recommended).
128144
ep : IntervalSet
129145
The epoch on which decoding is computed
130146
bin_size : float
@@ -153,28 +169,37 @@ def decode_2d(tuning_curves, group, ep, bin_size, xy, time_units="s", features=N
153169
154170
"""
155171

156-
if type(group) is dict:
157-
newgroup = nap.TsGroup(group, time_support=ep)
158-
numcells = len(newgroup)
172+
if type(group) is nap.TsdFrame:
173+
newgroup = group.restrict(ep)
174+
numcells = newgroup.shape[1]
175+
176+
if len(tuning_curves) != numcells:
177+
raise RuntimeError("Different shapes for tuning_curves and group")
178+
179+
if not np.all(np.array(list(tuning_curves.keys())) == np.array(newgroup.columns)):
180+
raise RuntimeError("Different indices for tuning curves and group keys")
181+
182+
count = group
183+
159184
elif type(group) is nap.TsGroup:
160185
newgroup = group.restrict(ep)
161186
numcells = len(newgroup)
187+
188+
if len(tuning_curves) != numcells:
189+
raise RuntimeError("Different shapes for tuning_curves and group")
190+
191+
if not np.all(np.array(list(tuning_curves.keys())) == np.array(newgroup.keys())):
192+
raise RuntimeError("Different indices for tuning curves and group keys")
193+
194+
count = newgroup.count(bin_size, ep, time_units)
195+
196+
elif type(group) is dict:
197+
newgroup = nap.TsGroup(group, time_support=ep)
198+
count = newgroup.count(bin_size, ep, time_units)
199+
162200
else:
163201
raise RuntimeError("Unknown format for group")
164-
165-
if len(tuning_curves) != numcells:
166-
raise RuntimeError("Different shapes for tuning_curves and group")
167-
168-
if not np.all(np.array(list(tuning_curves.keys())) == np.array(newgroup.keys())):
169-
raise RuntimeError("Difference indexes for tuning curves and group keys")
170-
171-
# Bin spikes
172-
# if type(newgroup) is not nap.TsdFrame:
173-
count = newgroup.count(bin_size, ep, time_units)
174-
# else:
175-
# #Spikes already "binned" with continuous TsdFrame input
176-
# count = newgroup
177-
202+
178203
indexes = list(tuning_curves.keys())
179204

180205
# Occupancy
@@ -199,9 +224,7 @@ def decode_2d(tuning_curves, group, ep, bin_size, xy, time_units="s", features=N
199224
tc = np.array([tuning_curves[i] for i in tuning_curves.keys()])
200225
tc = tc.reshape(tc.shape[0], np.prod(tc.shape[1:]))
201226
tc = tc.T
202-
203227
ct = count.values
204-
205228
bin_size_s = nap.TsIndex.format_timestamps(
206229
np.array([bin_size], dtype=np.float64), time_units
207230
)[0]

tests/test_decoding.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def get_testing_set_1d():
2626
def test_decode_1d():
2727
feature, group, tc, ep = get_testing_set_1d()
2828
decoded, proba = nap.decode_1d(tc, group, ep, bin_size=1)
29-
3029
assert isinstance(decoded, nap.Tsd)
3130
assert isinstance(proba, nap.TsdFrame)
3231
np.testing.assert_array_almost_equal(feature.values, decoded.values)
@@ -37,11 +36,10 @@ def test_decode_1d():
3736
tmp[0:50, 1] = 0.0
3837
np.testing.assert_array_almost_equal(proba.values, tmp)
3938

40-
def test_decode_1d_with_dict():
39+
def test_decode_1d_with_TsdFrame():
4140
feature, group, tc, ep = get_testing_set_1d()
42-
group = dict(group)
43-
decoded, proba = nap.decode_1d(tc, group, ep, bin_size=1)
44-
41+
count = group.count(bin_size=1, ep = ep)
42+
decoded, proba = nap.decode_1d(tc, count, ep, bin_size=1)
4543
assert isinstance(decoded, nap.Tsd)
4644
assert isinstance(proba, nap.TsdFrame)
4745
np.testing.assert_array_almost_equal(feature.values, decoded.values)
@@ -65,6 +63,21 @@ def test_decode_1d_with_feature():
6563
tmp[50:, 0] = 0.0
6664
tmp[0:50, 1] = 0.0
6765
np.testing.assert_array_almost_equal(proba.values, tmp)
66+
67+
def test_decode_1d_with_dict():
68+
feature, group, tc, ep = get_testing_set_1d()
69+
group = dict(group)
70+
decoded, proba = nap.decode_1d(tc, group, ep, bin_size=1, feature=feature)
71+
np.testing.assert_array_almost_equal(feature.values, decoded.values)
72+
assert isinstance(decoded, nap.Tsd)
73+
assert isinstance(proba, nap.TsdFrame)
74+
np.testing.assert_array_almost_equal(feature.values, decoded.values)
75+
assert len(decoded) == 100
76+
assert len(proba) == 100
77+
tmp = np.ones((100, 2))
78+
tmp[50:, 0] = 0.0
79+
tmp[0:50, 1] = 0.0
80+
np.testing.assert_array_almost_equal(proba.values, tmp)
6881

6982
def test_decode_1d_with_wrong_feature():
7083
feature, group, tc, ep = get_testing_set_1d()
@@ -94,7 +107,7 @@ def test_decoded_1d_raise_errors():
94107
tc.columns = [0, 2]
95108
with pytest.raises(Exception) as e_info:
96109
nap.decode_1d(tc, group, ep, 1)
97-
assert str(e_info.value) == "Difference indexes for tuning curves and group keys"
110+
assert str(e_info.value) == "Different indices for tuning curves and group keys"
98111

99112

100113
def get_testing_set_2d():
@@ -137,6 +150,26 @@ def test_decode_2d():
137150
tmp[51:100:2, 1] = 1
138151
np.testing.assert_array_almost_equal(proba[:, :, 1], tmp)
139152

153+
def test_decode_2d_with_TsdFrame():
154+
features, group, tc, ep, xy = get_testing_set_2d()
155+
count = group.count(bin_size=1, ep = ep)
156+
decoded, proba = nap.decode_2d(tc, count, ep, 1, xy)
157+
158+
assert isinstance(decoded, nap.TsdFrame)
159+
assert isinstance(proba, np.ndarray)
160+
np.testing.assert_array_almost_equal(features.values, decoded.values)
161+
assert len(decoded) == 100
162+
assert len(proba) == 100
163+
tmp = np.zeros((100, 2))
164+
tmp[0:50:2, 0] = 1
165+
tmp[50:100:2, 1] = 1
166+
np.testing.assert_array_almost_equal(proba[:, :, 0], tmp)
167+
168+
tmp = np.zeros((100, 2))
169+
tmp[1:50:2, 0] = 1
170+
tmp[51:100:2, 1] = 1
171+
np.testing.assert_array_almost_equal(proba[:, :, 1], tmp)
172+
140173
def test_decode_2d_with_dict():
141174
features, group, tc, ep, xy = get_testing_set_2d()
142175
group = dict(group)
@@ -186,4 +219,4 @@ def test_decoded_2d_raise_errors():
186219
tc = {k: tc[i] for k, i in zip(np.arange(0, 40, 10), tc.keys())}
187220
with pytest.raises(Exception) as e_info:
188221
nap.decode_2d(tc, group, ep, 1, xy)
189-
assert str(e_info.value) == "Difference indexes for tuning curves and group keys"
222+
assert str(e_info.value) == "Different indices for tuning curves and group keys"

0 commit comments

Comments
 (0)