Skip to content

Commit

Permalink
Merge pull request #374 from dhruvm9/main
Browse files Browse the repository at this point in the history
Decoder with TsdFrame
  • Loading branch information
gviejo authored Dec 11, 2024
2 parents 286e64b + 2db6598 commit 9bccc4b
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 37 deletions.
87 changes: 57 additions & 30 deletions pynapple/process/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ def decode_1d(tuning_curves, group, ep, bin_size, time_units="s", feature=None):
tuning_curves : pandas.DataFrame
Each column is the tuning curve of one neuron relative to the feature.
Index should be the center of the bin.
group : TsGroup or dict of Ts/Tsd object.
group : TsGroup, TsdFrame or dict of Ts/Tsd object.
A group of neurons with the same index as tuning curves column names.
You may also pass a TsdFrame with smoothed rates (recommended).
ep : IntervalSet
The epoch on which decoding is computed
bin_size : float
Expand All @@ -48,21 +49,35 @@ def decode_1d(tuning_curves, group, ep, bin_size, time_units="s", feature=None):
If different size of neurons for tuning_curves and group.
If indexes don't match between tuning_curves and group.
"""
if isinstance(group, dict):
newgroup = nap.TsGroup(group, time_support=ep)
if isinstance(group, nap.TsdFrame):
newgroup = group.restrict(ep)

if tuning_curves.shape[1] != newgroup.shape[1]:
raise RuntimeError("Different shapes for tuning_curves and group")

if not np.all(tuning_curves.columns.values == np.array(newgroup.columns)):
raise RuntimeError("Different indices for tuning curves and group keys")

count = group

elif isinstance(group, nap.TsGroup):
newgroup = group.restrict(ep)
else:
raise RuntimeError("Unknown format for group")

if tuning_curves.shape[1] != len(newgroup):
raise RuntimeError("Different shapes for tuning_curves and group")
if tuning_curves.shape[1] != len(newgroup):
raise RuntimeError("Different shapes for tuning_curves and group")

if not np.all(tuning_curves.columns.values == np.array(newgroup.keys())):
raise RuntimeError("Difference indexes for tuning curves and group keys")
if not np.all(tuning_curves.columns.values == np.array(newgroup.keys())):
raise RuntimeError("Different indices for tuning curves and group keys")

# Bin spikes
count = newgroup.count(bin_size, ep, time_units)
# Bin spikes
count = newgroup.count(bin_size, ep, time_units)

elif isinstance(group, dict):
newgroup = nap.TsGroup(group, time_support=ep)
count = newgroup.count(bin_size, ep, time_units)

else:
raise RuntimeError("Unknown format for group")

# Occupancy
if feature is None:
Expand Down Expand Up @@ -122,9 +137,10 @@ def decode_2d(tuning_curves, group, ep, bin_size, xy, time_units="s", features=N
Parameters
----------
tuning_curves : dict
Dictionnay of 2d tuning curves (one for each neuron).
group : TsGroup or dict of Ts/Tsd object.
Dictionary of 2d tuning curves (one for each neuron).
group : TsGroup, TsdFrame or dict of Ts/Tsd object.
A group of neurons with the same keys as tuning_curves dictionary.
You may also pass a TsdFrame with smoothed rates (recommended).
ep : IntervalSet
The epoch on which decoding is computed
bin_size : float
Expand Down Expand Up @@ -153,27 +169,40 @@ def decode_2d(tuning_curves, group, ep, bin_size, xy, time_units="s", features=N
"""

if type(group) is dict:
newgroup = nap.TsGroup(group, time_support=ep)
numcells = len(newgroup)
if type(group) is nap.TsdFrame:
newgroup = group.restrict(ep)
numcells = newgroup.shape[1]

if len(tuning_curves) != numcells:
raise RuntimeError("Different shapes for tuning_curves and group")

if not np.all(
np.array(list(tuning_curves.keys())) == np.array(newgroup.columns)
):
raise RuntimeError("Different indices for tuning curves and group keys")

count = group

elif type(group) is nap.TsGroup:
newgroup = group.restrict(ep)
numcells = len(newgroup)
else:
raise RuntimeError("Unknown format for group")

if len(tuning_curves) != numcells:
raise RuntimeError("Different shapes for tuning_curves and group")
if len(tuning_curves) != numcells:
raise RuntimeError("Different shapes for tuning_curves and group")

if not np.all(np.array(list(tuning_curves.keys())) == np.array(newgroup.keys())):
raise RuntimeError("Difference indexes for tuning curves and group keys")
if not np.all(
np.array(list(tuning_curves.keys())) == np.array(newgroup.keys())
):
raise RuntimeError("Different indices for tuning curves and group keys")

# Bin spikes
# if type(newgroup) is not nap.TsdFrame:
count = newgroup.count(bin_size, ep, time_units)
# else:
# #Spikes already "binned" with continuous TsdFrame input
# count = newgroup
count = newgroup.count(bin_size, ep, time_units)

elif type(group) is dict:
newgroup = nap.TsGroup(group, time_support=ep)
count = newgroup.count(bin_size, ep, time_units)

else:
raise RuntimeError("Unknown format for group")

indexes = list(tuning_curves.keys())

Expand All @@ -199,9 +228,7 @@ def decode_2d(tuning_curves, group, ep, bin_size, xy, time_units="s", features=N
tc = np.array([tuning_curves[i] for i in tuning_curves.keys()])
tc = tc.reshape(tc.shape[0], np.prod(tc.shape[1:]))
tc = tc.T

ct = count.values

bin_size_s = nap.TsIndex.format_timestamps(
np.array([bin_size], dtype=np.float64), time_units
)[0]
Expand Down
47 changes: 40 additions & 7 deletions tests/test_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def get_testing_set_1d():
def test_decode_1d():
feature, group, tc, ep = get_testing_set_1d()
decoded, proba = nap.decode_1d(tc, group, ep, bin_size=1)

assert isinstance(decoded, nap.Tsd)
assert isinstance(proba, nap.TsdFrame)
np.testing.assert_array_almost_equal(feature.values, decoded.values)
Expand All @@ -37,11 +36,10 @@ def test_decode_1d():
tmp[0:50, 1] = 0.0
np.testing.assert_array_almost_equal(proba.values, tmp)

def test_decode_1d_with_dict():
def test_decode_1d_with_TsdFrame():
feature, group, tc, ep = get_testing_set_1d()
group = dict(group)
decoded, proba = nap.decode_1d(tc, group, ep, bin_size=1)

count = group.count(bin_size=1, ep = ep)
decoded, proba = nap.decode_1d(tc, count, ep, bin_size=1)
assert isinstance(decoded, nap.Tsd)
assert isinstance(proba, nap.TsdFrame)
np.testing.assert_array_almost_equal(feature.values, decoded.values)
Expand All @@ -65,6 +63,21 @@ def test_decode_1d_with_feature():
tmp[50:, 0] = 0.0
tmp[0:50, 1] = 0.0
np.testing.assert_array_almost_equal(proba.values, tmp)

def test_decode_1d_with_dict():
feature, group, tc, ep = get_testing_set_1d()
group = dict(group)
decoded, proba = nap.decode_1d(tc, group, ep, bin_size=1, feature=feature)
np.testing.assert_array_almost_equal(feature.values, decoded.values)
assert isinstance(decoded, nap.Tsd)
assert isinstance(proba, nap.TsdFrame)
np.testing.assert_array_almost_equal(feature.values, decoded.values)
assert len(decoded) == 100
assert len(proba) == 100
tmp = np.ones((100, 2))
tmp[50:, 0] = 0.0
tmp[0:50, 1] = 0.0
np.testing.assert_array_almost_equal(proba.values, tmp)

def test_decode_1d_with_wrong_feature():
feature, group, tc, ep = get_testing_set_1d()
Expand Down Expand Up @@ -94,7 +107,7 @@ def test_decoded_1d_raise_errors():
tc.columns = [0, 2]
with pytest.raises(Exception) as e_info:
nap.decode_1d(tc, group, ep, 1)
assert str(e_info.value) == "Difference indexes for tuning curves and group keys"
assert str(e_info.value) == "Different indices for tuning curves and group keys"


def get_testing_set_2d():
Expand Down Expand Up @@ -137,6 +150,26 @@ def test_decode_2d():
tmp[51:100:2, 1] = 1
np.testing.assert_array_almost_equal(proba[:, :, 1], tmp)

def test_decode_2d_with_TsdFrame():
features, group, tc, ep, xy = get_testing_set_2d()
count = group.count(bin_size=1, ep = ep)
decoded, proba = nap.decode_2d(tc, count, ep, 1, xy)

assert isinstance(decoded, nap.TsdFrame)
assert isinstance(proba, np.ndarray)
np.testing.assert_array_almost_equal(features.values, decoded.values)
assert len(decoded) == 100
assert len(proba) == 100
tmp = np.zeros((100, 2))
tmp[0:50:2, 0] = 1
tmp[50:100:2, 1] = 1
np.testing.assert_array_almost_equal(proba[:, :, 0], tmp)

tmp = np.zeros((100, 2))
tmp[1:50:2, 0] = 1
tmp[51:100:2, 1] = 1
np.testing.assert_array_almost_equal(proba[:, :, 1], tmp)

def test_decode_2d_with_dict():
features, group, tc, ep, xy = get_testing_set_2d()
group = dict(group)
Expand Down Expand Up @@ -186,4 +219,4 @@ def test_decoded_2d_raise_errors():
tc = {k: tc[i] for k, i in zip(np.arange(0, 40, 10), tc.keys())}
with pytest.raises(Exception) as e_info:
nap.decode_2d(tc, group, ep, 1, xy)
assert str(e_info.value) == "Difference indexes for tuning curves and group keys"
assert str(e_info.value) == "Different indices for tuning curves and group keys"

0 comments on commit 9bccc4b

Please sign in to comment.