Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decoder with TsdFrame #374

Merged
merged 2 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 57 additions & 30 deletions pynapple/process/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ def decode_1d(tuning_curves, group, ep, bin_size, time_units="s", feature=None):
tuning_curves : pandas.DataFrame
Each column is the tuning curve of one neuron relative to the feature.
Index should be the center of the bin.
group : TsGroup or dict of Ts/Tsd object.
group : TsGroup, TsdFrame or dict of Ts/Tsd object.
A group of neurons with the same index as tuning curves column names.
You may also pass a TsdFrame with smoothed rates (recommended).
ep : IntervalSet
The epoch on which decoding is computed
bin_size : float
Expand All @@ -48,21 +49,35 @@ def decode_1d(tuning_curves, group, ep, bin_size, time_units="s", feature=None):
If different size of neurons for tuning_curves and group.
If indexes don't match between tuning_curves and group.
"""
if isinstance(group, dict):
newgroup = nap.TsGroup(group, time_support=ep)
if isinstance(group, nap.TsdFrame):
newgroup = group.restrict(ep)

if tuning_curves.shape[1] != newgroup.shape[1]:
raise RuntimeError("Different shapes for tuning_curves and group")

if not np.all(tuning_curves.columns.values == np.array(newgroup.columns)):
raise RuntimeError("Different indices for tuning curves and group keys")

count = group

elif isinstance(group, nap.TsGroup):
newgroup = group.restrict(ep)
else:
raise RuntimeError("Unknown format for group")

if tuning_curves.shape[1] != len(newgroup):
raise RuntimeError("Different shapes for tuning_curves and group")
if tuning_curves.shape[1] != len(newgroup):
raise RuntimeError("Different shapes for tuning_curves and group")

if not np.all(tuning_curves.columns.values == np.array(newgroup.keys())):
raise RuntimeError("Difference indexes for tuning curves and group keys")
if not np.all(tuning_curves.columns.values == np.array(newgroup.keys())):
raise RuntimeError("Different indices for tuning curves and group keys")

# Bin spikes
count = newgroup.count(bin_size, ep, time_units)
# Bin spikes
count = newgroup.count(bin_size, ep, time_units)

elif isinstance(group, dict):
newgroup = nap.TsGroup(group, time_support=ep)
count = newgroup.count(bin_size, ep, time_units)

else:
raise RuntimeError("Unknown format for group")

# Occupancy
if feature is None:
Expand Down Expand Up @@ -122,9 +137,10 @@ def decode_2d(tuning_curves, group, ep, bin_size, xy, time_units="s", features=N
Parameters
----------
tuning_curves : dict
Dictionnay of 2d tuning curves (one for each neuron).
group : TsGroup or dict of Ts/Tsd object.
Dictionary of 2d tuning curves (one for each neuron).
group : TsGroup, TsdFrame or dict of Ts/Tsd object.
A group of neurons with the same keys as tuning_curves dictionary.
You may also pass a TsdFrame with smoothed rates (recommended).
ep : IntervalSet
The epoch on which decoding is computed
bin_size : float
Expand Down Expand Up @@ -153,27 +169,40 @@ def decode_2d(tuning_curves, group, ep, bin_size, xy, time_units="s", features=N

"""

if type(group) is dict:
newgroup = nap.TsGroup(group, time_support=ep)
numcells = len(newgroup)
if type(group) is nap.TsdFrame:
newgroup = group.restrict(ep)
numcells = newgroup.shape[1]

if len(tuning_curves) != numcells:
raise RuntimeError("Different shapes for tuning_curves and group")

if not np.all(
np.array(list(tuning_curves.keys())) == np.array(newgroup.columns)
):
raise RuntimeError("Different indices for tuning curves and group keys")

count = group

elif type(group) is nap.TsGroup:
newgroup = group.restrict(ep)
numcells = len(newgroup)
else:
raise RuntimeError("Unknown format for group")

if len(tuning_curves) != numcells:
raise RuntimeError("Different shapes for tuning_curves and group")
if len(tuning_curves) != numcells:
raise RuntimeError("Different shapes for tuning_curves and group")

if not np.all(np.array(list(tuning_curves.keys())) == np.array(newgroup.keys())):
raise RuntimeError("Difference indexes for tuning curves and group keys")
if not np.all(
np.array(list(tuning_curves.keys())) == np.array(newgroup.keys())
):
raise RuntimeError("Different indices for tuning curves and group keys")

# Bin spikes
# if type(newgroup) is not nap.TsdFrame:
count = newgroup.count(bin_size, ep, time_units)
# else:
# #Spikes already "binned" with continuous TsdFrame input
# count = newgroup
count = newgroup.count(bin_size, ep, time_units)

elif type(group) is dict:
newgroup = nap.TsGroup(group, time_support=ep)
count = newgroup.count(bin_size, ep, time_units)

else:
raise RuntimeError("Unknown format for group")

indexes = list(tuning_curves.keys())

Expand All @@ -199,9 +228,7 @@ def decode_2d(tuning_curves, group, ep, bin_size, xy, time_units="s", features=N
tc = np.array([tuning_curves[i] for i in tuning_curves.keys()])
tc = tc.reshape(tc.shape[0], np.prod(tc.shape[1:]))
tc = tc.T

ct = count.values

bin_size_s = nap.TsIndex.format_timestamps(
np.array([bin_size], dtype=np.float64), time_units
)[0]
Expand Down
47 changes: 40 additions & 7 deletions tests/test_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def get_testing_set_1d():
def test_decode_1d():
feature, group, tc, ep = get_testing_set_1d()
decoded, proba = nap.decode_1d(tc, group, ep, bin_size=1)

assert isinstance(decoded, nap.Tsd)
assert isinstance(proba, nap.TsdFrame)
np.testing.assert_array_almost_equal(feature.values, decoded.values)
Expand All @@ -37,11 +36,10 @@ def test_decode_1d():
tmp[0:50, 1] = 0.0
np.testing.assert_array_almost_equal(proba.values, tmp)

def test_decode_1d_with_dict():
def test_decode_1d_with_TsdFrame():
feature, group, tc, ep = get_testing_set_1d()
group = dict(group)
decoded, proba = nap.decode_1d(tc, group, ep, bin_size=1)

count = group.count(bin_size=1, ep = ep)
decoded, proba = nap.decode_1d(tc, count, ep, bin_size=1)
assert isinstance(decoded, nap.Tsd)
assert isinstance(proba, nap.TsdFrame)
np.testing.assert_array_almost_equal(feature.values, decoded.values)
Expand All @@ -65,6 +63,21 @@ def test_decode_1d_with_feature():
tmp[50:, 0] = 0.0
tmp[0:50, 1] = 0.0
np.testing.assert_array_almost_equal(proba.values, tmp)

def test_decode_1d_with_dict():
feature, group, tc, ep = get_testing_set_1d()
group = dict(group)
decoded, proba = nap.decode_1d(tc, group, ep, bin_size=1, feature=feature)
np.testing.assert_array_almost_equal(feature.values, decoded.values)
assert isinstance(decoded, nap.Tsd)
assert isinstance(proba, nap.TsdFrame)
np.testing.assert_array_almost_equal(feature.values, decoded.values)
assert len(decoded) == 100
assert len(proba) == 100
tmp = np.ones((100, 2))
tmp[50:, 0] = 0.0
tmp[0:50, 1] = 0.0
np.testing.assert_array_almost_equal(proba.values, tmp)

def test_decode_1d_with_wrong_feature():
feature, group, tc, ep = get_testing_set_1d()
Expand Down Expand Up @@ -94,7 +107,7 @@ def test_decoded_1d_raise_errors():
tc.columns = [0, 2]
with pytest.raises(Exception) as e_info:
nap.decode_1d(tc, group, ep, 1)
assert str(e_info.value) == "Difference indexes for tuning curves and group keys"
assert str(e_info.value) == "Different indices for tuning curves and group keys"


def get_testing_set_2d():
Expand Down Expand Up @@ -137,6 +150,26 @@ def test_decode_2d():
tmp[51:100:2, 1] = 1
np.testing.assert_array_almost_equal(proba[:, :, 1], tmp)

def test_decode_2d_with_TsdFrame():
features, group, tc, ep, xy = get_testing_set_2d()
count = group.count(bin_size=1, ep = ep)
decoded, proba = nap.decode_2d(tc, count, ep, 1, xy)

assert isinstance(decoded, nap.TsdFrame)
assert isinstance(proba, np.ndarray)
np.testing.assert_array_almost_equal(features.values, decoded.values)
assert len(decoded) == 100
assert len(proba) == 100
tmp = np.zeros((100, 2))
tmp[0:50:2, 0] = 1
tmp[50:100:2, 1] = 1
np.testing.assert_array_almost_equal(proba[:, :, 0], tmp)

tmp = np.zeros((100, 2))
tmp[1:50:2, 0] = 1
tmp[51:100:2, 1] = 1
np.testing.assert_array_almost_equal(proba[:, :, 1], tmp)

def test_decode_2d_with_dict():
features, group, tc, ep, xy = get_testing_set_2d()
group = dict(group)
Expand Down Expand Up @@ -186,4 +219,4 @@ def test_decoded_2d_raise_errors():
tc = {k: tc[i] for k, i in zip(np.arange(0, 40, 10), tc.keys())}
with pytest.raises(Exception) as e_info:
nap.decode_2d(tc, group, ep, 1, xy)
assert str(e_info.value) == "Difference indexes for tuning curves and group keys"
assert str(e_info.value) == "Different indices for tuning curves and group keys"
Loading