From 522cebbaf3e506a4a880562e3baf6472e3b64f47 Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Tue, 7 Jan 2025 11:48:14 -0500 Subject: [PATCH 1/4] create intervalset from full dataframe with metadata when loading in an nwb file to ensure it is aligned correctly. also try an initial sort of an input dataframe to preserve metadata --- pynapple/core/interval_set.py | 6 ++++++ pynapple/io/interface_nwb.py | 7 +++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 44103d80..25f0944e 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -197,6 +197,12 @@ def __init__( ), """ DataFrame must contain columns name "start" and "end" for start and end times. """ + # try sorting the DataFrame by start times, preserving its end pair, as an effort to preserve metadata + # since metadata would be dropped if starts and ends are sorted separately + # note that if end times are still not sorted, metadata will be dropped + if np.any(start["start"].diff() < 0): + start = start.sort_values("start").reset_index(drop=True) + metadata = start.drop(columns=["start", "end"]) end = start["end"].values.astype(np.float64) start = start["start"].values.astype(np.float64) diff --git a/pynapple/io/interface_nwb.py b/pynapple/io/interface_nwb.py index e2388920..ad217a9d 100644 --- a/pynapple/io/interface_nwb.py +++ b/pynapple/io/interface_nwb.py @@ -85,10 +85,9 @@ def _make_interval_set(obj, **kwargs): df = obj.to_dataframe() if hasattr(df, "start_time") and hasattr(df, "stop_time"): - data = nap.IntervalSet(start=df["start_time"], end=df["stop_time"]) - if df.shape[1] > 2: - metadata = df.drop(columns=["start_time", "stop_time"]) - data.set_info(metadata) + df = df.rename(columns={"start_time": "start", "stop_time": "end"}) + # create from full dataframe to ensure that metadata is associated correctly + data = nap.IntervalSet(df) return data else: From 83a8182ab31473b09b8ac3e9c65b2de6969c04be Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Tue, 7 Jan 2025 12:15:45 -0500 Subject: [PATCH 2/4] added warning when dataframe is being sorted --- pynapple/core/interval_set.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 25f0944e..e2692121 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -201,6 +201,9 @@ def __init__( # since metadata would be dropped if starts and ends are sorted separately # note that if end times are still not sorted, metadata will be dropped if np.any(start["start"].diff() < 0): + warnings.warn( + "DataFrame is not sorted by start times. Sorting it.", stacklevel=2 + ) start = start.sort_values("start").reset_index(drop=True) metadata = start.drop(columns=["start", "end"]) From 0abd0d831efe060b9b52290784be9e6ee40860d2 Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Tue, 7 Jan 2025 15:15:59 -0500 Subject: [PATCH 3/4] tests for sorting dataframe in intervalset constructor --- tests/test_metadata.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index e4403001..1f7819ff 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -111,6 +111,44 @@ def test_create_iset_from_df_with_metadata(): np.testing.assert_array_almost_equal(df.end.values, ep.end) +@pytest.mark.parametrize( + "df, expected", + [ + # dataframe is sorted and metadata is kept + ( + pd.DataFrame( + { + "start": [25.0, 0.0, 10.0, 16.0], + "end": [40.0, 5.0, 15.0, 20.0], + "label": np.arange(4), + } + ), + ["DataFrame is not sorted by start times"], + ), + ( + # dataframe is sorted and and metadata is dropped + pd.DataFrame( + { + "start": [25, 0, 10, 16], + "end": [40, 20, 15, 20], + "label": np.arange(4), + } + ), + ["DataFrame is not sorted by start times", "dropping metadata"], + ), + ], +) +def test_create_iset_from_df_with_metadata_sort(df, expected): + with warnings.catch_warnings(record=True) as w: + ep = nap.IntervalSet(df) + for e in expected: + assert np.any([e in str(w.message) for w in w]) + if "dropping metadata" not in expected: + pd.testing.assert_frame_equal( + ep.as_dataframe(), df.sort_values("start").reset_index(drop=True) + ) + + @pytest.mark.parametrize( "index", [ From 68b638a845b147949a28e985cb3a921191456e2e Mon Sep 17 00:00:00 2001 From: gviejo Date: Tue, 7 Jan 2025 15:46:39 -0500 Subject: [PATCH 4/4] Fixing linting --- tests/test_decoding.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/tests/test_decoding.py b/tests/test_decoding.py index 240c7f4d..c5f95802 100644 --- a/tests/test_decoding.py +++ b/tests/test_decoding.py @@ -7,11 +7,12 @@ """Tests of decoding for `pynapple` package.""" -import pynapple as nap import numpy as np import pandas as pd import pytest +import pynapple as nap + def get_testing_set_1d(): feature = nap.Tsd(t=np.arange(0, 100, 1), d=np.repeat(np.arange(0, 2), 50)) @@ -36,9 +37,10 @@ def test_decode_1d(): tmp[0:50, 1] = 0.0 np.testing.assert_array_almost_equal(proba.values, tmp) + def test_decode_1d_with_TsdFrame(): feature, group, tc, ep = get_testing_set_1d() - count = group.count(bin_size=1, ep = ep) + count = group.count(bin_size=1, ep=ep) decoded, proba = nap.decode_1d(tc, count, ep, bin_size=1) assert isinstance(decoded, nap.Tsd) assert isinstance(proba, nap.TsdFrame) @@ -50,6 +52,7 @@ def test_decode_1d_with_TsdFrame(): tmp[0:50, 1] = 0.0 np.testing.assert_array_almost_equal(proba.values, tmp) + def test_decode_1d_with_feature(): feature, group, tc, ep = get_testing_set_1d() decoded, proba = nap.decode_1d(tc, group, ep, bin_size=1, feature=feature) @@ -63,7 +66,8 @@ def test_decode_1d_with_feature(): tmp[50:, 0] = 0.0 tmp[0:50, 1] = 0.0 np.testing.assert_array_almost_equal(proba.values, tmp) - + + def test_decode_1d_with_dict(): feature, group, tc, ep = get_testing_set_1d() group = dict(group) @@ -79,18 +83,21 @@ def test_decode_1d_with_dict(): tmp[0:50, 1] = 0.0 np.testing.assert_array_almost_equal(proba.values, tmp) + def test_decode_1d_with_wrong_feature(): feature, group, tc, ep = get_testing_set_1d() with pytest.raises(RuntimeError) as e_info: - nap.decode_1d(tc, group, ep, bin_size=1, feature=[1,2,3]) + nap.decode_1d(tc, group, ep, bin_size=1, feature=[1, 2, 3]) assert str(e_info.value) == "Unknown format for feature in decode_1d" + def test_decode_1d_with_time_units(): feature, group, tc, ep = get_testing_set_1d() for t, tu in zip([1, 1e3, 1e6], ["s", "ms", "us"]): decoded, proba = nap.decode_1d(tc, group, ep, 1.0 * t, time_units=tu) np.testing.assert_array_almost_equal(feature.values, decoded.values) + def test_decoded_1d_raise_errors(): feature, group, tc, ep = get_testing_set_1d() with pytest.raises(Exception) as e_info: @@ -150,9 +157,10 @@ def test_decode_2d(): tmp[51:100:2, 1] = 1 np.testing.assert_array_almost_equal(proba[:, :, 1], tmp) + def test_decode_2d_with_TsdFrame(): features, group, tc, ep, xy = get_testing_set_2d() - count = group.count(bin_size=1, ep = ep) + count = group.count(bin_size=1, ep=ep) decoded, proba = nap.decode_2d(tc, count, ep, 1, xy) assert isinstance(decoded, nap.TsdFrame) @@ -169,7 +177,8 @@ def test_decode_2d_with_TsdFrame(): tmp[1:50:2, 0] = 1 tmp[51:100:2, 1] = 1 np.testing.assert_array_almost_equal(proba[:, :, 1], tmp) - + + def test_decode_2d_with_dict(): features, group, tc, ep, xy = get_testing_set_2d() group = dict(group) @@ -190,6 +199,7 @@ def test_decode_2d_with_dict(): tmp[51:100:2, 1] = 1 np.testing.assert_array_almost_equal(proba[:, :, 1], tmp) + def test_decode_2d_with_feature(): features, group, tc, ep, xy = get_testing_set_2d() decoded, proba = nap.decode_2d(tc, group, ep, 1, xy)