diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 44103d80..e2692121 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -197,6 +197,15 @@ def __init__( ), """ DataFrame must contain columns name "start" and "end" for start and end times. """ + # try sorting the DataFrame by start times, preserving its end pair, as an effort to preserve metadata + # since metadata would be dropped if starts and ends are sorted separately + # note that if end times are still not sorted, metadata will be dropped + if np.any(start["start"].diff() < 0): + warnings.warn( + "DataFrame is not sorted by start times. Sorting it.", stacklevel=2 + ) + start = start.sort_values("start").reset_index(drop=True) + metadata = start.drop(columns=["start", "end"]) end = start["end"].values.astype(np.float64) start = start["start"].values.astype(np.float64) diff --git a/pynapple/io/interface_nwb.py b/pynapple/io/interface_nwb.py index e2388920..ad217a9d 100644 --- a/pynapple/io/interface_nwb.py +++ b/pynapple/io/interface_nwb.py @@ -85,10 +85,9 @@ def _make_interval_set(obj, **kwargs): df = obj.to_dataframe() if hasattr(df, "start_time") and hasattr(df, "stop_time"): - data = nap.IntervalSet(start=df["start_time"], end=df["stop_time"]) - if df.shape[1] > 2: - metadata = df.drop(columns=["start_time", "stop_time"]) - data.set_info(metadata) + df = df.rename(columns={"start_time": "start", "stop_time": "end"}) + # create from full dataframe to ensure that metadata is associated correctly + data = nap.IntervalSet(df) return data else: diff --git a/tests/test_decoding.py b/tests/test_decoding.py index 240c7f4d..c5f95802 100644 --- a/tests/test_decoding.py +++ b/tests/test_decoding.py @@ -7,11 +7,12 @@ """Tests of decoding for `pynapple` package.""" -import pynapple as nap import numpy as np import pandas as pd import pytest +import pynapple as nap + def get_testing_set_1d(): feature = nap.Tsd(t=np.arange(0, 100, 1), d=np.repeat(np.arange(0, 2), 50)) @@ -36,9 +37,10 @@ def test_decode_1d(): tmp[0:50, 1] = 0.0 np.testing.assert_array_almost_equal(proba.values, tmp) + def test_decode_1d_with_TsdFrame(): feature, group, tc, ep = get_testing_set_1d() - count = group.count(bin_size=1, ep = ep) + count = group.count(bin_size=1, ep=ep) decoded, proba = nap.decode_1d(tc, count, ep, bin_size=1) assert isinstance(decoded, nap.Tsd) assert isinstance(proba, nap.TsdFrame) @@ -50,6 +52,7 @@ def test_decode_1d_with_TsdFrame(): tmp[0:50, 1] = 0.0 np.testing.assert_array_almost_equal(proba.values, tmp) + def test_decode_1d_with_feature(): feature, group, tc, ep = get_testing_set_1d() decoded, proba = nap.decode_1d(tc, group, ep, bin_size=1, feature=feature) @@ -63,7 +66,8 @@ def test_decode_1d_with_feature(): tmp[50:, 0] = 0.0 tmp[0:50, 1] = 0.0 np.testing.assert_array_almost_equal(proba.values, tmp) - + + def test_decode_1d_with_dict(): feature, group, tc, ep = get_testing_set_1d() group = dict(group) @@ -79,18 +83,21 @@ def test_decode_1d_with_dict(): tmp[0:50, 1] = 0.0 np.testing.assert_array_almost_equal(proba.values, tmp) + def test_decode_1d_with_wrong_feature(): feature, group, tc, ep = get_testing_set_1d() with pytest.raises(RuntimeError) as e_info: - nap.decode_1d(tc, group, ep, bin_size=1, feature=[1,2,3]) + nap.decode_1d(tc, group, ep, bin_size=1, feature=[1, 2, 3]) assert str(e_info.value) == "Unknown format for feature in decode_1d" + def test_decode_1d_with_time_units(): feature, group, tc, ep = get_testing_set_1d() for t, tu in zip([1, 1e3, 1e6], ["s", "ms", "us"]): decoded, proba = nap.decode_1d(tc, group, ep, 1.0 * t, time_units=tu) np.testing.assert_array_almost_equal(feature.values, decoded.values) + def test_decoded_1d_raise_errors(): feature, group, tc, ep = get_testing_set_1d() with pytest.raises(Exception) as e_info: @@ -150,9 +157,10 @@ def test_decode_2d(): tmp[51:100:2, 1] = 1 np.testing.assert_array_almost_equal(proba[:, :, 1], tmp) + def test_decode_2d_with_TsdFrame(): features, group, tc, ep, xy = get_testing_set_2d() - count = group.count(bin_size=1, ep = ep) + count = group.count(bin_size=1, ep=ep) decoded, proba = nap.decode_2d(tc, count, ep, 1, xy) assert isinstance(decoded, nap.TsdFrame) @@ -169,7 +177,8 @@ def test_decode_2d_with_TsdFrame(): tmp[1:50:2, 0] = 1 tmp[51:100:2, 1] = 1 np.testing.assert_array_almost_equal(proba[:, :, 1], tmp) - + + def test_decode_2d_with_dict(): features, group, tc, ep, xy = get_testing_set_2d() group = dict(group) @@ -190,6 +199,7 @@ def test_decode_2d_with_dict(): tmp[51:100:2, 1] = 1 np.testing.assert_array_almost_equal(proba[:, :, 1], tmp) + def test_decode_2d_with_feature(): features, group, tc, ep, xy = get_testing_set_2d() decoded, proba = nap.decode_2d(tc, group, ep, 1, xy) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index e4403001..1f7819ff 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -111,6 +111,44 @@ def test_create_iset_from_df_with_metadata(): np.testing.assert_array_almost_equal(df.end.values, ep.end) +@pytest.mark.parametrize( + "df, expected", + [ + # dataframe is sorted and metadata is kept + ( + pd.DataFrame( + { + "start": [25.0, 0.0, 10.0, 16.0], + "end": [40.0, 5.0, 15.0, 20.0], + "label": np.arange(4), + } + ), + ["DataFrame is not sorted by start times"], + ), + ( + # dataframe is sorted and and metadata is dropped + pd.DataFrame( + { + "start": [25, 0, 10, 16], + "end": [40, 20, 15, 20], + "label": np.arange(4), + } + ), + ["DataFrame is not sorted by start times", "dropping metadata"], + ), + ], +) +def test_create_iset_from_df_with_metadata_sort(df, expected): + with warnings.catch_warnings(record=True) as w: + ep = nap.IntervalSet(df) + for e in expected: + assert np.any([e in str(w.message) for w in w]) + if "dropping metadata" not in expected: + pd.testing.assert_frame_equal( + ep.as_dataframe(), df.sort_values("start").reset_index(drop=True) + ) + + @pytest.mark.parametrize( "index", [