diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 1c9891f3..8d05b186 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -1332,7 +1332,7 @@ def tsgroup_gba(): 3: np.geomspace(1, 100, 3000), 4: np.geomspace(1, 100, 4000), } - return nap.TsGroup(units, metadata={"label": [1, 1, 2, 2]}) + return nap.TsGroup(units, metadata={"label": ["A", "A", "B", "B"]}) @pytest.fixture @@ -1349,7 +1349,7 @@ def tsdframe_gba(): t=np.linspace(1, 100, 1000), d=np.random.rand(1000, 4), time_units="s", - metadata={"label": [1, 1, 2, 2]}, + metadata={"label": ["x", "x", "y", "y"]}, ) @@ -1357,6 +1357,7 @@ def test_metadata_groupby_apply_tuning_curves(tsgroup_gba, iset_gba): feature = nap.Tsd(t=np.linspace(1, 100, 100), d=np.tile(np.arange(5), 20)) + # apply to intervalset out = iset_gba.groupby_apply( "label", nap.compute_1d_tuning_curves, @@ -1371,6 +1372,19 @@ def test_metadata_groupby_apply_tuning_curves(tsgroup_gba, iset_gba): ) pd.testing.assert_frame_equal(out[grp], tmp) + # apply to tsgroup + out2 = tsgroup_gba.groupby_apply( + "label", + nap.compute_1d_tuning_curves, + feature=feature, + nb_bins=5, + ) + # make sure groups are different + assert out2.keys() != out.keys() + for grp, idx in tsgroup_gba.groupby("label").items(): + tmp = nap.compute_1d_tuning_curves(tsgroup_gba[idx], feature, nb_bins=5) + pd.testing.assert_frame_equal(out2[grp], tmp) + def test_metadata_groupby_apply_tsgroup_lambda(tsgroup_gba): func = lambda x: np.mean(x.rate) @@ -1382,6 +1396,7 @@ def test_metadata_groupby_apply_tsgroup_lambda(tsgroup_gba): def test_metadata_groupby_apply_compute_mean_psd(tsdframe_gba, iset_gba): + # test on iset out = iset_gba.groupby_apply( "label", nap.compute_mean_power_spectral_density, @@ -1395,6 +1410,20 @@ def test_metadata_groupby_apply_compute_mean_psd(tsdframe_gba, iset_gba): ) pd.testing.assert_frame_equal(out[grp], tmp) + # test on tsdframe + out2 = tsdframe_gba.groupby_apply( + "label", + nap.compute_mean_power_spectral_density, + interval_size=1, + ) + # make sure groups are different + assert out2.keys() != out.keys() + for grp, idx in tsdframe_gba.groupby("label").items(): + tmp = nap.compute_mean_power_spectral_density( + tsdframe_gba[:, idx], interval_size=1 + ) + pd.testing.assert_frame_equal(out2[grp], tmp) + @pytest.mark.parametrize( "obj",