Skip to content

Commit

Permalink
Merge pull request #265 from pynapple-org/ts_group_repr
Browse files Browse the repository at this point in the history
Ts group repr
  • Loading branch information
gviejo authored Apr 11, 2024
2 parents cf46f90 + f7362da commit 07b131e
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 32 deletions.
78 changes: 72 additions & 6 deletions pynapple/core/ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def __getitem__(self, key):
elif key in self._metadata.columns:
return self.get_info(key)
else:
raise KeyError(f"Can't find key {key} in group index.")
raise KeyError(r"Key {} not in group index.".format(key))

# array boolean are transformed into indices
# note that raw boolean are hashable, and won't be
Expand All @@ -242,6 +242,12 @@ def __getitem__(self, key):
f"has length {len(key)} instead!"
)
key = self.index[key]

keys_not_in = list(filter(lambda x: x not in self.index, key))

if len(keys_not_in):
raise KeyError(r"Key {} not in group index.".format(keys_not_in))

return self._ts_group_from_keys(key)

def _ts_group_from_keys(self, keys):
Expand All @@ -255,13 +261,73 @@ def _ts_group_from_keys(self, keys):
def __repr__(self):
cols = self._metadata.columns.drop("rate")
headers = ["Index", "rate"] + [c for c in cols]

max_cols = 6
max_rows = 2

try:
max_cols, max_rows = os.get_terminal_size()
max_cols = max_cols // 12
max_rows = max_rows - 10
except Exception:
import shutil

max_cols, max_rows = shutil.get_terminal_size()
max_cols = max_cols // 12
max_rows = max_rows - 10
else:
pass

max_rows = np.maximum(max_rows, 2)
max_cols = np.maximum(max_cols, 6)

end_line = []
lines = []

for i in self.data.keys():
lines.append(
[str(i), "%.2f" % self._metadata.loc[i, "rate"]]
+ [self._metadata.loc[i, c] for c in cols]
)
def round_if_float(x):
if isinstance(x, float):
return np.round(x, 5)
else:
return x

if len(headers) > max_cols:
headers = headers[0:max_cols] + ["..."]
end_line.append("...")

if len(self) > max_rows:
n_rows = max_rows // 2
index = self.keys()

for i in index[0:n_rows]:
lines.append(
[i, np.round(self._metadata.loc[i, "rate"], 5)]
+ [
round_if_float(self._metadata.loc[i, c])
for c in cols[0 : max_cols - 2]
]
+ end_line
)
lines.append(["..." for _ in range(len(headers))])
for i in index[-n_rows:]:
lines.append(
[i, np.round(self._metadata.loc[i, "rate"], 5)]
+ [
round_if_float(self._metadata.loc[i, c])
for c in cols[0 : max_cols - 2]
]
+ end_line
)
else:
for i in self.data.keys():
lines.append(
[i, np.round(self._metadata.loc[i, "rate"], 5)]
+ [
round_if_float(self._metadata.loc[i, c])
for c in cols[0 : max_cols - 2]
]
+ end_line
)

return tabulate(lines, headers=headers)

def __str__(self):
Expand Down
10 changes: 5 additions & 5 deletions tests/test_npz_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: Guillaume Viejo
# @Date: 2023-07-10 17:08:55
# @Last Modified by: Guillaume Viejo
# @Last Modified time: 2023-07-10 18:14:49
# @Last Modified time: 2024-04-11 13:13:37

"""Tests of NPZ file functions"""

Expand Down Expand Up @@ -73,7 +73,7 @@ def test_load_tsgroup(path, k):
assert tmp.keys() == data[k].keys()
assert np.all(tmp._metadata == data[k]._metadata)
assert np.all(tmp[neu] == data[k][neu] for neu in tmp.keys())
assert np.all(tmp.time_support == data[k].time_support)
np.testing.assert_array_almost_equal(tmp.time_support.values, data[k].time_support.values)


@pytest.mark.parametrize("path", [path])
Expand All @@ -85,7 +85,7 @@ def test_load_tsd(path, k):
assert type(tmp) == type(data[k])
assert np.all(tmp.d == data[k].d)
assert np.all(tmp.t == data[k].t)
assert np.all(tmp.time_support == data[k].time_support)
np.testing.assert_array_almost_equal(tmp.time_support.values, data[k].time_support.values)


@pytest.mark.parametrize("path", [path])
Expand All @@ -96,7 +96,7 @@ def test_load_ts(path, k):
tmp = file.load()
assert type(tmp) == type(data[k])
assert np.all(tmp.t == data[k].t)
assert np.all(tmp.time_support == data[k].time_support)
np.testing.assert_array_almost_equal(tmp.time_support.values, data[k].time_support.values)



Expand All @@ -108,7 +108,7 @@ def test_load_tsdframe(path, k):
tmp = file.load()
assert type(tmp) == type(data[k])
assert np.all(tmp.t == data[k].t)
assert np.all(tmp.time_support == data[k].time_support)
np.testing.assert_array_almost_equal(tmp.time_support.values, data[k].time_support.values)
assert np.all(tmp.columns == data[k].columns)
assert np.all(tmp.d == data[k].d)

Expand Down
39 changes: 18 additions & 21 deletions tests/test_ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: gviejo
# @Date: 2022-03-30 11:14:41
# @Last Modified by: Guillaume Viejo
# @Last Modified time: 2024-04-01 18:07:53
# @Last Modified time: 2024-04-11 14:42:50

"""Tests of ts group for `pynapple` package."""

Expand Down Expand Up @@ -426,33 +426,30 @@ def test_repr_(self, group):
from tabulate import tabulate

tsgroup = nap.TsGroup(group)
tsgroup.set_info(abc = ['a']*len(tsgroup))
tsgroup.set_info(bbb = [1]*len(tsgroup))
tsgroup.set_info(ccc = [np.pi]*len(tsgroup))

cols = tsgroup._metadata.columns.drop("rate")
headers = ["Index", "rate"] + [c for c in cols]
lines = []

def round_if_float(x):
if isinstance(x, float):
return np.round(x, 5)
else:
return x

for i in tsgroup.index:
lines.append(
[str(i), "%.2f" % tsgroup._metadata.loc[i, "rate"]]
+ [tsgroup._metadata.loc[i, c] for c in cols]
[str(i), np.round(tsgroup._metadata.loc[i, "rate"], 5)]
+ [round_if_float(tsgroup._metadata.loc[i, c]) for c in cols]
)
assert tabulate(lines, headers=headers) == tsgroup.__repr__()

def test_str_(self, group):
from tabulate import tabulate

tsgroup = nap.TsGroup(group)

cols = tsgroup._metadata.columns.drop("rate")
headers = ["Index", "rate"] + [c for c in cols]
lines = []

for i in tsgroup.index:
lines.append(
[str(i), "%.2f" % tsgroup._metadata.loc[i, "rate"]]
+ [tsgroup._metadata.loc[i, c] for c in cols]
)
assert tabulate(lines, headers=headers) == tsgroup.__str__()
tsgroup = nap.TsGroup(group)
assert tsgroup.__str__() == tsgroup.__repr__()

def test_to_tsd(self, group):
t = []
Expand Down Expand Up @@ -602,9 +599,9 @@ def test_save_npz(self, group):
(np.array([False, True, True]), does_not_raise()),
([False, True, True], does_not_raise()),
(True, does_not_raise()),
(4, pytest.raises(KeyError, match="Can't find key")),
([3, 4], pytest.raises(KeyError, match= r"None of \[Index\(\[3, 4\]")),
([2, 3], pytest.raises(KeyError, match=r"\[3\] not in index"))
(4, pytest.raises(KeyError, match="Key 4 not in group index.")),
([3, 4], pytest.raises(KeyError, match= r"Key \[3, 4\] not in group index.")),
([2, 3], pytest.raises(KeyError, match= r"Key \[3\] not in group index."))
]
)
def test_indexing_type(self, group, keys, expectation):
Expand Down Expand Up @@ -738,7 +735,7 @@ def test_getitem_metadata_direct(self, ts_group):
assert np.all(ts_group.rates == np.array([10/9, 5/9]))

def test_getitem_key_error(self, ts_group):
with pytest.raises(KeyError, match="Can\'t find key nonexistent"):
with pytest.raises(KeyError, match="Key nonexistent not in group index."):
_ = ts_group['nonexistent']

def test_getitem_attribute_error(self, ts_group):
Expand Down

0 comments on commit 07b131e

Please sign in to comment.