@@ -22,8 +22,9 @@ def decode_1d(tuning_curves, group, ep, bin_size, time_units="s", feature=None):
22
22
tuning_curves : pandas.DataFrame
23
23
Each column is the tuning curve of one neuron relative to the feature.
24
24
Index should be the center of the bin.
25
- group : TsGroup or dict of Ts/Tsd object.
25
+ group : TsGroup, TsdFrame or dict of Ts/Tsd object.
26
26
A group of neurons with the same index as tuning curves column names.
27
+ You may also pass a TsdFrame with smoothed rates (recommended).
27
28
ep : IntervalSet
28
29
The epoch on which decoding is computed
29
30
bin_size : float
@@ -48,22 +49,36 @@ def decode_1d(tuning_curves, group, ep, bin_size, time_units="s", feature=None):
48
49
If different size of neurons for tuning_curves and group.
49
50
If indexes don't match between tuning_curves and group.
50
51
"""
51
- if isinstance (group , dict ):
52
- newgroup = nap .TsGroup (group , time_support = ep )
52
+ if isinstance (group , nap .TsdFrame ):
53
+ newgroup = group .restrict (ep )
54
+
55
+ if tuning_curves .shape [1 ] != newgroup .shape [1 ]:
56
+ raise RuntimeError ("Different shapes for tuning_curves and group" )
57
+
58
+ if not np .all (tuning_curves .columns .values == np .array (newgroup .columns )):
59
+ raise RuntimeError ("Different indices for tuning curves and group keys" )
60
+
61
+ count = group
62
+
53
63
elif isinstance (group , nap .TsGroup ):
54
64
newgroup = group .restrict (ep )
65
+
66
+ if tuning_curves .shape [1 ] != len (newgroup ):
67
+ raise RuntimeError ("Different shapes for tuning_curves and group" )
68
+
69
+ if not np .all (tuning_curves .columns .values == np .array (newgroup .keys ())):
70
+ raise RuntimeError ("Different indices for tuning curves and group keys" )
71
+
72
+ # Bin spikes
73
+ count = newgroup .count (bin_size , ep , time_units )
74
+
75
+ elif isinstance (group , dict ):
76
+ newgroup = nap .TsGroup (group , time_support = ep )
77
+ count = newgroup .count (bin_size , ep , time_units )
78
+
55
79
else :
56
80
raise RuntimeError ("Unknown format for group" )
57
-
58
- if tuning_curves .shape [1 ] != len (newgroup ):
59
- raise RuntimeError ("Different shapes for tuning_curves and group" )
60
-
61
- if not np .all (tuning_curves .columns .values == np .array (newgroup .keys ())):
62
- raise RuntimeError ("Difference indexes for tuning curves and group keys" )
63
-
64
- # Bin spikes
65
- count = newgroup .count (bin_size , ep , time_units )
66
-
81
+
67
82
# Occupancy
68
83
if feature is None :
69
84
occupancy = np .ones (tuning_curves .shape [0 ])
@@ -122,9 +137,10 @@ def decode_2d(tuning_curves, group, ep, bin_size, xy, time_units="s", features=N
122
137
Parameters
123
138
----------
124
139
tuning_curves : dict
125
- Dictionnay of 2d tuning curves (one for each neuron).
126
- group : TsGroup or dict of Ts/Tsd object.
140
+ Dictionary of 2d tuning curves (one for each neuron).
141
+ group : TsGroup, TsdFrame or dict of Ts/Tsd object.
127
142
A group of neurons with the same keys as tuning_curves dictionary.
143
+ You may also pass a TsdFrame with smoothed rates (recommended).
128
144
ep : IntervalSet
129
145
The epoch on which decoding is computed
130
146
bin_size : float
@@ -153,28 +169,37 @@ def decode_2d(tuning_curves, group, ep, bin_size, xy, time_units="s", features=N
153
169
154
170
"""
155
171
156
- if type (group ) is dict :
157
- newgroup = nap .TsGroup (group , time_support = ep )
158
- numcells = len (newgroup )
172
+ if type (group ) is nap .TsdFrame :
173
+ newgroup = group .restrict (ep )
174
+ numcells = newgroup .shape [1 ]
175
+
176
+ if len (tuning_curves ) != numcells :
177
+ raise RuntimeError ("Different shapes for tuning_curves and group" )
178
+
179
+ if not np .all (np .array (list (tuning_curves .keys ())) == np .array (newgroup .columns )):
180
+ raise RuntimeError ("Different indices for tuning curves and group keys" )
181
+
182
+ count = group
183
+
159
184
elif type (group ) is nap .TsGroup :
160
185
newgroup = group .restrict (ep )
161
186
numcells = len (newgroup )
187
+
188
+ if len (tuning_curves ) != numcells :
189
+ raise RuntimeError ("Different shapes for tuning_curves and group" )
190
+
191
+ if not np .all (np .array (list (tuning_curves .keys ())) == np .array (newgroup .keys ())):
192
+ raise RuntimeError ("Different indices for tuning curves and group keys" )
193
+
194
+ count = newgroup .count (bin_size , ep , time_units )
195
+
196
+ elif type (group ) is dict :
197
+ newgroup = nap .TsGroup (group , time_support = ep )
198
+ count = newgroup .count (bin_size , ep , time_units )
199
+
162
200
else :
163
201
raise RuntimeError ("Unknown format for group" )
164
-
165
- if len (tuning_curves ) != numcells :
166
- raise RuntimeError ("Different shapes for tuning_curves and group" )
167
-
168
- if not np .all (np .array (list (tuning_curves .keys ())) == np .array (newgroup .keys ())):
169
- raise RuntimeError ("Difference indexes for tuning curves and group keys" )
170
-
171
- # Bin spikes
172
- # if type(newgroup) is not nap.TsdFrame:
173
- count = newgroup .count (bin_size , ep , time_units )
174
- # else:
175
- # #Spikes already "binned" with continuous TsdFrame input
176
- # count = newgroup
177
-
202
+
178
203
indexes = list (tuning_curves .keys ())
179
204
180
205
# Occupancy
@@ -199,9 +224,7 @@ def decode_2d(tuning_curves, group, ep, bin_size, xy, time_units="s", features=N
199
224
tc = np .array ([tuning_curves [i ] for i in tuning_curves .keys ()])
200
225
tc = tc .reshape (tc .shape [0 ], np .prod (tc .shape [1 :]))
201
226
tc = tc .T
202
-
203
227
ct = count .values
204
-
205
228
bin_size_s = nap .TsIndex .format_timestamps (
206
229
np .array ([bin_size ], dtype = np .float64 ), time_units
207
230
)[0 ]
0 commit comments