Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix loading and IntervalSet from NWB file with metadata #386

Merged
merged 4 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pynapple/core/interval_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,15 @@ def __init__(
), """
DataFrame must contain columns name "start" and "end" for start and end times.
"""
# try sorting the DataFrame by start times, preserving its end pair, as an effort to preserve metadata
# since metadata would be dropped if starts and ends are sorted separately
# note that if end times are still not sorted, metadata will be dropped
if np.any(start["start"].diff() < 0):
warnings.warn(
"DataFrame is not sorted by start times. Sorting it.", stacklevel=2
)
start = start.sort_values("start").reset_index(drop=True)

metadata = start.drop(columns=["start", "end"])
end = start["end"].values.astype(np.float64)
start = start["start"].values.astype(np.float64)
Expand Down
7 changes: 3 additions & 4 deletions pynapple/io/interface_nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,9 @@ def _make_interval_set(obj, **kwargs):
df = obj.to_dataframe()

if hasattr(df, "start_time") and hasattr(df, "stop_time"):
data = nap.IntervalSet(start=df["start_time"], end=df["stop_time"])
if df.shape[1] > 2:
metadata = df.drop(columns=["start_time", "stop_time"])
data.set_info(metadata)
df = df.rename(columns={"start_time": "start", "stop_time": "end"})
# create from full dataframe to ensure that metadata is associated correctly
data = nap.IntervalSet(df)
return data

else:
Expand Down
22 changes: 16 additions & 6 deletions tests/test_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

"""Tests of decoding for `pynapple` package."""

import pynapple as nap
import numpy as np
import pandas as pd
import pytest

import pynapple as nap


def get_testing_set_1d():
feature = nap.Tsd(t=np.arange(0, 100, 1), d=np.repeat(np.arange(0, 2), 50))
Expand All @@ -36,9 +37,10 @@ def test_decode_1d():
tmp[0:50, 1] = 0.0
np.testing.assert_array_almost_equal(proba.values, tmp)


def test_decode_1d_with_TsdFrame():
feature, group, tc, ep = get_testing_set_1d()
count = group.count(bin_size=1, ep = ep)
count = group.count(bin_size=1, ep=ep)
decoded, proba = nap.decode_1d(tc, count, ep, bin_size=1)
assert isinstance(decoded, nap.Tsd)
assert isinstance(proba, nap.TsdFrame)
Expand All @@ -50,6 +52,7 @@ def test_decode_1d_with_TsdFrame():
tmp[0:50, 1] = 0.0
np.testing.assert_array_almost_equal(proba.values, tmp)


def test_decode_1d_with_feature():
feature, group, tc, ep = get_testing_set_1d()
decoded, proba = nap.decode_1d(tc, group, ep, bin_size=1, feature=feature)
Expand All @@ -63,7 +66,8 @@ def test_decode_1d_with_feature():
tmp[50:, 0] = 0.0
tmp[0:50, 1] = 0.0
np.testing.assert_array_almost_equal(proba.values, tmp)



def test_decode_1d_with_dict():
feature, group, tc, ep = get_testing_set_1d()
group = dict(group)
Expand All @@ -79,18 +83,21 @@ def test_decode_1d_with_dict():
tmp[0:50, 1] = 0.0
np.testing.assert_array_almost_equal(proba.values, tmp)


def test_decode_1d_with_wrong_feature():
feature, group, tc, ep = get_testing_set_1d()
with pytest.raises(RuntimeError) as e_info:
nap.decode_1d(tc, group, ep, bin_size=1, feature=[1,2,3])
nap.decode_1d(tc, group, ep, bin_size=1, feature=[1, 2, 3])
assert str(e_info.value) == "Unknown format for feature in decode_1d"


def test_decode_1d_with_time_units():
feature, group, tc, ep = get_testing_set_1d()
for t, tu in zip([1, 1e3, 1e6], ["s", "ms", "us"]):
decoded, proba = nap.decode_1d(tc, group, ep, 1.0 * t, time_units=tu)
np.testing.assert_array_almost_equal(feature.values, decoded.values)


def test_decoded_1d_raise_errors():
feature, group, tc, ep = get_testing_set_1d()
with pytest.raises(Exception) as e_info:
Expand Down Expand Up @@ -150,9 +157,10 @@ def test_decode_2d():
tmp[51:100:2, 1] = 1
np.testing.assert_array_almost_equal(proba[:, :, 1], tmp)


def test_decode_2d_with_TsdFrame():
features, group, tc, ep, xy = get_testing_set_2d()
count = group.count(bin_size=1, ep = ep)
count = group.count(bin_size=1, ep=ep)
decoded, proba = nap.decode_2d(tc, count, ep, 1, xy)

assert isinstance(decoded, nap.TsdFrame)
Expand All @@ -169,7 +177,8 @@ def test_decode_2d_with_TsdFrame():
tmp[1:50:2, 0] = 1
tmp[51:100:2, 1] = 1
np.testing.assert_array_almost_equal(proba[:, :, 1], tmp)



def test_decode_2d_with_dict():
features, group, tc, ep, xy = get_testing_set_2d()
group = dict(group)
Expand All @@ -190,6 +199,7 @@ def test_decode_2d_with_dict():
tmp[51:100:2, 1] = 1
np.testing.assert_array_almost_equal(proba[:, :, 1], tmp)


def test_decode_2d_with_feature():
features, group, tc, ep, xy = get_testing_set_2d()
decoded, proba = nap.decode_2d(tc, group, ep, 1, xy)
Expand Down
38 changes: 38 additions & 0 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,44 @@ def test_create_iset_from_df_with_metadata():
np.testing.assert_array_almost_equal(df.end.values, ep.end)


@pytest.mark.parametrize(
"df, expected",
[
# dataframe is sorted and metadata is kept
(
pd.DataFrame(
{
"start": [25.0, 0.0, 10.0, 16.0],
"end": [40.0, 5.0, 15.0, 20.0],
"label": np.arange(4),
}
),
["DataFrame is not sorted by start times"],
),
(
# dataframe is sorted and and metadata is dropped
pd.DataFrame(
{
"start": [25, 0, 10, 16],
"end": [40, 20, 15, 20],
"label": np.arange(4),
}
),
["DataFrame is not sorted by start times", "dropping metadata"],
),
],
)
def test_create_iset_from_df_with_metadata_sort(df, expected):
with warnings.catch_warnings(record=True) as w:
ep = nap.IntervalSet(df)
for e in expected:
assert np.any([e in str(w.message) for w in w])
if "dropping metadata" not in expected:
pd.testing.assert_frame_equal(
ep.as_dataframe(), df.sort_values("start").reset_index(drop=True)
)


@pytest.mark.parametrize(
"index",
[
Expand Down
Loading