Skip to content

Commit

Permalink
CHanged perivent continuous
Browse files Browse the repository at this point in the history
  • Loading branch information
gviejo committed May 7, 2024
1 parent bf866f0 commit ef0151c
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 74 deletions.
128 changes: 55 additions & 73 deletions pynapple/process/_process_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,82 +9,40 @@
"""

import numpy as np
from numba import jit, njit, prange
from numba import jit

from .. import core as nap


@njit(parallel=True)
def _jitcontinuous_perievent(
time_array, data_array, time_target_array, starts, ends, windowsize
):
N_samples = len(time_array)
N_target = len(time_target_array)
@jit(nopython=True)
def _jitcontinuous_perievent(time_array, time_target_array, starts, ends, windowsize):
N_epochs = len(starts)
count = np.zeros((N_epochs, 2), dtype=np.int64)
start_t = np.zeros((N_epochs, 2), dtype=np.int64)

k = 0 # Epochs
t = 0 # Samples
i = 0 # Target

while ends[k] < time_array[t] and ends[k] < time_target_array[i]:
k += 1

while k < N_epochs:
# Outside
while t < N_samples:
if time_array[t] >= starts[k]:
break
t += 1

while i < N_target:
if time_target_array[i] >= starts[k]:
break
i += 1

if time_array[t] <= ends[k]:
start_t[k, 0] = t

if time_target_array[i] <= ends[k]:
start_t[k, 1] = i

# Inside
while t < N_samples:
if time_array[t] > ends[k]:
break
else:
count[k, 0] += 1
t += 1

while i < N_target:
if time_target_array[i] > ends[k]:
break
else:
count[k, 1] += 1
i += 1

k += 1

if k == N_epochs:
break
if t == N_samples:
break
if i == N_target:
break

new_data_array = np.full(
(np.sum(windowsize) + 1, np.sum(count[:, 1]), *data_array.shape[1:]), np.nan
idx, count[:, 1] = nap._jitted_functions.jitrestrict_with_count(
time_target_array, starts, ends
)
time_target_array = time_target_array[idx]

idx, count[:, 0] = nap._jitted_functions.jitrestrict_with_count(
time_array, starts, ends
)
time_array = time_array[idx]

N_target = len(time_target_array)

slice_idx = np.zeros((N_target, 2), dtype=np.int64)
start_w = np.zeros(N_target, dtype=np.int64)

if np.any((count[:, 0] * count[:, 1]) > 0):
for k in prange(N_epochs):
for k in range(N_epochs):
if count[k, 0] > 0 and count[k, 1] > 0:
t = start_t[k, 0]
i = start_t[k, 1]
t = np.sum(count[0:k, 0])
i = np.sum(count[0:k, 1])
maxt = t + count[k, 0]
maxi = i + count[k, 1]
cnt_i = np.sum(count[0:k, 1])

start_t = t

while i < maxi:
interval = abs(time_array[t] - time_target_array[i])
Expand All @@ -99,18 +57,17 @@ def _jitcontinuous_perievent(
t_pos = t
t += 1

left = np.minimum(windowsize[0], t_pos - start_t[k, 0])
left = np.minimum(windowsize[0], t_pos - start_t)
right = np.minimum(windowsize[1], maxt - t_pos - 1)
center = windowsize[0] + 1
new_data_array[center - left - 1 : center + right, cnt_i] = (
data_array[t_pos - left : t_pos + right + 1]
)
# center = windowsize[0] + 1

slice_idx[i] = (t_pos - left, t_pos + right + 1)
start_w[i] = windowsize[0] - left

t -= 1
i += 1
cnt_i += 1

return new_data_array
return idx, slice_idx, np.sum(count[:, 1]), start_w


@jit(nopython=True)
Expand Down Expand Up @@ -270,13 +227,38 @@ def _perievent_trigger_average(
def _perievent_continuous(
time_array, data_array, time_target_array, starts, ends, windowsize
):

idx, slice_idx, N_target, w_starts = _jitcontinuous_perievent(
time_array, time_target_array, starts, ends, windowsize
)

data_array = data_array[idx]

if nap.utils.get_backend() == "jax":
from pynajax.jax_process_perievent import perievent_continuous

return perievent_continuous(
time_array, data_array, time_target_array, starts, ends, windowsize
data_array, np.sum(windowsize) + 1, N_target, slice_idx, w_starts
)
else:
return _jitcontinuous_perievent(
time_array, data_array, time_target_array, starts, ends, windowsize
new_data_array = np.full(
(np.sum(windowsize) + 1, N_target, *data_array.shape[1:]), np.nan
)

w_sizes = slice_idx[:, 1] - slice_idx[:, 0] # Different sizes

all_w_sizes = np.unique(w_sizes)
all_w_start = np.unique(w_starts)

for w_size in all_w_sizes:
for w_start in all_w_start:
col_idx = w_sizes == w_size
new_idx = np.zeros((w_size, np.sum(col_idx)), dtype=int)
for i, tmp in enumerate(slice_idx[col_idx]):
new_idx[:, i] = np.arange(tmp[0], tmp[1])

new_data_array[w_start : w_start + w_size, col_idx] = data_array[
new_idx
]

return new_data_array
30 changes: 29 additions & 1 deletion tests/test_perievent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: gviejo
# @Date: 2022-03-30 11:16:53
# @Last Modified by: Guillaume Viejo
# @Last Modified time: 2023-12-12 17:45:38
# @Last Modified time: 2024-05-07 15:22:24
#!/usr/bin/env python

"""Tests of perievent for `pynapple` package."""
Expand Down Expand Up @@ -109,7 +109,21 @@ def test_compute_perievent_continuous():
tsd = nap.Tsd(t=np.arange(100), d=np.arange(100))
tref = nap.Ts(t=np.array([20, 60]))
minmax=(-5, 10)

# time_array = tsd.t
# data_array = tsd.d
# time_target_array = tref.t
# starts = tsd.time_support.start
# ends = tsd.time_support.end
# window = np.abs(minmax)
# binsize = time_array[1] - time_array[0]
# idx1 = -np.arange(0, window[0] + binsize, binsize)[::-1][:-1]
# idx2 = np.arange(0, window[1] + binsize, binsize)[1:]
# time_idx = np.hstack((idx1, np.zeros(1), idx2))
# windowsize = np.array([idx1.shape[0], idx2.shape[0]])

pe = nap.compute_perievent_continuous(tsd, tref, minmax=minmax)

assert isinstance(pe, nap.TsdFrame)
assert pe.shape[1] == len(tref)
np.testing.assert_array_almost_equal(pe.index.values, np.arange(minmax[0], minmax[-1]+1))
Expand Down Expand Up @@ -182,6 +196,20 @@ def test_compute_perievent_continuous_with_ep():
np.testing.assert_array_almost_equal(pe.values, tmp)

tref = ep.starts

# time_array = tsd.t
# data_array = tsd.d
# time_target_array = tref.t
# starts = ep.start
# ends = ep.end
# window = np.abs(minmax)
# binsize = time_array[1] - time_array[0]
# idx1 = -np.arange(0, window[0] + binsize, binsize)[::-1][:-1]
# idx2 = np.arange(0, window[1] + binsize, binsize)[1:]
# time_idx = np.hstack((idx1, np.zeros(1), idx2))
# windowsize = np.array([idx1.shape[0], idx2.shape[0]])


pe = nap.compute_perievent_continuous(tsd, tref, minmax=minmax, ep=ep)
tmp = np.array([np.arange(t, t+minmax[1]+1) for t in tref.restrict(ep).t]).T
np.testing.assert_array_almost_equal(pe.values[abs(minmax[0]):], tmp)
Expand Down

0 comments on commit ef0151c

Please sign in to comment.