From a9ae984045582f160b207fb83e44230f2391f762 Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Thu, 7 Nov 2024 11:35:47 -0500 Subject: [PATCH] prohibit setting of 'rate' for TsGroup --- pynapple/core/metadata_class.py | 16 ++++++++----- pynapple/core/ts_group.py | 3 +-- tests/test_metadata.py | 40 ++++++++++++++++++++++----------- 3 files changed, 39 insertions(+), 20 deletions(-) diff --git a/pynapple/core/metadata_class.py b/pynapple/core/metadata_class.py index a7fb1a7b..da67d7c8 100644 --- a/pynapple/core/metadata_class.py +++ b/pynapple/core/metadata_class.py @@ -112,10 +112,16 @@ def _raise_invalid_metadata_column_name(self, name): ) # warnings for metadata names that cannot be accessed as attributes or keys if name in self._class_attributes: - # existing non-metadata attribute - warnings.warn( - f"Metadata name '{name}' overlaps with an existing attribute, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata." - ) + if (self.nap_class == "TsGroup") and (name == "rate"): + # special exception for TsGroup rate attribute + raise ValueError( + f"Invalid metadata name '{name}'. Cannot overwrite TsGroup 'rate'!" + ) + else: + # existing non-metadata attribute + warnings.warn( + f"Metadata name '{name}' overlaps with an existing attribute, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata." + ) elif hasattr(self, "columns") and name in self.columns: if self.nap_class == "TsdFrame": # special exception for TsdFrame columns @@ -123,7 +129,7 @@ def _raise_invalid_metadata_column_name(self, name): f"Invalid metadata name '{name}'. Metadata name must differ from {list(self.columns)} column names!" ) else: - # existing non-metadata attribute + # existing non-metadata column warnings.warn( f"Metadata name '{name}' overlaps with an existing property, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata." ) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 3e5427eb..b01c60f4 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -188,10 +188,9 @@ def __init__( data = {k: data[k].restrict(self.time_support) for k in self.index} UserDict.__init__(self, data) - + self.nap_class = self.__class__.__name__ # grab current attributes before adding metadata self._class_attributes = self.__dir__() - self._class_attributes.remove("rate") # remove rate metadata self._class_attributes.append("_class_attributes") # add this property # Making the TsGroup non mutable diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 84570576..590dab6a 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -541,21 +541,23 @@ def tsgroup_meta(): @pytest.mark.parametrize( - "name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp", + "name, set_exp, set_attr_exp, set_key_exp, get_exp, get_attr_exp, get_key_exp", [ # pre-computed rate metadata ( "rate", - # no warning with set_info - warnings.catch_warnings(action="error"), - # warning with setattr - pytest.warns(UserWarning, match="Replacing TsGroup 'rate'"), - # warning with setitem - pytest.warns(UserWarning, match="Replacing TsGroup 'rate'"), - # get attribute is metadata - does_not_raise(), - # get key is metadata - does_not_raise(), + # error with set_info + pytest.raises(ValueError, match="Invalid metadata name"), + # error with setattr + pytest.raises(AttributeError, match="Cannot set attribute"), + # error with setitem + pytest.raises(ValueError, match="Invalid metadata name"), + # value mismatch with get_info + pytest.raises(AssertionError), + # value mismatch with getattr + pytest.raises(AssertionError), + # value mismatch with getitem + pytest.raises(AssertionError), ), # 'rates' attribute ( @@ -566,6 +568,8 @@ def tsgroup_meta(): pytest.raises(AttributeError, match="Cannot set attribute"), # warn with setitem pytest.warns(UserWarning, match="overlaps with an existing"), + # no error with get_info + does_not_raise(), # get attribute is not metadata pytest.raises(AssertionError), # get key is metadata @@ -580,6 +584,8 @@ def tsgroup_meta(): does_not_raise(), # no warning with setitem does_not_raise(), + # no error with get_info + does_not_raise(), # attr should match metadata does_not_raise(), # key should match metadata @@ -588,7 +594,14 @@ def tsgroup_meta(): ], ) def test_tsgroup_metadata_overlapping_names( - tsgroup_meta, name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp + tsgroup_meta, + name, + set_exp, + set_attr_exp, + set_key_exp, + get_exp, + get_attr_exp, + get_key_exp, ): assert hasattr(tsgroup_meta, name) @@ -602,7 +615,8 @@ def test_tsgroup_metadata_overlapping_names( with set_key_exp: tsgroup_meta[name] = np.ones(4) # retrieve with get_info - assert np.all(tsgroup_meta.get_info(name) == np.ones(4)) + with get_exp: + assert np.all(tsgroup_meta.get_info(name) == np.ones(4)) # make sure it doesn't access metadata if its an existing attribute or key with get_attr_exp: assert np.all(getattr(tsgroup_meta, name) == np.ones(4))