Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
gviejo committed Jan 11, 2025
1 parent c81c632 commit 235169c
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 65 deletions.
4 changes: 2 additions & 2 deletions pynapple/core/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,8 @@ def count(self, bin_size=None, ep=None, time_units="s", dtype=None):
if not isinstance(bin_size, float):
raise TypeError("bin_size argument should be float or int.")

if not isinstance(time_units, str) or time_units not in ["s", "ms", "us"]:
raise ValueError("time_units argument should be 's', 'ms' or 'us'.")
if not isinstance(time_units, str) or time_units not in ["s", "ms", "us"]:
raise ValueError("time_units argument should be 's', 'ms' or 'us'.")

if ep is None:
ep = self.time_support
Expand Down
45 changes: 14 additions & 31 deletions pynapple/core/ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def value_from(self, tsd, ep=None):
cols = self._metadata.columns.drop("rate")
return TsGroup(newgr, time_support=ep, metadata=self._metadata[cols])

def count(self, *args, dtype=None, **kwargs):
def count(self, bin_size=None, ep=None, time_units="s", dtype=None):
"""
Count occurences of events within bin_size or within a set of bins defined as an IntervalSet.
You can call this function in multiple ways :
Expand Down Expand Up @@ -652,39 +652,23 @@ def count(self, *args, dtype=None, **kwargs):
[1000 rows x 3 columns]
"""
bin_size = None
if "bin_size" in kwargs:
bin_size = kwargs["bin_size"]
if bin_size is not None:
if isinstance(bin_size, int):
bin_size = float(bin_size)
if not isinstance(bin_size, float):
raise ValueError("bin_size argument should be float.")
else:
for a in args:
if isinstance(a, (float, int)):
bin_size = float(a)

time_units = "s"
if "time_units" in kwargs:
time_units = kwargs["time_units"]
if not isinstance(time_units, str):
raise ValueError("time_units argument should be 's', 'ms' or 'us'.")
else:
for a in args:
if isinstance(a, str) and a in ["s", "ms", "us"]:
time_units = a

ep = self.time_support
if "ep" in kwargs:
ep = kwargs["ep"]
if not isinstance(ep, IntervalSet):
raise ValueError("ep argument should be IntervalSet")
else:
for a in args:
if isinstance(a, IntervalSet):
ep = a
raise TypeError("bin_size argument should be float or int.")

if not isinstance(time_units, str) or time_units not in ["s", "ms", "us"]:
raise ValueError("time_units argument should be 's', 'ms' or 'us'.")

if dtype:
if ep is None:
ep = self.time_support
if not isinstance(ep, IntervalSet):
raise TypeError("ep argument should be of type IntervalSet")

if dtype is None:
dtype = np.dtype(np.int64)
else:
try:
dtype = np.dtype(dtype)
except Exception:
Expand All @@ -694,7 +678,6 @@ def count(self, *args, dtype=None, **kwargs):
ends = ep.end

if isinstance(bin_size, (float, int)):
bin_size = float(bin_size)
bin_size = TsIndex.format_timestamps(np.array([bin_size]), time_units)[0]

# Call it on first element to pre-allocate the array
Expand Down
2 changes: 1 addition & 1 deletion pynapple/process/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@
compute_2d_tuning_curves_continuous,
compute_discrete_tuning_curves,
)
from .wavelets import compute_wavelet_transform, generate_morlet_filterbank
from .warping import build_tensor
from .wavelets import compute_wavelet_transform, generate_morlet_filterbank
106 changes: 87 additions & 19 deletions pynapple/process/warping.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,42 @@ def wrapper(*args, **kwargs):
def _build_tensor_from_tsgroup(input, ep, binsize, align, padding_value):
# Determine size of tensor
n_t = int(np.max(np.ceil((ep.end + binsize - ep.start) / binsize)))

output = np.ones(shape=(len(input), len(ep), n_t)) * padding_value

count = input.count(bin_size=binsize, ep=ep)

for i in range(len(ep)):
tmp = count.get(ep.start[i], ep.end[i]).values # Time by neuron
output[:, i, 0 : tmp.shape[0]] = np.transpose(tmp)
if align == "start":
for i in range(len(ep)):
tmp = count.get(ep.start[i], ep.end[i]).values
output[:, i, 0 : tmp.shape[0]] = np.transpose(tmp)
if np.all(np.isnan(output[:, :, -1])):
output = output[:, :, 0:-1]

if align == "end":
for i in range(len(ep)):
tmp = count.get(ep.start[i], ep.end[i]).values
output[:, i, -tmp.shape[0] :] = np.transpose(tmp)
if np.all(np.isnan(output[:, :, 0])):
output = output[:, :, 1:]

return output


def _build_tensor_from_tsd(input, ep, binsize, align, padding_value):
pass
def _build_tensor_from_tsd(input, ep, align, padding_value):
slices = [input.get_slice(s, e) for s, e in ep.values]
lengths = list(map(lambda sl: sl.stop - sl.start, slices))
n_t = max(lengths)
output = np.ones(shape=(len(ep), n_t, *input.shape[1:])) * padding_value
if align == "start":
for i, sl in enumerate(slices):
output[i, 0 : lengths[i]] = input[sl].values
if align == "end":
for i, sl in enumerate(slices):
output[i, -lengths[i] :] = input[sl].values

if output.ndim > 2:
output = np.moveaxis(output, source=[0, 1], destination=[-2, -1])

return output


@_validate_warping_inputs
Expand All @@ -65,23 +87,26 @@ def build_tensor(
"""
Return trial-based tensor from an IntervalSet object.
- if `input` is a `TsGroup`, returns a numpy array of shape (number of trial, number of group element, number of time bins).
The `binsize` parameter determines the number of time bins.
- If `input` is a `TsGroup`, returns a numpy array of shape (number of group element, number of trial, number of time bins). The `binsize` parameter determines the number of time bins.
- If `input` is `Tsd`, `TsdFrame` or `TsdTensor`, returns a numpy array of shape (shape of time series, number of trial, number of time points).
- if `input` is `Tsd`, `TsdFrame` or `TsdTensor`, returns a numpy array of shape
(number of trial, shape of time series, number of time points).
If the parameter `binsize` is used, the data are "bin-averaged".
The `align` parameter controls how the time series are aligned. If `align="start"`, the time
series are aligned to the start of the trials. If `align="end"`, the time series are aligned
to the end of the trials.
If trials are uneven durations, the returned array is padded. The parameter `padding_value`
determine which value is used to pad the array. Default is NaN.
Parameters
----------
input : Tsd, TsdFrame, TsdTensor or TsGroup
Returns a numpy array.
Input to slice and align to the trials within the `ep` parameter.
ep : IntervalSet
Epochs holding the trials. Each interval can be of unequal size.
binsize : Number, optional
align: str, optional
How to align the time series ('start' [default], 'end', 'both')
How to align the time series ('start' [default], 'end')
padding_value: Number, optional
How to pad the array if unequal intervals. Default is np.nan.
time_unit : str, optional
Expand All @@ -105,13 +130,56 @@ def build_tensor(
"""
if time_unit not in ["s", "ms", "us"]:
raise RuntimeError("time_unit should be 's', 'ms' or 'us'")
if align not in ["start", "end", "both"]:
raise RuntimeError("align should be 'start', 'end' or 'both'")

binsize = np.abs(nap.TsIndex.format_timestamps(np.array([binsize]), time_unit))[0]
if align not in ["start", "end"]:
raise RuntimeError("align should be 'start' or 'end'")

if isinstance(input, nap.TsGroup):
if not isinstance(binsize, Number):
raise RuntimeError("When input is a TsGroup, binsize should be specified")
return _build_tensor_from_tsgroup(input, ep, binsize, align, padding_value)

if isinstance(input, (nap.Tsd, nap.TsdFrame, nap.TsdTensor)):
return _build_tensor_from_tsd(input, ep, binsize, align, padding_value)
return _build_tensor_from_tsd(input, ep, align, padding_value)


@_validate_warping_inputs
def warp_tensor(input, ep, num_bin=None, align="start"):
"""
Return time-warped trial-based tensor from an IntervalSet object.
- If `input` is a `TsGroup`, returns a numpy array of shape (number of group element, number of trial, number of time bins). The `binsize` parameter determines the number of time bins.
- If `input` is `Tsd`, `TsdFrame` or `TsdTensor`, returns a numpy array of shape (shape of time series, number of trial, number of time points).
Parameters
----------
input : Tsd, TsdFrame, TsdTensor or TsGroup
Returns a numpy array.
ep : IntervalSet
Epochs holding the trials. Each interval can be of unequal size.
binsize : Number, optional
align: str, optional
How to align the time series ('start' [default], 'end')
padding_value: Number, optional
How to pad the array if unequal intervals. Default is np.nan.
time_unit : str, optional
Time units of the binsize parameter ('s' [default], 'ms', 'us').
Returns
-------
numpy.ndarray
Raises
------
RuntimeError
If `time_unit` not in ["s", "ms", "us"]
Examples
--------
"""
pass
14 changes: 2 additions & 12 deletions tests/test_ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,23 +306,13 @@ def test_count_time_units(self, group):
np.testing.assert_array_almost_equal(
count.loc[2].values[0:-1].flatten(), np.ones(len(count) - 1) * 5
)
count = tsgroup.count(b, tu)
np.testing.assert_array_almost_equal(
count.loc[0].values[0:-1].flatten(), np.ones(len(count) - 1)
)
np.testing.assert_array_almost_equal(
count.loc[1].values[0:-1].flatten(), np.ones(len(count) - 1) * 2
)
np.testing.assert_array_almost_equal(
count.loc[2].values[0:-1].flatten(), np.ones(len(count) - 1) * 5
)

def test_count_errors(self, group):
tsgroup = nap.TsGroup(group)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
tsgroup.count(bin_size={})

with pytest.raises(ValueError):
with pytest.raises(TypeError):
tsgroup.count(ep={})

with pytest.raises(ValueError):
Expand Down
Loading

0 comments on commit 235169c

Please sign in to comment.