Skip to content

Commit

Permalink
some test updates
Browse files Browse the repository at this point in the history
  • Loading branch information
sjvenditto committed Feb 28, 2025
1 parent df5fb82 commit b061f5c
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,7 @@ def tsgroup_gba():
3: np.geomspace(1, 100, 3000),
4: np.geomspace(1, 100, 4000),
}
return nap.TsGroup(units, metadata={"label": [1, 1, 2, 2]})
return nap.TsGroup(units, metadata={"label": ["A", "A", "B", "B"]})


@pytest.fixture
Expand All @@ -1349,14 +1349,15 @@ def tsdframe_gba():
t=np.linspace(1, 100, 1000),
d=np.random.rand(1000, 4),
time_units="s",
metadata={"label": [1, 1, 2, 2]},
metadata={"label": ["x", "x", "y", "y"]},
)


def test_metadata_groupby_apply_tuning_curves(tsgroup_gba, iset_gba):

feature = nap.Tsd(t=np.linspace(1, 100, 100), d=np.tile(np.arange(5), 20))

# apply to intervalset
out = iset_gba.groupby_apply(
"label",
nap.compute_1d_tuning_curves,
Expand All @@ -1371,6 +1372,19 @@ def test_metadata_groupby_apply_tuning_curves(tsgroup_gba, iset_gba):
)
pd.testing.assert_frame_equal(out[grp], tmp)

# apply to tsgroup
out2 = tsgroup_gba.groupby_apply(
"label",
nap.compute_1d_tuning_curves,
feature=feature,
nb_bins=5,
)
# make sure groups are different
assert out2.keys() != out.keys()
for grp, idx in tsgroup_gba.groupby("label").items():
tmp = nap.compute_1d_tuning_curves(tsgroup_gba[idx], feature, nb_bins=5)
pd.testing.assert_frame_equal(out2[grp], tmp)


def test_metadata_groupby_apply_tsgroup_lambda(tsgroup_gba):
func = lambda x: np.mean(x.rate)
Expand All @@ -1382,6 +1396,7 @@ def test_metadata_groupby_apply_tsgroup_lambda(tsgroup_gba):


def test_metadata_groupby_apply_compute_mean_psd(tsdframe_gba, iset_gba):
# test on iset
out = iset_gba.groupby_apply(
"label",
nap.compute_mean_power_spectral_density,
Expand All @@ -1395,6 +1410,20 @@ def test_metadata_groupby_apply_compute_mean_psd(tsdframe_gba, iset_gba):
)
pd.testing.assert_frame_equal(out[grp], tmp)

# test on tsdframe
out2 = tsdframe_gba.groupby_apply(
"label",
nap.compute_mean_power_spectral_density,
interval_size=1,
)
# make sure groups are different
assert out2.keys() != out.keys()
for grp, idx in tsdframe_gba.groupby("label").items():
tmp = nap.compute_mean_power_spectral_density(
tsdframe_gba[:, idx], interval_size=1
)
pd.testing.assert_frame_equal(out2[grp], tmp)


@pytest.mark.parametrize(
"obj",
Expand Down

0 comments on commit b061f5c

Please sign in to comment.