Skip to content

Commit 9e3e53f

Browse files
committed
IntervalSet can take IntervalSet
1 parent 0c4bf42 commit 9e3e53f

File tree

5 files changed

+33
-18
lines changed

5 files changed

+33
-18
lines changed

pynapple/core/_jitted_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -886,9 +886,9 @@ def jitcontinuous_perievent(
886886
left = np.minimum(windowsize[0], t_pos - start_t[k, 0])
887887
right = np.minimum(windowsize[1], maxt - t_pos - 1)
888888
center = windowsize[0] + 1
889-
new_data_array[center - left - 1 : center + right, cnt_i] = (
890-
data_array[t_pos - left : t_pos + right + 1]
891-
)
889+
new_data_array[
890+
center - left - 1 : center + right, cnt_i
891+
] = data_array[t_pos - left : t_pos + right + 1]
892892

893893
t -= 1
894894
i += 1

pynapple/core/interval_set.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,11 @@ def __init__(self, start, end=None, time_units="s", **kwargs):
9696
If `start` and `end` arguments are of unknown type
9797
9898
"""
99-
if isinstance(start, pd.DataFrame):
99+
if isinstance(start, IntervalSet):
100+
end = start.values[:, 1].astype(np.float64)
101+
start = start.values[:, 0].astype(np.float64)
102+
103+
elif isinstance(start, pd.DataFrame):
100104
assert (
101105
"start" in start.columns
102106
and "end" in start.columns

pynapple/core/ts_group.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ def __init__(
8888
To avoid checking that each element is within time_support.
8989
Useful to speed up initialization of TsGroup when Ts/Tsd objects have already been restricted beforehand
9090
**kwargs
91-
Meta-info about the Ts/Tsd objects. Can be either pandas.Series or numpy.ndarray.
92-
Note that the index should match the index of the input dictionnary.
91+
Meta-info about the Ts/Tsd objects. Can be either pandas.Series, numpy.ndarray, list or tuple
92+
Note that the index should match the index of the input dictionnary if pandas Series
9393
9494
Raises
9595
------
@@ -264,7 +264,7 @@ def set_info(self, *args, **kwargs):
264264
RuntimeError
265265
Raise an error if
266266
no column labels are found when passing simple arguments,
267-
indexes are not equals for a pandas series,
267+
indexes are not equals for a pandas series,+
268268
not the same length when passing numpy array.
269269
270270
Examples
@@ -308,15 +308,17 @@ def set_info(self, *args, **kwargs):
308308
self._metadata = self._metadata.join(arg)
309309
else:
310310
raise RuntimeError("Index are not equals")
311-
elif isinstance(arg, (pd.Series, np.ndarray)):
312-
raise RuntimeError("Columns needs to be labelled for metadata")
311+
elif isinstance(arg, (pd.Series, np.ndarray, list)):
312+
raise RuntimeError("Argument should be passed as keyword argument.")
313313
if len(kwargs):
314314
for k, v in kwargs.items():
315315
if isinstance(v, pd.Series):
316316
if pd.Index.equals(self._metadata.index, v.index):
317317
self._metadata[k] = v
318318
else:
319-
raise RuntimeError("Index are not equals")
319+
raise RuntimeError(
320+
"Index are not equals for argument {}".format(k)
321+
)
320322
elif isinstance(v, (np.ndarray, list, tuple)):
321323
if len(self._metadata) == len(v):
322324
self._metadata[k] = np.asarray(v)

tests/test_interval_set.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
# -*- coding: utf-8 -*-
33
# @Author: gviejo
44
# @Date: 2022-03-30 11:15:02
5-
# @Last Modified by: gviejo
6-
# @Last Modified time: 2024-02-21 21:39:07
5+
# @Last Modified by: Guillaume Viejo
6+
# @Last Modified time: 2024-03-29 11:04:32
77

88
"""Tests for IntervalSet of `pynapple` package."""
99

@@ -58,6 +58,13 @@ def test_create_iset_from_scalars():
5858
np.testing.assert_approx_equal(ep.start[0], 0)
5959
np.testing.assert_approx_equal(ep.end[0], 10)
6060

61+
def test_create_iset_from_iset():
62+
start = np.array([0, 10, 16, 25])
63+
end = np.array([5, 15, 20, 40])
64+
ep = nap.IntervalSet(start=start, end=end)
65+
ep2 = nap.IntervalSet(ep)
66+
np.testing.assert_array_almost_equal(ep.start, ep2.start)
67+
np.testing.assert_array_almost_equal(ep.end, ep2.end)
6168

6269
def test_create_iset_from_df():
6370
df = pd.DataFrame(data=[[16, 100]], columns=["start", "end"])

tests/test_ts_group.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# -*- coding: utf-8 -*-
22
# @Author: gviejo
33
# @Date: 2022-03-30 11:14:41
4-
# @Last Modified by: gviejo
5-
# @Last Modified time: 2024-02-19 15:11:43
4+
# @Last Modified by: Guillaume Viejo
5+
# @Last Modified time: 2024-03-29 12:08:48
66

77
"""Tests of ts group for `pynapple` package."""
88

@@ -112,30 +112,32 @@ def test_add_metainfo_raise_error(self, group):
112112

113113
with pytest.raises(RuntimeError) as e_info:
114114
tsgroup.set_info(sr_info)
115-
assert str(e_info.value) == "Columns needs to be labelled for metadata"
115+
assert str(e_info.value) == "Argument should be passed as keyword argument."
116116

117117
tsgroup = nap.TsGroup(group)
118118
ar_info = np.ones(3) * 3
119119

120120
with pytest.raises(RuntimeError) as e_info:
121121
tsgroup.set_info(ar_info)
122-
assert str(e_info.value) == "Columns needs to be labelled for metadata"
122+
assert str(e_info.value) == "Argument should be passed as keyword argument."
123123

124124

125125
def test_add_metainfo_test_runtime_errors(self, group):
126126
tsgroup = nap.TsGroup(group)
127127
sr_info = pd.Series(index=[1, 2, 3], data=[1, 1, 1], name="sr")
128128
with pytest.raises(Exception) as e_info:
129129
tsgroup.set_info(sr=sr_info)
130-
assert str(e_info.value) == "Index are not equals"
130+
assert str(e_info.value) == "Index are not equals for argument sr"
131131
df_info = pd.DataFrame(index=[1, 2, 3], data=[1, 1, 1], columns=["df"])
132132
with pytest.raises(Exception) as e_info:
133133
tsgroup.set_info(df_info)
134134
assert str(e_info.value) == "Index are not equals"
135+
135136
sr_info = pd.Series(index=[1, 2, 3], data=[1, 1, 1], name="sr")
136137
with pytest.raises(Exception) as e_info:
137138
tsgroup.set_info(sr_info)
138-
assert str(e_info.value) == "Columns needs to be labelled for metadata"
139+
assert str(e_info.value) == "Argument should be passed as keyword argument."
140+
139141
ar_info = np.ones(4)
140142
with pytest.raises(Exception) as e_info:
141143
tsgroup.set_info(ar=ar_info)

0 commit comments

Comments
 (0)