From 7e085a1daad72869d5d5241026036133bd934db6 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 11 Apr 2024 11:56:39 -0400 Subject: [PATCH 1/3] Improving repr of TsGroup --- pynapple/core/ts_group.py | 70 ++++++++++++++++++++++++++++++++++++--- tests/test_ts_group.py | 31 ++++++++--------- 2 files changed, 79 insertions(+), 22 deletions(-) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index abd3254c..7ac04208 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -255,13 +255,73 @@ def _ts_group_from_keys(self, keys): def __repr__(self): cols = self._metadata.columns.drop("rate") headers = ["Index", "rate"] + [c for c in cols] + + max_cols = 6 + max_rows = 2 + + try: + max_cols, max_rows = os.get_terminal_size() + max_cols = max_cols // 12 + max_rows = max_rows - 10 + except Exception: + import shutil + + max_cols, max_rows = shutil.get_terminal_size() + max_cols = max_cols // 12 + max_rows = max_rows - 10 + else: + pass + + max_rows = np.maximum(max_rows, 2) + max_cols = np.maximum(max_cols, 6) + + end_line = [] lines = [] - for i in self.data.keys(): - lines.append( - [str(i), "%.2f" % self._metadata.loc[i, "rate"]] - + [self._metadata.loc[i, c] for c in cols] - ) + def round_if_float(x): + if isinstance(x, float): + return np.round(x, 5) + else: + return x + + if len(headers) > max_cols: + headers = headers[0:max_cols] + ["..."] + end_line.append("...") + + if len(self) > max_rows: + n_rows = max_rows // 2 + index = self.keys() + + for i in index[0:n_rows]: + lines.append( + [i, np.round(self._metadata.loc[i, "rate"], 5)] + + [ + round_if_float(self._metadata.loc[i, c]) + for c in cols[0 : max_cols - 2] + ] + + end_line + ) + lines.append(["..." for _ in range(len(headers))]) + for i in index[-n_rows:]: + lines.append( + [i, np.round(self._metadata.loc[i, "rate"], 5)] + + [ + round_if_float(self._metadata.loc[i, c]) + for c in cols[0 : max_cols - 2] + ] + + end_line + ) + else: + for i in self.data.keys(): + lines.append( + [i, np.round(self._metadata.loc[i, "rate"], 5)] + + [ + round_if_float(self._metadata.loc[i, c]) + for c in cols[0 : max_cols - 2] + ] + + end_line + ) + return tabulate(lines, headers=headers) def __str__(self): diff --git a/tests/test_ts_group.py b/tests/test_ts_group.py index cd073ef0..23f33ba6 100644 --- a/tests/test_ts_group.py +++ b/tests/test_ts_group.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-03-30 11:14:41 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-04-01 18:07:53 +# @Last Modified time: 2024-04-11 11:42:01 """Tests of ts group for `pynapple` package.""" @@ -426,33 +426,30 @@ def test_repr_(self, group): from tabulate import tabulate tsgroup = nap.TsGroup(group) + tsgroup.set_info(abc = ['a']*len(tsgroup)) + tsgroup.set_info(bbb = [1]*len(tsgroup)) + tsgroup.set_info(ccc = [np.pi]*len(tsgroup)) cols = tsgroup._metadata.columns.drop("rate") headers = ["Index", "rate"] + [c for c in cols] lines = [] + def round_if_float(x): + if isinstance(x, float): + return np.round(x, 5) + else: + return x + for i in tsgroup.index: lines.append( - [str(i), "%.2f" % tsgroup._metadata.loc[i, "rate"]] - + [tsgroup._metadata.loc[i, c] for c in cols] + [str(i), np.round(tsgroup._metadata.loc[i, "rate"], 5)] + + [round_if_float(tsgroup._metadata.loc[i, c]) for c in cols] ) assert tabulate(lines, headers=headers) == tsgroup.__repr__() def test_str_(self, group): - from tabulate import tabulate - - tsgroup = nap.TsGroup(group) - - cols = tsgroup._metadata.columns.drop("rate") - headers = ["Index", "rate"] + [c for c in cols] - lines = [] - - for i in tsgroup.index: - lines.append( - [str(i), "%.2f" % tsgroup._metadata.loc[i, "rate"]] - + [tsgroup._metadata.loc[i, c] for c in cols] - ) - assert tabulate(lines, headers=headers) == tsgroup.__str__() + tsgroup = nap.TsGroup(group) + assert tsgroup.__str__() == tsgroup.__repr__() def test_to_tsd(self, group): t = [] From 02afb43687a85978bae1639e3f118006121e74ba Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 11 Apr 2024 14:45:51 -0400 Subject: [PATCH 2/3] linting and fixing tests --- pynapple/core/ts_group.py | 8 +++++++- tests/test_npz_file.py | 10 +++++----- tests/test_ts_group.py | 10 +++++----- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 7ac04208..0be7b3f3 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -226,7 +226,7 @@ def __getitem__(self, key): elif key in self._metadata.columns: return self.get_info(key) else: - raise KeyError(f"Can't find key {key} in group index.") + raise KeyError(r"Key {} not in group index.".format(key)) # array boolean are transformed into indices # note that raw boolean are hashable, and won't be @@ -242,6 +242,12 @@ def __getitem__(self, key): f"has length {len(key)} instead!" ) key = self.index[key] + + keys_not_in = list(filter(lambda x: x not in self.index, key)) + + if len(keys_not_in): + raise KeyError(r"Key {} not in group index.".format(keys_not_in)) + return self._ts_group_from_keys(key) def _ts_group_from_keys(self, keys): diff --git a/tests/test_npz_file.py b/tests/test_npz_file.py index 0a36cdd4..74eec96f 100644 --- a/tests/test_npz_file.py +++ b/tests/test_npz_file.py @@ -2,7 +2,7 @@ # @Author: Guillaume Viejo # @Date: 2023-07-10 17:08:55 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-07-10 18:14:49 +# @Last Modified time: 2024-04-11 13:13:37 """Tests of NPZ file functions""" @@ -73,7 +73,7 @@ def test_load_tsgroup(path, k): assert tmp.keys() == data[k].keys() assert np.all(tmp._metadata == data[k]._metadata) assert np.all(tmp[neu] == data[k][neu] for neu in tmp.keys()) - assert np.all(tmp.time_support == data[k].time_support) + np.testing.assert_array_almost_equal(tmp.time_support.values, data[k].time_support.values) @pytest.mark.parametrize("path", [path]) @@ -85,7 +85,7 @@ def test_load_tsd(path, k): assert type(tmp) == type(data[k]) assert np.all(tmp.d == data[k].d) assert np.all(tmp.t == data[k].t) - assert np.all(tmp.time_support == data[k].time_support) + np.testing.assert_array_almost_equal(tmp.time_support.values, data[k].time_support.values) @pytest.mark.parametrize("path", [path]) @@ -96,7 +96,7 @@ def test_load_ts(path, k): tmp = file.load() assert type(tmp) == type(data[k]) assert np.all(tmp.t == data[k].t) - assert np.all(tmp.time_support == data[k].time_support) + np.testing.assert_array_almost_equal(tmp.time_support.values, data[k].time_support.values) @@ -108,7 +108,7 @@ def test_load_tsdframe(path, k): tmp = file.load() assert type(tmp) == type(data[k]) assert np.all(tmp.t == data[k].t) - assert np.all(tmp.time_support == data[k].time_support) + np.testing.assert_array_almost_equal(tmp.time_support.values, data[k].time_support.values) assert np.all(tmp.columns == data[k].columns) assert np.all(tmp.d == data[k].d) diff --git a/tests/test_ts_group.py b/tests/test_ts_group.py index 23f33ba6..65d23af4 100644 --- a/tests/test_ts_group.py +++ b/tests/test_ts_group.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-03-30 11:14:41 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-04-11 11:42:01 +# @Last Modified time: 2024-04-11 14:42:50 """Tests of ts group for `pynapple` package.""" @@ -599,9 +599,9 @@ def test_save_npz(self, group): (np.array([False, True, True]), does_not_raise()), ([False, True, True], does_not_raise()), (True, does_not_raise()), - (4, pytest.raises(KeyError, match="Can't find key")), - ([3, 4], pytest.raises(KeyError, match= r"None of \[Index\(\[3, 4\]")), - ([2, 3], pytest.raises(KeyError, match=r"\[3\] not in index")) + (4, pytest.raises(KeyError, match="Key 4 not in group index.")), + ([3, 4], pytest.raises(KeyError, match= r"Key \[3, 4\] not in group index.")), + ([2, 3], pytest.raises(KeyError, match= r"Key \[3\] not in group index.")) ] ) def test_indexing_type(self, group, keys, expectation): @@ -735,7 +735,7 @@ def test_getitem_metadata_direct(self, ts_group): assert np.all(ts_group.rates == np.array([10/9, 5/9])) def test_getitem_key_error(self, ts_group): - with pytest.raises(KeyError, match="Can\'t find key nonexistent"): + with pytest.raises(KeyError, match="Key nonexistent not in group index."): _ = ts_group['nonexistent'] def test_getitem_attribute_error(self, ts_group): From f7362daf05ae8ee4df6822d8a2bc82a5a46ee9b7 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 11 Apr 2024 14:46:07 -0400 Subject: [PATCH 3/3] linting --- pynapple/core/ts_group.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 0be7b3f3..4e97aeed 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -247,7 +247,7 @@ def __getitem__(self, key): if len(keys_not_in): raise KeyError(r"Key {} not in group index.".format(keys_not_in)) - + return self._ts_group_from_keys(key) def _ts_group_from_keys(self, keys):