Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ts group init #259

Merged
merged 4 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions pynapple/core/ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,13 @@ def __init__(
self, data, time_support=None, time_units="s", bypass_check=False, **kwargs
):
"""
TsGroup Initializer
TsGroup Initializer.

Parameters
----------
data : dict
Dictionnary containing Ts/Tsd objects
Dictionary containing Ts/Tsd objects, keys should contain integer values or should be convertible
to integer.
time_support : IntervalSet, optional
The time support of the TsGroup. Ts/Tsd objects will be restricted to the time support if passed.
If no time support is specified, TsGroup will merge time supports from all the Ts/Tsd objects in data.
Expand All @@ -89,16 +90,38 @@ def __init__(
Useful to speed up initialization of TsGroup when Ts/Tsd objects have already been restricted beforehand
**kwargs
Meta-info about the Ts/Tsd objects. Can be either pandas.Series, numpy.ndarray, list or tuple
Note that the index should match the index of the input dictionnary if pandas Series
Note that the index should match the index of the input dictionary if pandas Series

Raises
------
RuntimeError
Raise error if the union of time support of Ts/Tsd object is empty.
ValueError
- If a key cannot be converted to integer.
- If a key was a floating point with non-negligible decimal part.
- If the converted keys are not unique, i.e. {1: ts_2, "2": ts_2} is valid,
{1: ts_2, "1": ts_2} is invalid.
"""
self._initialized = False

self.index = np.sort(list(data.keys()))
# convert all keys to integer
try:
keys = [int(k) for k in data.keys()]
except Exception:
raise ValueError("All keys must be convertible to integer.")

# check that there were no floats with decimal points in keys.i
# i.e. 0.5 is not a valid key
if not all(np.allclose(keys[j], float(k)) for j, k in enumerate(data.keys())):
raise ValueError("All keys must have integer value!}")

# check that we have the same num of unique keys
# {"0":val, 0:val} would be a problem...
if len(keys) != len(np.unique(keys)):
raise ValueError("Two dictionary keys contain the same integer value!")

data = {keys[j]: data[k] for j, k in enumerate(data.keys())}
self.index = np.sort(keys)

self._metadata = pd.DataFrame(index=self.index, columns=["rate"], dtype="float")

Expand Down
64 changes: 48 additions & 16 deletions tests/test_ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,59 @@
import pytest
from collections import UserDict
import warnings
from contextlib import nullcontext as does_not_raise


@pytest.mark.parametrize(
"group",
[
{
0: nap.Ts(t=np.arange(0, 200)),
1: nap.Ts(t=np.arange(0, 200, 0.5), time_units="s"),
2: nap.Ts(t=np.arange(0, 300, 0.2), time_units="s"),
}
],
)
class Test_Ts_Group_1:
@pytest.fixture
def group():
"""Fixture to be used in all tests."""
return {
0: nap.Ts(t=np.arange(0, 200)),
1: nap.Ts(t=np.arange(0, 200, 0.5), time_units="s"),
2: nap.Ts(t=np.arange(0, 300, 0.2), time_units="s"),
}


class TestTsGroup1:

def test_create_ts_group(self, group):
tsgroup = nap.TsGroup(group)
assert isinstance(tsgroup, UserDict)
assert len(tsgroup) == 3

def test_create_ts_group_from_array(self, group):
@pytest.mark.parametrize(
"test_dict, expectation",
[
({"1": nap.Ts(np.arange(10)), "2":nap.Ts(np.arange(10))}, does_not_raise()),
({"1": nap.Ts(np.arange(10)), 2: nap.Ts(np.arange(10))}, does_not_raise()),
({"1": nap.Ts(np.arange(10)), 1: nap.Ts(np.arange(10))},
pytest.raises(ValueError, match="Two dictionary keys contain the same integer")),
({"1.": nap.Ts(np.arange(10)), 2: nap.Ts(np.arange(10))},
pytest.raises(ValueError, match="All keys must be convertible")),
({-1: nap.Ts(np.arange(10)), 1: nap.Ts(np.arange(10))}, does_not_raise()),
({1.5: nap.Ts(np.arange(10)), 1: nap.Ts(np.arange(10))},
pytest.raises(ValueError, match="All keys must have integer value"))

]
)
def test_initialize_from_dict(self, test_dict, expectation):
with expectation:
nap.TsGroup(test_dict)

@pytest.mark.parametrize(
"tsgroup",
[
nap.TsGroup({"1": nap.Ts(np.arange(10)), "2": nap.Ts(np.arange(10))}),
nap.TsGroup({"1": nap.Ts(np.arange(10)), 2: nap.Ts(np.arange(10))}),
nap.TsGroup({-1: nap.Ts(np.arange(10)), 1: nap.Ts(np.arange(10))})

]
)
def test_metadata_len_match(self, tsgroup):
assert len(tsgroup._metadata) == len(tsgroup)


def test_create_ts_group_from_array(self):
with warnings.catch_warnings(record=True) as w:
nap.TsGroup({
0: np.arange(0, 200),
Expand All @@ -48,7 +82,7 @@ def test_create_ts_group_with_time_support(self, group):
assert np.all(first >= ep[0, 0])
assert np.all(last <= ep[0, 1])

def test_create_ts_group_with_empty_time_support(self, group):
def test_create_ts_group_with_empty_time_support(self):
with pytest.raises(RuntimeError) as e_info:
tmp = nap.TsGroup({
0: nap.Ts(t=np.array([])),
Expand All @@ -57,7 +91,7 @@ def test_create_ts_group_with_empty_time_support(self, group):
})
assert str(e_info.value) == "Union of time supports is empty. Consider passing a time support as argument."

def test_create_ts_group_with_bypass_check(self, group):
def test_create_ts_group_with_bypass_check(self):
tmp = {
0: nap.Ts(t=np.arange(0, 100)),
1: nap.Ts(t=np.arange(0, 200, 0.5), time_units="s"),
Expand Down Expand Up @@ -373,7 +407,6 @@ def test_threshold_error(self, group):
tsgroup.getby_threshold("sr", 1, op)
assert str(e_info.value) == "Operation {} not recognized.".format(op)


def test_intervals_slicing(self, group):
sr_info = pd.Series(index=[0, 1, 2], data=[0, 1, 2], name="sr")
tsgroup = nap.TsGroup(group, sr=sr_info)
Expand Down Expand Up @@ -460,7 +493,6 @@ def test_to_tsd(self, group):
np.testing.assert_array_almost_equal(tsd4.index, times)
np.testing.assert_array_almost_equal(tsd4.values, np.array([beta[int(i)] for i in data]))


def test_to_tsd_runtime_errors(self, group):

tsgroup = nap.TsGroup(group)
Expand Down
Loading