Skip to content

Commit

Permalink
Merge pull request #263 from pynapple-org/ts_group_getitem
Browse files Browse the repository at this point in the history
Ts group getitem
  • Loading branch information
gviejo authored Apr 10, 2024
2 parents 65c904d + ba4f40a commit cf46f90
Show file tree
Hide file tree
Showing 4 changed files with 338 additions and 47 deletions.
138 changes: 108 additions & 30 deletions pynapple/core/ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import os
import warnings
from collections import UserDict
from collections.abc import Hashable

import numpy
import numpy as np
import pandas as pd
from tabulate import tabulate
Expand Down Expand Up @@ -169,40 +171,86 @@ def __init__(
Base functions
"""

def __getattr__(self, name):
"""
Allows dynamic access to metadata columns as properties.
Parameters
----------
name : str
The name of the metadata column to access.
Returns
-------
pandas.Series
The series of values for the requested metadata column.
Raises
------
AttributeError
If the requested attribute is not a metadata column.
"""
# Check if the requested attribute is part of the metadata
if name in self._metadata.columns:
return self._metadata[name]
else:
# If the attribute is not part of the metadata, raise AttributeError
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)

def __setitem__(self, key, value):
if self._initialized:
raise RuntimeError("TsGroup object is not mutable.")

self._metadata.loc[int(key), "rate"] = float(value.rate)
super().__setitem__(int(key), value)
# if self.__contains__(key):
# raise KeyError("Key {} already in group index.".format(key))
# else:
# if isinstance(value, (Ts, Tsd)):
# self._metadata.loc[int(key), "rate"] = value.rate
# super().__setitem__(int(key), value)
# elif isinstance(value, (np.ndarray, list)):
# warnings.warn(
# "Elements should not be passed as numpy array. Default time units is seconds when creating the Ts object.",
# stacklevel=2,
# )
# tmp = Ts(t=value, time_units="s")
# self._metadata.loc[int(key), "rate"] = tmp.rate
# super().__setitem__(int(key), tmp)
# else:
# raise ValueError("Value with key {} is not an iterable.".format(key))
if not self._initialized:
self._metadata.loc[int(key), "rate"] = float(value.rate)
super().__setitem__(int(key), value)
else:
if not isinstance(key, str):
raise ValueError("Metadata keys must be strings!")
# replicate pandas behavior of over-writing cols
if key in self._metadata.columns:
old_meta = self._metadata.copy()
self._metadata.pop(key)
try:
self.set_info(**{key: value})
except Exception:
self._metadata = old_meta
raise
else:
self.set_info(**{key: value})

def __getitem__(self, key):
if key.__hash__:
# Standard dict keys are Hashable
if isinstance(key, Hashable):
if self.__contains__(key):
return self.data[key]
elif key in self._metadata.columns:
return self.get_info(key)
else:
raise KeyError("Can't find key {} in group index.".format(key))
else:
metadata = self._metadata.loc[key, self._metadata.columns.drop("rate")]
return TsGroup(
{k: self[k] for k in key}, time_support=self.time_support, **metadata
)
raise KeyError(f"Can't find key {key} in group index.")

# array boolean are transformed into indices
# note that raw boolean are hashable, and won't be
# tsd == tsg.to_tsd()
elif np.asarray(key).dtype == bool:
key = np.asarray(key)
if key.ndim != 1:
raise IndexError("Only 1-dimensional boolean indices are allowed!")
if len(key) != self.__len__():
raise IndexError(
"Boolean index length must be equal to the number of Ts in the group! "
f"The number of Ts is {self.__len__()}, but the bolean array"
f"has length {len(key)} instead!"
)
key = self.index[key]
return self._ts_group_from_keys(key)

def _ts_group_from_keys(self, keys):
metadata = self._metadata.loc[
np.sort(keys), self._metadata.columns.drop("rate")
]
return TsGroup(
{k: self[k] for k in keys}, time_support=self.time_support, **metadata
)

def __repr__(self):
cols = self._metadata.columns.drop("rate")
Expand Down Expand Up @@ -270,9 +318,25 @@ def metadata_columns(self):
"""
return list(self._metadata.columns)

def _check_metadata_column_names(self, *args, **kwargs):
invalid_cols = []
for arg in args:
if isinstance(arg, pd.DataFrame):
invalid_cols += [col for col in arg.columns if hasattr(self, col)]

for k, v in kwargs.items():
if isinstance(v, (list, numpy.ndarray, pd.Series)) and hasattr(self, k):
invalid_cols += [k]

if invalid_cols:
raise ValueError(
f"Invalid metadata name(s) {invalid_cols}. Metadata name must differ from "
f"TsGroup attribute names!"
)

def set_info(self, *args, **kwargs):
"""
Add metadata informations about the TsGroup.
Add metadata information about the TsGroup.
Metadata are saved as a DataFrame.
Parameters
Expand All @@ -289,6 +353,8 @@ def set_info(self, *args, **kwargs):
no column labels are found when passing simple arguments,
indexes are not equals for a pandas series,+
not the same length when passing numpy array.
TypeError
If some of the provided metadata could not be set.
Examples
--------
Expand Down Expand Up @@ -324,6 +390,10 @@ def set_info(self, *args, **kwargs):
2 4 ca1 1
"""
# check for duplicate names, otherwise "self.metadata_name"
# syntax would behave unexpectedly.
self._check_metadata_column_names(*args, **kwargs)
not_set = []
if len(args):
for arg in args:
if isinstance(arg, pd.DataFrame):
Expand All @@ -333,6 +403,8 @@ def set_info(self, *args, **kwargs):
raise RuntimeError("Index are not equals")
elif isinstance(arg, (pd.Series, np.ndarray, list)):
raise RuntimeError("Argument should be passed as keyword argument.")
else:
not_set.append(arg)
if len(kwargs):
for k, v in kwargs.items():
if isinstance(v, pd.Series):
Expand All @@ -347,7 +419,13 @@ def set_info(self, *args, **kwargs):
self._metadata[k] = np.asarray(v)
else:
raise RuntimeError("Array is not the same length.")
return
else:
not_set.append({k: v})
if not_set:
raise TypeError(
f"Cannot set the following metadata:\n{not_set}.\nMetadata columns provided must be "
f"of type `panda.Series`, `tuple`, `list`, or `numpy.ndarray`."
)

def get_info(self, key):
"""
Expand Down
1 change: 1 addition & 0 deletions pynapple/io/interface_npz.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def load(self):
"index",
"d",
"rate",
"keys",
}:
tmp = self.file[k]
if len(tmp) == len(tsgroup):
Expand Down
63 changes: 57 additions & 6 deletions tests/test_npz_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,63 @@ def test_init(path):
assert file.type == "Tsd"

@pytest.mark.parametrize("path", [path])
def test_load(path):
for k in data.keys():
file_path = os.path.join(path, k+".npz")
file = nap.NPZFile(file_path)
tmp = file.load()
assert type(tmp) == type(data[k])
@pytest.mark.parametrize("k", ['tsd', 'ts', 'tsdframe', 'tsgroup', 'iset'])
def test_load(path, k):
file_path = os.path.join(path, k+".npz")
file = nap.NPZFile(file_path)
tmp = file.load()
assert type(tmp) == type(data[k])

@pytest.mark.parametrize("path", [path])
@pytest.mark.parametrize("k", ['tsgroup'])
def test_load_tsgroup(path, k):
file_path = os.path.join(path, k+".npz")
file = nap.NPZFile(file_path)
tmp = file.load()
assert type(tmp) == type(data[k])
assert tmp.keys() == data[k].keys()
assert np.all(tmp._metadata == data[k]._metadata)
assert np.all(tmp[neu] == data[k][neu] for neu in tmp.keys())
assert np.all(tmp.time_support == data[k].time_support)


@pytest.mark.parametrize("path", [path])
@pytest.mark.parametrize("k", ['tsd'])
def test_load_tsd(path, k):
file_path = os.path.join(path, k+".npz")
file = nap.NPZFile(file_path)
tmp = file.load()
assert type(tmp) == type(data[k])
assert np.all(tmp.d == data[k].d)
assert np.all(tmp.t == data[k].t)
assert np.all(tmp.time_support == data[k].time_support)


@pytest.mark.parametrize("path", [path])
@pytest.mark.parametrize("k", ['ts'])
def test_load_ts(path, k):
file_path = os.path.join(path, k+".npz")
file = nap.NPZFile(file_path)
tmp = file.load()
assert type(tmp) == type(data[k])
assert np.all(tmp.t == data[k].t)
assert np.all(tmp.time_support == data[k].time_support)



@pytest.mark.parametrize("path", [path])
@pytest.mark.parametrize("k", ['tsdframe'])
def test_load_tsdframe(path, k):
file_path = os.path.join(path, k+".npz")
file = nap.NPZFile(file_path)
tmp = file.load()
assert type(tmp) == type(data[k])
assert np.all(tmp.t == data[k].t)
assert np.all(tmp.time_support == data[k].time_support)
assert np.all(tmp.columns == data[k].columns)
assert np.all(tmp.d == data[k].d)



@pytest.mark.parametrize("path", [path])
def test_load_non_npz(path):
Expand Down
Loading

0 comments on commit cf46f90

Please sign in to comment.