Skip to content

Commit

Permalink
added tests with some corner cases
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Apr 4, 2024
1 parent ee45765 commit 5116238
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 19 deletions.
7 changes: 3 additions & 4 deletions pynapple/core/ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
self, data, time_support=None, time_units="s", bypass_check=False, **kwargs
):
"""
TsGroup Initializer
TsGroup Initializer.
Parameters
----------
Expand All @@ -90,7 +90,7 @@ def __init__(
Useful to speed up initialization of TsGroup when Ts/Tsd objects have already been restricted beforehand
**kwargs
Meta-info about the Ts/Tsd objects. Can be either pandas.Series, numpy.ndarray, list or tuple
Note that the index should match the index of the input dictionnary if pandas Series
Note that the index should match the index of the input dictionary if pandas Series
Raises
------
Expand All @@ -108,7 +108,7 @@ def __init__(
try:
keys = [int(k) for k in data.keys()]
except Exception:
raise ValueError("keys must be convertible to integer.")
raise ValueError("All keys must be convertible to integer.")

# check that there were no floats with decimal points in keys.i
# i.e. 0.5 is not a valid key
Expand All @@ -120,7 +120,6 @@ def __init__(
if len(keys) != len(np.unique(keys)):
raise ValueError("Two dictionary keys contain the same integer value!")


data = {keys[j]: data[k] for j, k in enumerate(data.keys())}
self.index = np.sort(keys)

Expand Down
61 changes: 46 additions & 15 deletions tests/test_ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,64 @@
import pytest
from collections import UserDict
import warnings
from contextlib import nullcontext as does_not_raise

@pytest.fixture
def group():
return {
0: nap.Ts(t=np.arange(0, 200)),
1: nap.Ts(t=np.arange(0, 200, 0.5), time_units="s"),
2: nap.Ts(t=np.arange(0, 300, 0.2), time_units="s"),
}


class TestTsGroup1:

@pytest.mark.parametrize(
"group",
[
{
0: nap.Ts(t=np.arange(0, 200)),
1: nap.Ts(t=np.arange(0, 200, 0.5), time_units="s"),
2: nap.Ts(t=np.arange(0, 300, 0.2), time_units="s"),
}
],
)
class Test_Ts_Group_1:
def test_create_ts_group(self, group):
tsgroup = nap.TsGroup(group)
assert isinstance(tsgroup, UserDict)
assert len(tsgroup) == 3

def test_create_ts_group_from_array(self, group):
@pytest.mark.parametrize(
"test_dict, expectation",
[
({"1": nap.Ts(np.arange(10)), "2":nap.Ts(np.arange(10))}, does_not_raise()),
({"1": nap.Ts(np.arange(10)), 2: nap.Ts(np.arange(10))}, does_not_raise()),
({"1": nap.Ts(np.arange(10)), 1: nap.Ts(np.arange(10))},
pytest.raises(ValueError, match="Two dictionary keys contain the same integer")),
({"1.": nap.Ts(np.arange(10)), 2: nap.Ts(np.arange(10))},
pytest.raises(ValueError, match="All keys must be convertible")),
({-1: nap.Ts(np.arange(10)), 1: nap.Ts(np.arange(10))}, does_not_raise()),
({1.5: nap.Ts(np.arange(10)), 1: nap.Ts(np.arange(10))},
pytest.raises(ValueError, match="All keys must have integer value"))
]
)
def test_initialize_from_dict(self, test_dict, expectation):
with expectation:
nap.TsGroup(test_dict)

@pytest.mark.parametrize(
"tsgroup",
[
nap.TsGroup({"1": nap.Ts(np.arange(10)), "2": nap.Ts(np.arange(10))}),
nap.TsGroup({"1": nap.Ts(np.arange(10)), 2: nap.Ts(np.arange(10))}),
nap.TsGroup({-1: nap.Ts(np.arange(10)), 1: nap.Ts(np.arange(10))})
]
)
def test_metadata_len_match(self, tsgroup):
assert len(tsgroup._metadata) == len(tsgroup)


def test_create_ts_group_from_array(self):
with warnings.catch_warnings(record=True) as w:
nap.TsGroup({
0: np.arange(0, 200),
1: np.arange(0, 200, 0.5),
2: np.arange(0, 300, 0.2),
})
assert str(w[0].message) == "Elements should not be passed as <class 'numpy.ndarray'>. Default time units is seconds when creating the Ts object."

def test_create_ts_group_with_time_support(self, group):
ep = nap.IntervalSet(start=0, end=100)
tsgroup = nap.TsGroup(group, time_support=ep)
Expand All @@ -48,7 +79,7 @@ def test_create_ts_group_with_time_support(self, group):
assert np.all(first >= ep[0, 0])
assert np.all(last <= ep[0, 1])

def test_create_ts_group_with_empty_time_support(self, group):
def test_create_ts_group_with_empty_time_support(self):
with pytest.raises(RuntimeError) as e_info:
tmp = nap.TsGroup({
0: nap.Ts(t=np.array([])),
Expand All @@ -57,7 +88,7 @@ def test_create_ts_group_with_empty_time_support(self, group):
})
assert str(e_info.value) == "Union of time supports is empty. Consider passing a time support as argument."

def test_create_ts_group_with_bypass_check(self, group):
def test_create_ts_group_with_bypass_check(self):
tmp = {
0: nap.Ts(t=np.arange(0, 100)),
1: nap.Ts(t=np.arange(0, 200, 0.5), time_units="s"),
Expand Down

0 comments on commit 5116238

Please sign in to comment.