Skip to content

Commit

Permalink
Merge pull request #26 from pynapple-org/filter
Browse files Browse the repository at this point in the history
IIR Filter
  • Loading branch information
gviejo authored Sep 16, 2024
2 parents 164c047 + 8d26c20 commit 5e30a18
Show file tree
Hide file tree
Showing 18 changed files with 771 additions and 105 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
*.npz

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
4 changes: 3 additions & 1 deletion docs/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@ The functions that have been optimized with `pynajax` are :

- [`threshold`](https://pynapple-org.github.io/pynapple/reference/core/time_series/#pynapple.core.time_series.Tsd.threshold)

- [`event_trigger_average`](https://pynapple-org.github.io/pynapple/reference/process/perievent/#pynapple.process.perievent.compute_event_trigger_average)
- [`event_trigger_average`](https://pynapple-org.github.io/pynapple/reference/process/perievent/#pynapple.process.perievent.compute_event_trigger_average)

- filtering
125 changes: 125 additions & 0 deletions docs/examples/plot_benchmark_filtering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""
# filtering
This notebook compare the jax implementation of Butterworth filter with [scipy sosfiltfilt](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.sosfiltfilt.html).
Performances of the `'sinc'` mode can be found in the convolve benchmark as it is the function being called underneath.
⚠️ **Warning:** We do not recommend using GPU for filtering as it is much slower for the moment compared to CPU.
"""
import os
import numpy as np
import pynapple as nap
from time import perf_counter
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")



# %%
# Machine Configuration
import jax
print(jax.devices())

# %%
def get_mean_perf(tsd, mode, n=10):
tmp = np.zeros(n)
_ = nap.apply_lowpass_filter(tsd, 0.25 * tsd.rate, mode=mode)
for i in range(n):
t1 = perf_counter()
_ = nap.apply_lowpass_filter(tsd, 0.25 * tsd.rate, mode=mode)
t2 = perf_counter()
tmp[i] = t2 - t1
return [np.mean(tmp), np.std(tmp)]

# %%
# # Increasing number of time points

def benchmark_time_points(mode):
times = []
for T in np.arange(1000, 100000, 20000):
time_array = np.arange(T)/1000
data_array = np.random.randn(len(time_array))
startend = np.linspace(0, time_array[-1], T//100).reshape(T//200, 2)
ep = nap.IntervalSet(start=startend[::2,0], end=startend[::2,1])
tsd = nap.Tsd(t=time_array, d=data_array)#, time_support=ep)
times.append([T]+get_mean_perf(tsd, mode))
return np.array(times)


# %%
# Calling with numba/scipy
nap.nap_config.set_backend("numba")
times_butter_scipy = benchmark_time_points(mode="butter")

# %%
# Calling with jax
nap.nap_config.set_backend("jax")
times_butter_jax = benchmark_time_points(mode="butter")

# %%
# Figure

plt.figure()
for arr, label in zip(
[times_butter_scipy, times_butter_jax],
["Butter (scipy)", "Butter (jax)"],
):
plt.plot(arr[:, 0], arr[:, 1], "o-", label=label)
plt.fill_between(arr[:, 0], arr[:, 1] - arr[:, 2], arr[:, 1] + arr[:, 2], alpha=0.2)

plt.legend()
plt.xlabel("Number of time points")
plt.ylabel("Time (s)")
plt.title("Butterworth filter low pass")
# plt.show()


# %%
# # Increasing number of dimensions

def benchmark_dimensions(mode):
times = []
T = 60000
for n in np.arange(1, 100, 20):
time_array = np.arange(T)/1000
data_array = np.random.randn(len(time_array), n)
startend = np.linspace(0, time_array[-1], T//100).reshape(T//200, 2)
ep = nap.IntervalSet(start=startend[::2,0], end=startend[::2,1])
tsd = nap.TsdFrame(t=time_array, d=data_array, time_support=ep)
times.append([n]+get_mean_perf(tsd, mode))
return np.array(times)

# %%
# Calling with numba/scipy
nap.nap_config.set_backend("numba")
dims_butter_scipy = benchmark_dimensions(mode="butter")

# %%
# Calling with jax
nap.nap_config.set_backend("jax")
dims_butter_jax = benchmark_dimensions(mode="butter")

# %%
# Figure


plt.figure()

for arr, label in zip(
[dims_butter_scipy, dims_butter_jax],
["Butter (scipy)", "Butter (jax)"],
):
plt.plot(arr[:, 0], arr[:, 1], "o-", label=label)
plt.fill_between(arr[:, 0], arr[:, 1] - arr[:, 2], arr[:, 1] + arr[:, 2], alpha=0.2)

plt.legend()
plt.xlabel("Number of dimensions")
plt.ylabel("Time (s)")
plt.title("Butterworth filter low pass")
plt.show()

Binary file modified docs/images/convolve_benchmark.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 0 additions & 2 deletions src/pynajax/jax_core_bin_average.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import jax
import jax.numpy as jnp
import numpy as np

# import pynapple as nap
from numba import jit


Expand Down
8 changes: 6 additions & 2 deletions src/pynajax/jax_core_convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,12 @@ def convolve_intervals(time_array, data_array, starts, ends, kernel, trim="both"
extra = (extra[0], extra[1] + 1)

n = len(starts)
idx_start_shift = idx_start + np.arange(1, n + 1) * extra[0] + np.arange(0, n) * extra[1]
idx_end_shift = idx_end + np.arange(1, n + 1) * extra[0] + np.arange(0, n) * extra[1]
idx_start_shift = (
idx_start + np.arange(1, n + 1) * extra[0] + np.arange(0, n) * extra[1]
)
idx_end_shift = (
idx_end + np.arange(1, n + 1) * extra[0] + np.arange(0, n) * extra[1]
)

idx = _get_slicing(idx_start_shift, idx_end_shift)

Expand Down
6 changes: 4 additions & 2 deletions src/pynajax/jax_core_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ def threshold(time_array, data_array, starts, ends, thr, method):
ix2 = jnp.diff(ix * 1)

new_starts = (
time_array[1:][ix2 == 1] - (time_array[1:][ix2 == 1] - time_array[0:-1][ix2 == 1]) / 2
time_array[1:][ix2 == 1]
- (time_array[1:][ix2 == 1] - time_array[0:-1][ix2 == 1]) / 2
)
new_ends = (
time_array[0:-1][ix2 == -1] + (time_array[1:][ix2 == -1] - time_array[0:-1][ix2 == -1]) / 2
time_array[0:-1][ix2 == -1]
+ (time_array[1:][ix2 == -1] - time_array[0:-1][ix2 == -1]) / 2
)

if ix[0]: # First element to keep as start
Expand Down
210 changes: 210 additions & 0 deletions src/pynajax/jax_process_filtering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import scipy.signal as signal

from .utils import (
_get_shifted_indices,
_get_slicing,
_odd_ext_multiepoch,
_revert_epochs,
)


@partial(jax.jit, static_argnums=(3, ))
def _recursion_loop_sos(signal, sos, zi, nan_function):
"""
Applies a recursive second-order section (SOS) filter to the input signal.
Parameters
----------
signal : jnp.ndarray
The input signal to be filtered, with shape (n_samples,).
sos : jnp.ndarray
Array of second-order filter coefficients in the 'sos' format, with shape (n_sections, 6).
zi : jnp.ndarray
Initial conditions for the filter, with shape (n_sections, 2, n_epochs).
nan_function : callable
A function that specifies how to re-initialize the initial conditions when a NaN is encountered in the signal.
It should take two arguments: the epoch number and the current filter state, and return a tuple of the updated
epoch number and the re-initialized filter state.
Returns
-------
jnp.ndarray
The filtered signal, with the same shape as the input signal.
"""

def internal_loop(s, x_zi):
x_cur, zi_slice = x_zi
x_new = sos[s, 0] * x_cur + zi_slice[s, 0]
zi_slice = zi_slice.at[s, 0].set(
sos[s, 1] * x_cur - sos[s, 4] * x_new + zi_slice[s, 1]
)
zi_slice = zi_slice.at[s, 1].set(
sos[s, 2] * x_cur - sos[s, 5] * x_new)
x_cur = x_new
return x_cur, zi_slice

def recursion_step(carry, x):
epoch_num, zi_slice = carry

x_cur, zi_slice = jax.lax.fori_loop(
lower=0, upper=sos.shape[0], body_fun=internal_loop, init_val=(x, zi_slice)
)

# Use jax.lax.cond to choose between nan_case and not_nan_case
epoch_num, zi_slice = jax.lax.cond(
jnp.isnan(x), # Condition to check
nan_function, # Function to call if x is NaN
lambda i, x: (i, zi_slice), # Function to call if x is not NaN
epoch_num,
zi,
)

return (epoch_num, zi_slice), x_cur

_, res = jax.lax.scan(recursion_step, (0, zi[..., 0]), signal)

return res


# vectorize the recursion over signals.
_vmap_recursion_sos = jax.vmap(_recursion_loop_sos, in_axes=(1, None, 2, None), out_axes=1)


def _insert_constant(idx_start, idx_end, data_array, window_size, const=jnp.nan):
"""
Insert a constant value array between epochs in a time series data array.
This function interleaves a constant value array of specified size between each epoch in the data array.
Parameters
----------
idx_start : jnp.ndarray
Array of start indices for each epoch.
idx_end : jnp.ndarray
Array of end indices for each epoch.
data_array : jnp.ndarray
The input data array, with shape (n_samples, ...).
window_size : int
The size of the constant array to be inserted between epochs.
const : float, optional
The constant value to be inserted, by default jnp.nan.
Returns
-------
data_array: jnp.ndarray
The modified data array with the constant arrays inserted.
ix_orig: jnp.ndarray
Indices corresponding to the samples in the original data array.
ix_shift: jnp.ndarray
The shifted indices after the constant array has been interleaved.
idx_start_shift:
The shifted start indices of the epochs in the modified array.
idx_end_shift:
The shifted end indices of the epochs in the modified array.
"""
# shift by a window every epoch
idx_start_shift, idx_end_shift = _get_shifted_indices(
idx_start, idx_end, window_size
)

# get the indices for setting elements
ix_orig = _get_slicing(idx_start, idx_end)
ix_shift = _get_slicing(idx_start_shift, idx_end_shift)

tot_size = ix_shift[-1] - ix_shift[0] + 1
data_array = (
jnp.full((tot_size, *data_array.shape[1:]), const)
.at[ix_shift]
.set(data_array[ix_orig])
)
return data_array, ix_orig, ix_shift, idx_start_shift, idx_end_shift


def jax_sosfiltfilt(sos, time_array, data_array, starts, ends):
"""
Apply forward-backward filtering using a second-order section (SOS) filter.
This function applies an SOS filter to the data array in both forward and reverse directions,
which results in zero-phase filtering.
Parameters
----------
sos : np.ndarray
Array of second-order filter coefficients in the 'sos' format, with shape (n_sections, 6).
time_array : np.ndarray
The time array corresponding to the data, with shape (n_samples,).
data_array : jnp.ndarray
The data array to be filtered, with shape (n_samples, ...).
starts : np.ndarray
Array of start indices for the epochs in the data array.
ends : np.ndarray
Array of end indices for the epochs in the data array.
Returns
-------
: jnp.ndarray
The zero-phase filtered data array, with the same shape as the input data array.
"""

original_shape = data_array.shape
data_array = data_array.reshape(data_array.shape[0], -1)

# same default padding as scipy.sosfiltfilt ("pad" method and "odd" padtype).
n_sections = sos.shape[0]
ntaps = 2 * n_sections + 1
ntaps -= min((sos[:, 2] == 0).sum(), (sos[:, 5] == 0).sum())
pad_num = 3 * ntaps

ext, ix_start_pad, ix_end_pad, ix_data = _odd_ext_multiepoch(pad_num, time_array, data_array, starts, ends)

# get the start/end index of each epoch after padding
ix_start_ep = np.hstack((ix_start_pad[0], ix_start_pad[1:-1] + pad_num))
ix_end_ep = np.hstack((ix_start_ep[1:], ix_end_pad[-1]))

zi = signal.sosfilt_zi(sos)

# this braodcast has shape (*zi.shape, data_array.shape[1], len(ix_start_pad))
z0 = zi[..., jnp.newaxis, jnp.newaxis] * ext.T[jnp.newaxis, jnp.newaxis, ..., ix_start_ep]

if len(starts) > 1:
# multi epoch case augmenting with nans.
aug_data, ix_orig, ix_shift, idx_start_shift, idx_end_shift = _insert_constant(
ix_start_ep, ix_end_ep, ext, window_size=1, const=np.nan
)

# grab the next initial condition, increase the epoch counter
nan_func = lambda ep_num, x: (ep_num + 1, x[..., ep_num + 1])
else:
# single epoch, no augmentation
nan_func = lambda ep_num, x: (ep_num + 1, x[..., 0])
aug_data = ext
idx_start_shift = ix_start_ep
idx_end_shift = ix_end_ep
ix_shift = slice(None)


# call forward recursion
out = _vmap_recursion_sos(aug_data, sos, z0, nan_func)

# reverse time axis
irev = _revert_epochs(idx_start_shift, idx_end_shift)
out = out.at[ix_shift].set(out[irev])

# compute new init cond
z0 = zi[..., jnp.newaxis, jnp.newaxis] * out.T[jnp.newaxis, jnp.newaxis, ..., idx_start_shift]

# call backward recursion
out = _vmap_recursion_sos(out, sos, z0, nan_func)

# re-flip axis
out = out.at[ix_shift].set(out[irev])

# remove nans and padding
out = out[ix_shift][ix_data]

return out.reshape(original_shape)
Loading

0 comments on commit 5e30a18

Please sign in to comment.