Skip to content

Commit 3f59ca9

Browse files
authored
Merge pull request #382 from pynapple-org/bugfix
Bugfix on TsdFrame getitem
2 parents ce27247 + c8d21f3 commit 3f59ca9

File tree

3 files changed

+25
-16
lines changed

3 files changed

+25
-16
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ jobs:
8585
with:
8686
directory: "./doc/_build/html"
8787
# The directory to scan
88-
arguments: --checks Links,Scripts --ignore-urls "https://fonts.gstatic.com,https://mkdocs-gallery.github.io,./doc/_build/html/_static/" --assume-extension --check-external-hash --ignore-status-codes 403 --ignore-files "/.+\/html\/_static\/.+/"
88+
arguments: --checks Links,Scripts --ignore-urls "https://fonts.gstatic.com,https://mkdocs-gallery.github.io,./doc/_build/html/_static/,https://www.nature.com/articles/s41593-022-01020-w" --assume-extension --check-external-hash --ignore-status-codes 403 --ignore-files "/.+\/html\/_static\/.+/"
8989
# The arguments to pass to HTMLProofer
9090

9191

pynapple/core/time_series.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,23 +1169,24 @@ def __getitem__(self, key, *args, **kwargs):
11691169
index = np.array([index])
11701170

11711171
if all(is_array_like(a) for a in [index, output]):
1172-
if (
1173-
(len(index) == 1)
1174-
and (output.ndim == 1)
1175-
and ((len(output) > 1) or isinstance(key[1], (list, np.ndarray)))
1176-
):
1177-
# reshape output of single index to preserve column axis if there are more than one columns being indexed
1178-
# or if column key is a list or array
1172+
if isinstance(key, tuple):
1173+
if (
1174+
len(index) == 1
1175+
and output.ndim == 1
1176+
and not isinstance(key[1], int)
1177+
):
1178+
output = output[None, :]
1179+
elif (
1180+
(output.ndim == 1)
1181+
and isinstance(key[1], (list, np.ndarray))
1182+
and (len(columns) == 1)
1183+
):
1184+
# reshape output of single column if column key is a list or array
1185+
output = output[:, None]
1186+
# if getting a row (1 dim implied)
1187+
elif isinstance(key, Number):
11791188
output = output[None, :]
11801189

1181-
elif (
1182-
(output.ndim == 1)
1183-
and isinstance(key[1], (list, np.ndarray))
1184-
and (len(columns) == 1)
1185-
):
1186-
# reshape output of single column if column key is a list or array
1187-
output = output[:, None]
1188-
11891190
kwargs["columns"] = columns
11901191
kwargs["metadata"] = self._metadata.loc[columns]
11911192
return _initialize_tsd_output(

tests/test_time_series.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,7 @@ def test_interpolate_with_ep(self, tsd):
954954
@pytest.mark.parametrize(
955955
"tsdframe",
956956
[
957+
nap.TsdFrame(t=np.arange(100), d=np.random.rand(100, 1), time_units="s"),
957958
nap.TsdFrame(t=np.arange(100), d=np.random.rand(100, 3), time_units="s"),
958959
nap.TsdFrame(
959960
t=np.arange(100),
@@ -997,6 +998,7 @@ def test_copy(self, tsdframe):
997998
],
998999
)
9991000
def test_horizontal_slicing(self, tsdframe, index, nap_type):
1001+
index = index if isinstance(index, int) else index[: tsdframe.shape[1]]
10001002
assert isinstance(tsdframe[:, index], nap_type)
10011003
np.testing.assert_array_almost_equal(
10021004
tsdframe[:, index].values, tsdframe.values[:, index]
@@ -1067,6 +1069,12 @@ def test_vertical_slicing(self, tsdframe, index):
10671069
],
10681070
)
10691071
def test_vert_and_horz_slicing(self, tsdframe, row, col, expected):
1072+
if tsdframe.shape[1] == 1:
1073+
if isinstance(col, list) and isinstance(col[0], int):
1074+
col = [0]
1075+
elif isinstance(col, list) and isinstance(col[0], bool):
1076+
col = [col[0]]
1077+
10701078
# get details about row index
10711079
row_array = isinstance(row, (list, np.ndarray))
10721080
if row_array and isinstance(row[0], (bool, np.bool_)):

0 commit comments

Comments
 (0)