diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 86213bfd..ceab66b1 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -22,6 +22,7 @@ from .metadata_class import _MetadataMixin from .time_index import TsIndex from .utils import ( + _convert_iter_to_str, _get_terminal_size, _IntervalSetSliceHelper, check_filename, @@ -214,10 +215,12 @@ def __init__( self.index = np.arange(data.shape[0], dtype="int") self.columns = np.array(["start", "end"]) self.nap_class = self.__class__.__name__ - if drop_meta: - _MetadataMixin.__init__(self) - else: - _MetadataMixin.__init__(self, metadata) + # initialize metadata to get all attributes before setting metadata + _MetadataMixin.__init__(self) + self._class_attributes = self.__dir__() # get list of all attributes + self._class_attributes.append("_class_attributes") # add this property + if drop_meta is False: + self.set_info(metadata) self._initialized = True def __repr__(self): @@ -229,7 +232,14 @@ def __repr__(self): # By default, the first three columns should always show. # Adding an extra column between actual values and metadata - col_names = self._metadata.columns + try: + metadata = self._metadata + col_names = metadata.columns + except Exception: + # Necessary for backward compatibility when saving IntervalSet as pickle + metadata = pd.DataFrame(index=self.index) + col_names = [] + headers = ["index", "start", "end"] if len(col_names): headers += [""] + [c for c in col_names] @@ -249,7 +259,7 @@ def __repr__(self): self.index[0:n_rows, None], self.values[0:n_rows], separator, - self._metadata.values[0:n_rows], + _convert_iter_to_str(metadata.values[0:n_rows]), ), dtype=object, ), @@ -259,7 +269,7 @@ def __repr__(self): self.index[-n_rows:, None], self.values[0:n_rows], separator, - self._metadata.values[-n_rows:], + _convert_iter_to_str(metadata.values[-n_rows:]), ), dtype=object, ), @@ -271,7 +281,12 @@ def __repr__(self): else: separator = np.empty((len(self), 0)) data = np.hstack( - (self.index[:, None], self.values, separator, self._metadata.values), + ( + self.index[:, None], + self.values, + separator, + _convert_iter_to_str(metadata.values), + ), dtype=object, ) @@ -286,12 +301,41 @@ def __len__(self): def __setattr__(self, name, value): # necessary setter to allow metadata to be set as an attribute if self._initialized: - _MetadataMixin.__setattr__(self, name, value) + if name in self._class_attributes: + raise AttributeError( + f"Cannot set attribute '{name}'; IntervalSet is immutable. Use 'set_info()' to set '{name}' as metadata." + ) + else: + _MetadataMixin.__setattr__(self, name, value) else: object.__setattr__(self, name, value) + def __getattr__(self, name): + # Necessary for backward compatibility with pickle + + # avoid infinite recursion when pickling due to + # self._metadata.column having attributes '__reduce__', '__reduce_ex__' + if name in ("__getstate__", "__setstate__", "__reduce__", "__reduce_ex__"): + raise AttributeError(name) + + try: + metadata = self._metadata + except Exception: + metadata = pd.DataFrame(index=self.index) + + if name == "_metadata": + return metadata + elif name in metadata.columns: + return _MetadataMixin.__getattr__(self, name) + else: + return super().__getattr__(name) + def __setitem__(self, key, value): - if (isinstance(key, str)) and (key not in self.columns): + if key in self.columns: + raise RuntimeError( + "IntervalSet is immutable. Starts and ends have been already sorted." + ) + elif isinstance(key, str): _MetadataMixin.__setitem__(self, key, value) else: raise RuntimeError( @@ -299,6 +343,11 @@ def __setitem__(self, key, value): ) def __getitem__(self, key): + try: + metadata = _MetadataMixin.__getitem__(self, key) + except Exception: + metadata = pd.DataFrame(index=self.index) + if isinstance(key, str): # self[str] if key == "start": @@ -323,7 +372,6 @@ def __getitem__(self, key): elif isinstance(key, Number): # self[Number] output = self.values.__getitem__(key) - metadata = _MetadataMixin.__getitem__(self, key) return IntervalSet(start=output[0], end=output[1], metadata=metadata) elif isinstance(key, (slice, list, np.ndarray, pd.Series)): # self[array_like] diff --git a/pynapple/core/metadata_class.py b/pynapple/core/metadata_class.py index dd5126ab..da67d7c8 100644 --- a/pynapple/core/metadata_class.py +++ b/pynapple/core/metadata_class.py @@ -11,7 +11,7 @@ class _MetadataMixin: """ - def __init__(self, metadata=None, **kwargs): + def __init__(self, metadata=None): """ Metadata initializer @@ -21,7 +21,6 @@ def __init__(self, metadata=None, **kwargs): List of pandas.DataFrame **kwargs : dict Dictionary containing metadata information - """ if self.__class__.__name__ == "TsdFrame": # metadata index is the same as the columns for TsdFrame @@ -31,12 +30,8 @@ def __init__(self, metadata=None, **kwargs): self.metadata_index = self.index self._metadata = pd.DataFrame(index=self.metadata_index) - if len(kwargs): - warnings.warn( - "initializing metadata with variable keyword arguments may be unsupported in a future version of Pynapple. Instead, initialize using the metadata argument.", - FutureWarning, - ) - self.set_info(metadata, **kwargs) + + self.set_info(metadata) def __dir__(self): """ @@ -115,32 +110,48 @@ def _raise_invalid_metadata_column_name(self, name): raise TypeError( f"Invalid metadata type {type(name)}. Metadata column names must be strings!" ) - if hasattr(self, name) and (name not in self.metadata_columns): - # existing non-metadata attribute - raise ValueError( - f"Invalid metadata name '{name}'. Metadata name must differ from " - f"{type(self).__dict__.keys()} attribute names!" - ) - if hasattr(self, "columns") and name in self.columns: - # existing column (since TsdFrame columns are not attributes) - raise ValueError( - f"Invalid metadata name '{name}'. Metadata name must differ from " - f"{self.columns} column names!" - ) - if name[0].isalpha() is False: - # starts with a number - raise ValueError( - f"Invalid metadata name '{name}'. Metadata name cannot start with a number" - ) + # warnings for metadata names that cannot be accessed as attributes or keys + if name in self._class_attributes: + if (self.nap_class == "TsGroup") and (name == "rate"): + # special exception for TsGroup rate attribute + raise ValueError( + f"Invalid metadata name '{name}'. Cannot overwrite TsGroup 'rate'!" + ) + else: + # existing non-metadata attribute + warnings.warn( + f"Metadata name '{name}' overlaps with an existing attribute, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata." + ) + elif hasattr(self, "columns") and name in self.columns: + if self.nap_class == "TsdFrame": + # special exception for TsdFrame columns + raise ValueError( + f"Invalid metadata name '{name}'. Metadata name must differ from {list(self.columns)} column names!" + ) + else: + # existing non-metadata column + warnings.warn( + f"Metadata name '{name}' overlaps with an existing property, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata." + ) + # elif name in self.metadata_columns: + # # warnings for metadata that already exists + # warnings.warn(f"Overwriting existing metadata column '{name}'.") + + # warnings for metadata that cannot be accessed as attributes if name.replace("_", "").isalnum() is False: # contains invalid characters - raise ValueError( - f"Invalid metadata name '{name}'. Metadata name cannot contain special characters except for underscores" + warnings.warn( + f"Metadata name '{name}' contains a special character, and cannot be accessed as an attribute. Use 'get_info()' or key indexing to access metadata." + ) + elif (name[0].isalpha() is False) and (name[0] != "_"): + # starts with a number + warnings.warn( + f"Metadata name '{name}' starts with a number, and cannot be accessed as an attribute. Use 'get_info()' or key indexing to access metadata." ) def _check_metadata_column_names(self, metadata=None, **kwargs): """ - Check that metadata column names don't conflict with existing attributes, don't start with a number, and don't contain invalid characters. + Throw warnings when metadata names cannot be accessed as attributes or keys. """ if metadata is not None: diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index c61b3654..389432b4 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -33,6 +33,7 @@ from .time_index import TsIndex from .utils import ( _concatenate_tsd, + _convert_iter_to_str, _get_terminal_size, _split_tsd, _TsdFrameSliceHelper, @@ -949,11 +950,22 @@ def __init__( self.columns = pd.Index(c) self.nap_class = self.__class__.__name__ - _MetadataMixin.__init__(self, metadata) + # initialize metadata for class attributes + _MetadataMixin.__init__(self) + # get current list of attributes + self._class_attributes = self.__dir__() + self._class_attributes.append("_class_attributes") + # set metadata + self.set_info(metadata) self._initialized = True @property def loc(self): + # add deprecation warning + warnings.warn( + "'loc' will be deprecated in a future version. Use bracket indexing instead.", + DeprecationWarning, + ) return _TsdFrameSliceHelper(self) def __repr__(self): @@ -1030,7 +1042,9 @@ def __repr__(self): np.hstack( ( col_names[:, None], - self._metadata.values[0:max_cols].T, + _convert_iter_to_str( + self._metadata.values[0:max_cols].T + ), ends, ), dtype=object, @@ -1045,7 +1059,12 @@ def __repr__(self): def __setattr__(self, name, value): # necessary setter to allow metadata to be set as an attribute if self._initialized: - _MetadataMixin.__setattr__(self, name, value) + if name in self._class_attributes: + raise AttributeError( + f"Cannot set attribute: '{name}' is a reserved attribute. Use 'set_info()' to set '{name}' as metadata." + ) + else: + _MetadataMixin.__setattr__(self, name, value) else: super().__setattr__(name, value) @@ -1056,7 +1075,15 @@ def __getattr__(self, name): # self._metadata.column having attributes '__reduce__', '__reduce_ex__' if name in ("__getstate__", "__setstate__", "__reduce__", "__reduce_ex__"): raise AttributeError(name) - if name in self._metadata.columns: + + try: + metadata = self._metadata + except (AttributeError, RecursionError): + metadata = pd.DataFrame(index=self.columns) + + if name == "_metadata": + return metadata + elif name in metadata.columns: return _MetadataMixin.__getattr__(self, name) else: return super().__getattr__(name) @@ -1096,18 +1123,18 @@ def __getitem__(self, key, *args, **kwargs): "When indexing with a Tsd, it must contain boolean values" ) key = key.d - elif isinstance(key, str) and (key in self.metadata_columns): - return _MetadataMixin.__getitem__(self, key) elif ( isinstance(key, str) or hasattr(key, "__iter__") and all([isinstance(k, str) for k in key]) ): - if all(k in self.metadata_columns for k in key): - return _MetadataMixin.__getitem__(self, key) + if all(k in self.columns for k in key): + with warnings.catch_warnings(): + # ignore deprecated warning for loc + warnings.simplefilter("ignore") + return self.loc[key] else: - return self.loc[key] - + return _MetadataMixin.__getitem__(self, key) else: if isinstance(key, pd.Series) and key.index.equals(self.columns): # if indexing with a pd.Series from metadata, transform it to tuple with slice(None) in first position diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index ba4b55cb..270f9532 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -22,7 +22,12 @@ from .metadata_class import _MetadataMixin from .time_index import TsIndex from .time_series import Ts, Tsd, TsdFrame, _BaseTsd, is_array_like -from .utils import _get_terminal_size, check_filename, convert_to_numpy_array +from .utils import ( + _convert_iter_to_str, + _get_terminal_size, + check_filename, + convert_to_numpy_array, +) def _union_intervals(i_sets): @@ -188,11 +193,20 @@ def __init__( data = {k: data[k].restrict(self.time_support) for k in self.index} UserDict.__init__(self, data) + self.nap_class = self.__class__.__name__ + # grab current attributes before adding metadata + self._class_attributes = self.__dir__() + self._class_attributes.append("_class_attributes") # add this property # Making the TsGroup non mutable self._initialized = True # Trying to add argument as metainfo + if len(kwargs): + warnings.warn( + "initializing metadata with variable keyword arguments may be unsupported in a future version of Pynapple. Instead, initialize using the metadata argument.", + FutureWarning, + ) self.set_info(metadata, **kwargs) """ @@ -202,10 +216,35 @@ def __init__( def __setattr__(self, name, value): # necessary setter to allow metadata to be set as an attribute if self._initialized: - _MetadataMixin.__setattr__(self, name, value) + if name in self._class_attributes: + raise AttributeError( + f"Cannot set attribute: '{name}' is a reserved attribute. Use 'set_info()' to set '{name}' as metadata." + ) + else: + _MetadataMixin.__setattr__(self, name, value) else: object.__setattr__(self, name, value) + def __getattr__(self, name): + # Necessary for backward compatibility with pickle + + # avoid infinite recursion when pickling due to + # self._metadata.column having attributes '__reduce__', '__reduce_ex__' + if name in ("__getstate__", "__setstate__", "__reduce__", "__reduce_ex__"): + raise AttributeError(name) + + try: + metadata = self._metadata + except Exception: + metadata = pd.DataFrame(index=self.index) + + if name == "_metadata": + return metadata + elif name in metadata.columns: + return _MetadataMixin.__getattr__(self, name) + else: + return super().__getattr__(name) + def __setitem__(self, key, value): if not self._initialized: self._metadata.loc[int(key), "rate"] = float(value.rate) @@ -280,7 +319,9 @@ def __repr__(self): ( self.index[0:n_rows, None], np.round(self._metadata[["rate"]].values[0:n_rows], 5), - self._metadata[col_names].values[0:n_rows, 0:max_cols], + _convert_iter_to_str( + self._metadata[col_names].values[0:n_rows, 0:max_cols] + ), ends, ), dtype=object, @@ -293,7 +334,9 @@ def __repr__(self): ( self.index[-n_rows:, None], np.round(self._metadata[["rate"]].values[-n_rows:], 5), - self._metadata[col_names].values[-n_rows:, 0:max_cols], + _convert_iter_to_str( + self._metadata[col_names].values[-n_rows:, 0:max_cols] + ), ends, ), dtype=object, @@ -306,7 +349,9 @@ def __repr__(self): ( self.index[:, None], np.round(self._metadata[["rate"]].values, 5), - self._metadata[col_names].values[:, 0:max_cols], + _convert_iter_to_str( + self._metadata[col_names].values[:, 0:max_cols] + ), ends, ), dtype=object, diff --git a/pynapple/core/utils.py b/pynapple/core/utils.py index 84293949..eb8410be 100644 --- a/pynapple/core/utils.py +++ b/pynapple/core/utils.py @@ -384,7 +384,12 @@ def __getitem__(self, key): IndexError """ - if key in ["start", "end"] + self.intervalset.metadata_columns: + # Pickle backward compatibility + try: + metadata_columns = self.intervalset.metadata_columns + except Exception: + metadata_columns = [] + if key in ["start", "end"] + metadata_columns: return self.intervalset[key] elif isinstance(key, list): return self.intervalset[key] @@ -393,10 +398,7 @@ def __getitem__(self, key): else: if isinstance(key, tuple): if len(key) == 2: - if ( - key[1] - not in ["start", "end"] + self.intervalset.metadata_columns - ): + if key[1] not in ["start", "end"] + metadata_columns: raise IndexError out = self.intervalset[key[0]][key[1]] if len(out) == 1: @@ -439,3 +441,19 @@ def check_filename(filename): raise RuntimeError("Path {} does not exist.".format(parent_folder)) return filename + + +def _convert_iter_to_str(array): + """ + This function converts an array of arrays to array of strings. + This help avoids a DeprecationWarning from numpy when printing an object with metadata + """ + try: + shape = array.shape + array = array.flatten() + for i in range(len(array)): + if isinstance(array[i], np.ndarray): + array[i] = np.array2string(array[i]) + return array.reshape(shape) + except Exception: + return array diff --git a/pynapple/io/interface_nwb.py b/pynapple/io/interface_nwb.py index d6b941fc..7fc30f33 100644 --- a/pynapple/io/interface_nwb.py +++ b/pynapple/io/interface_nwb.py @@ -1,9 +1,3 @@ -# -*- coding: utf-8 -*- -# @Author: Guillaume Viejo -# @Date: 2023-08-01 11:54:45 -# @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-05-21 15:28:27 - """ Pynapple class to interface with NWB files. Data are always lazy-loaded. @@ -458,3 +452,36 @@ def __getitem__(self, key): def close(self): """Close the NWB file""" self.io.close() + + def keys(self): + """ + Return keys of NWBFile + + Returns + ------- + list + List of keys + """ + return list(self.data.keys()) + + def items(self): + """ + Return a list of key/object. + + Returns + ------- + list + List of tuples + """ + return list(self.data.items()) + + def values(self): + """ + Return a list of all the objects + + Returns + ------- + list + List of objects + """ + return list(self.data.values()) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 44d7c6ef..acde65ee 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -312,6 +312,91 @@ def test_drop_metadata_warnings(iset_meta): iset_meta.time_span() +@pytest.mark.parametrize( + "name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp", + [ + # existing attribute and key + ( + "start", + # warn with set_info + pytest.warns(UserWarning, match="overlaps with an existing"), + # error with setattr + pytest.raises(AttributeError, match="IntervalSet is immutable"), + # error with setitem + pytest.raises(RuntimeError, match="IntervalSet is immutable"), + # attr should not match metadata + pytest.raises(AssertionError), + # key should not match metadata + pytest.raises(AssertionError), + ), + # existing attribute and key + ( + "end", + # warn with set_info + pytest.warns(UserWarning, match="overlaps with an existing"), + # error with setattr + pytest.raises(AttributeError, match="IntervalSet is immutable"), + # error with setitem + pytest.raises(RuntimeError, match="IntervalSet is immutable"), + # attr should not match metadata + pytest.raises(AssertionError), + # key should not match metadata + pytest.raises(AssertionError), + ), + # existing attribute + ( + "values", + # warn with set_info + pytest.warns(UserWarning, match="overlaps with an existing"), + # error with setattr + pytest.raises(AttributeError, match="IntervalSet is immutable"), + # warn with setitem + pytest.warns(UserWarning, match="overlaps with an existing"), + # attr should not match metadata + pytest.raises(AssertionError), + # key should match metadata + does_not_raise(), + ), + # existing metdata + ( + "label", + # no warning with set_info + does_not_raise(), + # no warning with setattr + does_not_raise(), + # no warning with setitem + does_not_raise(), + # attr should match metadata + does_not_raise(), + # key should match metadata + does_not_raise(), + ), + ], +) +def test_iset_metadata_overlapping_names( + iset_meta, name, set_exp, set_attr_exp, set_key_exp, get_attr_exp, get_key_exp +): + assert hasattr(iset_meta, name) + + # warning when set + with set_exp: + iset_meta.set_info({name: np.ones(4)}) + # error when set as attribute + with set_attr_exp: + setattr(iset_meta, name, np.ones(4)) + # error when set as key + with set_key_exp: + iset_meta[name] = np.ones(4) + # retrieve with get_info + np.testing.assert_array_almost_equal(iset_meta.get_info(name), np.ones(4)) + # make sure it doesn't access metadata if its an existing attribute or key + with get_attr_exp: + np.testing.assert_array_almost_equal(getattr(iset_meta, name), np.ones(4)) + # make sure it doesn't access metadata if its an existing key + with get_key_exp: + np.testing.assert_array_almost_equal(iset_meta[name], np.ones(4)) + + ############## ## TsdFrame ## ############## @@ -347,30 +432,209 @@ def test_tsdframe_metadata_slicing(tsdframe_meta): @pytest.mark.parametrize( - "args, kwargs, expected", + "name, attr_exp, set_exp, set_attr_exp, set_key_exp, get_exp, get_attr_exp, get_key_exp", [ + # existing data column ( - # invalid metadata names that are the same as column names - [ - pd.DataFrame( - index=["a", "b", "c"], - columns=["a", "b", "c"], - data=np.random.randint(0, 5, size=(3, 3)), - ) - ], - {}, + "a", + # not attribute + pytest.raises(AssertionError), + # error with set_info + pytest.raises(ValueError, match="Invalid metadata name"), + # error with setattr pytest.raises(ValueError, match="Invalid metadata name"), + # shape mismatch with setitem + pytest.raises(ValueError), + # key error with get_info + pytest.raises(KeyError), + # attribute should raise error + pytest.raises(AttributeError), + # key should not match metadata + pytest.raises(AssertionError), + ), + ( + "columns", + # attribute exists + does_not_raise(), + # warn with set_info + pytest.warns(UserWarning, match="overlaps with an existing"), + # cannot be set as attribute + pytest.raises(AttributeError, match="Cannot set attribute"), + # warn when set as key + pytest.warns(UserWarning, match="overlaps with an existing"), + # no error with get_info + does_not_raise(), + # attribute should not match metadata + pytest.raises(TypeError), + # key should match metadata + does_not_raise(), + ), + # existing metdata + ( + "l1", + # attribute exists + does_not_raise(), + # no warning with set_info + does_not_raise(), + # no warning with setattr + does_not_raise(), + # no warning with setitem + does_not_raise(), + # no error with get_info + does_not_raise(), + # attr should match metadata + does_not_raise(), + # key should match metadata + does_not_raise(), ), ], ) -def test_tsdframe_add_metadata_error(tsdframe_meta, args, kwargs, expected): - with expected: - tsdframe_meta.set_info(*args, **kwargs) +def test_tsdframe_metadata_overlapping_names( + tsdframe_meta, + name, + attr_exp, + set_exp, + set_attr_exp, + get_exp, + set_key_exp, + get_attr_exp, + get_key_exp, +): + with attr_exp: + assert hasattr(tsdframe_meta, name) + # warning when set + with set_exp: + # warnings.simplefilter("error") + tsdframe_meta.set_info({name: np.ones(4)}) + # error when set as attribute + with set_attr_exp: + setattr(tsdframe_meta, name, np.ones(4)) + # error when set as key + with set_key_exp: + tsdframe_meta[name] = np.ones(4) + # retrieve with get_info + with get_exp: + np.testing.assert_array_almost_equal(tsdframe_meta.get_info(name), np.ones(4)) + # make sure it doesn't access metadata if its an existing attribute or key + with get_attr_exp: + np.testing.assert_array_almost_equal(getattr(tsdframe_meta, name), np.ones(4)) + # make sure it doesn't access metadata if its an existing key + with get_key_exp: + np.testing.assert_array_almost_equal(tsdframe_meta[name], np.ones(4)) ############# ## TsGroup ## ############# +@pytest.fixture +def tsgroup_meta(): + return nap.TsGroup( + { + 0: nap.Ts(t=np.arange(0, 200)), + 1: nap.Ts(t=np.arange(0, 200, 0.5), time_units="s"), + 2: nap.Ts(t=np.arange(0, 300, 0.2), time_units="s"), + 3: nap.Ts(t=np.arange(0, 400, 1), time_units="s"), + }, + metadata={"label": [1, 2, 3, 4]}, + ) + + +@pytest.mark.parametrize( + "name, set_exp, set_attr_exp, set_key_exp, get_exp, get_attr_exp, get_key_exp", + [ + # pre-computed rate metadata + ( + "rate", + # error with set_info + pytest.raises(ValueError, match="Invalid metadata name"), + # error with setattr + pytest.raises(AttributeError, match="Cannot set attribute"), + # error with setitem + pytest.raises(ValueError, match="Invalid metadata name"), + # value mismatch with get_info + pytest.raises(AssertionError), + # value mismatch with getattr + pytest.raises(AssertionError), + # value mismatch with getitem + pytest.raises(AssertionError), + ), + # 'rates' attribute + ( + "rates", + # warning with set_info + pytest.warns(UserWarning, match="overlaps with an existing"), + # error with setattr + pytest.raises(AttributeError, match="Cannot set attribute"), + # warn with setitem + pytest.warns(UserWarning, match="overlaps with an existing"), + # no error with get_info + does_not_raise(), + # get attribute is not metadata + pytest.raises(AssertionError), + # get key is metadata + does_not_raise(), + ), + # existing metdata + ( + "label", + # no warning with set_info + does_not_raise(), + # no warning with setattr + does_not_raise(), + # no warning with setitem + does_not_raise(), + # no error with get_info + does_not_raise(), + # attr should match metadata + does_not_raise(), + # key should match metadata + does_not_raise(), + ), + ], +) +def test_tsgroup_metadata_overlapping_names( + tsgroup_meta, + name, + set_exp, + set_attr_exp, + set_key_exp, + get_exp, + get_attr_exp, + get_key_exp, +): + assert hasattr(tsgroup_meta, name) + + # warning when set + with set_exp: + tsgroup_meta.set_info({name: np.ones(4)}) + # error when set as attribute + with set_attr_exp: + setattr(tsgroup_meta, name, np.ones(4)) + # error when set as key + with set_key_exp: + tsgroup_meta[name] = np.ones(4) + # retrieve with get_info + with get_exp: + np.testing.assert_array_almost_equal(tsgroup_meta.get_info(name), np.ones(4)) + # make sure it doesn't access metadata if its an existing attribute or key + with get_attr_exp: + np.testing.assert_array_almost_equal(getattr(tsgroup_meta, name), np.ones(4)) + # make sure it doesn't access metadata if its an existing key + with get_key_exp: + np.testing.assert_array_almost_equal(tsgroup_meta[name], np.ones(4)) + + +def test_tsgroup_metadata_future_warnings(): + with pytest.warns(FutureWarning, match="may be unsupported"): + tsgroup = nap.TsGroup( + { + 0: nap.Ts(t=np.arange(0, 200)), + 1: nap.Ts(t=np.arange(0, 200, 0.5), time_units="s"), + 2: nap.Ts(t=np.arange(0, 300, 0.2), time_units="s"), + 3: nap.Ts(t=np.arange(0, 400, 1), time_units="s"), + }, + label=[1, 2, 3, 4], + ) ################## @@ -574,52 +838,26 @@ def test_add_metadata_df(self, obj, info, obj_len): ( # invalid names as strings starting with a number [ - pd.DataFrame( - columns=["1"], - data=np.ones((4, 1)), - ) + {"1": np.ones(4)}, ], {}, - pytest.raises(ValueError, match="Invalid metadata name"), + pytest.warns(UserWarning, match="starts with a number"), ), ( # invalid names with spaces [ - pd.DataFrame( - columns=["l 1"], - data=np.ones((4, 1)), - ) + {"l 1": np.ones(4)}, ], {}, - pytest.raises(ValueError, match="Invalid metadata name"), + pytest.warns(UserWarning, match="contains a special character"), ), ( # invalid names with periods [ - pd.DataFrame( - columns=["l.1"], - data=np.ones((4, 1)), - ) - ], - {}, - pytest.raises(ValueError, match="Invalid metadata name"), - ), - ( - # invalid names with dashes - [ - pd.DataFrame( - columns=["l-1"], - data=np.ones((4, 1)), - ) + {"1.1": np.ones(4)}, ], {}, - pytest.raises(ValueError, match="Invalid metadata name"), - ), - ( - # name that overlaps with existing attribute - [], - {"__dir__": np.zeros(4)}, - pytest.raises(ValueError, match="Invalid metadata name"), + pytest.warns(UserWarning, match="contains a special character"), ), ( # metadata with wrong length @@ -632,9 +870,17 @@ def test_add_metadata_df(self, obj, info, obj_len): ), ], ) - def test_add_metadata_error(self, obj, args, kwargs, expected): + def test_add_metadata_error(self, obj, obj_len, args, kwargs, expected): + # trim to appropriate length + if len(args): + if isinstance(args[0], pd.DataFrame): + metadata = args[0].iloc[:obj_len] + elif isinstance(args[0], dict): + metadata = {k: v[:obj_len] for k, v in args[0].items()} + else: + metadata = None with expected: - obj.set_info(*args, **kwargs) + obj.set_info(metadata, **kwargs) def test_add_metadata_key_error(self, obj, obj_len): # type specific key errors @@ -666,6 +912,62 @@ def test_overwrite_metadata(self, obj, obj_len): obj.label = [4] * obj_len assert np.all(obj.label == 4) + # test naming overlap of shared attributes + @pytest.mark.parametrize( + "name", + [ + "set_info", + "_metadata", + "_class_attributes", + ], + ) + def test_metadata_overlapping_names(self, obj, obj_len, name): + values = np.ones(obj_len) + + # set some metadata to force assertion error with "_metadata" case + obj.set_info(label=values) + + # assert attribute exists + assert hasattr(obj, name) + + # warning when set + with pytest.warns(UserWarning, match="overlaps with an existing"): + obj.set_info({name: values}) + # error when set as attribute + with pytest.raises(AttributeError, match="Cannot set attribute"): + setattr(obj, name, values) + # warning when set as key + with pytest.warns(UserWarning, match="overlaps with an existing"): + obj[name] = values + # retrieve with get_info + np.testing.assert_array_almost_equal(obj.get_info(name), values) + # make sure it doesn't access metadata if its an existing attribute or key + with pytest.raises((AssertionError, ValueError, TypeError)): + np.testing.assert_array_almost_equal(getattr(obj, name), values) + # access metadata as key + np.testing.assert_array_almost_equal(obj[name], values) + + # test metadata that can only be accessed as key + @pytest.mark.parametrize( + "name", + [ + "l.1", + "l 1", + "0", + ], + ) + def test_metadata_nonattribute_names(self, obj, obj_len, name): + values = np.ones(obj_len) + + # set some metadata to force assertion error with "_metadata" case + with pytest.warns(UserWarning, match="cannot be accessed as an attribute"): + obj.set_info({name: values}) + + # make sure it can be accessed with get_info + np.testing.assert_array_almost_equal(obj.get_info(name), values) + # make sure it can be accessed as key + np.testing.assert_array_almost_equal(obj[name], values) + @pytest.mark.parametrize("label, val", [([1, 1, 2, 2], 2)]) def test_metadata_slicing(self, obj, label, val, obj_len): # slicing not relevant for length 1 objects diff --git a/tests/test_nwb.py b/tests/test_nwb.py index 8b9db257..4e685644 100644 --- a/tests/test_nwb.py +++ b/tests/test_nwb.py @@ -89,6 +89,23 @@ def test_NWBFile(): assert isinstance(nwb.io, pynwb.NWBHDF5IO) nwb.close() + assert nwb.keys() == [ + "position_time_support", + "epochs", + "z", + "y", + "x", + "rz", + "ry", + "rx", + ] + + for a, b in zip(nwb.items(), nwb.data.items()): + assert a == b + + for a, b in zip(nwb.values(), nwb.data.values()): + assert a == b + def test_NWBFile_missing_file(): with pytest.raises(FileNotFoundError) as e_info: diff --git a/tests/test_time_series.py b/tests/test_time_series.py index a53612fd..e1064da9 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -1,6 +1,7 @@ """Tests of time series for `pynapple` package.""" import pickle +import warnings from contextlib import nullcontext as does_not_raise from numbers import Number from pathlib import Path @@ -1381,6 +1382,17 @@ def test_convolve_keep_columns(self, tsdframe): assert isinstance(tsd2, nap.TsdFrame) np.testing.assert_array_equal(tsd2.columns, tsdframe.columns) + def test_deprecation_warning(self, tsdframe): + columns = tsdframe.columns + # warning using loc + with pytest.warns(DeprecationWarning): + tsdframe.loc[columns[0]] + if isinstance(columns[0], str): + # suppressed warning with getitem, which implicitly uses loc + with warnings.catch_warnings(): + warnings.simplefilter("error") + tsdframe[columns[0]] + #################################################### # Test for ts @@ -1898,6 +1910,8 @@ def test_pickling(obj): assert np.all(obj.time_support == unpickled_obj.time_support) +# + #################################################### # Test for slicing #################################################### diff --git a/tests/test_ts_group.py b/tests/test_ts_group.py index 25308baf..bd5b3cc4 100644 --- a/tests/test_ts_group.py +++ b/tests/test_ts_group.py @@ -628,10 +628,6 @@ def test_setitem_metadata_twice(self, group): group["a"] = np.arange(len(group)) + 10 assert all(group._metadata["a"] == np.arange(len(group)) + 10) - def test_prevent_overwriting_existing_methods(self, ts_group): - with pytest.raises(ValueError, match=r"Invalid metadata name"): - ts_group["set_info"] = np.arange(2) - def test_getitem_ts_object(self, ts_group): assert isinstance(ts_group[1], nap.Ts)