Skip to content

Commit

Permalink
prohibit setting of 'rate' for TsGroup
Browse files Browse the repository at this point in the history
  • Loading branch information
sjvenditto committed Nov 7, 2024
1 parent 1128fa5 commit a9ae984
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 20 deletions.
16 changes: 11 additions & 5 deletions pynapple/core/metadata_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,24 @@ def _raise_invalid_metadata_column_name(self, name):
)
# warnings for metadata names that cannot be accessed as attributes or keys
if name in self._class_attributes:
# existing non-metadata attribute
warnings.warn(
f"Metadata name '{name}' overlaps with an existing attribute, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata."
)
if (self.nap_class == "TsGroup") and (name == "rate"):
# special exception for TsGroup rate attribute
raise ValueError(
f"Invalid metadata name '{name}'. Cannot overwrite TsGroup 'rate'!"
)
else:
# existing non-metadata attribute
warnings.warn(
f"Metadata name '{name}' overlaps with an existing attribute, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata."
)
elif hasattr(self, "columns") and name in self.columns:
if self.nap_class == "TsdFrame":
# special exception for TsdFrame columns
raise ValueError(
f"Invalid metadata name '{name}'. Metadata name must differ from {list(self.columns)} column names!"
)
else:
# existing non-metadata attribute
# existing non-metadata column
warnings.warn(
f"Metadata name '{name}' overlaps with an existing property, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata."
)
Expand Down
3 changes: 1 addition & 2 deletions pynapple/core/ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,9 @@ def __init__(
data = {k: data[k].restrict(self.time_support) for k in self.index}

UserDict.__init__(self, data)

self.nap_class = self.__class__.__name__
# grab current attributes before adding metadata
self._class_attributes = self.__dir__()
self._class_attributes.remove("rate") # remove rate metadata
self._class_attributes.append("_class_attributes") # add this property

# Making the TsGroup non mutable
Expand Down
40 changes: 27 additions & 13 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,21 +541,23 @@ def tsgroup_meta():


@pytest.mark.parametrize(
"name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp",
"name, set_exp, set_attr_exp, set_key_exp, get_exp, get_attr_exp, get_key_exp",
[
# pre-computed rate metadata
(
"rate",
# no warning with set_info
warnings.catch_warnings(action="error"),
# warning with setattr
pytest.warns(UserWarning, match="Replacing TsGroup 'rate'"),
# warning with setitem
pytest.warns(UserWarning, match="Replacing TsGroup 'rate'"),
# get attribute is metadata
does_not_raise(),
# get key is metadata
does_not_raise(),
# error with set_info
pytest.raises(ValueError, match="Invalid metadata name"),
# error with setattr
pytest.raises(AttributeError, match="Cannot set attribute"),
# error with setitem
pytest.raises(ValueError, match="Invalid metadata name"),
# value mismatch with get_info
pytest.raises(AssertionError),
# value mismatch with getattr
pytest.raises(AssertionError),
# value mismatch with getitem
pytest.raises(AssertionError),
),
# 'rates' attribute
(
Expand All @@ -566,6 +568,8 @@ def tsgroup_meta():
pytest.raises(AttributeError, match="Cannot set attribute"),
# warn with setitem
pytest.warns(UserWarning, match="overlaps with an existing"),
# no error with get_info
does_not_raise(),
# get attribute is not metadata
pytest.raises(AssertionError),
# get key is metadata
Expand All @@ -580,6 +584,8 @@ def tsgroup_meta():
does_not_raise(),
# no warning with setitem
does_not_raise(),
# no error with get_info
does_not_raise(),
# attr should match metadata
does_not_raise(),
# key should match metadata
Expand All @@ -588,7 +594,14 @@ def tsgroup_meta():
],
)
def test_tsgroup_metadata_overlapping_names(
tsgroup_meta, name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp
tsgroup_meta,
name,
set_exp,
set_attr_exp,
set_key_exp,
get_exp,
get_attr_exp,
get_key_exp,
):
assert hasattr(tsgroup_meta, name)

Expand All @@ -602,7 +615,8 @@ def test_tsgroup_metadata_overlapping_names(
with set_key_exp:
tsgroup_meta[name] = np.ones(4)
# retrieve with get_info
assert np.all(tsgroup_meta.get_info(name) == np.ones(4))
with get_exp:
assert np.all(tsgroup_meta.get_info(name) == np.ones(4))
# make sure it doesn't access metadata if its an existing attribute or key
with get_attr_exp:
assert np.all(getattr(tsgroup_meta, name) == np.ones(4))
Expand Down

0 comments on commit a9ae984

Please sign in to comment.