From 21d48935aac6f7f054c307cf587cf1432f4c8a4c Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Mon, 4 Nov 2024 16:27:48 -0500 Subject: [PATCH 01/18] allow metadata as any name for interval set with tests --- pynapple/core/interval_set.py | 22 ++-- pynapple/core/metadata_class.py | 32 +++--- tests/test_metadata.py | 182 ++++++++++++++++++++------------ tests/test_test.py | 86 +++++++++++++++ tests/test_ts_group.py | 4 +- 5 files changed, 233 insertions(+), 93 deletions(-) create mode 100644 tests/test_test.py diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 86213bfd..9a893ca2 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -214,10 +214,11 @@ def __init__( self.index = np.arange(data.shape[0], dtype="int") self.columns = np.array(["start", "end"]) self.nap_class = self.__class__.__name__ - if drop_meta: - _MetadataMixin.__init__(self) - else: - _MetadataMixin.__init__(self, metadata) + # initialize metadata to get all attributes before setting metadata + _MetadataMixin.__init__(self) + self._class_attributes = self.__dir__() + if drop_meta is False: + self.set_info(metadata) self._initialized = True def __repr__(self): @@ -286,12 +287,21 @@ def __len__(self): def __setattr__(self, name, value): # necessary setter to allow metadata to be set as an attribute if self._initialized: - _MetadataMixin.__setattr__(self, name, value) + if name in self._class_attributes: + raise AttributeError( + f"Cannot set attribute '{name}'; IntervalSet is immutable. Use 'set_info()' to set '{name}' as metadata." + ) + else: + _MetadataMixin.__setattr__(self, name, value) else: object.__setattr__(self, name, value) def __setitem__(self, key, value): - if (isinstance(key, str)) and (key not in self.columns): + if key in self.columns: + raise RuntimeError( + "IntervalSet is immutable. Starts and ends have been already sorted." + ) + elif isinstance(key, str): _MetadataMixin.__setitem__(self, key, value) else: raise RuntimeError( diff --git a/pynapple/core/metadata_class.py b/pynapple/core/metadata_class.py index dd5126ab..077c84d8 100644 --- a/pynapple/core/metadata_class.py +++ b/pynapple/core/metadata_class.py @@ -115,32 +115,32 @@ def _raise_invalid_metadata_column_name(self, name): raise TypeError( f"Invalid metadata type {type(name)}. Metadata column names must be strings!" ) + # warnings for metadata names that cannot be accessed as attributes or keys if hasattr(self, name) and (name not in self.metadata_columns): # existing non-metadata attribute - raise ValueError( - f"Invalid metadata name '{name}'. Metadata name must differ from " - f"{type(self).__dict__.keys()} attribute names!" - ) - if hasattr(self, "columns") and name in self.columns: - # existing column (since TsdFrame columns are not attributes) - raise ValueError( - f"Invalid metadata name '{name}'. Metadata name must differ from " - f"{self.columns} column names!" + warnings.warn( + f"Metadata name '{name}' overlaps with an existing attribute, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata." ) - if name[0].isalpha() is False: - # starts with a number - raise ValueError( - f"Invalid metadata name '{name}'. Metadata name cannot start with a number" + elif hasattr(self, "columns") and name in self.columns: + # existing non-metadata attribute + warnings.warn( + f"Metadata name '{name}' overlaps with an existing property, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata." ) + # warnings for metadata that cannot be accessed as attributes if name.replace("_", "").isalnum() is False: # contains invalid characters - raise ValueError( - f"Invalid metadata name '{name}'. Metadata name cannot contain special characters except for underscores" + warnings.warn( + f"Metadata name '{name}' contains a special character, and cannot be accessed as an attribute. Use 'get_info()' or key indexing to access metadata." + ) + elif name[0].isalpha() is False: + # starts with a number + warnings.warn( + f"Metadata name '{name}' starts with a number, and cannot be accessed as an attribute. Use 'get_info()' or key indexing to access metadata." ) def _check_metadata_column_names(self, metadata=None, **kwargs): """ - Check that metadata column names don't conflict with existing attributes, don't start with a number, and don't contain invalid characters. + Throw warnings when metadata names cannot be accessed as attributes or keys. """ if metadata is not None: diff --git a/tests/test_metadata.py b/tests/test_metadata.py index bbcf3319..13dc478e 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -313,6 +313,71 @@ def test_drop_metadata_warnings(iset_meta): iset_meta.time_span() +@pytest.mark.parametrize( + "name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp", + [ + # existing attribute and key + ( + "start", + pytest.warns(UserWarning, match="overlaps with an existing"), + pytest.raises(AttributeError, match="IntervalSet is immutable"), + pytest.raises(RuntimeError, match="IntervalSet is immutable"), + does_not_raise(), + does_not_raise(), + ), + # existing attribute and key + ( + "end", + pytest.warns(UserWarning, match="overlaps with an existing"), + pytest.raises(AttributeError, match="IntervalSet is immutable"), + pytest.raises(RuntimeError, match="IntervalSet is immutable"), + does_not_raise(), + does_not_raise(), + ), + # existing attribute + ( + "values", + pytest.warns(UserWarning, match="overlaps with an existing"), + pytest.raises(AttributeError, match="IntervalSet is immutable"), + does_not_raise(), + pytest.raises(ValueError), # shape mismatch + pytest.raises(AssertionError), # we do want metadata + ), + # existing metdata + ( + "label", + does_not_raise(), + does_not_raise(), + does_not_raise(), + pytest.raises(AssertionError), # we do want metadata + pytest.raises(AssertionError), # we do want metadata + ), + ], +) +def test_iset_metadata_overlapping_names( + iset_meta, name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp +): + assert hasattr(iset_meta, name) + + # warning when set + with set_exp: + iset_meta.set_info({name: np.ones(4)}) + # error when set as attribute + with set_attr_exp: + setattr(iset_meta, name, np.ones(4)) + # error when set as key + with set_key_exp: + iset_meta[name] = np.ones(4) + # retrieve with get_info + assert np.all(iset_meta.get_info(name) == np.ones(4)) + # make sure it doesn't access metadata if its an existing attribute or key + with get_attr_exp: + assert np.all(getattr(iset_meta, name) == np.ones(4)) == False + # make sure it doesn't access metadata if its an existing key + with get_key_exp: + assert np.all(iset_meta[name] == np.ones(4)) == False + + ############## ## TsdFrame ## ############## @@ -347,26 +412,21 @@ def test_tsdframe_metadata_slicing(tsdframe_meta): ) -@pytest.mark.parametrize( - "args, kwargs, expected", - [ - ( - # invalid metadata names that are the same as column names - [ - pd.DataFrame( - index=["a", "b", "c"], - columns=["a", "b", "c"], - data=np.random.randint(0, 5, size=(3, 3)), - ) - ], - {}, - pytest.raises(ValueError, match="Invalid metadata name"), - ), - ], -) -def test_tsdframe_add_metadata_error(tsdframe_meta, args, kwargs, expected): - with expected: - tsdframe_meta.set_info(*args, **kwargs) +# @pytest.mark.parametrize( +# "name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp", +# [ +# ( +# # invalid metadata names that are the same as column names +# "a", +# pytest.warns(UserWarning, match="overlaps with an existing"), +# ), +# ], +# ) +# def test_tsdframe_metadata_overlapping_names(tsdframe_meta, args, kwargs, expected): +# assert + +# with expected: +# tsdframe_meta.set_info(*args, **kwargs) ############# @@ -575,52 +635,26 @@ def test_add_metadata_df(self, obj, info, obj_len): ( # invalid names as strings starting with a number [ - pd.DataFrame( - columns=["1"], - data=np.ones((4, 1)), - ) + {"1": np.ones(4)}, ], {}, - pytest.raises(ValueError, match="Invalid metadata name"), + pytest.warns(UserWarning, match="starts with a number"), ), ( # invalid names with spaces [ - pd.DataFrame( - columns=["l 1"], - data=np.ones((4, 1)), - ) + {"l 1": np.ones(4)}, ], {}, - pytest.raises(ValueError, match="Invalid metadata name"), + pytest.warns(UserWarning, match="contains a special character"), ), ( # invalid names with periods [ - pd.DataFrame( - columns=["l.1"], - data=np.ones((4, 1)), - ) - ], - {}, - pytest.raises(ValueError, match="Invalid metadata name"), - ), - ( - # invalid names with dashes - [ - pd.DataFrame( - columns=["l-1"], - data=np.ones((4, 1)), - ) + {"1.1": np.ones(4)}, ], {}, - pytest.raises(ValueError, match="Invalid metadata name"), - ), - ( - # name that overlaps with existing attribute - [], - {"__dir__": np.zeros(4)}, - pytest.raises(ValueError, match="Invalid metadata name"), + pytest.warns(UserWarning, match="contains a special character"), ), ( # metadata with wrong length @@ -633,25 +667,33 @@ def test_add_metadata_df(self, obj, info, obj_len): ), ], ) - def test_add_metadata_error(self, obj, args, kwargs, expected): + def test_add_metadata_error(self, obj, obj_len, args, kwargs, expected): + # trim to appropriate length + if len(args): + if isinstance(args[0], pd.DataFrame): + metadata = args[0].iloc[:obj_len] + elif isinstance(args[0], dict): + metadata = {k: v[:obj_len] for k, v in args[0].items()} + else: + metadata = None with expected: - obj.set_info(*args, **kwargs) - - def test_add_metadata_key_error(self, obj, obj_len): - # type specific key errors - info = np.ones(obj_len) - if isinstance(obj, nap.IntervalSet): - with pytest.raises(RuntimeError, match="IntervalSet is immutable"): - obj[0] = info - with pytest.raises(RuntimeError, match="IntervalSet is immutable"): - obj["start"] = info - with pytest.raises(RuntimeError, match="IntervalSet is immutable"): - obj["end"] = info - - elif isinstance(obj, nap.TsGroup): - # currently obj[0] does not raise an error for TsdFrame - with pytest.raises(TypeError, match="Metadata keys must be strings!"): - obj[0] = info + obj.set_info(metadata, **kwargs) + + # def test_add_metadata_key_error(self, obj, obj_len): + # # type specific key errors + # info = np.ones(obj_len) + # if isinstance(obj, nap.IntervalSet): + # with pytest.raises(RuntimeError, match="IntervalSet is immutable"): + # obj[0] = info + # with pytest.raises(RuntimeError, match="IntervalSet is immutable"): + # obj["start"] = info + # with pytest.raises(RuntimeError, match="IntervalSet is immutable"): + # obj["end"] = info + + # elif isinstance(obj, nap.TsGroup): + # # currently obj[0] does not raise an error for TsdFrame + # with pytest.raises(TypeError, match="Metadata keys must be strings!"): + # obj[0] = info def test_overwrite_metadata(self, obj, obj_len): # add metadata diff --git a/tests/test_test.py b/tests/test_test.py new file mode 100644 index 00000000..e23e2959 --- /dev/null +++ b/tests/test_test.py @@ -0,0 +1,86 @@ +from numbers import Number +import inspect + + +import pickle +import numpy as np +import pandas as pd +import pytest +from pathlib import Path +from contextlib import nullcontext as does_not_raise +import warnings + +import pynapple as nap + + +@pytest.fixture +def iset_meta(): + start = np.array([0, 10, 16, 25]) + end = np.array([5, 15, 20, 40]) + metadata = {"label": ["a", "b", "c", "d"], "info": np.arange(4)} + return nap.IntervalSet(start=start, end=end, metadata=metadata) + + +@pytest.mark.parametrize( + "name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp", + [ + # existing attribute and key + ( + "start", + pytest.warns(UserWarning, match="overlaps with an existing attribute"), + pytest.raises(AttributeError, match="IntervalSet is immutable"), + pytest.raises(RuntimeError, match="IntervalSet is immutable"), + does_not_raise(), + does_not_raise(), + ), + # existing attribute and key + ( + "end", + pytest.warns(UserWarning, match="overlaps with an existing attribute"), + pytest.raises(AttributeError, match="IntervalSet is immutable"), + pytest.raises(RuntimeError, match="IntervalSet is immutable"), + does_not_raise(), + does_not_raise(), + ), + # existing attribute + ( + "values", + pytest.warns(UserWarning, match="overlaps with an existing attribute"), + pytest.raises(AttributeError, match="IntervalSet is immutable"), + does_not_raise(), + pytest.raises(ValueError), # shape mismatch + pytest.raises(AssertionError), # we do want metadata + ), + # existing metdata + ( + "label", + does_not_raise(), + does_not_raise(), + does_not_raise(), + pytest.raises(AssertionError), # we do want metadata + pytest.raises(AssertionError), # we do want metadata + ), + ], +) +def test_iset_metadata_overlapping_names( + iset_meta, name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp +): + assert hasattr(iset_meta, name) + + # warning when set + with set_exp: + iset_meta.set_info({name: np.ones(4)}) + # error when set as attribute + with set_attr_exp: + setattr(iset_meta, name, np.ones(4)) + # error when set as key + with set_key_exp: + iset_meta[name] = np.ones(4) + # retrieve with get_info + assert np.all(iset_meta.get_info(name) == np.ones(4)) + # make sure it doesn't access metadata if its an existing attribute or key + with get_attr_exp: + assert np.all(getattr(iset_meta, name) == np.ones(4)) == False + # make sure it doesn't access metadata if its an existing key + with get_key_exp: + assert np.all(iset_meta[name] == np.ones(4)) == False diff --git a/tests/test_ts_group.py b/tests/test_ts_group.py index 02b12fba..6d41b46b 100644 --- a/tests/test_ts_group.py +++ b/tests/test_ts_group.py @@ -629,7 +629,9 @@ def test_setitem_metadata_twice(self, group): assert all(group._metadata["a"] == np.arange(len(group)) + 10) def test_prevent_overwriting_existing_methods(self, ts_group): - with pytest.raises(ValueError, match=r"Invalid metadata name"): + # with pytest.raises(ValueError, match=r"Invalid metadata name"): + # ts_group["set_info"] = np.arange(2) + with pytest.warns(UserWarning, match=r"overlaps with an existing"): ts_group["set_info"] = np.arange(2) def test_getitem_ts_object(self, ts_group): From e4c3451b9a11f59b7c35d10f9b766477fd211cfb Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Tue, 5 Nov 2024 16:08:38 -0500 Subject: [PATCH 02/18] move metadata kwargs warning to TsGroup, remove kwargs from metadata init --- pynapple/core/metadata_class.py | 10 +++------- pynapple/core/ts_group.py | 5 +++++ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/pynapple/core/metadata_class.py b/pynapple/core/metadata_class.py index 077c84d8..10579d44 100644 --- a/pynapple/core/metadata_class.py +++ b/pynapple/core/metadata_class.py @@ -11,7 +11,7 @@ class _MetadataMixin: """ - def __init__(self, metadata=None, **kwargs): + def __init__(self, metadata=None): """ Metadata initializer @@ -31,12 +31,8 @@ def __init__(self, metadata=None, **kwargs): self.metadata_index = self.index self._metadata = pd.DataFrame(index=self.metadata_index) - if len(kwargs): - warnings.warn( - "initializing metadata with variable keyword arguments may be unsupported in a future version of Pynapple. Instead, initialize using the metadata argument.", - FutureWarning, - ) - self.set_info(metadata, **kwargs) + + self.set_info(metadata) def __dir__(self): """ diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index ba4b55cb..0383f8a0 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -193,6 +193,11 @@ def __init__( self._initialized = True # Trying to add argument as metainfo + if len(kwargs): + warnings.warn( + "initializing metadata with variable keyword arguments may be unsupported in a future version of Pynapple. Instead, initialize using the metadata argument.", + FutureWarning, + ) self.set_info(metadata, **kwargs) """ From 5c7d1dff8ef8738b97b5d70c1752c7bcc467e018 Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Tue, 5 Nov 2024 16:11:40 -0500 Subject: [PATCH 03/18] Delete test_test.py --- tests/test_test.py | 86 ---------------------------------------------- 1 file changed, 86 deletions(-) delete mode 100644 tests/test_test.py diff --git a/tests/test_test.py b/tests/test_test.py deleted file mode 100644 index e23e2959..00000000 --- a/tests/test_test.py +++ /dev/null @@ -1,86 +0,0 @@ -from numbers import Number -import inspect - - -import pickle -import numpy as np -import pandas as pd -import pytest -from pathlib import Path -from contextlib import nullcontext as does_not_raise -import warnings - -import pynapple as nap - - -@pytest.fixture -def iset_meta(): - start = np.array([0, 10, 16, 25]) - end = np.array([5, 15, 20, 40]) - metadata = {"label": ["a", "b", "c", "d"], "info": np.arange(4)} - return nap.IntervalSet(start=start, end=end, metadata=metadata) - - -@pytest.mark.parametrize( - "name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp", - [ - # existing attribute and key - ( - "start", - pytest.warns(UserWarning, match="overlaps with an existing attribute"), - pytest.raises(AttributeError, match="IntervalSet is immutable"), - pytest.raises(RuntimeError, match="IntervalSet is immutable"), - does_not_raise(), - does_not_raise(), - ), - # existing attribute and key - ( - "end", - pytest.warns(UserWarning, match="overlaps with an existing attribute"), - pytest.raises(AttributeError, match="IntervalSet is immutable"), - pytest.raises(RuntimeError, match="IntervalSet is immutable"), - does_not_raise(), - does_not_raise(), - ), - # existing attribute - ( - "values", - pytest.warns(UserWarning, match="overlaps with an existing attribute"), - pytest.raises(AttributeError, match="IntervalSet is immutable"), - does_not_raise(), - pytest.raises(ValueError), # shape mismatch - pytest.raises(AssertionError), # we do want metadata - ), - # existing metdata - ( - "label", - does_not_raise(), - does_not_raise(), - does_not_raise(), - pytest.raises(AssertionError), # we do want metadata - pytest.raises(AssertionError), # we do want metadata - ), - ], -) -def test_iset_metadata_overlapping_names( - iset_meta, name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp -): - assert hasattr(iset_meta, name) - - # warning when set - with set_exp: - iset_meta.set_info({name: np.ones(4)}) - # error when set as attribute - with set_attr_exp: - setattr(iset_meta, name, np.ones(4)) - # error when set as key - with set_key_exp: - iset_meta[name] = np.ones(4) - # retrieve with get_info - assert np.all(iset_meta.get_info(name) == np.ones(4)) - # make sure it doesn't access metadata if its an existing attribute or key - with get_attr_exp: - assert np.all(getattr(iset_meta, name) == np.ones(4)) == False - # make sure it doesn't access metadata if its an existing key - with get_key_exp: - assert np.all(iset_meta[name] == np.ones(4)) == False From 5d1d177d4bd8da7033d27f8e495a53b0fd977e9f Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Tue, 5 Nov 2024 16:29:14 -0500 Subject: [PATCH 04/18] allow any name for metadata in tsgroup. print warning if replacing "rate" --- pynapple/core/interval_set.py | 3 ++- pynapple/core/ts_group.py | 22 +++++++++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 9a893ca2..b782ab5a 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -216,7 +216,8 @@ def __init__( self.nap_class = self.__class__.__name__ # initialize metadata to get all attributes before setting metadata _MetadataMixin.__init__(self) - self._class_attributes = self.__dir__() + self._class_attributes = self.__dir__() # get list of all attributes + self._class_attributes.append("_class_attributes") # add this property if drop_meta is False: self.set_info(metadata) self._initialized = True diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 0383f8a0..375ee4c7 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -189,6 +189,11 @@ def __init__( UserDict.__init__(self, data) + # grab current attributes before adding metadata + self._class_attributes = self.__dir__() + self._class_attributes.remove("rate") # remove rate metadata + self._class_attributes.append("_class_attributes") # add this property + # Making the TsGroup non mutable self._initialized = True @@ -207,7 +212,17 @@ def __init__( def __setattr__(self, name, value): # necessary setter to allow metadata to be set as an attribute if self._initialized: - _MetadataMixin.__setattr__(self, name, value) + if name in self._class_attributes: + raise AttributeError( + f"Cannot set attribute: '{name}' is a reserved attribute. Use 'set_info()' to set '{name}' as metadata." + ) + else: + if name == "rate": + warnings.warn( + "Replacing TsGroup rate with user-defined metadata.", + UserWarning, + ) + _MetadataMixin.__setattr__(self, name, value) else: object.__setattr__(self, name, value) @@ -216,6 +231,11 @@ def __setitem__(self, key, value): self._metadata.loc[int(key), "rate"] = float(value.rate) super().__setitem__(int(key), value) else: + if key == "rate": + warnings.warn( + "Replacing TsGroup rate with user-defined metadata.", + UserWarning, + ) _MetadataMixin.__setitem__(self, key, value) def __getitem__(self, key): From 0e394b1f2e470ffdad7510d3dee92daae41c7126 Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Tue, 5 Nov 2024 16:41:20 -0500 Subject: [PATCH 05/18] tsgroup tests for metadata with overlapping names --- pynapple/core/ts_group.py | 4 +-- tests/test_metadata.py | 57 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 375ee4c7..3e5427eb 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -219,7 +219,7 @@ def __setattr__(self, name, value): else: if name == "rate": warnings.warn( - "Replacing TsGroup rate with user-defined metadata.", + "Replacing TsGroup 'rate' with user-defined metadata.", UserWarning, ) _MetadataMixin.__setattr__(self, name, value) @@ -233,7 +233,7 @@ def __setitem__(self, key, value): else: if key == "rate": warnings.warn( - "Replacing TsGroup rate with user-defined metadata.", + "Replacing TsGroup 'rate' with user-defined metadata.", UserWarning, ) _MetadataMixin.__setitem__(self, key, value) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 13dc478e..94ef01c9 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -432,6 +432,63 @@ def test_tsdframe_metadata_slicing(tsdframe_meta): ############# ## TsGroup ## ############# +@pytest.fixture +def tsgroup_meta(): + return nap.TsGroup( + { + 0: nap.Ts(t=np.arange(0, 200)), + 1: nap.Ts(t=np.arange(0, 200, 0.5), time_units="s"), + 2: nap.Ts(t=np.arange(0, 300, 0.2), time_units="s"), + 3: nap.Ts(t=np.arange(0, 400, 1), time_units="s"), + } + ) + + +@pytest.mark.parametrize( + "name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp", + [ + # pre-computed rate metadata + ( + "rate", + does_not_raise(), + pytest.warns(UserWarning, match="Replacing TsGroup 'rate'"), + pytest.warns(UserWarning, match="Replacing TsGroup 'rate'"), + pytest.raises(AssertionError), # we do want metadata + pytest.raises(AssertionError), # we do want metadata + ), + # 'rates' attribute + ( + "rates", + pytest.warns(UserWarning, match="overlaps with an existing"), + pytest.raises(AttributeError, match="Cannot set attribute"), + does_not_raise(), + does_not_raise(), + pytest.raises(AssertionError), # we do want metadata + ), + ], +) +def test_tsgroup_metadata_overlapping_names( + tsgroup_meta, name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp +): + assert hasattr(tsgroup_meta, name) + + # warning when set + with set_exp: + tsgroup_meta.set_info({name: np.ones(4)}) + # error when set as attribute + with set_attr_exp: + setattr(tsgroup_meta, name, np.ones(4)) + # error when set as key + with set_key_exp: + tsgroup_meta[name] = np.ones(4) + # retrieve with get_info + assert np.all(tsgroup_meta.get_info(name) == np.ones(4)) + # make sure it doesn't access metadata if its an existing attribute or key + with get_attr_exp: + assert np.all(getattr(tsgroup_meta, name) == np.ones(4)) == False + # make sure it doesn't access metadata if its an existing key + with get_key_exp: + assert np.all(tsgroup_meta[name] == np.ones(4)) == False ################## From 5693cf2cfee30ab7522a24dc166476f4df4fb1f9 Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Wed, 6 Nov 2024 14:45:18 -0500 Subject: [PATCH 06/18] allow any metadata name for tsdframe and tests --- pynapple/core/metadata_class.py | 9 +- pynapple/core/time_series.py | 28 ++-- tests/test_metadata.py | 226 +++++++++++++++++++++++++------- 3 files changed, 205 insertions(+), 58 deletions(-) diff --git a/pynapple/core/metadata_class.py b/pynapple/core/metadata_class.py index 10579d44..9e17b774 100644 --- a/pynapple/core/metadata_class.py +++ b/pynapple/core/metadata_class.py @@ -21,7 +21,6 @@ def __init__(self, metadata=None): List of pandas.DataFrame **kwargs : dict Dictionary containing metadata information - """ if self.__class__.__name__ == "TsdFrame": # metadata index is the same as the columns for TsdFrame @@ -112,7 +111,7 @@ def _raise_invalid_metadata_column_name(self, name): f"Invalid metadata type {type(name)}. Metadata column names must be strings!" ) # warnings for metadata names that cannot be accessed as attributes or keys - if hasattr(self, name) and (name not in self.metadata_columns): + if name in self._class_attributes: # existing non-metadata attribute warnings.warn( f"Metadata name '{name}' overlaps with an existing attribute, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata." @@ -122,13 +121,17 @@ def _raise_invalid_metadata_column_name(self, name): warnings.warn( f"Metadata name '{name}' overlaps with an existing property, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata." ) + # elif name in self.metadata_columns: + # # warnings for metadata that already exists + # warnings.warn(f"Overwriting existing metadata column '{name}'.") + # warnings for metadata that cannot be accessed as attributes if name.replace("_", "").isalnum() is False: # contains invalid characters warnings.warn( f"Metadata name '{name}' contains a special character, and cannot be accessed as an attribute. Use 'get_info()' or key indexing to access metadata." ) - elif name[0].isalpha() is False: + elif (name[0].isalpha() is False) and (name[0] != "_"): # starts with a number warnings.warn( f"Metadata name '{name}' starts with a number, and cannot be accessed as an attribute. Use 'get_info()' or key indexing to access metadata." diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 49c6cd67..dadee775 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -938,7 +938,13 @@ def __init__( self.columns = pd.Index(c) self.nap_class = self.__class__.__name__ - _MetadataMixin.__init__(self, metadata) + # initialize metadata for class attributes + _MetadataMixin.__init__(self) + # get current list of attributes + self._class_attributes = self.__dir__() + self._class_attributes.append("_class_attributes") + # set metadata + self.set_info(metadata) self._initialized = True @property @@ -1034,7 +1040,12 @@ def __repr__(self): def __setattr__(self, name, value): # necessary setter to allow metadata to be set as an attribute if self._initialized: - _MetadataMixin.__setattr__(self, name, value) + if name in self._class_attributes: + raise AttributeError( + f"Cannot set attribute: '{name}' is a reserved attribute. Use 'set_info()' to set '{name}' as metadata." + ) + else: + _MetadataMixin.__setattr__(self, name, value) else: super().__setattr__(name, value) @@ -1069,17 +1080,18 @@ def __setitem__(self, key, value): raise IndexError def __getitem__(self, key, *args, **kwargs): - if isinstance(key, str) and (key in self.metadata_columns): - return _MetadataMixin.__getitem__(self, key) - elif ( + if ( isinstance(key, str) or hasattr(key, "__iter__") and all([isinstance(k, str) for k in key]) ): - if all(k in self.metadata_columns for k in key): - return _MetadataMixin.__getitem__(self, key) - else: + if all(k in self.columns for k in key): return self.loc[key] + else: + # if all(k in self.metadata_columns for k in key): + return _MetadataMixin.__getitem__(self, key) + # else: + # return self.loc[key] else: if isinstance(key, pd.Series) and key.index.equals(self.columns): diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 94ef01c9..28a6f926 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -319,38 +319,58 @@ def test_drop_metadata_warnings(iset_meta): # existing attribute and key ( "start", + # warn with set_info pytest.warns(UserWarning, match="overlaps with an existing"), + # error with setattr pytest.raises(AttributeError, match="IntervalSet is immutable"), + # error with setitem pytest.raises(RuntimeError, match="IntervalSet is immutable"), - does_not_raise(), - does_not_raise(), + # attr should not match metadata + pytest.raises(AssertionError), + # key should not match metadata + pytest.raises(AssertionError), ), # existing attribute and key ( "end", + # warn with set_info pytest.warns(UserWarning, match="overlaps with an existing"), + # error with setattr pytest.raises(AttributeError, match="IntervalSet is immutable"), + # error with setitem pytest.raises(RuntimeError, match="IntervalSet is immutable"), - does_not_raise(), - does_not_raise(), + # attr should not match metadata + pytest.raises(AssertionError), + # key should not match metadata + pytest.raises(AssertionError), ), # existing attribute ( "values", + # warn with set_info pytest.warns(UserWarning, match="overlaps with an existing"), + # error with setattr pytest.raises(AttributeError, match="IntervalSet is immutable"), - does_not_raise(), + # warn with setitem + pytest.warns(UserWarning, match="overlaps with an existing"), + # attr should not match metadata pytest.raises(ValueError), # shape mismatch - pytest.raises(AssertionError), # we do want metadata + # key should match metadata + does_not_raise(), ), # existing metdata ( "label", + # no warning with set_info + does_not_raise(), + # no warning with setattr does_not_raise(), + # no warning with setitem does_not_raise(), + # attr should match metadata + does_not_raise(), + # key should match metadata does_not_raise(), - pytest.raises(AssertionError), # we do want metadata - pytest.raises(AssertionError), # we do want metadata ), ], ) @@ -372,10 +392,10 @@ def test_iset_metadata_overlapping_names( assert np.all(iset_meta.get_info(name) == np.ones(4)) # make sure it doesn't access metadata if its an existing attribute or key with get_attr_exp: - assert np.all(getattr(iset_meta, name) == np.ones(4)) == False + assert np.all(getattr(iset_meta, name) == np.ones(4)) # make sure it doesn't access metadata if its an existing key with get_key_exp: - assert np.all(iset_meta[name] == np.ones(4)) == False + assert np.all(iset_meta[name] == np.ones(4)) ############## @@ -412,21 +432,88 @@ def test_tsdframe_metadata_slicing(tsdframe_meta): ) -# @pytest.mark.parametrize( -# "name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp", -# [ -# ( -# # invalid metadata names that are the same as column names -# "a", -# pytest.warns(UserWarning, match="overlaps with an existing"), -# ), -# ], -# ) -# def test_tsdframe_metadata_overlapping_names(tsdframe_meta, args, kwargs, expected): -# assert - -# with expected: -# tsdframe_meta.set_info(*args, **kwargs) +@pytest.mark.parametrize( + "name, attr_exp, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp", + [ + # pre-computed rate metadata + ( + "a", + # not attribute + pytest.raises(AssertionError), + # warn with set_info + pytest.warns(UserWarning, match="overlaps with an existing"), + # warn with setattr + pytest.warns(UserWarning, match="overlaps with an existing"), + # cannot be set as key + pytest.raises(ValueError), + # attribute should match metadata + does_not_raise(), + # key should not match metadata + pytest.raises(ValueError), + ), + ( + "columns", + # attribute exists + does_not_raise(), + # warn with set_info + pytest.warns(UserWarning, match="overlaps with an existing"), + # cannot be set as attribute + pytest.raises(AttributeError, match="Cannot set attribute"), + # warn when set as key + pytest.warns(UserWarning, match="overlaps with an existing"), + # attribute should not match metadata + pytest.raises(AssertionError), + # key should match metadata + does_not_raise(), + ), + # existing metdata + ( + "l1", + # attribute exists + does_not_raise(), + # no warning with set_info + does_not_raise(), + # no warning with setattr + does_not_raise(), + # no warning with setitem + does_not_raise(), + # attr should match metadata + does_not_raise(), + # key should match metadata + does_not_raise(), + ), + ], +) +def test_tsdframe_metadata_overlapping_names( + tsdframe_meta, + name, + attr_exp, + set_exp, + set_attr_exp, + set_key_exp, + get_attr_exp, + get_key_exp, +): + with attr_exp: + assert hasattr(tsdframe_meta, name) + # warning when set + with set_exp: + # warnings.simplefilter("error") + tsdframe_meta.set_info({name: np.ones(4)}) + # error when set as attribute + with set_attr_exp: + setattr(tsdframe_meta, name, np.ones(4)) + # error when set as key + with set_key_exp: + tsdframe_meta[name] = np.ones(4) + # retrieve with get_info + assert np.all(tsdframe_meta.get_info(name) == np.ones(4)) + # make sure it doesn't access metadata if its an existing attribute or key + with get_attr_exp: + assert np.all(getattr(tsdframe_meta, name) == np.ones(4)) + # make sure it doesn't access metadata if its an existing key + with get_key_exp: + assert np.all(tsdframe_meta[name] == np.ones(4)) ############# @@ -450,20 +537,30 @@ def tsgroup_meta(): # pre-computed rate metadata ( "rate", - does_not_raise(), + # no warning with set_info + warnings.catch_warnings(action="error"), + # warning with setattr pytest.warns(UserWarning, match="Replacing TsGroup 'rate'"), + # warning with setitem pytest.warns(UserWarning, match="Replacing TsGroup 'rate'"), - pytest.raises(AssertionError), # we do want metadata - pytest.raises(AssertionError), # we do want metadata + # get attribute is metadata + does_not_raise(), + # get key is metadata + does_not_raise(), ), # 'rates' attribute ( "rates", + # warning with set_info pytest.warns(UserWarning, match="overlaps with an existing"), + # error with setattr pytest.raises(AttributeError, match="Cannot set attribute"), + # warn with setitem + pytest.warns(UserWarning, match="overlaps with an existing"), + # get attribute is not metadata + pytest.raises(AssertionError), + # get key is metadata does_not_raise(), - does_not_raise(), - pytest.raises(AssertionError), # we do want metadata ), ], ) @@ -485,10 +582,10 @@ def test_tsgroup_metadata_overlapping_names( assert np.all(tsgroup_meta.get_info(name) == np.ones(4)) # make sure it doesn't access metadata if its an existing attribute or key with get_attr_exp: - assert np.all(getattr(tsgroup_meta, name) == np.ones(4)) == False + assert np.all(getattr(tsgroup_meta, name) == np.ones(4)) # make sure it doesn't access metadata if its an existing key with get_key_exp: - assert np.all(tsgroup_meta[name] == np.ones(4)) == False + assert np.all(tsgroup_meta[name] == np.ones(4)) ################## @@ -736,21 +833,21 @@ def test_add_metadata_error(self, obj, obj_len, args, kwargs, expected): with expected: obj.set_info(metadata, **kwargs) - # def test_add_metadata_key_error(self, obj, obj_len): - # # type specific key errors - # info = np.ones(obj_len) - # if isinstance(obj, nap.IntervalSet): - # with pytest.raises(RuntimeError, match="IntervalSet is immutable"): - # obj[0] = info - # with pytest.raises(RuntimeError, match="IntervalSet is immutable"): - # obj["start"] = info - # with pytest.raises(RuntimeError, match="IntervalSet is immutable"): - # obj["end"] = info - - # elif isinstance(obj, nap.TsGroup): - # # currently obj[0] does not raise an error for TsdFrame - # with pytest.raises(TypeError, match="Metadata keys must be strings!"): - # obj[0] = info + def test_add_metadata_key_error(self, obj, obj_len): + # type specific key errors + info = np.ones(obj_len) + if isinstance(obj, nap.IntervalSet): + with pytest.raises(RuntimeError, match="IntervalSet is immutable"): + obj[0] = info + with pytest.raises(RuntimeError, match="IntervalSet is immutable"): + obj["start"] = info + with pytest.raises(RuntimeError, match="IntervalSet is immutable"): + obj["end"] = info + + elif isinstance(obj, nap.TsGroup): + # currently obj[0] does not raise an error for TsdFrame + with pytest.raises(TypeError, match="Metadata keys must be strings!"): + obj[0] = info def test_overwrite_metadata(self, obj, obj_len): # add metadata @@ -766,6 +863,41 @@ def test_overwrite_metadata(self, obj, obj_len): obj.label = [4] * obj_len assert np.all(obj.label == 4) + # test naming overlap of shared attributes + @pytest.mark.parametrize( + "name", + [ + "set_info", + "_metadata", + "_class_attributes", + ], + ) + def test_metadata_overlapping_names(self, obj, obj_len, name): + values = np.ones(obj_len) + + # set some metadata to force assertion error with "_metadata" case + obj.set_info(label=values) + + # assert attribute exists + assert hasattr(obj, name) + + # warning when set + with pytest.warns(UserWarning, match="overlaps with an existing"): + obj.set_info({name: values}) + # error when set as attribute + with pytest.raises(AttributeError, match="Cannot set attribute"): + setattr(obj, name, values) + # warning when set as key + with pytest.warns(UserWarning, match="overlaps with an existing"): + obj[name] = values + # retrieve with get_info + assert np.all(obj.get_info(name) == values) + # make sure it doesn't access metadata if its an existing attribute or key + with pytest.raises((AssertionError, ValueError)): + assert np.all(getattr(obj, name) == values) + # access metadata as key + assert np.all(obj[name] == values) + @pytest.mark.parametrize("label, val", [([1, 1, 2, 2], 2)]) def test_metadata_slicing(self, obj, label, val, obj_len): # slicing not relevant for length 1 objects From 1128fa548811474ff63661c464475681da385306 Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Thu, 7 Nov 2024 11:24:50 -0500 Subject: [PATCH 07/18] deprecation warning for loc for tsdframe. disallow metadata names that overlap with data column names for tsdframe --- pynapple/core/metadata_class.py | 14 +++++++--- pynapple/core/time_series.py | 13 ++++++---- tests/test_metadata.py | 45 +++++++++++++++++++++++++-------- tests/test_time_series.py | 11 ++++++++ 4 files changed, 63 insertions(+), 20 deletions(-) diff --git a/pynapple/core/metadata_class.py b/pynapple/core/metadata_class.py index 9e17b774..a7fb1a7b 100644 --- a/pynapple/core/metadata_class.py +++ b/pynapple/core/metadata_class.py @@ -117,10 +117,16 @@ def _raise_invalid_metadata_column_name(self, name): f"Metadata name '{name}' overlaps with an existing attribute, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata." ) elif hasattr(self, "columns") and name in self.columns: - # existing non-metadata attribute - warnings.warn( - f"Metadata name '{name}' overlaps with an existing property, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata." - ) + if self.nap_class == "TsdFrame": + # special exception for TsdFrame columns + raise ValueError( + f"Invalid metadata name '{name}'. Metadata name must differ from {list(self.columns)} column names!" + ) + else: + # existing non-metadata attribute + warnings.warn( + f"Metadata name '{name}' overlaps with an existing property, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata." + ) # elif name in self.metadata_columns: # # warnings for metadata that already exists # warnings.warn(f"Overwriting existing metadata column '{name}'.") diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index dadee775..8c0a8fc5 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -949,6 +949,11 @@ def __init__( @property def loc(self): + # add deprecation warning + warnings.warn( + "'loc' will be deprecated in a future version. Use bracket indexing instead.", + DeprecationWarning, + ) return _TsdFrameSliceHelper(self) def __repr__(self): @@ -1086,13 +1091,11 @@ def __getitem__(self, key, *args, **kwargs): and all([isinstance(k, str) for k in key]) ): if all(k in self.columns for k in key): - return self.loc[key] + with warnings.catch_warnings(action="ignore"): + # ignore deprecated warning for loc + return self.loc[key] else: - # if all(k in self.metadata_columns for k in key): return _MetadataMixin.__getitem__(self, key) - # else: - # return self.loc[key] - else: if isinstance(key, pd.Series) and key.index.equals(self.columns): # if indexing with a pd.Series from metadata, transform it to tuple with slice(None) in first position diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 28a6f926..84570576 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -433,21 +433,23 @@ def test_tsdframe_metadata_slicing(tsdframe_meta): @pytest.mark.parametrize( - "name, attr_exp, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp", + "name, attr_exp, set_exp, set_attr_exp, set_key_exp, get_exp, get_attr_exp, get_key_exp", [ - # pre-computed rate metadata + # existing data column ( "a", # not attribute pytest.raises(AssertionError), - # warn with set_info - pytest.warns(UserWarning, match="overlaps with an existing"), - # warn with setattr - pytest.warns(UserWarning, match="overlaps with an existing"), - # cannot be set as key + # error with set_info + pytest.raises(ValueError, match="Invalid metadata name"), + # error with setattr + pytest.raises(ValueError, match="Invalid metadata name"), + # shape mismatch with setitem pytest.raises(ValueError), - # attribute should match metadata - does_not_raise(), + # key error with get_info + pytest.raises(KeyError), + # attribute should raise error + pytest.raises(AttributeError), # key should not match metadata pytest.raises(ValueError), ), @@ -461,6 +463,8 @@ def test_tsdframe_metadata_slicing(tsdframe_meta): pytest.raises(AttributeError, match="Cannot set attribute"), # warn when set as key pytest.warns(UserWarning, match="overlaps with an existing"), + # no error with get_info + does_not_raise(), # attribute should not match metadata pytest.raises(AssertionError), # key should match metadata @@ -477,6 +481,8 @@ def test_tsdframe_metadata_slicing(tsdframe_meta): does_not_raise(), # no warning with setitem does_not_raise(), + # no error with get_info + does_not_raise(), # attr should match metadata does_not_raise(), # key should match metadata @@ -490,6 +496,7 @@ def test_tsdframe_metadata_overlapping_names( attr_exp, set_exp, set_attr_exp, + get_exp, set_key_exp, get_attr_exp, get_key_exp, @@ -507,7 +514,8 @@ def test_tsdframe_metadata_overlapping_names( with set_key_exp: tsdframe_meta[name] = np.ones(4) # retrieve with get_info - assert np.all(tsdframe_meta.get_info(name) == np.ones(4)) + with get_exp: + assert np.all(tsdframe_meta.get_info(name) == np.ones(4)) # make sure it doesn't access metadata if its an existing attribute or key with get_attr_exp: assert np.all(getattr(tsdframe_meta, name) == np.ones(4)) @@ -527,7 +535,8 @@ def tsgroup_meta(): 1: nap.Ts(t=np.arange(0, 200, 0.5), time_units="s"), 2: nap.Ts(t=np.arange(0, 300, 0.2), time_units="s"), 3: nap.Ts(t=np.arange(0, 400, 1), time_units="s"), - } + }, + metadata={"label": [1, 2, 3, 4]}, ) @@ -562,6 +571,20 @@ def tsgroup_meta(): # get key is metadata does_not_raise(), ), + # existing metdata + ( + "label", + # no warning with set_info + does_not_raise(), + # no warning with setattr + does_not_raise(), + # no warning with setitem + does_not_raise(), + # attr should match metadata + does_not_raise(), + # key should match metadata + does_not_raise(), + ), ], ) def test_tsgroup_metadata_overlapping_names( diff --git a/tests/test_time_series.py b/tests/test_time_series.py index ac8956c3..f25f479c 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -1,6 +1,7 @@ """Tests of time series for `pynapple` package.""" from numbers import Number +import warnings import pickle import numpy as np @@ -1335,6 +1336,16 @@ def test_convolve_keep_columns(self, tsdframe): assert isinstance(tsd2, nap.TsdFrame) np.testing.assert_array_equal(tsd2.columns, tsdframe.columns) + def test_deprecation_warning(self, tsdframe): + columns = tsdframe.columns + # warning using loc + with pytest.warns(DeprecationWarning): + tsdframe.loc[columns[0]] + if isinstance(columns[0], str): + # suppressed warning with getitem, which implicitly uses loc + with warnings.catch_warnings(action="error"): + tsdframe[columns[0]] + #################################################### # Test for ts From a9ae984045582f160b207fb83e44230f2391f762 Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Thu, 7 Nov 2024 11:35:47 -0500 Subject: [PATCH 08/18] prohibit setting of 'rate' for TsGroup --- pynapple/core/metadata_class.py | 16 ++++++++----- pynapple/core/ts_group.py | 3 +-- tests/test_metadata.py | 40 ++++++++++++++++++++++----------- 3 files changed, 39 insertions(+), 20 deletions(-) diff --git a/pynapple/core/metadata_class.py b/pynapple/core/metadata_class.py index a7fb1a7b..da67d7c8 100644 --- a/pynapple/core/metadata_class.py +++ b/pynapple/core/metadata_class.py @@ -112,10 +112,16 @@ def _raise_invalid_metadata_column_name(self, name): ) # warnings for metadata names that cannot be accessed as attributes or keys if name in self._class_attributes: - # existing non-metadata attribute - warnings.warn( - f"Metadata name '{name}' overlaps with an existing attribute, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata." - ) + if (self.nap_class == "TsGroup") and (name == "rate"): + # special exception for TsGroup rate attribute + raise ValueError( + f"Invalid metadata name '{name}'. Cannot overwrite TsGroup 'rate'!" + ) + else: + # existing non-metadata attribute + warnings.warn( + f"Metadata name '{name}' overlaps with an existing attribute, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata." + ) elif hasattr(self, "columns") and name in self.columns: if self.nap_class == "TsdFrame": # special exception for TsdFrame columns @@ -123,7 +129,7 @@ def _raise_invalid_metadata_column_name(self, name): f"Invalid metadata name '{name}'. Metadata name must differ from {list(self.columns)} column names!" ) else: - # existing non-metadata attribute + # existing non-metadata column warnings.warn( f"Metadata name '{name}' overlaps with an existing property, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata." ) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 3e5427eb..b01c60f4 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -188,10 +188,9 @@ def __init__( data = {k: data[k].restrict(self.time_support) for k in self.index} UserDict.__init__(self, data) - + self.nap_class = self.__class__.__name__ # grab current attributes before adding metadata self._class_attributes = self.__dir__() - self._class_attributes.remove("rate") # remove rate metadata self._class_attributes.append("_class_attributes") # add this property # Making the TsGroup non mutable diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 84570576..590dab6a 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -541,21 +541,23 @@ def tsgroup_meta(): @pytest.mark.parametrize( - "name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp", + "name, set_exp, set_attr_exp, set_key_exp, get_exp, get_attr_exp, get_key_exp", [ # pre-computed rate metadata ( "rate", - # no warning with set_info - warnings.catch_warnings(action="error"), - # warning with setattr - pytest.warns(UserWarning, match="Replacing TsGroup 'rate'"), - # warning with setitem - pytest.warns(UserWarning, match="Replacing TsGroup 'rate'"), - # get attribute is metadata - does_not_raise(), - # get key is metadata - does_not_raise(), + # error with set_info + pytest.raises(ValueError, match="Invalid metadata name"), + # error with setattr + pytest.raises(AttributeError, match="Cannot set attribute"), + # error with setitem + pytest.raises(ValueError, match="Invalid metadata name"), + # value mismatch with get_info + pytest.raises(AssertionError), + # value mismatch with getattr + pytest.raises(AssertionError), + # value mismatch with getitem + pytest.raises(AssertionError), ), # 'rates' attribute ( @@ -566,6 +568,8 @@ def tsgroup_meta(): pytest.raises(AttributeError, match="Cannot set attribute"), # warn with setitem pytest.warns(UserWarning, match="overlaps with an existing"), + # no error with get_info + does_not_raise(), # get attribute is not metadata pytest.raises(AssertionError), # get key is metadata @@ -580,6 +584,8 @@ def tsgroup_meta(): does_not_raise(), # no warning with setitem does_not_raise(), + # no error with get_info + does_not_raise(), # attr should match metadata does_not_raise(), # key should match metadata @@ -588,7 +594,14 @@ def tsgroup_meta(): ], ) def test_tsgroup_metadata_overlapping_names( - tsgroup_meta, name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp + tsgroup_meta, + name, + set_exp, + set_attr_exp, + set_key_exp, + get_exp, + get_attr_exp, + get_key_exp, ): assert hasattr(tsgroup_meta, name) @@ -602,7 +615,8 @@ def test_tsgroup_metadata_overlapping_names( with set_key_exp: tsgroup_meta[name] = np.ones(4) # retrieve with get_info - assert np.all(tsgroup_meta.get_info(name) == np.ones(4)) + with get_exp: + assert np.all(tsgroup_meta.get_info(name) == np.ones(4)) # make sure it doesn't access metadata if its an existing attribute or key with get_attr_exp: assert np.all(getattr(tsgroup_meta, name) == np.ones(4)) From 0c126303e2936089d2e9fcec95dc2b1724913dc3 Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Thu, 7 Nov 2024 11:43:37 -0500 Subject: [PATCH 09/18] remove redundant test --- tests/test_ts_group.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/test_ts_group.py b/tests/test_ts_group.py index 6d41b46b..c58b0c95 100644 --- a/tests/test_ts_group.py +++ b/tests/test_ts_group.py @@ -628,12 +628,6 @@ def test_setitem_metadata_twice(self, group): group["a"] = np.arange(len(group)) + 10 assert all(group._metadata["a"] == np.arange(len(group)) + 10) - def test_prevent_overwriting_existing_methods(self, ts_group): - # with pytest.raises(ValueError, match=r"Invalid metadata name"): - # ts_group["set_info"] = np.arange(2) - with pytest.warns(UserWarning, match=r"overlaps with an existing"): - ts_group["set_info"] = np.arange(2) - def test_getitem_ts_object(self, ts_group): assert isinstance(ts_group[1], nap.Ts) From f5f15e4292ce213e3750d935e1dfb96ee70930d9 Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Thu, 7 Nov 2024 11:56:34 -0500 Subject: [PATCH 10/18] fix sort --- tests/test_time_series.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 8be5c2df..c20d9578 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -1,9 +1,7 @@ """Tests of time series for `pynapple` package.""" -from numbers import Number -import warnings - import pickle +import warnings from contextlib import nullcontext as does_not_raise from numbers import Number from pathlib import Path From bc506d324b76365241fcb4b6305eea1b064a761b Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Thu, 7 Nov 2024 12:09:32 -0500 Subject: [PATCH 11/18] update use of catch_warnings for backwards compatibility --- pynapple/core/time_series.py | 3 ++- tests/test_time_series.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 72b1b234..7c49d773 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -1118,8 +1118,9 @@ def __getitem__(self, key, *args, **kwargs): and all([isinstance(k, str) for k in key]) ): if all(k in self.columns for k in key): - with warnings.catch_warnings(action="ignore"): + with warnings.catch_warnings(): # ignore deprecated warning for loc + warnings.simplefilter("ignore") return self.loc[key] else: return _MetadataMixin.__getitem__(self, key) diff --git a/tests/test_time_series.py b/tests/test_time_series.py index c20d9578..4c7f05af 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -1389,7 +1389,8 @@ def test_deprecation_warning(self, tsdframe): tsdframe.loc[columns[0]] if isinstance(columns[0], str): # suppressed warning with getitem, which implicitly uses loc - with warnings.catch_warnings(action="error"): + with warnings.catch_warnings(): + warnings.simplefilter("error") tsdframe[columns[0]] From e7319f423b281bb5a81f1be9955a178d0a0102a3 Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Thu, 7 Nov 2024 12:38:08 -0500 Subject: [PATCH 12/18] remove unnecessary warnings, fix tests for python 3.8 --- pynapple/core/ts_group.py | 10 ---------- tests/test_metadata.py | 32 ++++++++++++++++---------------- 2 files changed, 16 insertions(+), 26 deletions(-) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index b01c60f4..a29297e9 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -216,11 +216,6 @@ def __setattr__(self, name, value): f"Cannot set attribute: '{name}' is a reserved attribute. Use 'set_info()' to set '{name}' as metadata." ) else: - if name == "rate": - warnings.warn( - "Replacing TsGroup 'rate' with user-defined metadata.", - UserWarning, - ) _MetadataMixin.__setattr__(self, name, value) else: object.__setattr__(self, name, value) @@ -230,11 +225,6 @@ def __setitem__(self, key, value): self._metadata.loc[int(key), "rate"] = float(value.rate) super().__setitem__(int(key), value) else: - if key == "rate": - warnings.warn( - "Replacing TsGroup 'rate' with user-defined metadata.", - UserWarning, - ) _MetadataMixin.__setitem__(self, key, value) def __getitem__(self, key): diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 197ab2fc..8200b2ed 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -353,7 +353,7 @@ def test_drop_metadata_warnings(iset_meta): # warn with setitem pytest.warns(UserWarning, match="overlaps with an existing"), # attr should not match metadata - pytest.raises(ValueError), # shape mismatch + pytest.raises(AssertionError), # key should match metadata does_not_raise(), ), @@ -388,13 +388,13 @@ def test_iset_metadata_overlapping_names( with set_key_exp: iset_meta[name] = np.ones(4) # retrieve with get_info - assert np.all(iset_meta.get_info(name) == np.ones(4)) + np.testing.assert_array_almost_equal(iset_meta.get_info(name), np.ones(4)) # make sure it doesn't access metadata if its an existing attribute or key with get_attr_exp: - assert np.all(getattr(iset_meta, name) == np.ones(4)) + np.testing.assert_array_almost_equal(getattr(iset_meta, name), np.ones(4)) # make sure it doesn't access metadata if its an existing key with get_key_exp: - assert np.all(iset_meta[name] == np.ones(4)) + np.testing.assert_array_almost_equal(iset_meta[name], np.ones(4)) ############## @@ -450,7 +450,7 @@ def test_tsdframe_metadata_slicing(tsdframe_meta): # attribute should raise error pytest.raises(AttributeError), # key should not match metadata - pytest.raises(ValueError), + pytest.raises(AssertionError), ), ( "columns", @@ -465,7 +465,7 @@ def test_tsdframe_metadata_slicing(tsdframe_meta): # no error with get_info does_not_raise(), # attribute should not match metadata - pytest.raises(AssertionError), + pytest.raises(TypeError), # key should match metadata does_not_raise(), ), @@ -514,13 +514,13 @@ def test_tsdframe_metadata_overlapping_names( tsdframe_meta[name] = np.ones(4) # retrieve with get_info with get_exp: - assert np.all(tsdframe_meta.get_info(name) == np.ones(4)) + np.testing.assert_array_almost_equal(tsdframe_meta.get_info(name), np.ones(4)) # make sure it doesn't access metadata if its an existing attribute or key with get_attr_exp: - assert np.all(getattr(tsdframe_meta, name) == np.ones(4)) + np.testing.assert_array_almost_equal(getattr(tsdframe_meta, name), np.ones(4)) # make sure it doesn't access metadata if its an existing key with get_key_exp: - assert np.all(tsdframe_meta[name] == np.ones(4)) + np.testing.assert_array_almost_equal(tsdframe_meta[name], np.ones(4)) ############# @@ -615,13 +615,13 @@ def test_tsgroup_metadata_overlapping_names( tsgroup_meta[name] = np.ones(4) # retrieve with get_info with get_exp: - assert np.all(tsgroup_meta.get_info(name) == np.ones(4)) + np.testing.assert_array_almost_equal(tsgroup_meta.get_info(name), np.ones(4)) # make sure it doesn't access metadata if its an existing attribute or key with get_attr_exp: - assert np.all(getattr(tsgroup_meta, name) == np.ones(4)) + np.testing.assert_array_almost_equal(getattr(tsgroup_meta, name), np.ones(4)) # make sure it doesn't access metadata if its an existing key with get_key_exp: - assert np.all(tsgroup_meta[name] == np.ones(4)) + np.testing.assert_array_almost_equal(tsgroup_meta[name], np.ones(4)) ################## @@ -927,12 +927,12 @@ def test_metadata_overlapping_names(self, obj, obj_len, name): with pytest.warns(UserWarning, match="overlaps with an existing"): obj[name] = values # retrieve with get_info - assert np.all(obj.get_info(name) == values) + np.testing.assert_array_almost_equal(obj.get_info(name), values) # make sure it doesn't access metadata if its an existing attribute or key - with pytest.raises((AssertionError, ValueError)): - assert np.all(getattr(obj, name) == values) + with pytest.raises((AssertionError, ValueError, TypeError)): + np.testing.assert_array_almost_equal(getattr(obj, name), values) # access metadata as key - assert np.all(obj[name] == values) + np.testing.assert_array_almost_equal(obj[name], values) @pytest.mark.parametrize("label, val", [([1, 1, 2, 2], 2)]) def test_metadata_slicing(self, obj, label, val, obj_len): From d7ff683807d123f9a73cbd911e8f2b3a87889eca Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 7 Nov 2024 14:38:13 -0500 Subject: [PATCH 13/18] Fixing repr --- pynapple/core/interval_set.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index b782ab5a..93f138e6 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -231,7 +231,14 @@ def __repr__(self): # By default, the first three columns should always show. # Adding an extra column between actual values and metadata - col_names = self._metadata.columns + try: + metadata = self._metadata + col_names = metadata.columns + except: + # Necessary for backward compatibility when saving IntervalSet as pickle + metadata = pd.DataFrame(index=self.index) + col_names = [] + headers = ["index", "start", "end"] if len(col_names): headers += [""] + [c for c in col_names] @@ -251,7 +258,7 @@ def __repr__(self): self.index[0:n_rows, None], self.values[0:n_rows], separator, - self._metadata.values[0:n_rows], + metadata.values[0:n_rows], ), dtype=object, ), @@ -261,7 +268,7 @@ def __repr__(self): self.index[-n_rows:, None], self.values[0:n_rows], separator, - self._metadata.values[-n_rows:], + metadata.values[-n_rows:], ), dtype=object, ), @@ -273,7 +280,7 @@ def __repr__(self): else: separator = np.empty((len(self), 0)) data = np.hstack( - (self.index[:, None], self.values, separator, self._metadata.values), + (self.index[:, None], self.values, separator, metadata.values), dtype=object, ) From cf120d2d9e417d17c460da6c2736030680b42b98 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 7 Nov 2024 15:59:06 -0500 Subject: [PATCH 14/18] Trying to fix pickling --- pynapple/core/interval_set.py | 20 +++++++++++++++++++- pynapple/core/time_series.py | 10 +++++++++- pynapple/core/utils.py | 12 +++++++----- tests/test_time_series.py | 2 +- 4 files changed, 36 insertions(+), 8 deletions(-) diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 93f138e6..9d9612f5 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -304,6 +304,20 @@ def __setattr__(self, name, value): else: object.__setattr__(self, name, value) + def __getattr__(self, name): + # Necessary for backward compatibility with pickle + try: + metadata = self._metadata + except: + metadata = pd.DataFrame(index=self.index) + + if name == "_metadata": + return metadata + elif name in metadata.columns: + return _MetadataMixin.__getattr__(self, name) + else: + return super().__getattr__(name) + def __setitem__(self, key, value): if key in self.columns: raise RuntimeError( @@ -317,6 +331,11 @@ def __setitem__(self, key, value): ) def __getitem__(self, key): + try: + metadata = _MetadataMixin.__getitem__(self, key) + except: + metadata = pd.DataFrame(index=self.index) + if isinstance(key, str): # self[str] if key == "start": @@ -341,7 +360,6 @@ def __getitem__(self, key): elif isinstance(key, Number): # self[Number] output = self.values.__getitem__(key) - metadata = _MetadataMixin.__getitem__(self, key) return IntervalSet(start=output[0], end=output[1], metadata=metadata) elif isinstance(key, (slice, list, np.ndarray, pd.Series)): # self[array_like] diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 7c49d773..8b7966ca 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -1072,7 +1072,15 @@ def __getattr__(self, name): # self._metadata.column having attributes '__reduce__', '__reduce_ex__' if name in ("__getstate__", "__setstate__", "__reduce__", "__reduce_ex__"): raise AttributeError(name) - if name in self._metadata.columns: + + try: + metadata = self._metadata + except: + metadata = pd.DataFrame(index=self.columns) + + if name == "_metadata": + return metadata + elif name in metadata.columns: return _MetadataMixin.__getattr__(self, name) else: return super().__getattr__(name) diff --git a/pynapple/core/utils.py b/pynapple/core/utils.py index 84293949..40108bfb 100644 --- a/pynapple/core/utils.py +++ b/pynapple/core/utils.py @@ -384,7 +384,12 @@ def __getitem__(self, key): IndexError """ - if key in ["start", "end"] + self.intervalset.metadata_columns: + # Pickle backward compatibility + try: + metadata_columns = self.intervalset.metadata_columns + except: + metadata_columns = [] + if key in ["start", "end"] + metadata_columns: return self.intervalset[key] elif isinstance(key, list): return self.intervalset[key] @@ -393,10 +398,7 @@ def __getitem__(self, key): else: if isinstance(key, tuple): if len(key) == 2: - if ( - key[1] - not in ["start", "end"] + self.intervalset.metadata_columns - ): + if key[1] not in ["start", "end"] + metadata_columns: raise IndexError out = self.intervalset[key[0]][key[1]] if len(out) == 1: diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 4c7f05af..7b810f79 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -1908,7 +1908,7 @@ def test_pickling(obj): # Ensure time support is the same assert np.all(obj.time_support == unpickled_obj.time_support) - +# #################################################### # Test for slicing From 07a1d69f503710644e96c24e8745c8c797749e9b Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 7 Nov 2024 16:23:36 -0500 Subject: [PATCH 15/18] fixing pickilong --- pynapple/core/interval_set.py | 12 +++++++++--- pynapple/core/time_series.py | 2 +- pynapple/core/utils.py | 2 +- tests/test_time_series.py | 2 ++ 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 9d9612f5..784f2126 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -234,7 +234,7 @@ def __repr__(self): try: metadata = self._metadata col_names = metadata.columns - except: + except Exception: # Necessary for backward compatibility when saving IntervalSet as pickle metadata = pd.DataFrame(index=self.index) col_names = [] @@ -306,9 +306,15 @@ def __setattr__(self, name, value): def __getattr__(self, name): # Necessary for backward compatibility with pickle + + # avoid infinite recursion when pickling due to + # self._metadata.column having attributes '__reduce__', '__reduce_ex__' + if name in ("__getstate__", "__setstate__", "__reduce__", "__reduce_ex__"): + raise AttributeError(name) + try: metadata = self._metadata - except: + except Exception: metadata = pd.DataFrame(index=self.index) if name == "_metadata": @@ -333,7 +339,7 @@ def __setitem__(self, key, value): def __getitem__(self, key): try: metadata = _MetadataMixin.__getitem__(self, key) - except: + except Exception: metadata = pd.DataFrame(index=self.index) if isinstance(key, str): diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 8b7966ca..22450c2b 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -1075,7 +1075,7 @@ def __getattr__(self, name): try: metadata = self._metadata - except: + except (AttributeError, RecursionError): metadata = pd.DataFrame(index=self.columns) if name == "_metadata": diff --git a/pynapple/core/utils.py b/pynapple/core/utils.py index 40108bfb..96afc6f6 100644 --- a/pynapple/core/utils.py +++ b/pynapple/core/utils.py @@ -387,7 +387,7 @@ def __getitem__(self, key): # Pickle backward compatibility try: metadata_columns = self.intervalset.metadata_columns - except: + except Exception: metadata_columns = [] if key in ["start", "end"] + metadata_columns: return self.intervalset[key] diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 7b810f79..e1064da9 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -1908,6 +1908,8 @@ def test_pickling(obj): # Ensure time support is the same assert np.all(obj.time_support == unpickled_obj.time_support) + + # #################################################### From 0731e41dcabc41834311fdbafb8f69d59de61f4b Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Fri, 8 Nov 2024 12:44:56 -0500 Subject: [PATCH 16/18] test for tsgroup future warnings, and warnings for names that cannot be accessed as an attribute --- tests/test_metadata.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 8200b2ed..acde65ee 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -624,6 +624,19 @@ def test_tsgroup_metadata_overlapping_names( np.testing.assert_array_almost_equal(tsgroup_meta[name], np.ones(4)) +def test_tsgroup_metadata_future_warnings(): + with pytest.warns(FutureWarning, match="may be unsupported"): + tsgroup = nap.TsGroup( + { + 0: nap.Ts(t=np.arange(0, 200)), + 1: nap.Ts(t=np.arange(0, 200, 0.5), time_units="s"), + 2: nap.Ts(t=np.arange(0, 300, 0.2), time_units="s"), + 3: nap.Ts(t=np.arange(0, 400, 1), time_units="s"), + }, + label=[1, 2, 3, 4], + ) + + ################## ## Shared tests ## ################## @@ -934,6 +947,27 @@ def test_metadata_overlapping_names(self, obj, obj_len, name): # access metadata as key np.testing.assert_array_almost_equal(obj[name], values) + # test metadata that can only be accessed as key + @pytest.mark.parametrize( + "name", + [ + "l.1", + "l 1", + "0", + ], + ) + def test_metadata_nonattribute_names(self, obj, obj_len, name): + values = np.ones(obj_len) + + # set some metadata to force assertion error with "_metadata" case + with pytest.warns(UserWarning, match="cannot be accessed as an attribute"): + obj.set_info({name: values}) + + # make sure it can be accessed with get_info + np.testing.assert_array_almost_equal(obj.get_info(name), values) + # make sure it can be accessed as key + np.testing.assert_array_almost_equal(obj[name], values) + @pytest.mark.parametrize("label, val", [([1, 1, 2, 2], 2)]) def test_metadata_slicing(self, obj, label, val, obj_len): # slicing not relevant for length 1 objects From c9484c6b7179972910fe20e86d471d0ecf611a9a Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 14 Nov 2024 12:51:13 -0500 Subject: [PATCH 17/18] Fixing some repr issues --- pynapple/core/interval_set.py | 12 ++++++++--- pynapple/core/time_series.py | 5 ++++- pynapple/core/ts_group.py | 39 +++++++++++++++++++++++++++++++---- pynapple/core/utils.py | 16 ++++++++++++++ pynapple/io/interface_nwb.py | 39 +++++++++++++++++++++++++++++------ 5 files changed, 97 insertions(+), 14 deletions(-) diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 784f2126..ceab66b1 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -22,6 +22,7 @@ from .metadata_class import _MetadataMixin from .time_index import TsIndex from .utils import ( + _convert_iter_to_str, _get_terminal_size, _IntervalSetSliceHelper, check_filename, @@ -258,7 +259,7 @@ def __repr__(self): self.index[0:n_rows, None], self.values[0:n_rows], separator, - metadata.values[0:n_rows], + _convert_iter_to_str(metadata.values[0:n_rows]), ), dtype=object, ), @@ -268,7 +269,7 @@ def __repr__(self): self.index[-n_rows:, None], self.values[0:n_rows], separator, - metadata.values[-n_rows:], + _convert_iter_to_str(metadata.values[-n_rows:]), ), dtype=object, ), @@ -280,7 +281,12 @@ def __repr__(self): else: separator = np.empty((len(self), 0)) data = np.hstack( - (self.index[:, None], self.values, separator, metadata.values), + ( + self.index[:, None], + self.values, + separator, + _convert_iter_to_str(metadata.values), + ), dtype=object, ) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 22450c2b..389432b4 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -33,6 +33,7 @@ from .time_index import TsIndex from .utils import ( _concatenate_tsd, + _convert_iter_to_str, _get_terminal_size, _split_tsd, _TsdFrameSliceHelper, @@ -1041,7 +1042,9 @@ def __repr__(self): np.hstack( ( col_names[:, None], - self._metadata.values[0:max_cols].T, + _convert_iter_to_str( + self._metadata.values[0:max_cols].T + ), ends, ), dtype=object, diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index a29297e9..270f9532 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -22,7 +22,12 @@ from .metadata_class import _MetadataMixin from .time_index import TsIndex from .time_series import Ts, Tsd, TsdFrame, _BaseTsd, is_array_like -from .utils import _get_terminal_size, check_filename, convert_to_numpy_array +from .utils import ( + _convert_iter_to_str, + _get_terminal_size, + check_filename, + convert_to_numpy_array, +) def _union_intervals(i_sets): @@ -220,6 +225,26 @@ def __setattr__(self, name, value): else: object.__setattr__(self, name, value) + def __getattr__(self, name): + # Necessary for backward compatibility with pickle + + # avoid infinite recursion when pickling due to + # self._metadata.column having attributes '__reduce__', '__reduce_ex__' + if name in ("__getstate__", "__setstate__", "__reduce__", "__reduce_ex__"): + raise AttributeError(name) + + try: + metadata = self._metadata + except Exception: + metadata = pd.DataFrame(index=self.index) + + if name == "_metadata": + return metadata + elif name in metadata.columns: + return _MetadataMixin.__getattr__(self, name) + else: + return super().__getattr__(name) + def __setitem__(self, key, value): if not self._initialized: self._metadata.loc[int(key), "rate"] = float(value.rate) @@ -294,7 +319,9 @@ def __repr__(self): ( self.index[0:n_rows, None], np.round(self._metadata[["rate"]].values[0:n_rows], 5), - self._metadata[col_names].values[0:n_rows, 0:max_cols], + _convert_iter_to_str( + self._metadata[col_names].values[0:n_rows, 0:max_cols] + ), ends, ), dtype=object, @@ -307,7 +334,9 @@ def __repr__(self): ( self.index[-n_rows:, None], np.round(self._metadata[["rate"]].values[-n_rows:], 5), - self._metadata[col_names].values[-n_rows:, 0:max_cols], + _convert_iter_to_str( + self._metadata[col_names].values[-n_rows:, 0:max_cols] + ), ends, ), dtype=object, @@ -320,7 +349,9 @@ def __repr__(self): ( self.index[:, None], np.round(self._metadata[["rate"]].values, 5), - self._metadata[col_names].values[:, 0:max_cols], + _convert_iter_to_str( + self._metadata[col_names].values[:, 0:max_cols] + ), ends, ), dtype=object, diff --git a/pynapple/core/utils.py b/pynapple/core/utils.py index 96afc6f6..eb8410be 100644 --- a/pynapple/core/utils.py +++ b/pynapple/core/utils.py @@ -441,3 +441,19 @@ def check_filename(filename): raise RuntimeError("Path {} does not exist.".format(parent_folder)) return filename + + +def _convert_iter_to_str(array): + """ + This function converts an array of arrays to array of strings. + This help avoids a DeprecationWarning from numpy when printing an object with metadata + """ + try: + shape = array.shape + array = array.flatten() + for i in range(len(array)): + if isinstance(array[i], np.ndarray): + array[i] = np.array2string(array[i]) + return array.reshape(shape) + except Exception: + return array diff --git a/pynapple/io/interface_nwb.py b/pynapple/io/interface_nwb.py index d6b941fc..7fc30f33 100644 --- a/pynapple/io/interface_nwb.py +++ b/pynapple/io/interface_nwb.py @@ -1,9 +1,3 @@ -# -*- coding: utf-8 -*- -# @Author: Guillaume Viejo -# @Date: 2023-08-01 11:54:45 -# @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-05-21 15:28:27 - """ Pynapple class to interface with NWB files. Data are always lazy-loaded. @@ -458,3 +452,36 @@ def __getitem__(self, key): def close(self): """Close the NWB file""" self.io.close() + + def keys(self): + """ + Return keys of NWBFile + + Returns + ------- + list + List of keys + """ + return list(self.data.keys()) + + def items(self): + """ + Return a list of key/object. + + Returns + ------- + list + List of tuples + """ + return list(self.data.items()) + + def values(self): + """ + Return a list of all the objects + + Returns + ------- + list + List of objects + """ + return list(self.data.values()) From 33c6ff44b1da4bb75307cd5b006a24ae56cba83d Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 14 Nov 2024 15:20:31 -0500 Subject: [PATCH 18/18] Update --- tests/test_nwb.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/test_nwb.py b/tests/test_nwb.py index 8b9db257..4e685644 100644 --- a/tests/test_nwb.py +++ b/tests/test_nwb.py @@ -89,6 +89,23 @@ def test_NWBFile(): assert isinstance(nwb.io, pynwb.NWBHDF5IO) nwb.close() + assert nwb.keys() == [ + "position_time_support", + "epochs", + "z", + "y", + "x", + "rz", + "ry", + "rx", + ] + + for a, b in zip(nwb.items(), nwb.data.items()): + assert a == b + + for a, b in zip(nwb.values(), nwb.data.values()): + assert a == b + def test_NWBFile_missing_file(): with pytest.raises(FileNotFoundError) as e_info: