Skip to content

Commit

Permalink
Merge pull request #363 from pynapple-org/metadata
Browse files Browse the repository at this point in the history
Allow any string name for metadata
  • Loading branch information
gviejo authored Nov 14, 2024
2 parents 93b9cd3 + 33c6ff4 commit 41da12f
Show file tree
Hide file tree
Showing 10 changed files with 621 additions and 116 deletions.
70 changes: 59 additions & 11 deletions pynapple/core/interval_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .metadata_class import _MetadataMixin
from .time_index import TsIndex
from .utils import (
_convert_iter_to_str,
_get_terminal_size,
_IntervalSetSliceHelper,
check_filename,
Expand Down Expand Up @@ -214,10 +215,12 @@ def __init__(
self.index = np.arange(data.shape[0], dtype="int")
self.columns = np.array(["start", "end"])
self.nap_class = self.__class__.__name__
if drop_meta:
_MetadataMixin.__init__(self)
else:
_MetadataMixin.__init__(self, metadata)
# initialize metadata to get all attributes before setting metadata
_MetadataMixin.__init__(self)
self._class_attributes = self.__dir__() # get list of all attributes
self._class_attributes.append("_class_attributes") # add this property
if drop_meta is False:
self.set_info(metadata)
self._initialized = True

def __repr__(self):
Expand All @@ -229,7 +232,14 @@ def __repr__(self):
# By default, the first three columns should always show.

# Adding an extra column between actual values and metadata
col_names = self._metadata.columns
try:
metadata = self._metadata
col_names = metadata.columns
except Exception:
# Necessary for backward compatibility when saving IntervalSet as pickle
metadata = pd.DataFrame(index=self.index)
col_names = []

headers = ["index", "start", "end"]
if len(col_names):
headers += [""] + [c for c in col_names]
Expand All @@ -249,7 +259,7 @@ def __repr__(self):
self.index[0:n_rows, None],
self.values[0:n_rows],
separator,
self._metadata.values[0:n_rows],
_convert_iter_to_str(metadata.values[0:n_rows]),
),
dtype=object,
),
Expand All @@ -259,7 +269,7 @@ def __repr__(self):
self.index[-n_rows:, None],
self.values[0:n_rows],
separator,
self._metadata.values[-n_rows:],
_convert_iter_to_str(metadata.values[-n_rows:]),
),
dtype=object,
),
Expand All @@ -271,7 +281,12 @@ def __repr__(self):
else:
separator = np.empty((len(self), 0))
data = np.hstack(
(self.index[:, None], self.values, separator, self._metadata.values),
(
self.index[:, None],
self.values,
separator,
_convert_iter_to_str(metadata.values),
),
dtype=object,
)

Expand All @@ -286,19 +301,53 @@ def __len__(self):
def __setattr__(self, name, value):
# necessary setter to allow metadata to be set as an attribute
if self._initialized:
_MetadataMixin.__setattr__(self, name, value)
if name in self._class_attributes:
raise AttributeError(
f"Cannot set attribute '{name}'; IntervalSet is immutable. Use 'set_info()' to set '{name}' as metadata."
)
else:
_MetadataMixin.__setattr__(self, name, value)
else:
object.__setattr__(self, name, value)

def __getattr__(self, name):
# Necessary for backward compatibility with pickle

# avoid infinite recursion when pickling due to
# self._metadata.column having attributes '__reduce__', '__reduce_ex__'
if name in ("__getstate__", "__setstate__", "__reduce__", "__reduce_ex__"):
raise AttributeError(name)

try:
metadata = self._metadata
except Exception:
metadata = pd.DataFrame(index=self.index)

if name == "_metadata":
return metadata
elif name in metadata.columns:
return _MetadataMixin.__getattr__(self, name)
else:
return super().__getattr__(name)

def __setitem__(self, key, value):
if (isinstance(key, str)) and (key not in self.columns):
if key in self.columns:
raise RuntimeError(
"IntervalSet is immutable. Starts and ends have been already sorted."
)
elif isinstance(key, str):
_MetadataMixin.__setitem__(self, key, value)
else:
raise RuntimeError(
"IntervalSet is immutable. Starts and ends have been already sorted."
)

def __getitem__(self, key):
try:
metadata = _MetadataMixin.__getitem__(self, key)
except Exception:
metadata = pd.DataFrame(index=self.index)

if isinstance(key, str):
# self[str]
if key == "start":
Expand All @@ -323,7 +372,6 @@ def __getitem__(self, key):
elif isinstance(key, Number):
# self[Number]
output = self.values.__getitem__(key)
metadata = _MetadataMixin.__getitem__(self, key)
return IntervalSet(start=output[0], end=output[1], metadata=metadata)
elif isinstance(key, (slice, list, np.ndarray, pd.Series)):
# self[array_like]
Expand Down
67 changes: 39 additions & 28 deletions pynapple/core/metadata_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class _MetadataMixin:
"""

def __init__(self, metadata=None, **kwargs):
def __init__(self, metadata=None):
"""
Metadata initializer
Expand All @@ -21,7 +21,6 @@ def __init__(self, metadata=None, **kwargs):
List of pandas.DataFrame
**kwargs : dict
Dictionary containing metadata information
"""
if self.__class__.__name__ == "TsdFrame":
# metadata index is the same as the columns for TsdFrame
Expand All @@ -31,12 +30,8 @@ def __init__(self, metadata=None, **kwargs):
self.metadata_index = self.index

self._metadata = pd.DataFrame(index=self.metadata_index)
if len(kwargs):
warnings.warn(
"initializing metadata with variable keyword arguments may be unsupported in a future version of Pynapple. Instead, initialize using the metadata argument.",
FutureWarning,
)
self.set_info(metadata, **kwargs)

self.set_info(metadata)

def __dir__(self):
"""
Expand Down Expand Up @@ -115,32 +110,48 @@ def _raise_invalid_metadata_column_name(self, name):
raise TypeError(
f"Invalid metadata type {type(name)}. Metadata column names must be strings!"
)
if hasattr(self, name) and (name not in self.metadata_columns):
# existing non-metadata attribute
raise ValueError(
f"Invalid metadata name '{name}'. Metadata name must differ from "
f"{type(self).__dict__.keys()} attribute names!"
)
if hasattr(self, "columns") and name in self.columns:
# existing column (since TsdFrame columns are not attributes)
raise ValueError(
f"Invalid metadata name '{name}'. Metadata name must differ from "
f"{self.columns} column names!"
)
if name[0].isalpha() is False:
# starts with a number
raise ValueError(
f"Invalid metadata name '{name}'. Metadata name cannot start with a number"
)
# warnings for metadata names that cannot be accessed as attributes or keys
if name in self._class_attributes:
if (self.nap_class == "TsGroup") and (name == "rate"):
# special exception for TsGroup rate attribute
raise ValueError(
f"Invalid metadata name '{name}'. Cannot overwrite TsGroup 'rate'!"
)
else:
# existing non-metadata attribute
warnings.warn(
f"Metadata name '{name}' overlaps with an existing attribute, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata."
)
elif hasattr(self, "columns") and name in self.columns:
if self.nap_class == "TsdFrame":
# special exception for TsdFrame columns
raise ValueError(
f"Invalid metadata name '{name}'. Metadata name must differ from {list(self.columns)} column names!"
)
else:
# existing non-metadata column
warnings.warn(
f"Metadata name '{name}' overlaps with an existing property, and cannot be accessed as an attribute or key. Use 'get_info()' to access metadata."
)
# elif name in self.metadata_columns:
# # warnings for metadata that already exists
# warnings.warn(f"Overwriting existing metadata column '{name}'.")

# warnings for metadata that cannot be accessed as attributes
if name.replace("_", "").isalnum() is False:
# contains invalid characters
raise ValueError(
f"Invalid metadata name '{name}'. Metadata name cannot contain special characters except for underscores"
warnings.warn(
f"Metadata name '{name}' contains a special character, and cannot be accessed as an attribute. Use 'get_info()' or key indexing to access metadata."
)
elif (name[0].isalpha() is False) and (name[0] != "_"):
# starts with a number
warnings.warn(
f"Metadata name '{name}' starts with a number, and cannot be accessed as an attribute. Use 'get_info()' or key indexing to access metadata."
)

def _check_metadata_column_names(self, metadata=None, **kwargs):
"""
Check that metadata column names don't conflict with existing attributes, don't start with a number, and don't contain invalid characters.
Throw warnings when metadata names cannot be accessed as attributes or keys.
"""

if metadata is not None:
Expand Down
47 changes: 37 additions & 10 deletions pynapple/core/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from .time_index import TsIndex
from .utils import (
_concatenate_tsd,
_convert_iter_to_str,
_get_terminal_size,
_split_tsd,
_TsdFrameSliceHelper,
Expand Down Expand Up @@ -949,11 +950,22 @@ def __init__(

self.columns = pd.Index(c)
self.nap_class = self.__class__.__name__
_MetadataMixin.__init__(self, metadata)
# initialize metadata for class attributes
_MetadataMixin.__init__(self)
# get current list of attributes
self._class_attributes = self.__dir__()
self._class_attributes.append("_class_attributes")
# set metadata
self.set_info(metadata)
self._initialized = True

@property
def loc(self):
# add deprecation warning
warnings.warn(
"'loc' will be deprecated in a future version. Use bracket indexing instead.",
DeprecationWarning,
)
return _TsdFrameSliceHelper(self)

def __repr__(self):
Expand Down Expand Up @@ -1030,7 +1042,9 @@ def __repr__(self):
np.hstack(
(
col_names[:, None],
self._metadata.values[0:max_cols].T,
_convert_iter_to_str(
self._metadata.values[0:max_cols].T
),
ends,
),
dtype=object,
Expand All @@ -1045,7 +1059,12 @@ def __repr__(self):
def __setattr__(self, name, value):
# necessary setter to allow metadata to be set as an attribute
if self._initialized:
_MetadataMixin.__setattr__(self, name, value)
if name in self._class_attributes:
raise AttributeError(
f"Cannot set attribute: '{name}' is a reserved attribute. Use 'set_info()' to set '{name}' as metadata."
)
else:
_MetadataMixin.__setattr__(self, name, value)
else:
super().__setattr__(name, value)

Expand All @@ -1056,7 +1075,15 @@ def __getattr__(self, name):
# self._metadata.column having attributes '__reduce__', '__reduce_ex__'
if name in ("__getstate__", "__setstate__", "__reduce__", "__reduce_ex__"):
raise AttributeError(name)
if name in self._metadata.columns:

try:
metadata = self._metadata
except (AttributeError, RecursionError):
metadata = pd.DataFrame(index=self.columns)

if name == "_metadata":
return metadata
elif name in metadata.columns:
return _MetadataMixin.__getattr__(self, name)
else:
return super().__getattr__(name)
Expand Down Expand Up @@ -1096,18 +1123,18 @@ def __getitem__(self, key, *args, **kwargs):
"When indexing with a Tsd, it must contain boolean values"
)
key = key.d
elif isinstance(key, str) and (key in self.metadata_columns):
return _MetadataMixin.__getitem__(self, key)
elif (
isinstance(key, str)
or hasattr(key, "__iter__")
and all([isinstance(k, str) for k in key])
):
if all(k in self.metadata_columns for k in key):
return _MetadataMixin.__getitem__(self, key)
if all(k in self.columns for k in key):
with warnings.catch_warnings():
# ignore deprecated warning for loc
warnings.simplefilter("ignore")
return self.loc[key]
else:
return self.loc[key]

return _MetadataMixin.__getitem__(self, key)
else:
if isinstance(key, pd.Series) and key.index.equals(self.columns):
# if indexing with a pd.Series from metadata, transform it to tuple with slice(None) in first position
Expand Down
Loading

0 comments on commit 41da12f

Please sign in to comment.