From a2fe7c1629f8c0f6288fb815123c23834a7604fd Mon Sep 17 00:00:00 2001 From: gviejo Date: Tue, 27 Feb 2024 09:49:59 -0500 Subject: [PATCH 01/37] Working branch when passing jax for d --- pynapple/__init__.py | 2 +- pynapple/core/__init__.py | 2 +- pynapple/core/base_class.py | 4 ++-- pynapple/core/config.py | 41 ++++++++++++++++++++++++++++++- pynapple/core/interval_set.py | 2 +- pynapple/core/time_series.py | 20 ++++++---------- pynapple/core/ts_group.py | 2 +- pynapple/core/utils.py | 45 ++++++++++++++++++++++++++++------- tests/test_jax_backend.py | 14 +++++++++++ 9 files changed, 104 insertions(+), 28 deletions(-) create mode 100644 tests/test_jax_backend.py diff --git a/pynapple/__init__.py b/pynapple/__init__.py index 989e5ccc..f66dfa7c 100644 --- a/pynapple/__init__.py +++ b/pynapple/__init__.py @@ -1,4 +1,4 @@ __version__ = "0.6.0" -from .core import IntervalSet, Ts, Tsd, TsdFrame, TsdTensor, TsGroup, TsIndex, config +from .core import IntervalSet, Ts, Tsd, TsdFrame, TsdTensor, TsGroup, TsIndex, nap_config from .io import * from .process import * diff --git a/pynapple/core/__init__.py b/pynapple/core/__init__.py index d26dafd0..01810331 100644 --- a/pynapple/core/__init__.py +++ b/pynapple/core/__init__.py @@ -1,4 +1,4 @@ -from . import config +from .config import nap_config from .interval_set import IntervalSet from .time_index import TsIndex from .time_series import Ts, Tsd, TsdFrame, TsdTensor diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 1d647cc4..ef0ba8ad 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -18,7 +18,7 @@ ) from .interval_set import IntervalSet from .time_index import TsIndex -from .utils import convert_to_numpy, is_array_like +from .utils import cast_to_numpy, is_array_like class Base(abc.ABC): @@ -43,7 +43,7 @@ def __init__(self, t, time_units="s", time_support=None): # convert array-like data to numpy. # raise a warning to avoid silent conversion if non-numpy array is provided (jax arrays for instance) elif is_array_like(t): - t = convert_to_numpy(t, "t") + t = cast_to_numpy(t, "t") self.index = TsIndex(t, time_units) else: raise RuntimeError( diff --git a/pynapple/core/config.py b/pynapple/core/config.py index 3dfd813b..28f3c118 100644 --- a/pynapple/core/config.py +++ b/pynapple/core/config.py @@ -54,7 +54,46 @@ class PynappleConfig: def __init__(self): self.suppress_conversion_warnings = False - self.suppress_time_index_sorting_warnings = False + self.suppress_time_index_sorting_warnings = False + self.backend = "numba" + + @property + def backend(self): + """ + Pynapple backend. Can be "jax" or "numpy". + """ + return self._backend + + @backend.setter + def backend(self, backend): + self.set_backend(backend) + + def set_backend(self, backend): + + assert backend in ["numba", "jax"], "Options for backend are 'jax' or 'numba'" + + # Try to import pynajax + if backend == "jax": + try: + import pynajax as nax + self._backend = "jax" + except ImportError: + warnings.warn( + "Importing pynajax failed. Falling back to numba. To use the jax backend for pynapple, please install pynajax", + UserWarning, + ) + self._backend = "numba" + else: + self._backend = "numba" + + @property + def get_pynajax_backend_status(self): + try: + import pynajax + self.backend = "jax" + except: + self.backend = "numba" + @property def time_index_precision(self): diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index ec91f311..0b0ac61c 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -50,7 +50,7 @@ from .utils import ( _IntervalSetSliceHelper, _jitfix_iset, - convert_to_numpy, + cast_to_numpy, is_array_like, ) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 3c39f148..1dc22bf7 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -43,8 +43,10 @@ _concatenate_tsd, _split_tsd, _TsdFrameSliceHelper, - convert_to_numpy, + convert_to_numpy_array, + convert_to_jax_array, is_array_like, + get_backend ) @@ -78,19 +80,11 @@ class BaseTsd(Base, NDArrayOperatorsMixin, abc.ABC): def __init__(self, t, d, time_units="s", time_support=None): super().__init__(t, time_units, time_support) - # Converting d to numpy array - if isinstance(d, Number): - self.values = np.array([d]) - elif isinstance(d, (list, tuple)): - self.values = np.array(d) - elif isinstance(d, np.ndarray): - self.values = d - elif is_array_like(d): - self.values = convert_to_numpy(d, "d") + # Check if jax backend + if get_backend() == "jax": + self.values = convert_to_jax_array(d, "d") else: - raise RuntimeError( - "Unknown format for d. Accepted formats are numpy.ndarray, list, tuple or any array-like objects." - ) + self.values = convert_to_numpy_array(d, "d") assert len(self.index) == len( self.values diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index bd1a62f3..8db17957 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -22,7 +22,7 @@ from .interval_set import IntervalSet from .time_index import TsIndex from .time_series import BaseTsd, Ts, Tsd, TsdFrame, is_array_like -from .utils import convert_to_numpy +from .utils import cast_to_numpy def union_intervals(i_sets): diff --git a/pynapple/core/utils.py b/pynapple/core/utils.py index 91f88092..a456a3ef 100644 --- a/pynapple/core/utils.py +++ b/pynapple/core/utils.py @@ -2,7 +2,7 @@ # @Author: Guillaume Viejo # @Date: 2024-02-09 11:45:45 # @Last Modified by: gviejo -# @Last Modified time: 2024-02-21 21:27:04 +# @Last Modified time: 2024-02-27 09:47:11 """ Utility functions @@ -17,6 +17,41 @@ from .config import nap_config +def convert_to_numpy_array(array, array_name): + if isinstance(array, Number): + return np.array([d]) + elif isinstance(array, (list, tuple)): + return np.array(array) + elif isinstance(array, np.ndarray): + return array + elif is_array_like(array): + return cast_to_numpy(array, array_name) + else: + raise RuntimeError( + "Unknown format for d. Accepted formats are numpy.ndarray, list, tuple or any array-like objects." + ) + +def convert_to_jax_array(array, array_name): + import jax.numpy as jnp + if isinstance(array, Number): + return jnp.array([d]) + elif isinstance(array, (list, tuple)): + return jnp.array(array) + elif isinstance(array, jnp.ndarray): + return array + elif isinstance(array, np.ndarray): + return jnp.asarray(array) + else: + raise RuntimeError( + "Unknown format for d. Accepted formats are numpy.ndarray, list, tuple or any array-like objects." + ) + +def get_backend(): + """ + Return the current backend of pynapple + """ + return nap_config.backend + def is_array_like(obj): """ Check if an object is array-like. @@ -75,8 +110,7 @@ def is_array_like(obj): # and not_tsd_type ) - -def convert_to_numpy(array, array_name): +def cast_to_numpy(array, array_name): """ Convert an input array-like object to a NumPy array. @@ -118,7 +152,6 @@ def convert_to_numpy(array, array_name): ) return np.asarray(array) - def _split_tsd(func, tsd, indices_or_sections, axis=0): """ Wrappers of numpy split functions @@ -135,7 +168,6 @@ def _split_tsd(func, tsd, indices_or_sections, axis=0): else: return func._implementation(tsd.values, indices_or_sections, axis) - def _concatenate_tsd(func, tsds): """ Wrappers of np.concatenate and np.vstack @@ -179,7 +211,6 @@ def _concatenate_tsd(func, tsds): else: raise TypeError - @jit(nopython=True) def _jitfix_iset(start, end): """ @@ -254,7 +285,6 @@ def _jitfix_iset(start, end): return (data, to_warn) - class _TsdFrameSliceHelper: def __init__(self, tsdframe): self.tsdframe = tsdframe @@ -277,7 +307,6 @@ def __getitem__(self, key): (slice(None, None, None), index), columns=key ) - class _IntervalSetSliceHelper: def __init__(self, intervalset): self.intervalset = intervalset diff --git a/tests/test_jax_backend.py b/tests/test_jax_backend.py new file mode 100644 index 00000000..c764a434 --- /dev/null +++ b/tests/test_jax_backend.py @@ -0,0 +1,14 @@ +import pynapple as nap +import numpy as np +import pytest + +import jax +import jax.numpy as jnp + +t = np.arange(10) +d = jnp.arange(10) + +nap.nap_config.set_backend("jax") + +tsd = nap.Tsd(t=t, d=d) + From 28b8b6649d60d1ec6809736e645d6e4bdfca2032 Mon Sep 17 00:00:00 2001 From: gviejo Date: Tue, 27 Feb 2024 11:28:54 -0500 Subject: [PATCH 02/37] Adding not_implemented_in_pynajax --- pynapple/core/time_series.py | 13 ++++++++----- pynapple/core/utils.py | 18 +++++++++++++++++- tests/test_jax_backend.py | 19 ++++++++++++++++++- 3 files changed, 43 insertions(+), 7 deletions(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 1dc22bf7..031fdbbb 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -46,7 +46,8 @@ convert_to_numpy_array, convert_to_jax_array, is_array_like, - get_backend + get_backend, + not_implemented_in_pynajax ) @@ -95,7 +96,9 @@ def __init__(self, t, d, time_units="s", time_support=None): if isinstance(time_support, IntervalSet) and len(self.index): starts = time_support.start ends = time_support.end - t, d = jitrestrict(self.index.values, self.values, starts, ends) + t, d = not_implemented_in_pynajax(jitrestrict, 1, 1, + self.index.values, self.values, starts, ends) + self.index = TsIndex(t) self.values = d self.rate = self.index.shape[0] / np.sum( @@ -618,7 +621,7 @@ def __getitem__(self, key, *args, **kwargs): if isinstance(index, Number): index = np.array([index]) - if all(isinstance(a, np.ndarray) for a in [index, output]): + if all(is_array_like(a) for a in [index, output]): if output.shape[0] == index.shape[0]: if output.ndim == 1: return Tsd(t=index, d=output, time_support=self.time_support) @@ -844,7 +847,7 @@ def __getitem__(self, key, *args, **kwargs): if isinstance(index, Number): index = np.array([index]) - if all(isinstance(a, np.ndarray) for a in [index, output]): + if all(is_array_like(a) for a in [index, output]): if output.shape[0] == index.shape[0]: return _get_class(output)( t=index, d=output, time_support=self.time_support, **kwargs @@ -1060,7 +1063,7 @@ def __getitem__(self, key, *args, **kwargs): if isinstance(index, Number): index = np.array([index]) - if all(isinstance(a, np.ndarray) for a in [index, output]): + if all(is_array_like(a) for a in [index, output]): if output.shape[0] == index.shape[0]: return _get_class(output)( t=index, d=output, time_support=self.time_support, **kwargs diff --git a/pynapple/core/utils.py b/pynapple/core/utils.py index a456a3ef..b570302c 100644 --- a/pynapple/core/utils.py +++ b/pynapple/core/utils.py @@ -2,7 +2,7 @@ # @Author: Guillaume Viejo # @Date: 2024-02-09 11:45:45 # @Last Modified by: gviejo -# @Last Modified time: 2024-02-27 09:47:11 +# @Last Modified time: 2024-02-27 11:26:25 """ Utility functions @@ -16,6 +16,22 @@ from .config import nap_config +def not_implemented_in_pynajax(func, which_in, which_out, *args, **kwargs): + + if nap_config.backend == "jax": + import jax + import jax.numpy as jnp + # def wrapper(*args, **kwargs): + arguments, struct = jax.tree_util.tree_flatten((args, kwargs)) + arguments[which_in] = jax.tree_map(np.asarray, arguments[which_in]) + args, kwargs = jax.tree_util.tree_unflatten(struct, arguments) + out = func(*args, **kwargs) + out = list(out) + out[which_out] = jax.tree_map(jnp.asarray, out[which_out]) + return tuple(out) + else: + # def wrapper(*args, **kwargs): + return func(*args, **kwargs) def convert_to_numpy_array(array, array_name): if isinstance(array, Number): diff --git a/tests/test_jax_backend.py b/tests/test_jax_backend.py index c764a434..18f78400 100644 --- a/tests/test_jax_backend.py +++ b/tests/test_jax_backend.py @@ -10,5 +10,22 @@ nap.nap_config.set_backend("jax") -tsd = nap.Tsd(t=t, d=d) +tsd = nap.Tsd(t=t, d=d, time_support = nap.IntervalSet(start=0, end=5)) + + +# # from numba import jit + + + +# def decoratorcasting(func): + +# def wrapper(array): +# return func(np.asarray(array)) + +# return wrapper + +# @decoratorcasting +# @jit(nopython=True) +# def test_numba(array): +# return np.sum(array) \ No newline at end of file From a2a0d76c2c78c7a8966d86cc8f9011ed24e0844d Mon Sep 17 00:00:00 2001 From: gviejo Date: Tue, 27 Feb 2024 12:22:33 -0500 Subject: [PATCH 03/37] adding test --- pynapple/core/time_series.py | 92 +++++++++++++++++++----------------- pynapple/core/utils.py | 3 +- tests/test_jax_backend.py | 22 +++++++-- 3 files changed, 70 insertions(+), 47 deletions(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 031fdbbb..a5bfc415 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -395,52 +395,58 @@ def convolve(self, array, ep=None, trim="both"): Tsd, TsdFrame or TsdTensor The convolved time series """ - assert isinstance(array, np.ndarray), "Input should be a 1-d numpy array." - assert array.ndim == 1, "Input should be a one dimensional array." - assert trim in [ - "both", - "left", - "right", - ], "Unknow argument. trim should be 'both', 'left' or 'right'." - - if ep is None: - ep = self.time_support - - time_array = self.index.values - data_array = self.values - starts = ep.start - ends = ep.end + # Check if jax backend + if get_backend() == "jax": + from pynajax.jax_core import convolve_epoch + return convolve_epoch(self, array) + else: - if data_array.ndim == 1: - new_data_array = np.zeros(data_array.shape) - k = array.shape[0] - for s, e in zip(starts, ends): - idx_s = np.searchsorted(time_array, s) - idx_e = np.searchsorted(time_array, e, side="right") - - t = idx_e - idx_s - if trim == "left": - cut = (k - 1, t + k - 1) - elif trim == "right": - cut = (0, t) - else: - cut = ((1 - k % 2) + (k - 1) // 2, t + k - 1 - ((k - 1) // 2)) - # scipy is actually faster for Tsd - new_data_array[idx_s:idx_e] = signal.convolve( - data_array[idx_s:idx_e], array - )[cut[0] : cut[1]] + assert isinstance(array, np.ndarray), "Input should be a 1-d numpy array." + assert array.ndim == 1, "Input should be a one dimensional array." + assert trim in [ + "both", + "left", + "right", + ], "Unknow argument. trim should be 'both', 'left' or 'right'." + + if ep is None: + ep = self.time_support + + time_array = self.index.values + data_array = self.values + starts = ep.start + ends = ep.end + + if data_array.ndim == 1: + new_data_array = np.zeros(data_array.shape) + k = array.shape[0] + for s, e in zip(starts, ends): + idx_s = np.searchsorted(time_array, s) + idx_e = np.searchsorted(time_array, e, side="right") + + t = idx_e - idx_s + if trim == "left": + cut = (k - 1, t + k - 1) + elif trim == "right": + cut = (0, t) + else: + cut = ((1 - k % 2) + (k - 1) // 2, t + k - 1 - ((k - 1) // 2)) + # scipy is actually faster for Tsd + new_data_array[idx_s:idx_e] = signal.convolve( + data_array[idx_s:idx_e], array + )[cut[0] : cut[1]] - return self.__class__(t=time_array, d=new_data_array, time_support=ep) - else: - new_data_array = np.zeros(data_array.shape) - for s, e in zip(starts, ends): - idx_s = np.searchsorted(time_array, s) - idx_e = np.searchsorted(time_array, e, side="right") - new_data_array[idx_s:idx_e] = pjitconvolve( - data_array[idx_s:idx_e], array, trim=trim - ) + return self.__class__(t=time_array, d=new_data_array, time_support=ep) + else: + new_data_array = np.zeros(data_array.shape) + for s, e in zip(starts, ends): + idx_s = np.searchsorted(time_array, s) + idx_e = np.searchsorted(time_array, e, side="right") + new_data_array[idx_s:idx_e] = pjitconvolve( + data_array[idx_s:idx_e], array, trim=trim + ) - return self.__class__(t=time_array, d=new_data_array, time_support=ep) + return self.__class__(t=time_array, d=new_data_array, time_support=ep) def smooth(self, std, size): """Smooth a time series with a gaussian kernel. std is the standard deviation and size is the number of point of the window. diff --git a/pynapple/core/utils.py b/pynapple/core/utils.py index b570302c..4e9cff91 100644 --- a/pynapple/core/utils.py +++ b/pynapple/core/utils.py @@ -2,7 +2,7 @@ # @Author: Guillaume Viejo # @Date: 2024-02-09 11:45:45 # @Last Modified by: gviejo -# @Last Modified time: 2024-02-27 11:26:25 +# @Last Modified time: 2024-02-27 11:34:40 """ Utility functions @@ -16,6 +16,7 @@ from .config import nap_config + def not_implemented_in_pynajax(func, which_in, which_out, *args, **kwargs): if nap_config.backend == "jax": diff --git a/tests/test_jax_backend.py b/tests/test_jax_backend.py index 18f78400..3fc70db5 100644 --- a/tests/test_jax_backend.py +++ b/tests/test_jax_backend.py @@ -4,13 +4,29 @@ import jax import jax.numpy as jnp +from time import time -t = np.arange(10) -d = jnp.arange(10) +t = np.arange(1000) +d = jnp.asarray(np.random.randn(1000)) nap.nap_config.set_backend("jax") -tsd = nap.Tsd(t=t, d=d, time_support = nap.IntervalSet(start=0, end=5)) +tsd = nap.Tsd(t=t, d=d)#, time_support = nap.IntervalSet(start=0, end=5)) + +t1 = time() +tsd2 = tsd.convolve(np.ones(10)) +print(time() - t1) + +nap.nap_config.set_backend("numba") + +tsd = nap.Tsd(t=t, d=tsd.values) +t2 = time() +tsd3 = tsd.convolve(np.ones(10)) +print(time() - t2) + + + + From 7237a19b0aaf37fa43bf8dcfed8f7831161d13ba Mon Sep 17 00:00:00 2001 From: gviejo Date: Tue, 27 Feb 2024 13:18:57 -0500 Subject: [PATCH 04/37] comparing --- pynapple/core/time_series.py | 4 ++-- tests/test_jax_backend.py | 30 ++++++++++++++++++++++-------- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index a5bfc415..231ea312 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -397,8 +397,8 @@ def convolve(self, array, ep=None, trim="both"): """ # Check if jax backend if get_backend() == "jax": - from pynajax.jax_core import convolve_epoch - return convolve_epoch(self, array) + from pynajax.jax_core import convolve + return convolve(self, array) else: assert isinstance(array, np.ndarray), "Input should be a 1-d numpy array." diff --git a/tests/test_jax_backend.py b/tests/test_jax_backend.py index 3fc70db5..b0dddb43 100644 --- a/tests/test_jax_backend.py +++ b/tests/test_jax_backend.py @@ -6,23 +6,37 @@ import jax.numpy as jnp from time import time -t = np.arange(1000) -d = jnp.asarray(np.random.randn(1000)) +t = np.arange(10000) +d = jnp.asarray(np.random.randn(10000, 10)) nap.nap_config.set_backend("jax") -tsd = nap.Tsd(t=t, d=d)#, time_support = nap.IntervalSet(start=0, end=5)) +# ep = nap.IntervalSet(start=np.arange(0, 1000, 100),end = np.arange(0, 1000, 100)+50) +tsd = nap.TsdFrame(t=t, d=d)#, time_support = ep) + +kernel = np.ones(10) + +tsd2 = tsd.convolve(kernel) t1 = time() -tsd2 = tsd.convolve(np.ones(10)) -print(time() - t1) +tsd2 = tsd.convolve(kernel) +print("jax", time() - t1) + +from pynajax.jax_core import _convolve_vec + +_convolve_vec(tsd.values, jnp.asarray(kernel)) +t4 = time() +out = _convolve_vec(tsd.values, jnp.asarray(kernel)) +print("convolvevec", time() - t4) nap.nap_config.set_backend("numba") -tsd = nap.Tsd(t=t, d=tsd.values) +tsd = nap.TsdFrame(t=t, d=np.asarray(d))#, time_support = ep) +tsd3 = tsd.convolve(kernel) t2 = time() -tsd3 = tsd.convolve(np.ones(10)) -print(time() - t2) +tsd3 = tsd.convolve(kernel) +print("numba", time() - t2) + From 1168ab4b547c6d26a15870b6f387edab848be84c Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 28 Feb 2024 00:14:57 +0000 Subject: [PATCH 05/37] speed test with great results on CPU --- tests/test_jax_backend.py | 62 +++++++++++++++++++++++++++++++-------- 1 file changed, 50 insertions(+), 12 deletions(-) diff --git a/tests/test_jax_backend.py b/tests/test_jax_backend.py index b0dddb43..2c4317b8 100644 --- a/tests/test_jax_backend.py +++ b/tests/test_jax_backend.py @@ -5,44 +5,82 @@ import jax import jax.numpy as jnp from time import time +from pynajax.jax_core import _convolve_vec, _convolve_mat -t = np.arange(10000) -d = jnp.asarray(np.random.randn(10000, 10)) +t = np.arange(100000) +d = jnp.asarray(np.random.randn(100000, 10)) +if d.ndim >2: + NAPTYPE = nap.TsdTensor + +else: + NAPTYPE = nap.TsdFrame + +CONV_JAX = _convolve_vec + +### +# SET BACKEND JAX +### nap.nap_config.set_backend("jax") -# ep = nap.IntervalSet(start=np.arange(0, 1000, 100),end = np.arange(0, 1000, 100)+50) +ep = nap.IntervalSet(start=np.arange(0, 1000, 100),end = np.arange(0, 1000, 100)+50) -tsd = nap.TsdFrame(t=t, d=d)#, time_support = ep) +tsd_jax = NAPTYPE(t=t, d=d, time_support = ep) kernel = np.ones(10) -tsd2 = tsd.convolve(kernel) +tsd2 = tsd_jax.convolve(kernel) t1 = time() -tsd2 = tsd.convolve(kernel) +tsd2 = tsd_jax.convolve(kernel) print("jax", time() - t1) -from pynajax.jax_core import _convolve_vec -_convolve_vec(tsd.values, jnp.asarray(kernel)) + +CONV_JAX(tsd_jax.values, jnp.asarray(kernel)) t4 = time() -out = _convolve_vec(tsd.values, jnp.asarray(kernel)) +out = CONV_JAX(tsd_jax.values, jnp.asarray(kernel)) print("convolvevec", time() - t4) +### +# SET BACKEND NUMBA +### nap.nap_config.set_backend("numba") -tsd = nap.TsdFrame(t=t, d=np.asarray(d))#, time_support = ep) -tsd3 = tsd.convolve(kernel) +tsd_numpy_one_ep = NAPTYPE(t=t, d=np.asarray(d))#, time_support = ep) +tsd3 = tsd_numpy_one_ep.convolve(kernel) t2 = time() -tsd3 = tsd.convolve(kernel) +tsd3 = tsd_numpy_one_ep.convolve(kernel) print("numba", time() - t2) +print("... testing multi epoch") +tsd_numpy_multi_ep = NAPTYPE(t=t, d=np.asarray(d), time_support = ep) +tsd3 = tsd_numpy_multi_ep.convolve(kernel) +t2 = time() +tsd3 = tsd_numpy_multi_ep.convolve(kernel) +print("numba mutlti-ep", time() - t2) +### +# SET BACKEND JAX +### +nap.nap_config.set_backend("jax") + +tsd_jax = NAPTYPE(t=t, d=d, time_support = ep) +tree_of_jax = [jnp.asarray(tsd_jax.get(s, e).d) for s, e in tsd_jax.time_support.values] +@jax.jit +def map_par(tree, jax_kernel): + return jax.tree_map(lambda x: CONV_JAX(x, jax_kernel), tree) +jkernel = jnp.asarray(kernel) +tree_of_jax = [jnp.asarray(tsd_jax.get(s, e).d) for s, e in tsd_jax.time_support.values] +map_par(tree_of_jax, jkernel) + +t2 = time() +tsd6 = map_par(tree_of_jax, kernel) +print("jax jit pytree", time() - t2) # # from numba import jit From 0854886a7476db6ecae0f2443ed314eb5eb391bb Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 28 Feb 2024 00:50:20 +0000 Subject: [PATCH 06/37] massive speedup on convolve --- tests/test_jax_backend.py | 90 +++++++++++++++++++-------------------- 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/tests/test_jax_backend.py b/tests/test_jax_backend.py index 2c4317b8..6652ddfe 100644 --- a/tests/test_jax_backend.py +++ b/tests/test_jax_backend.py @@ -52,48 +52,48 @@ tsd3 = tsd_numpy_one_ep.convolve(kernel) print("numba", time() - t2) - -print("... testing multi epoch") -tsd_numpy_multi_ep = NAPTYPE(t=t, d=np.asarray(d), time_support = ep) -tsd3 = tsd_numpy_multi_ep.convolve(kernel) -t2 = time() -tsd3 = tsd_numpy_multi_ep.convolve(kernel) -print("numba mutlti-ep", time() - t2) - -### -# SET BACKEND JAX -### -nap.nap_config.set_backend("jax") - -tsd_jax = NAPTYPE(t=t, d=d, time_support = ep) - -tree_of_jax = [jnp.asarray(tsd_jax.get(s, e).d) for s, e in tsd_jax.time_support.values] - - -@jax.jit -def map_par(tree, jax_kernel): - return jax.tree_map(lambda x: CONV_JAX(x, jax_kernel), tree) - -jkernel = jnp.asarray(kernel) -tree_of_jax = [jnp.asarray(tsd_jax.get(s, e).d) for s, e in tsd_jax.time_support.values] -map_par(tree_of_jax, jkernel) - -t2 = time() -tsd6 = map_par(tree_of_jax, kernel) -print("jax jit pytree", time() - t2) - -# # from numba import jit - - - -# def decoratorcasting(func): - -# def wrapper(array): -# return func(np.asarray(array)) - -# return wrapper - -# @decoratorcasting -# @jit(nopython=True) -# def test_numba(array): -# return np.sum(array) \ No newline at end of file +# +# print("... testing multi epoch") +# tsd_numpy_multi_ep = NAPTYPE(t=t, d=np.asarray(d), time_support = ep) +# tsd3 = tsd_numpy_multi_ep.convolve(kernel) +# t2 = time() +# tsd3 = tsd_numpy_multi_ep.convolve(kernel) +# print("numba mutlti-ep", time() - t2) +# +# ### +# # SET BACKEND JAX +# ### +# nap.nap_config.set_backend("jax") +# +# tsd_jax = NAPTYPE(t=t, d=d, time_support = ep) +# +# tree_of_jax = [jnp.asarray(tsd_jax.get(s, e).d) for s, e in tsd_jax.time_support.values] +# +# +# @jax.jit +# def map_par(tree, jax_kernel): +# return jax.tree_map(lambda x: CONV_JAX(x, jax_kernel), tree) +# +# jkernel = jnp.asarray(kernel) +# tree_of_jax = [jnp.asarray(tsd_jax.get(s, e).d) for s, e in tsd_jax.time_support.values] +# map_par(tree_of_jax, jkernel) +# +# t2 = time() +# tsd6 = map_par(tree_of_jax, kernel) +# print("jax jit pytree", time() - t2) +# +# # # from numba import jit +# +# +# +# # def decoratorcasting(func): +# +# # def wrapper(array): +# # return func(np.asarray(array)) +# +# # return wrapper +# +# # @decoratorcasting +# # @jit(nopython=True) +# # def test_numba(array): +# # return np.sum(array) \ No newline at end of file From 8dcd108aaf3576072d7faed330b357a88a74e42e Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 28 Feb 2024 00:53:04 +0000 Subject: [PATCH 07/37] maxi-speedup --- tests/test_jax_backend.py | 58 ++++----------------------------------- 1 file changed, 6 insertions(+), 52 deletions(-) diff --git a/tests/test_jax_backend.py b/tests/test_jax_backend.py index 6652ddfe..09c51c16 100644 --- a/tests/test_jax_backend.py +++ b/tests/test_jax_backend.py @@ -23,23 +23,23 @@ ### nap.nap_config.set_backend("jax") -ep = nap.IntervalSet(start=np.arange(0, 1000, 100),end = np.arange(0, 1000, 100)+50) +ep = nap.IntervalSet(start=np.arange(0, 1000, 100), end = np.arange(0, 1000, 100)+50) -tsd_jax = NAPTYPE(t=t, d=d, time_support = ep) +tsd_jax = NAPTYPE(t=t, d=d, time_support=ep) kernel = np.ones(10) tsd2 = tsd_jax.convolve(kernel) t1 = time() tsd2 = tsd_jax.convolve(kernel) -print("jax", time() - t1) +print("pynajax convolve multi-epoch", time() - t1) CONV_JAX(tsd_jax.values, jnp.asarray(kernel)) t4 = time() -out = CONV_JAX(tsd_jax.values, jnp.asarray(kernel)) -print("convolvevec", time() - t4) +out = _convolve_vec(tsd_jax.values, jnp.asarray(kernel)) +print("convolve-vec", time() - t4) ### # SET BACKEND NUMBA @@ -50,50 +50,4 @@ tsd3 = tsd_numpy_one_ep.convolve(kernel) t2 = time() tsd3 = tsd_numpy_one_ep.convolve(kernel) -print("numba", time() - t2) - -# -# print("... testing multi epoch") -# tsd_numpy_multi_ep = NAPTYPE(t=t, d=np.asarray(d), time_support = ep) -# tsd3 = tsd_numpy_multi_ep.convolve(kernel) -# t2 = time() -# tsd3 = tsd_numpy_multi_ep.convolve(kernel) -# print("numba mutlti-ep", time() - t2) -# -# ### -# # SET BACKEND JAX -# ### -# nap.nap_config.set_backend("jax") -# -# tsd_jax = NAPTYPE(t=t, d=d, time_support = ep) -# -# tree_of_jax = [jnp.asarray(tsd_jax.get(s, e).d) for s, e in tsd_jax.time_support.values] -# -# -# @jax.jit -# def map_par(tree, jax_kernel): -# return jax.tree_map(lambda x: CONV_JAX(x, jax_kernel), tree) -# -# jkernel = jnp.asarray(kernel) -# tree_of_jax = [jnp.asarray(tsd_jax.get(s, e).d) for s, e in tsd_jax.time_support.values] -# map_par(tree_of_jax, jkernel) -# -# t2 = time() -# tsd6 = map_par(tree_of_jax, kernel) -# print("jax jit pytree", time() - t2) -# -# # # from numba import jit -# -# -# -# # def decoratorcasting(func): -# -# # def wrapper(array): -# # return func(np.asarray(array)) -# -# # return wrapper -# -# # @decoratorcasting -# # @jit(nopython=True) -# # def test_numba(array): -# # return np.sum(array) \ No newline at end of file +print("numba multi-epoch", time() - t2) From 1d79b6a208c3423828de5c9d1193e4ead59f79a4 Mon Sep 17 00:00:00 2001 From: gviejo Date: Wed, 28 Feb 2024 11:41:19 -0500 Subject: [PATCH 08/37] working version of convolve --- pynapple/core/time_series.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 231ea312..92f474ea 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -398,9 +398,9 @@ def convolve(self, array, ep=None, trim="both"): # Check if jax backend if get_backend() == "jax": from pynajax.jax_core import convolve - return convolve(self, array) + new_data_array = convolve(self, array) + return self.__class__(t=self.index, d=new_data_array, time_support=self.time_support) else: - assert isinstance(array, np.ndarray), "Input should be a 1-d numpy array." assert array.ndim == 1, "Input should be a one dimensional array." assert trim in [ From f7eed3d341aa6982a530550cd3aaa2ad89fde478 Mon Sep 17 00:00:00 2001 From: gviejo Date: Thu, 29 Feb 2024 10:30:37 -0500 Subject: [PATCH 09/37] Update --- tests/test_jax_backend.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/tests/test_jax_backend.py b/tests/test_jax_backend.py index 09c51c16..e91d9674 100644 --- a/tests/test_jax_backend.py +++ b/tests/test_jax_backend.py @@ -5,10 +5,10 @@ import jax import jax.numpy as jnp from time import time -from pynajax.jax_core import _convolve_vec, _convolve_mat +from pynajax.jax_core import _convolve_vec, _convolve_mat, _jit_tree_convolve_1d_kernel t = np.arange(100000) -d = jnp.asarray(np.random.randn(100000, 10)) +d = jnp.asarray(np.random.randn(100000, 100)) if d.ndim >2: NAPTYPE = nap.TsdTensor @@ -25,22 +25,31 @@ ep = nap.IntervalSet(start=np.arange(0, 1000, 100), end = np.arange(0, 1000, 100)+50) -tsd_jax = NAPTYPE(t=t, d=d, time_support=ep) +tsd_jax = NAPTYPE(t=t, d=d)#, time_support=ep) kernel = np.ones(10) +jkernel = jnp.asarray(kernel) -tsd2 = tsd_jax.convolve(kernel) +tsd2 = tsd_jax.convolve(jkernel) t1 = time() -tsd2 = tsd_jax.convolve(kernel) -print("pynajax convolve multi-epoch", time() - t1) +tsd2 = tsd_jax.convolve(jkernel) +print("pynajax convolve", time() - t1) +print(tsd2) -CONV_JAX(tsd_jax.values, jnp.asarray(kernel)) + +# CONV_JAX(tsd_jax.values, jnp.asarray(kernel)) +out = _convolve_vec(tsd_jax.values, jkernel) t4 = time() -out = _convolve_vec(tsd_jax.values, jnp.asarray(kernel)) +out = _convolve_vec(tsd_jax.values, jkernel) print("convolve-vec", time() - t4) +data = _jit_tree_convolve_1d_kernel(d, jkernel) +t5 = time() +data = _jit_tree_convolve_1d_kernel(d, jkernel) +print("_jit_tree_convolve_1d_kernel", time() - t5) + ### # SET BACKEND NUMBA ### From cb01bc2bade91c577424c79ddb1fd67a0b8b6eeb Mon Sep 17 00:00:00 2001 From: gviejo Date: Tue, 5 Mar 2024 16:57:34 -0500 Subject: [PATCH 10/37] Update --- pynapple/core/base_class.py | 22 ++++----- pynapple/core/core_function.py | 83 ++++++++++++++++++++++++++++++++++ pynapple/core/time_series.py | 81 ++++++++------------------------- 3 files changed, 111 insertions(+), 75 deletions(-) create mode 100644 pynapple/core/core_function.py diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index ef0ba8ad..c3db8afe 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -8,14 +8,7 @@ import numpy as np -from ._jitted_functions import ( - jitcount, - jitrestrict, - jittsrestrict, - jittsrestrict_with_count, - jitvaluefrom, - jitvaluefromtensor, -) +from .core_function import _restrict from .interval_set import IntervalSet from .time_index import TsIndex from .utils import cast_to_numpy, is_array_like @@ -365,21 +358,22 @@ def restrict(self, iset): assert isinstance(iset, IntervalSet), "Argument should be IntervalSet" time_array = self.index.values + data_array = None starts = iset.start ends = iset.end if hasattr(self, "values"): data_array = self.values - t, d = jitrestrict(time_array, data_array, starts, ends) - kwargs = {} - if hasattr(self, "columns"): - kwargs["columns"] = self.columns + t, d = _restrict(time_array, data_array, starts, ends) - return self.__class__(t=t, d=d, time_support=iset, **kwargs) + kwargs = {} + if hasattr(self, "columns"): + kwargs["columns"] = self.columns + if hasattr(self, "values"): + return self.__class__(t=t, d=d, time_support=iset, **kwargs) else: - t = jittsrestrict(time_array, starts, ends) return self.__class__(t=t, time_support=iset) def copy(self): diff --git a/pynapple/core/core_function.py b/pynapple/core/core_function.py new file mode 100644 index 00000000..0cb82325 --- /dev/null +++ b/pynapple/core/core_function.py @@ -0,0 +1,83 @@ +""" + This module holds the core function of pynapple as well as + the dispatch between numba and jax. + + If pynajax is installed and `nap.nap_config.backend` is set + to `jax`, the module will call the functions within pynajax. + Otherwise the module will call the functions within `_jitted_functions.py`. + +""" +import numpy as np +from ._jitted_functions import ( + jitbin, + jitbin_array, + jitremove_nan, + jitrestrict, + jitthreshold, + jittsrestrict, + pjitconvolve, +) +from .utils import get_backend + +def _convolve(time_array, data_array, starts, ends, array, trim="both"): + if get_backend() == "jax": + # Todo convert to jax if numpy + from pynajax.jax_core_convolve import convolve + return convolve(time_array, data_array, starts, ends, array) + else: + if data_array.ndim == 1: + new_data_array = np.zeros(data_array.shape) + k = array.shape[0] + for s, e in zip(starts, ends): + idx_s = np.searchsorted(time_array, s) + idx_e = np.searchsorted(time_array, e, side="right") + + t = idx_e - idx_s + if trim == "left": + cut = (k - 1, t + k - 1) + elif trim == "right": + cut = (0, t) + else: + cut = ((1 - k % 2) + (k - 1) // 2, t + k - 1 - ((k - 1) // 2)) + # scipy is actually faster for Tsd + new_data_array[idx_s:idx_e] = signal.convolve( + data_array[idx_s:idx_e], array + )[cut[0] : cut[1]] + + return new_data_array + else: + new_data_array = np.zeros(data_array.shape) + for s, e in zip(starts, ends): + idx_s = np.searchsorted(time_array, s) + idx_e = np.searchsorted(time_array, e, side="right") + new_data_array[idx_s:idx_e] = pjitconvolve( + data_array[idx_s:idx_e], array, trim=trim + ) + + return new_data_array + +def _restrict(time_array, data_array, starts, ends): + if get_backend() == "jax": + # Todo convert to jax if numpy + from pynajax.jax_core_restrict import restrict + return restrict(time_array, data_array, starts, ends) + else: + if data_array is not None: + return jitrestrict(time_array, data_array, starts, ends) + else: + return (jittsrestrict(time_array, starts, ends), None) + +def _value_from(): + pass + +def _count(): + pass + +def _bin_average(): + pass + +def _interpolate(): + pass + +def _threshold(): + pass diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 92f474ea..0659891a 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -27,15 +27,7 @@ from scipy import signal from tabulate import tabulate -from ._jitted_functions import ( - jitbin, - jitbin_array, - jitremove_nan, - jitrestrict, - jitthreshold, - jittsrestrict, - pjitconvolve, -) +from .core_function import _convolve from .base_class import Base from .interval_set import IntervalSet from .time_index import TsIndex @@ -383,8 +375,8 @@ def convolve(self, array, ep=None, trim="both"): Parameters ---------- - array : np.ndarray - One dimensional input array + array : array-like + ep : None, optional The epochs to apply the convolution trim : str, optional @@ -395,58 +387,25 @@ def convolve(self, array, ep=None, trim="both"): Tsd, TsdFrame or TsdTensor The convolved time series """ - # Check if jax backend - if get_backend() == "jax": - from pynajax.jax_core import convolve - new_data_array = convolve(self, array) - return self.__class__(t=self.index, d=new_data_array, time_support=self.time_support) - else: - assert isinstance(array, np.ndarray), "Input should be a 1-d numpy array." - assert array.ndim == 1, "Input should be a one dimensional array." - assert trim in [ - "both", - "left", - "right", - ], "Unknow argument. trim should be 'both', 'left' or 'right'." - - if ep is None: - ep = self.time_support - - time_array = self.index.values - data_array = self.values - starts = ep.start - ends = ep.end - - if data_array.ndim == 1: - new_data_array = np.zeros(data_array.shape) - k = array.shape[0] - for s, e in zip(starts, ends): - idx_s = np.searchsorted(time_array, s) - idx_e = np.searchsorted(time_array, e, side="right") - - t = idx_e - idx_s - if trim == "left": - cut = (k - 1, t + k - 1) - elif trim == "right": - cut = (0, t) - else: - cut = ((1 - k % 2) + (k - 1) // 2, t + k - 1 - ((k - 1) // 2)) - # scipy is actually faster for Tsd - new_data_array[idx_s:idx_e] = signal.convolve( - data_array[idx_s:idx_e], array - )[cut[0] : cut[1]] + assert is_array_like(array), "Input should be a numpy array (or jax array if pynajax is installed)." + assert array.ndim in [1,2], "Input should be one or two dimensional array." + assert trim in [ + "both", + "left", + "right", + ], "Unknow argument. trim should be 'both', 'left' or 'right'." + + if ep is None: + ep = self.time_support - return self.__class__(t=time_array, d=new_data_array, time_support=ep) - else: - new_data_array = np.zeros(data_array.shape) - for s, e in zip(starts, ends): - idx_s = np.searchsorted(time_array, s) - idx_e = np.searchsorted(time_array, e, side="right") - new_data_array[idx_s:idx_e] = pjitconvolve( - data_array[idx_s:idx_e], array, trim=trim - ) + time_array = self.index.values + data_array = self.values + starts = ep.start + ends = ep.end + + new_data_array = _convolve(time_array, data_array, kernel, ep, trim) - return self.__class__(t=time_array, d=new_data_array, time_support=ep) + return self.__class__(t=time_array, d=new_data_array, time_support=ep) def smooth(self, std, size): """Smooth a time series with a gaussian kernel. std is the standard deviation and size is the number of point of the window. From 55827f7d940632cea4f049a2714bcb1f1b8fb03b Mon Sep 17 00:00:00 2001 From: gviejo Date: Thu, 7 Mar 2024 07:53:23 -0500 Subject: [PATCH 11/37] Working version with jax backend --- pynapple/core/base_class.py | 46 +++++++++++++------------ pynapple/core/config.py | 13 +++---- pynapple/core/core_function.py | 19 ++++++++--- pynapple/core/interval_set.py | 9 ++--- pynapple/core/time_index.py | 3 +- pynapple/core/time_series.py | 35 ++++++++++++------- pynapple/core/ts_group.py | 2 +- pynapple/core/utils.py | 47 ++++++++++++++++++-------- tests/test_jax_backend.py | 62 ---------------------------------- tests/test_time_series.py | 4 +-- 10 files changed, 108 insertions(+), 132 deletions(-) delete mode 100644 tests/test_jax_backend.py diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index c3db8afe..895b3f26 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -11,8 +11,23 @@ from .core_function import _restrict from .interval_set import IntervalSet from .time_index import TsIndex -from .utils import cast_to_numpy, is_array_like - +from .utils import ( + cast_to_numpy, + convert_to_jax_array, + convert_to_numpy_array, + get_backend, + is_array_like, +) + +# to remove +from ._jitted_functions import ( + jitcount, + jitrestrict, + jittsrestrict, + jittsrestrict_with_count, + jitvaluefrom, + jitvaluefromtensor, +) class Base(abc.ABC): """ @@ -23,25 +38,14 @@ class Base(abc.ABC): _initialized = False def __init__(self, t, time_units="s", time_support=None): - # Converting t to TsIndex array + if isinstance(t, TsIndex): self.index = t - elif isinstance(t, Number): - self.index = TsIndex(np.array([t]), time_units) - elif isinstance(t, (list, tuple)): - self.index = TsIndex(np.array(t).flatten(), time_units) - elif isinstance(t, np.ndarray): - assert t.ndim == 1, "t should be 1 dimensional" - self.index = TsIndex(t, time_units) - # convert array-like data to numpy. - # raise a warning to avoid silent conversion if non-numpy array is provided (jax arrays for instance) - elif is_array_like(t): - t = cast_to_numpy(t, "t") - self.index = TsIndex(t, time_units) else: - raise RuntimeError( - "Unknown format for t. Accepted formats are numpy.ndarray, list, tuple or any array-like objects." - ) + if get_backend() == "jax": + self.index = TsIndex(convert_to_jax_array(t, "t"), time_units) + else: + self.index = TsIndex(convert_to_numpy_array(t, "t"), time_units) if time_support is not None: assert isinstance( @@ -365,16 +369,16 @@ def restrict(self, iset): if hasattr(self, "values"): data_array = self.values - t, d = _restrict(time_array, data_array, starts, ends) + out = _restrict(time_array, data_array, starts, ends) kwargs = {} if hasattr(self, "columns"): kwargs["columns"] = self.columns if hasattr(self, "values"): - return self.__class__(t=t, d=d, time_support=iset, **kwargs) + return self.__class__(t=out[0], d=out[1], time_support=iset, **kwargs) else: - return self.__class__(t=t, time_support=iset) + return self.__class__(t=out, time_support=iset) def copy(self): """Copy the data, index and time support""" diff --git a/pynapple/core/config.py b/pynapple/core/config.py index 28f3c118..88d32baa 100644 --- a/pynapple/core/config.py +++ b/pynapple/core/config.py @@ -54,9 +54,9 @@ class PynappleConfig: def __init__(self): self.suppress_conversion_warnings = False - self.suppress_time_index_sorting_warnings = False + self.suppress_time_index_sorting_warnings = False self.backend = "numba" - + @property def backend(self): """ @@ -66,7 +66,7 @@ def backend(self): @backend.setter def backend(self, backend): - self.set_backend(backend) + self.set_backend(backend) def set_backend(self, backend): @@ -75,13 +75,14 @@ def set_backend(self, backend): # Try to import pynajax if backend == "jax": try: - import pynajax as nax + import pynajax as nax + self._backend = "jax" except ImportError: warnings.warn( "Importing pynajax failed. Falling back to numba. To use the jax backend for pynapple, please install pynajax", UserWarning, - ) + ) self._backend = "numba" else: self._backend = "numba" @@ -90,11 +91,11 @@ def set_backend(self, backend): def get_pynajax_backend_status(self): try: import pynajax + self.backend = "jax" except: self.backend = "numba" - @property def time_index_precision(self): """Precision for the time index diff --git a/pynapple/core/core_function.py b/pynapple/core/core_function.py index 0cb82325..7a37f75d 100644 --- a/pynapple/core/core_function.py +++ b/pynapple/core/core_function.py @@ -7,7 +7,9 @@ Otherwise the module will call the functions within `_jitted_functions.py`. """ + import numpy as np + from ._jitted_functions import ( jitbin, jitbin_array, @@ -18,12 +20,13 @@ pjitconvolve, ) from .utils import get_backend +from scipy import signal -def _convolve(time_array, data_array, starts, ends, array, trim="both"): +def _convolve(time_array, data_array, starts, ends, array, trim="both"): if get_backend() == "jax": - # Todo convert to jax if numpy from pynajax.jax_core_convolve import convolve - return convolve(time_array, data_array, starts, ends, array) + + return convolve(time_array, data_array, starts, ends, array) else: if data_array.ndim == 1: new_data_array = np.zeros(data_array.shape) @@ -56,28 +59,34 @@ def _convolve(time_array, data_array, starts, ends, array, trim="both"): return new_data_array + def _restrict(time_array, data_array, starts, ends): if get_backend() == "jax": - # Todo convert to jax if numpy from pynajax.jax_core_restrict import restrict + return restrict(time_array, data_array, starts, ends) else: if data_array is not None: return jitrestrict(time_array, data_array, starts, ends) else: - return (jittsrestrict(time_array, starts, ends), None) + return jittsrestrict(time_array, starts, ends) + def _value_from(): pass + def _count(): pass + def _bin_average(): pass + def _interpolate(): pass + def _threshold(): pass diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 1e3e1c53..3fe6e3a0 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -47,12 +47,7 @@ from ._jitted_functions import jitdiff, jitin_interval, jitintersect, jitunion from .config import nap_config from .time_index import TsIndex -from .utils import ( - _IntervalSetSliceHelper, - _jitfix_iset, - cast_to_numpy, - is_array_like, -) +from .utils import _IntervalSetSliceHelper, _jitfix_iset, cast_to_numpy, is_array_like all_warnings = np.array( [ @@ -124,7 +119,7 @@ def __init__(self, start, end=None, time_units="s", **kwargs): elif isinstance(data, np.ndarray): args[arg] = np.ravel(data) elif is_array_like(data): - args[arg] = convert_to_numpy(data, arg) + args[arg] = cast_to_numpy(data, arg) else: raise RuntimeError( "Unknown format for {}. Accepted formats are numpy.ndarray, list, tuple or any array-like objects.".format( diff --git a/pynapple/core/time_index.py b/pynapple/core/time_index.py index c1c1333d..4d8f8f0c 100644 --- a/pynapple/core/time_index.py +++ b/pynapple/core/time_index.py @@ -110,7 +110,8 @@ def sort_timestamps(t, give_warning=True): return t def __new__(cls, t, time_units="s"): - t = t.astype(np.float64).flatten() + assert t.ndim == 1, "t should be 1 dimensional" + t = t.astype(np.float64) t = TsIndex.format_timestamps(t, time_units) t = TsIndex.sort_timestamps(t) obj = np.asarray(t).view(cls) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 7c04d2d8..86ec3e35 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -24,24 +24,34 @@ import numpy as np import pandas as pd from numpy.lib.mixins import NDArrayOperatorsMixin -from scipy import signal from tabulate import tabulate -from .core_function import _convolve from .base_class import Base +from .core_function import _convolve, _restrict from .interval_set import IntervalSet from .time_index import TsIndex from .utils import ( _concatenate_tsd, _split_tsd, _TsdFrameSliceHelper, - convert_to_numpy_array, convert_to_jax_array, - is_array_like, + convert_to_numpy_array, get_backend, - not_implemented_in_pynajax + is_array_like, + not_implemented_in_pynajax, ) +# to remove +from ._jitted_functions import ( + jitbin, + jitbin_array, + jitremove_nan, + jitrestrict, + jitthreshold, + jittsrestrict, + pjitconvolve, +) +from scipy import signal def _get_class(data): """Select the right time series object and return the class @@ -88,8 +98,7 @@ def __init__(self, t, d, time_units="s", time_support=None): if isinstance(time_support, IntervalSet) and len(self.index): starts = time_support.start ends = time_support.end - t, d = not_implemented_in_pynajax(jitrestrict, 1, 1, - self.index.values, self.values, starts, ends) + t, d = _restrict(self.index.values, self.values, starts, ends) self.index = TsIndex(t) self.values = d @@ -376,7 +385,7 @@ def convolve(self, array, ep=None, trim="both"): Parameters ---------- array : array-like - + ep : None, optional The epochs to apply the convolution trim : str, optional @@ -387,8 +396,10 @@ def convolve(self, array, ep=None, trim="both"): Tsd, TsdFrame or TsdTensor The convolved time series """ - assert is_array_like(array), "Input should be a numpy array (or jax array if pynajax is installed)." - assert array.ndim in [1,2], "Input should be one or two dimensional array." + assert is_array_like( + array + ), "Input should be a numpy array (or jax array if pynajax is installed)." + assert array.ndim == 1, "Input should be a one dimensional array." assert trim in [ "both", "left", @@ -403,7 +414,7 @@ def convolve(self, array, ep=None, trim="both"): starts = ep.start ends = ep.end - new_data_array = _convolve(time_array, data_array, kernel, ep, trim) + new_data_array = _convolve(time_array, data_array, starts, ends, array, trim) return self.__class__(t=time_array, d=new_data_array, time_support=ep) @@ -1283,7 +1294,7 @@ def __init__(self, t, time_units="s", time_support=None): if isinstance(time_support, IntervalSet) and len(self.index): starts = time_support.start ends = time_support.end - t = jittsrestrict(self.index.values, starts, ends) + t = _restrict(self.index.values, None, starts, ends) self.index = TsIndex(t) self.rate = self.index.shape[0] / np.sum( time_support.values[:, 1] - time_support.values[:, 0] diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 8db17957..05371f7f 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -113,7 +113,7 @@ def __init__( stacklevel=2, ) data[k] = Ts( - t=convert_to_numpy(data[k], "key {}".format(k)), + t=cast_to_numpy(data[k], "key {}".format(k)), time_support=time_support, time_units=time_units, ) diff --git a/pynapple/core/utils.py b/pynapple/core/utils.py index 7f066782..52579e73 100644 --- a/pynapple/core/utils.py +++ b/pynapple/core/utils.py @@ -2,13 +2,14 @@ # @Author: Guillaume Viejo # @Date: 2024-02-09 11:45:45 # @Last Modified by: gviejo -# @Last Modified time: 2024-03-03 06:28:59 +# @Last Modified time: 2024-03-06 16:34:07 """ Utility functions """ import warnings +from numbers import Number import numpy as np from numba import jit @@ -18,24 +19,26 @@ def not_implemented_in_pynajax(func, which_in, which_out, *args, **kwargs): - if nap_config.backend == "jax": - import jax - import jax.numpy as jnp + if nap_config.backend == "jax": + import jax + import jax.numpy as jnp + # def wrapper(*args, **kwargs): - arguments, struct = jax.tree_util.tree_flatten((args, kwargs)) - arguments[which_in] = jax.tree_map(np.asarray, arguments[which_in]) - args, kwargs = jax.tree_util.tree_unflatten(struct, arguments) - out = func(*args, **kwargs) - out = list(out) - out[which_out] = jax.tree_map(jnp.asarray, out[which_out]) - return tuple(out) + arguments, struct = jax.tree_util.tree_flatten((args, kwargs)) + arguments[which_in] = jax.tree_map(np.asarray, arguments[which_in]) + args, kwargs = jax.tree_util.tree_unflatten(struct, arguments) + out = func(*args, **kwargs) + out = list(out) + out[which_out] = jax.tree_map(jnp.asarray, out[which_out]) + return tuple(out) else: # def wrapper(*args, **kwargs): return func(*args, **kwargs) + def convert_to_numpy_array(array, array_name): if isinstance(array, Number): - return np.array([d]) + return np.array([array]) elif isinstance(array, (list, tuple)): return np.array(array) elif isinstance(array, np.ndarray): @@ -44,13 +47,17 @@ def convert_to_numpy_array(array, array_name): return cast_to_numpy(array, array_name) else: raise RuntimeError( - "Unknown format for d. Accepted formats are numpy.ndarray, list, tuple or any array-like objects." + "Unknown format for {}. Accepted formats are numpy.ndarray, list, tuple or any array-like objects.".format( + array_name + ) ) + def convert_to_jax_array(array, array_name): import jax.numpy as jnp + if isinstance(array, Number): - return jnp.array([d]) + return jnp.array([array]) elif isinstance(array, (list, tuple)): return jnp.array(array) elif isinstance(array, jnp.ndarray): @@ -59,15 +66,19 @@ def convert_to_jax_array(array, array_name): return jnp.asarray(array) else: raise RuntimeError( - "Unknown format for d. Accepted formats are numpy.ndarray, list, tuple or any array-like objects." + "Unknown format for {}. Accepted formats are numpy.ndarray, list, tuple or any array-like objects.".format( + array_name + ) ) + def get_backend(): """ Return the current backend of pynapple """ return nap_config.backend + def is_array_like(obj): """ Check if an object is array-like. @@ -126,6 +137,7 @@ def is_array_like(obj): # and not_tsd_type ) + def cast_to_numpy(array, array_name): """ Convert an input array-like object to a NumPy array. @@ -168,6 +180,7 @@ def cast_to_numpy(array, array_name): ) return np.asarray(array) + def _split_tsd(func, tsd, indices_or_sections, axis=0): """ Wrappers of numpy split functions @@ -184,6 +197,7 @@ def _split_tsd(func, tsd, indices_or_sections, axis=0): else: return func._implementation(tsd.values, indices_or_sections, axis) + def _concatenate_tsd(func, tsds): """ Wrappers of np.concatenate and np.vstack @@ -227,6 +241,7 @@ def _concatenate_tsd(func, tsds): else: raise TypeError + @jit(nopython=True) def _jitfix_iset(start, end): """ @@ -301,6 +316,7 @@ def _jitfix_iset(start, end): return (data, to_warn) + class _TsdFrameSliceHelper: def __init__(self, tsdframe): self.tsdframe = tsdframe @@ -323,6 +339,7 @@ def __getitem__(self, key): (slice(None, None, None), index), columns=key ) + class _IntervalSetSliceHelper: """ This class helps `IntervalSet` behaves like pandas.DataFrame for the `loc` function. diff --git a/tests/test_jax_backend.py b/tests/test_jax_backend.py deleted file mode 100644 index e91d9674..00000000 --- a/tests/test_jax_backend.py +++ /dev/null @@ -1,62 +0,0 @@ -import pynapple as nap -import numpy as np -import pytest - -import jax -import jax.numpy as jnp -from time import time -from pynajax.jax_core import _convolve_vec, _convolve_mat, _jit_tree_convolve_1d_kernel - -t = np.arange(100000) -d = jnp.asarray(np.random.randn(100000, 100)) - -if d.ndim >2: - NAPTYPE = nap.TsdTensor - -else: - NAPTYPE = nap.TsdFrame - -CONV_JAX = _convolve_vec - -### -# SET BACKEND JAX -### -nap.nap_config.set_backend("jax") - -ep = nap.IntervalSet(start=np.arange(0, 1000, 100), end = np.arange(0, 1000, 100)+50) - -tsd_jax = NAPTYPE(t=t, d=d)#, time_support=ep) - -kernel = np.ones(10) -jkernel = jnp.asarray(kernel) - -tsd2 = tsd_jax.convolve(jkernel) -t1 = time() -tsd2 = tsd_jax.convolve(jkernel) -print("pynajax convolve", time() - t1) -print(tsd2) - - - - -# CONV_JAX(tsd_jax.values, jnp.asarray(kernel)) -out = _convolve_vec(tsd_jax.values, jkernel) -t4 = time() -out = _convolve_vec(tsd_jax.values, jkernel) -print("convolve-vec", time() - t4) - -data = _jit_tree_convolve_1d_kernel(d, jkernel) -t5 = time() -data = _jit_tree_convolve_1d_kernel(d, jkernel) -print("_jit_tree_convolve_1d_kernel", time() - t5) - -### -# SET BACKEND NUMBA -### -nap.nap_config.set_backend("numba") - -tsd_numpy_one_ep = NAPTYPE(t=t, d=np.asarray(d))#, time_support = ep) -tsd3 = tsd_numpy_one_ep.convolve(kernel) -t2 = time() -tsd3 = tsd_numpy_one_ep.convolve(kernel) -print("numba multi-epoch", time() - t2) diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 24806da4..674d0de5 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-04-01 09:57:55 # @Last Modified by: gviejo -# @Last Modified time: 2024-02-29 10:34:07 +# @Last Modified time: 2024-03-07 07:23:42 #!/usr/bin/env python """Tests of time series for `pynapple` package.""" @@ -446,7 +446,7 @@ def test_convolve(self, tsd): with pytest.raises(AssertionError) as e_info: tsd.convolve([1,2,3]) - assert str(e_info.value) == "Input should be a 1-d numpy array." + assert str(e_info.value) == "Input should be a numpy array (or jax array if pynajax is installed)." with pytest.raises(AssertionError) as e_info: tsd.convolve(np.random.rand(2,3)) From 76b8bf2d3b3b020ccd8c7f1a05f7bfadf573132c Mon Sep 17 00:00:00 2001 From: gviejo Date: Fri, 8 Mar 2024 12:41:17 -0500 Subject: [PATCH 12/37] Adding count --- pynapple/core/base_class.py | 34 ++++++----------------------- pynapple/core/config.py | 2 -- pynapple/core/core_function.py | 39 +++++++++++++++++++++++++++------- pynapple/core/time_series.py | 19 +++++++---------- 4 files changed, 46 insertions(+), 48 deletions(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 895b3f26..63808bea 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -8,26 +8,15 @@ import numpy as np -from .core_function import _restrict +from .core_function import _restrict, _count, _value_from from .interval_set import IntervalSet from .time_index import TsIndex -from .utils import ( - cast_to_numpy, +from .utils import ( convert_to_jax_array, convert_to_numpy_array, get_backend, - is_array_like, ) -# to remove -from ._jitted_functions import ( - jitcount, - jitrestrict, - jittsrestrict, - jittsrestrict_with_count, - jitvaluefrom, - jitvaluefromtensor, -) class Base(abc.ABC): """ @@ -210,14 +199,7 @@ def value_from(self, data, ep=None): starts = ep.start ends = ep.end - if data_target_array.ndim == 1: - t, d, ns, ne = jitvaluefrom( - time_array, time_target_array, data_target_array, starts, ends - ) - else: - t, d, ns, ne = jitvaluefromtensor( - time_array, time_target_array, data_target_array, starts, ends - ) + t, d, ns, ne = _value_from(time_array, time_target_array, data_target_array, starts, ends) time_support = IntervalSet(start=ns, end=ne) @@ -304,6 +286,9 @@ def count(self, *args, **kwargs): if isinstance(a, str) and a in ["s", "ms", "us"]: time_units = a + if isinstance(bin_size, (float, int)): + bin_size = TsIndex.format_timestamps(np.array([bin_size]), time_units)[0] + ep = self.time_support if "ep" in kwargs: ep = kwargs["ep"] @@ -318,12 +303,7 @@ def count(self, *args, **kwargs): starts = ep.start ends = ep.end - if isinstance(bin_size, (float, int)): - bin_size = TsIndex.format_timestamps(np.array([bin_size]), time_units)[0] - t, d = jitcount(time_array, starts, ends, bin_size) - else: - _, d = jittsrestrict_with_count(time_array, starts, ends) - t = starts + (ends - starts) / 2 + t, d = _count(time_array, starts, ends, bin_size) return t, d, ep diff --git a/pynapple/core/config.py b/pynapple/core/config.py index 88d32baa..b8876deb 100644 --- a/pynapple/core/config.py +++ b/pynapple/core/config.py @@ -76,7 +76,6 @@ def set_backend(self, backend): if backend == "jax": try: import pynajax as nax - self._backend = "jax" except ImportError: warnings.warn( @@ -91,7 +90,6 @@ def set_backend(self, backend): def get_pynajax_backend_status(self): try: import pynajax - self.backend = "jax" except: self.backend = "numba" diff --git a/pynapple/core/core_function.py b/pynapple/core/core_function.py index 7a37f75d..d7f57f7e 100644 --- a/pynapple/core/core_function.py +++ b/pynapple/core/core_function.py @@ -9,6 +9,7 @@ """ import numpy as np +from scipy import signal from ._jitted_functions import ( jitbin, @@ -18,9 +19,13 @@ jitthreshold, jittsrestrict, pjitconvolve, + jitcount, + jittsrestrict_with_count, + jitvaluefrom, + jitvaluefromtensor, ) from .utils import get_backend -from scipy import signal + def _convolve(time_array, data_array, starts, ends, array, trim="both"): if get_backend() == "jax": @@ -63,7 +68,6 @@ def _convolve(time_array, data_array, starts, ends, array, trim="both"): def _restrict(time_array, data_array, starts, ends): if get_backend() == "jax": from pynajax.jax_core_restrict import restrict - return restrict(time_array, data_array, starts, ends) else: if data_array is not None: @@ -72,13 +76,32 @@ def _restrict(time_array, data_array, starts, ends): return jittsrestrict(time_array, starts, ends) -def _value_from(): - pass - - -def _count(): - pass +def _count(time_array, starts, ends, bin_size=None): + if get_backend() == "jax": + from pynajax.jax_core_count import count + return count(time_array, starts, ends, bin_size) + else: + if isinstance(bin_size, (float, int)): + return jitcount(time_array, starts, ends, bin_size) + else: + _, d = jittsrestrict_with_count(time_array, starts, ends) + t = starts + (ends - starts) / 2 + return t, d +def _value_from(time_array, time_target_array, data_target_array, starts, ends): + if get_backend() == "jax": + from pynajax.jax_core_value_from import value_from + return value_from(time_array, time_target_array, data_target_array, starts, ends) + else: + if data_target_array.ndim == 1: + t, d, ns, ne = jitvaluefrom( + time_array, time_target_array, data_target_array, starts, ends + ) + else: + t, d, ns, ne = jitvaluefromtensor( + time_array, time_target_array, data_target_array, starts, ends + ) + return t, d, ns, ne def _bin_average(): pass diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 86ec3e35..8afd4e2d 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -24,8 +24,16 @@ import numpy as np import pandas as pd from numpy.lib.mixins import NDArrayOperatorsMixin +from scipy import signal from tabulate import tabulate +# to remove +from ._jitted_functions import ( + jitbin, + jitbin_array, + jitremove_nan, + jitthreshold, +) from .base_class import Base from .core_function import _convolve, _restrict from .interval_set import IntervalSet @@ -41,17 +49,6 @@ not_implemented_in_pynajax, ) -# to remove -from ._jitted_functions import ( - jitbin, - jitbin_array, - jitremove_nan, - jitrestrict, - jitthreshold, - jittsrestrict, - pjitconvolve, -) -from scipy import signal def _get_class(data): """Select the right time series object and return the class From e2271b26c0273b236882fe2b457776308737da7b Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 3 Apr 2024 11:05:51 -0400 Subject: [PATCH 13/37] Updating history --- docs/HISTORY.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/HISTORY.md b/docs/HISTORY.md index 02c53ef5..2a100e31 100644 --- a/docs/HISTORY.md +++ b/docs/HISTORY.md @@ -16,6 +16,7 @@ In 2021, Guillaume and other trainees in Adrien's lab decided to fork from neuro - Fixed `TsGroup` saving method. - `__getattr__` of `BaseTsd` allow numpy functions to be attached as attributes of Tsd objects - Added `get` method for `TsGroup` +- Tsds can be concatenate vertically if time indexes matches. 0.6.1 (2024-03-03) From 4f1bf1216fab67b5b6c489a47cddd2453291298c Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 4 Apr 2024 15:30:11 -0400 Subject: [PATCH 14/37] working version with backend dispatch --- pynapple/__init__.py | 11 ++- pynapple/core/base_class.py | 21 ++-- pynapple/core/config.py | 111 ++++++++++++--------- pynapple/core/core_function.py | 76 +++++++++++--- pynapple/core/time_series.py | 159 ++++++++++++++++++++++-------- pynapple/core/ts_group.py | 1 - pynapple/core/utils.py | 78 +++++++++++---- pynapple/process/correlograms.py | 12 +-- tests/test_config.py | 57 ++++++++--- tests/test_correlograms.py | 12 +-- tests/test_interval_set.py | 10 +- tests/test_numpy_compatibility.py | 7 +- 12 files changed, 396 insertions(+), 159 deletions(-) diff --git a/pynapple/__init__.py b/pynapple/__init__.py index 0855144a..570fdc53 100644 --- a/pynapple/__init__.py +++ b/pynapple/__init__.py @@ -1,4 +1,13 @@ __version__ = "0.6.1" -from .core import IntervalSet, Ts, Tsd, TsdFrame, TsdTensor, TsGroup, TsIndex, config +from .core import ( + IntervalSet, + Ts, + Tsd, + TsdFrame, + TsdTensor, + TsGroup, + TsIndex, + nap_config, +) from .io import * from .process import * diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 87b01b90..367cdad3 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -8,14 +8,10 @@ import numpy as np -from .core_function import _restrict, _count, _value_from +from .core_function import _count, _restrict, _value_from from .interval_set import IntervalSet from .time_index import TsIndex -from .utils import ( - convert_to_jax_array, - convert_to_numpy_array, - get_backend, -) +from .utils import convert_to_jax_array, convert_to_numpy_array, get_backend class Base(abc.ABC): @@ -199,7 +195,9 @@ def value_from(self, data, ep=None): starts = ep.start ends = ep.end - t, d, ns, ne = _value_from(time_array, time_target_array, data_target_array, starts, ends) + t, d, ns, ne = _value_from( + time_array, time_target_array, data_target_array, starts, ends + ) time_support = IntervalSet(start=ns, end=ne) @@ -286,9 +284,6 @@ def count(self, *args, **kwargs): if isinstance(a, str) and a in ["s", "ms", "us"]: time_units = a - if isinstance(bin_size, (float, int)): - bin_size = TsIndex.format_timestamps(np.array([bin_size]), time_units)[0] - ep = self.time_support if "ep" in kwargs: ep = kwargs["ep"] @@ -299,10 +294,14 @@ def count(self, *args, **kwargs): if isinstance(a, IntervalSet): ep = a - time_array = self.index.values starts = ep.start ends = ep.end + if isinstance(bin_size, (float, int)): + bin_size = TsIndex.format_timestamps(np.array([bin_size]), time_units)[0] + + time_array = self.index.values + t, d = _count(time_array, starts, ends, bin_size) return t, d, ep diff --git a/pynapple/core/config.py b/pynapple/core/config.py index b8876deb..97eaafaa 100644 --- a/pynapple/core/config.py +++ b/pynapple/core/config.py @@ -1,6 +1,62 @@ -"""This module deals with package configurations. For now it includes only warning configurations. +"""This module controls the pynapple configurations. + +## Backend configuration + +By default, pynapple core functions are compiled with [Numba](https://numba.pydata.org/). +It is possible to change the backend to [Jax](https://jax.readthedocs.io/en/latest/index.html) +through the [pynajax package](https://github.com/pynapple-org/pynajax). + +While numba core functions runs on CPU, the `jax` backend allows pynapple to use GPU accelerated core functions. +For some core functions, the `jax` backend offers speed gains (provided that Jax runs on the GPU). + +See the example below to update the backend. Don't forget to install [pynajax](https://github.com/pynapple-org/pynajax). + +``` py +import pynapple as nap +import numpy as np +nap.nap_config.set_backend("jax") # Default option is 'numba'. +``` + +You can view the current backend with +``` py +>>> print(nap.nap_config.backend) +'jax' +``` + +## Warnings configuration + +pynapple gives warnings that can be helpful to debug. For example when passing time indexes that are not sorted: + +``` py +>>> import pynapple as nap +>>> t = [0, 2, 1] +>>> nap.Ts(t) +UserWarning: timestamps are not sorted + warn("timestamps are not sorted", UserWarning) +Time (s) +0.0 +1.0 +2.0 +shape: 3 +``` + +pynapple's warnings can be suppressed : + +``` py +>>> nap.nap_config.suppress_time_index_sorting_warnings = True +>>> nap.Ts(t=t) +Time (s) +0.0 +1.0 +2.0 +shape: 3 +``` + """ +import importlib.util +import warnings + class PynappleConfig: """ @@ -9,37 +65,10 @@ class PynappleConfig: This class includes all configuration settings that control the behavior of pynapple. It offers a structured way to access and modify settings. - Examples - -------- - >>> import pynapple as nap - >>> import jax.numpy as jnp - >>> t = jnp.arange(3) - >>> print(t) - Array([0, 1, 2], dtype=int32) - - >>> # Suppress warnings when converting a non-numpy array to numpy array - >>> nap.config.nap_config.suppress_conversion_warnings = True - >>> nap.Ts(t=t) - Time (s) - 0.0 - 1.0 - 2.0 - shape: 3 - - >>> # Restore to defaults - >>> nap.config.nap_config.restore_defaults() - >>> nap.Ts(t=t) - /mnt/home/gviejo/pynapple/pynapple/core/time_series.py:151: UserWarning: Converting 't' to n - umpy.array. The provided array was of type 'ArrayImpl'. - warnings.warn( - Time (s) - 0.0 - 1.0 - 2.0 - shape: 3 - Attributes ---------- + backend : str + Current pynapple backend. Options are ('numba' [default], 'jax') suppress_conversion_warnings : boolean Determines whether to suppress warnings when automatically converting non-NumPy array-like objects to NumPy arrays. This is useful for users who frequently work with array-like objects from other @@ -49,7 +78,7 @@ class PynappleConfig: Control the warning raised when passing a non-sorted array for time index. It can be useful to catch data where timestamps are not properly sorted before using pynapple. time_index_precision : int - Precision for the time index is set to nanoseconds. It's a fixed parameter in pynapple and cannot be changed. + Number of decimal places to round time index. Pynapple's precision is set by default to 9. """ def __init__(self): @@ -74,26 +103,18 @@ def set_backend(self, backend): # Try to import pynajax if backend == "jax": - try: - import pynajax as nax - self._backend = "jax" - except ImportError: + spec = importlib.util.find_spec("pynajax") + if spec is None: warnings.warn( - "Importing pynajax failed. Falling back to numba. To use the jax backend for pynapple, please install pynajax", - UserWarning, + "Package pynajax is not found. Falling back to numba backend. To use the jax backend for pynapple, please install pynajax", + stacklevel=2, ) self._backend = "numba" + else: + self._backend = "jax" else: self._backend = "numba" - @property - def get_pynajax_backend_status(self): - try: - import pynajax - self.backend = "jax" - except: - self.backend = "numba" - @property def time_index_precision(self): """Precision for the time index diff --git a/pynapple/core/core_function.py b/pynapple/core/core_function.py index d7f57f7e..3ccba773 100644 --- a/pynapple/core/core_function.py +++ b/pynapple/core/core_function.py @@ -14,15 +14,15 @@ from ._jitted_functions import ( jitbin, jitbin_array, + jitcount, jitremove_nan, jitrestrict, jitthreshold, jittsrestrict, - pjitconvolve, - jitcount, - jittsrestrict_with_count, + jittsrestrict_with_count, jitvaluefrom, - jitvaluefromtensor, + jitvaluefromtensor, + pjitconvolve, ) from .utils import get_backend @@ -68,6 +68,7 @@ def _convolve(time_array, data_array, starts, ends, array, trim="both"): def _restrict(time_array, data_array, starts, ends): if get_backend() == "jax": from pynajax.jax_core_restrict import restrict + return restrict(time_array, data_array, starts, ends) else: if data_array is not None: @@ -79,19 +80,24 @@ def _restrict(time_array, data_array, starts, ends): def _count(time_array, starts, ends, bin_size=None): if get_backend() == "jax": from pynajax.jax_core_count import count + return count(time_array, starts, ends, bin_size) else: if isinstance(bin_size, (float, int)): - return jitcount(time_array, starts, ends, bin_size) + return jitcount(time_array, starts, ends, bin_size) else: _, d = jittsrestrict_with_count(time_array, starts, ends) t = starts + (ends - starts) / 2 return t, d + def _value_from(time_array, time_target_array, data_target_array, starts, ends): if get_backend() == "jax": from pynajax.jax_core_value_from import value_from - return value_from(time_array, time_target_array, data_target_array, starts, ends) + + return value_from( + time_array, time_target_array, data_target_array, starts, ends + ) else: if data_target_array.ndim == 1: t, d, ns, ne = jitvaluefrom( @@ -103,13 +109,59 @@ def _value_from(time_array, time_target_array, data_target_array, starts, ends): ) return t, d, ns, ne -def _bin_average(): - pass + +def _bin_average(time_array, data_array, starts, ends, bin_size): + if get_backend() == "jax": + from pynajax.jax_core_bin_average import bin_average + + return bin_average(time_array, data_array, starts, ends, bin_size) + else: + if data_array.ndim > 1: + t, d = jitbin_array(time_array, data_array, starts, ends, bin_size) + else: + t, d = jitbin(time_array, data_array, starts, ends, bin_size) + + return t, d -def _interpolate(): - pass +def _threshold(time_array, data_array, starts, ends, thr, method): + if get_backend() == "jax": + from pynajax.jax_core_threshold import threshold + + return threshold(time_array, data_array, starts, ends, thr, method) + else: + return jitthreshold(time_array, data_array, starts, ends, thr, method) -def _threshold(): - pass +def _dropna(time_array, data_array, starts, ends, update_time_support, ndim): + if get_backend() == "jax": + from pynajax.jax_core_dropna import dropna + + return dropna(time_array, data_array, starts, ends, update_time_support) + else: + index_nan = np.any(np.isnan(data_array), axis=tuple(range(1, ndim))) + if np.all(index_nan): # In case it's only NaNs + if update_time_support: + starts = None + ends = None + return ( + np.array([]), + np.empty(tuple([0] + [d for d in data_array.shape[1:]])), + starts, + ends, + ) + elif np.any(index_nan): + if update_time_support: + starts, ends = jitremove_nan(time_array, index_nan) + + to_fix = starts == ends + if np.any(to_fix): + ends[ + to_fix + ] += 1e-6 # adding 1 millisecond in case of a single point + + return (time_array[~index_nan], data_array[~index_nan], starts, ends) + else: + return (time_array[~index_nan], data_array[~index_nan], starts, ends) + else: + return (time_array, data_array, starts, ends) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 8da079d7..82c99b59 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -27,15 +27,8 @@ from scipy import signal from tabulate import tabulate -# to remove -from ._jitted_functions import ( - jitbin, - jitbin_array, - jitremove_nan, - jitthreshold, -) from .base_class import Base -from .core_function import _convolve, _restrict +from .core_function import _bin_average, _convolve, _dropna, _restrict, _threshold from .interval_set import IntervalSet from .time_index import TsIndex from .utils import ( @@ -46,7 +39,6 @@ convert_to_numpy_array, get_backend, is_array_like, - not_implemented_in_pynajax, ) @@ -264,6 +256,43 @@ def copy(self): ) def value_from(self, data, ep=None): + """ + Replace the value with the closest value from Tsd/TsdFrame/TsdTensor argument + + Parameters + ---------- + data : Tsd, TsdFrame or TsdTensor + The object holding the values to replace. + ep : IntervalSet (optional) + The IntervalSet object to restrict the operation. + If None, the time support of the tsd input object is used. + + Returns + ------- + out : Tsd, TsdFrame or TsdTensor + Object with the new values + + Examples + -------- + In this example, the ts object will receive the closest values in time from tsd. + + >>> import pynapple as nap + >>> import numpy as np + >>> t = np.unique(np.sort(np.random.randint(0, 1000, 100))) # random times + >>> ts = nap.Ts(t=t, time_units='s') + >>> tsd = nap.Tsd(t=np.arange(0,1000), d=np.random.rand(1000), time_units='s') + >>> ep = nap.IntervalSet(start = 0, end = 500, time_units = 's') + + The variable ts is a time series object containing only nan. + The tsd object containing the values, for example the tracking data, and the epoch to restrict the operation. + + >>> newts = ts.value_from(tsd, ep) + + newts is the same size as ts restrict to ep. + + >>> print(len(ts.restrict(ep)), len(newts)) + 52 52 + """ assert isinstance( data, BaseTsd ), "First argument should be an instance of Tsd, TsdFrame or TsdTensor" @@ -272,6 +301,60 @@ def value_from(self, data, ep=None): return data.__class__(t=t, d=d, time_support=time_support, **kwargs) def count(self, *args, **kwargs): + """ + Count occurences of events within bin_size or within a set of bins defined as an IntervalSet. + You can call this function in multiple ways : + + 1. *tsd.count(bin_size=1, time_units = 'ms')* + -> Count occurence of events within a 1 ms bin defined on the time support of the object. + + 2. *tsd.count(1, ep=my_epochs)* + -> Count occurent of events within a 1 second bin defined on the IntervalSet my_epochs. + + 3. *tsd.count(ep=my_bins)* + -> Count occurent of events within each epoch of the intervalSet object my_bins + + 4. *tsd.count()* + -> Count occurent of events within each epoch of the time support. + + bin_size should be seconds unless specified. + If bin_size is used and no epochs is passed, the data will be binned based on the time support of the object. + + Parameters + ---------- + bin_size : None or float, optional + The bin size (default is second) + ep : None or IntervalSet, optional + IntervalSet to restrict the operation + time_units : str, optional + Time units of bin size ('us', 'ms', 's' [default]) + + Returns + ------- + out: Tsd + A Tsd object indexed by the center of the bins. + + Examples + -------- + This example shows how to count events within bins of 0.1 second. + + >>> import pynapple as nap + >>> import numpy as np + >>> t = np.unique(np.sort(np.random.randint(0, 1000, 100))) + >>> ts = nap.Ts(t=t, time_units='s') + >>> bincount = ts.count(0.1) + + An epoch can be specified: + + >>> ep = nap.IntervalSet(start = 100, end = 800, time_units = 's') + >>> bincount = ts.count(0.1, ep=ep) + + And bincount automatically inherit ep as time support: + + >>> bincount.time_support + start end + 0 100.0 800.0 + """ t, d, ep = super().count(*args, **kwargs) return Tsd(t=t, d=d, time_support=ep) @@ -324,14 +407,13 @@ def bin_average(self, bin_size, ep=None, time_units="s"): data_array = self.values starts = ep.start ends = ep.end - if data_array.ndim > 1: - t, d = jitbin_array(time_array, data_array, starts, ends, bin_size) - else: - t, d = jitbin(time_array, data_array, starts, ends, bin_size) + + t, d = _bin_average(time_array, data_array, starts, ends, bin_size) kwargs = {} if hasattr(self, "columns"): kwargs["columns"] = self.columns + return self.__class__(t=t, d=d, time_support=ep, **kwargs) def dropna(self, update_time_support=True): @@ -347,34 +429,30 @@ def dropna(self, update_time_support=True): Tsd, TsdFrame or TsdTensor The time series without the NaNs """ - index_nan = np.any(np.isnan(self.values), axis=tuple(range(1, self.ndim))) - if np.all(index_nan): # In case it's only NaNs - return self.__class__( - t=np.array([]), d=np.empty(tuple([0] + [d for d in self.shape[1:]])) - ) + assert isinstance(update_time_support, bool) - elif np.any(index_nan): - if update_time_support: - time_array = self.index.values - starts, ends = jitremove_nan(time_array, index_nan) + time_array = self.index.values + data_array = self.values + starts = self.time_support.start + ends = self.time_support.end - to_fix = starts == ends - if np.any(to_fix): - ends[ - to_fix - ] += 1e-6 # adding 1 millisecond in case of a single point + t, d, starts, ends = _dropna( + time_array, data_array, starts, ends, update_time_support, self.ndim + ) + if update_time_support: + if is_array_like(starts) and is_array_like(ends): ep = IntervalSet(starts, ends) - - return self.__class__( - t=time_array[~index_nan], d=self.values[~index_nan], time_support=ep - ) - else: - return self[~index_nan] - + ep = None else: - return self + ep = self.time_support + + kwargs = {} + if hasattr(self, "columns"): + kwargs["columns"] = self.columns + + return self.__class__(t=t, d=d, time_support=ep) def convolve(self, array, ep=None, trim="both"): """Return the discrete linear convolution of the time series with a one dimensional sequence. @@ -1150,16 +1228,17 @@ def threshold(self, thr, method="above"): >>> 0 50.5 99.0 """ - time_array = self.index.values - data_array = self.values - starts = self.time_support.start - ends = self.time_support.end if method not in ["above", "below", "aboveequal", "belowequal"]: raise ValueError( "Method {} for thresholding is not accepted.".format(method) ) - t, d, ns, ne = jitthreshold(time_array, data_array, starts, ends, thr, method) + time_array = self.index.values + data_array = self.values + starts = self.time_support.start + ends = self.time_support.end + + t, d, ns, ne = _threshold(time_array, data_array, starts, ends, thr, method) time_support = IntervalSet(start=ns, end=ne) return Tsd(t=t, d=d, time_support=time_support) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index c82759e8..1d45f9e8 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -537,7 +537,6 @@ def count(self, *args, **kwargs): ends = ep.end if isinstance(bin_size, (float, int)): - bin_size = float(bin_size) bin_size = TsIndex.format_timestamps(np.array([bin_size]), time_units)[0] time_index, _ = jitcount(np.array([]), starts, ends, bin_size) n = len(self.index) diff --git a/pynapple/core/utils.py b/pynapple/core/utils.py index 81f580c9..d1813927 100644 --- a/pynapple/core/utils.py +++ b/pynapple/core/utils.py @@ -2,42 +2,60 @@ # @Author: Guillaume Viejo # @Date: 2024-02-09 11:45:45 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-04-02 16:19:08 +# @Last Modified time: 2024-04-03 11:54:27 """ Utility functions """ import warnings -from numbers import Number from itertools import combinations +from numbers import Number import numpy as np from numba import jit from .config import nap_config +# def not_implemented_in_pynajax(func, which_in, which_out, *args, **kwargs): -def not_implemented_in_pynajax(func, which_in, which_out, *args, **kwargs): - - if nap_config.backend == "jax": - import jax - import jax.numpy as jnp +# if nap_config.backend == "jax": +# import jax +# import jax.numpy as jnp - # def wrapper(*args, **kwargs): - arguments, struct = jax.tree_util.tree_flatten((args, kwargs)) - arguments[which_in] = jax.tree_map(np.asarray, arguments[which_in]) - args, kwargs = jax.tree_util.tree_unflatten(struct, arguments) - out = func(*args, **kwargs) - out = list(out) - out[which_out] = jax.tree_map(jnp.asarray, out[which_out]) - return tuple(out) - else: - # def wrapper(*args, **kwargs): - return func(*args, **kwargs) +# # def wrapper(*args, **kwargs): +# arguments, struct = jax.tree_util.tree_flatten((args, kwargs)) +# arguments[which_in] = jax.tree_map(np.asarray, arguments[which_in]) +# args, kwargs = jax.tree_util.tree_unflatten(struct, arguments) +# out = func(*args, **kwargs) +# out = list(out) +# out[which_out] = jax.tree_map(jnp.asarray, out[which_out]) +# return tuple(out) +# else: +# # def wrapper(*args, **kwargs): +# return func(*args, **kwargs) def convert_to_numpy_array(array, array_name): + """Convert any array like object to numpy ndarray. + + Parameters + ---------- + array : ArrayLike + + array_name : str + Array name if RuntimeError is raised + + Returns + ------- + numpy.ndarray + Numpy array object + + Raises + ------ + RuntimeError + If input can't be converted to numpy array + """ if isinstance(array, Number): return np.array([array]) elif isinstance(array, (list, tuple)): @@ -55,6 +73,25 @@ def convert_to_numpy_array(array, array_name): def convert_to_jax_array(array, array_name): + """Convert any array like object to jax Array. + + Parameters + ---------- + array : ArrayLike + + array_name : str + Array name if RuntimeError is raised or object is casted to numpy + + Returns + ------- + jax.Array + Jax array object + + Raises + ------ + RuntimeError + If input can't be converted to jax array + """ import jax.numpy as jnp if isinstance(array, Number): @@ -75,7 +112,8 @@ def convert_to_jax_array(array, array_name): def get_backend(): """ - Return the current backend of pynapple + Return the current backend of pynapple. Possible backends are + 'numba' or 'jax'. """ return nap_config.backend @@ -189,7 +227,7 @@ def _check_time_equals(time_arrays): Parameters ---------- - time_arrays : list + time_arrays : list of arrays The time arrays to compare to each other Returns diff --git a/pynapple/process/correlograms.py b/pynapple/process/correlograms.py index e47f1161..e49bd293 100644 --- a/pynapple/process/correlograms.py +++ b/pynapple/process/correlograms.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-01-02 11:39:55 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-09-21 15:48:15 +# @Last Modified time: 2024-04-04 10:46:03 from itertools import combinations, product @@ -18,7 +18,7 @@ # CORRELATION ######################################################### @jit(nopython=True) -def cross_correlogram(t1, t2, binsize, windowsize): +def _cross_correlogram(t1, t2, binsize, windowsize): """ Performs the discrete cross-correlogram of two time series. The units should be in s for all arguments. @@ -142,7 +142,7 @@ def compute_autocorrelogram( for n in newgroup.keys(): spk_time = newgroup[n].index - auc, times = cross_correlogram(spk_time, spk_time, binsize, windowsize) + auc, times = _cross_correlogram(spk_time, spk_time, binsize, windowsize) autocorrs[n] = pd.Series(index=np.round(times, 6), data=auc, dtype="float") autocorrs = pd.DataFrame.from_dict(autocorrs) @@ -225,7 +225,7 @@ def compute_crosscorrelogram( for i, j in pairs: spk1 = newgroup[i].index spk2 = newgroup[j].index - auc, times = cross_correlogram(spk1, spk2, binsize, windowsize) + auc, times = _cross_correlogram(spk1, spk2, binsize, windowsize) crosscorrs[(i, j)] = pd.Series(index=times, data=auc, dtype="float") crosscorrs = pd.DataFrame.from_dict(crosscorrs) @@ -252,7 +252,7 @@ def compute_crosscorrelogram( for i, j in pairs: spk1 = newgroup[0][i].index spk2 = newgroup[1][j].index - auc, times = cross_correlogram(spk1, spk2, binsize, windowsize) + auc, times = _cross_correlogram(spk1, spk2, binsize, windowsize) if norm: auc /= newgroup[1][j].rate crosscorrs[(i, j)] = pd.Series(index=times, data=auc, dtype="float") @@ -327,7 +327,7 @@ def compute_eventcorrelogram( for n in newgroup.keys(): spk_time = newgroup[n].index - auc, times = cross_correlogram(tsd1, spk_time, binsize, windowsize) + auc, times = _cross_correlogram(tsd1, spk_time, binsize, windowsize) crosscorrs[n] = pd.Series(index=times, data=auc, dtype="float") crosscorrs = pd.DataFrame.from_dict(crosscorrs) diff --git a/tests/test_config.py b/tests/test_config.py index 09b6ed8a..e63ffd8c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -6,6 +6,8 @@ import numpy as np +import warnings + class MockArray: """ A mock array class designed for testing purposes. It mimics the behavior of array-like objects @@ -54,7 +56,40 @@ def __len__(self): """ return len(self.data) +################################## +# Test for backend +################################## + +def test_change_backend(): + nap.nap_config.set_backend("numba") + + assert nap.core.utils.get_backend() == "numba" + assert nap.nap_config.backend == "numba" + with pytest.raises(AssertionError, match="Options for backend are 'jax' or 'numba'"): + nap.nap_config.set_backend("blabla") + + # For local tests. + # Should not be installed for github actions + try: + import pynajax + + nap.nap_config.set_backend("jax") + assert nap.core.utils.get_backend() == "jax" + assert nap.nap_config.backend == "jax" + + except ModuleNotFoundError: + + with warnings.catch_warnings(record=True) as w: + nap.nap_config.set_backend("jax") + + assert str(w[0].message) == 'Package pynajax is not found. Falling back to numba backend. To use the jax backend for pynapple, please install pynajax' + assert nap.core.utils.get_backend() == "numba" + assert nap.nap_config.backend == "numba" + +################################## +# Tests for warnings +################################## @pytest.mark.parametrize("param, expectation", [ (True, does_not_raise()), @@ -65,20 +100,20 @@ def __len__(self): def test_config_setter_input_validity(param, expectation): """Test setting suppress_conversion_warnings with various inputs to validate type checking.""" with expectation: - nap.config.nap_config.suppress_conversion_warnings = param + nap.nap_config.suppress_conversion_warnings = param def test_config_setter_output(): """Test if suppress_conversion_warnings property correctly retains a True value after being set.""" - nap.config.nap_config.suppress_conversion_warnings = True - assert nap.config.nap_config.suppress_conversion_warnings + nap.nap_config.suppress_conversion_warnings = True + assert nap.nap_config.suppress_conversion_warnings def test_config_restore_default(): """Test if the restore_defaults method correctly resets suppress_conversion_warnings to its default.""" - nap.config.nap_config.suppress_conversion_warnings = True - nap.config.nap_config.restore_defaults() - assert not nap.config.nap_config.suppress_conversion_warnings + nap.nap_config.suppress_conversion_warnings = True + nap.nap_config.restore_defaults() + assert not nap.nap_config.suppress_conversion_warnings @pytest.mark.parametrize("cls, t, d, conf, expectation", @@ -99,7 +134,7 @@ def test_config_restore_default(): ]) def test_config_supress_warning_t(cls, t, d, conf, expectation): """Test if the restore_defaults method correctly resets suppress_conversion_warnings to its default.""" - nap.config.nap_config.suppress_conversion_warnings = conf + nap.nap_config.suppress_conversion_warnings = conf try: with expectation: if d is None: @@ -107,7 +142,7 @@ def test_config_supress_warning_t(cls, t, d, conf, expectation): else: cls(t=MockArray(t), d=d) finally: - nap.config.nap_config.restore_defaults() + nap.nap_config.restore_defaults() @pytest.mark.parametrize("cls, t, d, conf, expectation", [ @@ -124,13 +159,13 @@ def test_config_supress_warning_t(cls, t, d, conf, expectation): ]) def test_config_supress_warning_d(cls, t, d, conf, expectation): """Test if the restore_defaults method correctly resets suppress_conversion_warnings to its default.""" - nap.config.nap_config.suppress_conversion_warnings = conf + nap.nap_config.suppress_conversion_warnings = conf try: with expectation: cls(t=t, d=MockArray(d)) finally: - nap.config.nap_config.restore_defaults() + nap.nap_config.restore_defaults() def test_get_time_index_precision(): - assert nap.config.nap_config.time_index_precision == 9 \ No newline at end of file + assert nap.nap_config.time_index_precision == 9 \ No newline at end of file diff --git a/tests/test_correlograms.py b/tests/test_correlograms.py index b8294ac2..9e31e45d 100644 --- a/tests/test_correlograms.py +++ b/tests/test_correlograms.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-03-30 11:16:22 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-01-29 15:06:58 +# @Last Modified time: 2024-04-04 10:46:31 #!/usr/bin/env python """Tests of correlograms for `pynapple` package.""" @@ -17,19 +17,19 @@ def test_cross_correlogram(): t1 = np.array([0]) t2 = np.array([1]) - cc, bincenter = nap.process.correlograms.cross_correlogram(t1, t2, 1, 100) + cc, bincenter = nap.process.correlograms._cross_correlogram(t1, t2, 1, 100) np.testing.assert_approx_equal(cc[101], 1.0) - cc, bincenter = nap.process.correlograms.cross_correlogram(t2, t1, 1, 100) + cc, bincenter = nap.process.correlograms._cross_correlogram(t2, t1, 1, 100) np.testing.assert_approx_equal(cc[99], 1.0) t1 = np.array([0]) t2 = np.array([100]) - cc, bincenter = nap.process.correlograms.cross_correlogram(t1, t2, 1, 100) + cc, bincenter = nap.process.correlograms._cross_correlogram(t1, t2, 1, 100) np.testing.assert_approx_equal(cc[200], 1.0) t1 = np.array([0, 10]) - cc, bincenter = nap.process.correlograms.cross_correlogram(t1, t1, 1, 100) + cc, bincenter = nap.process.correlograms._cross_correlogram(t1, t1, 1, 100) np.testing.assert_approx_equal(cc[100], 1.0) np.testing.assert_approx_equal(cc[90], 0.5) np.testing.assert_approx_equal(cc[110], 0.5) @@ -38,7 +38,7 @@ def test_cross_correlogram(): for t in [100, 200, 1000]: np.testing.assert_array_almost_equal( - nap.process.correlograms.cross_correlogram(np.arange(0, t), np.arange(0, t), 1, t)[0], + nap.process.correlograms._cross_correlogram(np.arange(0, t), np.arange(0, t), 1, t)[0], np.hstack( (np.arange(0, 1, 1 / t), np.ones(1), np.arange(0, 1, 1 / t)[::-1]) ), diff --git a/tests/test_interval_set.py b/tests/test_interval_set.py index c8667f64..ef381e79 100644 --- a/tests/test_interval_set.py +++ b/tests/test_interval_set.py @@ -3,7 +3,7 @@ # @Author: gviejo # @Date: 2022-03-30 11:15:02 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-03-29 11:04:32 +# @Last Modified time: 2024-04-04 10:33:04 """Tests for IntervalSet of `pynapple` package.""" @@ -203,11 +203,11 @@ def test_array_ufunc(): # test warning from contextlib import nullcontext as does_not_raise - nap.config.nap_config.suppress_conversion_warnings = True + nap.nap_config.suppress_conversion_warnings = True with does_not_raise(): np.exp(ep) - nap.config.nap_config.suppress_conversion_warnings = False + nap.nap_config.suppress_conversion_warnings = False def test_array_func(): start = np.array([0, 10, 16], dtype=np.float64) @@ -226,11 +226,11 @@ def test_array_func(): # test warning from contextlib import nullcontext as does_not_raise - nap.config.nap_config.suppress_conversion_warnings = True + nap.nap_config.suppress_conversion_warnings = True with does_not_raise(): out = np.ravel(ep) - nap.config.nap_config.suppress_conversion_warnings = False + nap.nap_config.suppress_conversion_warnings = False def test_timespan(): start = [0, 10, 16, 25] diff --git a/tests/test_numpy_compatibility.py b/tests/test_numpy_compatibility.py index e8289bbb..cafc42b0 100644 --- a/tests/test_numpy_compatibility.py +++ b/tests/test_numpy_compatibility.py @@ -2,7 +2,7 @@ # @Author: Guillaume Viejo # @Date: 2023-09-18 18:11:24 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-04-01 17:09:04 +# @Last Modified time: 2024-04-04 11:34:35 @@ -154,6 +154,11 @@ def test_funcs(self, tsd): def test_attributes(self, tsd): assert tsd.min() == tsd.values.min() + with pytest.raises(AttributeError) as e_info: + tsd.blabla() + + assert str(e_info.value) == "Time series object does not have the attribute blabla" + def test_split(self, tsd): a = np.split(tsd, 4) b = np.split(tsd.values, 4) From f2381617f897ba7c10c7b7c62fd075d5aa1d7758 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Fri, 12 Apr 2024 11:51:26 -0400 Subject: [PATCH 15/37] keeping t as numpy array --- pynapple/core/base_class.py | 7 ++----- pynapple/core/core_function.py | 6 ++++++ 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 367cdad3..e73208af 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -11,7 +11,7 @@ from .core_function import _count, _restrict, _value_from from .interval_set import IntervalSet from .time_index import TsIndex -from .utils import convert_to_jax_array, convert_to_numpy_array, get_backend +from .utils import convert_to_numpy_array class Base(abc.ABC): @@ -27,10 +27,7 @@ def __init__(self, t, time_units="s", time_support=None): if isinstance(t, TsIndex): self.index = t else: - if get_backend() == "jax": - self.index = TsIndex(convert_to_jax_array(t, "t"), time_units) - else: - self.index = TsIndex(convert_to_numpy_array(t, "t"), time_units) + self.index = TsIndex(convert_to_numpy_array(t, "t"), time_units) if time_support is not None: assert isinstance( diff --git a/pynapple/core/core_function.py b/pynapple/core/core_function.py index 3ccba773..2b511d5a 100644 --- a/pynapple/core/core_function.py +++ b/pynapple/core/core_function.py @@ -165,3 +165,9 @@ def _dropna(time_array, data_array, starts, ends, update_time_support, ndim): return (time_array[~index_nan], data_array[~index_nan], starts, ends) else: return (time_array, data_array, starts, ends) + + +# def _sta(): + + +# def _interp From e542d7539752925ffcf65e1e6db5b5ebe10e4652 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 18 Apr 2024 12:53:06 -0400 Subject: [PATCH 16/37] Updating --- pynapple/core/_jitted_functions.py | 43 ++++++++++++++++++++++++++++++ pynapple/core/base_class.py | 2 +- pynapple/core/ts_group.py | 4 +-- pynapple/core/utils.py | 41 +++++++++++++++++++++++++++- tox.ini | 3 ++- 5 files changed, 88 insertions(+), 5 deletions(-) diff --git a/pynapple/core/_jitted_functions.py b/pynapple/core/_jitted_functions.py index 88d7df31..4d75a11e 100644 --- a/pynapple/core/_jitted_functions.py +++ b/pynapple/core/_jitted_functions.py @@ -81,6 +81,49 @@ def jittsrestrict(time_array, starts, ends): return new_time_array +@jit(nopython=True) +def jitrestrict2(time_array, starts, ends): + n = len(time_array) + m = len(starts) + ix = np.zeros(n, dtype=np.bool_) + + k = 0 + t = 0 + + while ends[k] < time_array[t]: + k += 1 + + while k < m: + # Outside + while t < n: + if time_array[t] >= starts[k]: + # ix[t] = True + # t += 1 + break + t += 1 + + # Inside + while t < n: + if time_array[t] > ends[k]: + k += 1 + break + else: + ix[t] = True + t += 1 + + if k == m: + break + if t == n: + break + + # new_time_array = time_array[ix] + return ix + + +def restrict(time_array, starts, ends): + return time_array[jitrestrict2(time_array, starts, ends)] + + @jit(nopython=True) def jitrestrict_with_count(time_array, data_array, starts, ends): n = len(time_array) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index e73208af..b18fadee 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -314,7 +314,7 @@ def restrict(self, iset): Returns ------- - out: Ts, Tsd, TsdFrame or TsdTensor + Ts, Tsd, TsdFrame or TsdTensor Tsd object restricted to ep Examples diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 7e33cff6..8435e13c 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -24,7 +24,7 @@ from .interval_set import IntervalSet from .time_index import TsIndex from .time_series import BaseTsd, Ts, Tsd, TsdFrame, is_array_like -from .utils import _get_terminal_size, convert_to_numpy +from .utils import _get_terminal_size, convert_to_numpy_array def _union_intervals(i_sets): @@ -138,7 +138,7 @@ def __init__( stacklevel=2, ) data[k] = Ts( - t=cast_to_numpy(data[k], "key {}".format(k)), + t=convert_to_numpy_array(data[k], "key {}".format(k)), time_support=time_support, time_units=time_units, ) diff --git a/pynapple/core/utils.py b/pynapple/core/utils.py index 884f8ce2..cb4e4e9c 100644 --- a/pynapple/core/utils.py +++ b/pynapple/core/utils.py @@ -96,7 +96,7 @@ def convert_to_jax_array(array, array_name): elif isinstance(array, jnp.ndarray): return array elif isinstance(array, np.ndarray): - return jnp.asarray(array) + return cast_to_jax(array, array_name) else: raise RuntimeError( "Unknown format for {}. Accepted formats are numpy.ndarray, list, tuple or any array-like objects.".format( @@ -235,6 +235,45 @@ def cast_to_numpy(array, array_name): return np.asarray(array) +def cast_to_jax(array, array_name): + """ + Convert an input array-like object to a jax Array. + + + Parameters + ---------- + array : array_like + The input object to convert. This can be any object that `np.asarray` is capable of + converting to a jax array, such as lists, tuples, and other array-like objects. + array_name : str + The name of the variable that we are converting, printed in the warning message. + + Returns + ------- + ndarray + A jax Array representation of the input `values`. If `values` is already a jax + Array, it is returned unchanged. Otherwise, a new jax Array is created and returned. + + Warnings + -------- + A warning is issued if the input `values` is not already a jax Array, indicating + that a conversion has taken place and showing the original type of the input. + + """ + import jax.numpy as jnp + + if ( + not isinstance(array, jnp.ndarray) + and not nap_config.suppress_conversion_warnings + ): + original_type = type(array).__name__ + warnings.warn( + f"Converting '{array_name}' to jax.ndarray. The provided array was of type '{original_type}'.", + UserWarning, + ) + return jnp.asarray(array) + + def _check_time_equals(time_arrays): """ Check if a list of time arrays are all equal. diff --git a/tox.ini b/tox.ini index e417d66e..8fecfc8a 100644 --- a/tox.ini +++ b/tox.ini @@ -21,4 +21,5 @@ commands = python = 3.8: py38 3.9: py39 - 3.10: py310 \ No newline at end of file + 3.10: py310 + 3.11: py311 \ No newline at end of file From 73da1c1f94d450cf3293ca51dcc05f5a1e927123 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 18 Apr 2024 17:27:01 -0400 Subject: [PATCH 17/37] Updating jittedfucntions --- pynapple/core/_jitted_functions.py | 1014 +++++++++++++--------------- pynapple/core/base_class.py | 19 +- pynapple/core/core_function.py | 144 ++-- pynapple/core/interval_set.py | 13 +- pynapple/core/time_series.py | 8 +- pynapple/core/utils.py | 94 --- tests/test_config.py | 24 +- tests/test_interval_set.py | 12 +- tests/test_jitted.py | 16 +- tox.ini | 7 +- 10 files changed, 562 insertions(+), 789 deletions(-) diff --git a/pynapple/core/_jitted_functions.py b/pynapple/core/_jitted_functions.py index 4d75a11e..c855236b 100644 --- a/pynapple/core/_jitted_functions.py +++ b/pynapple/core/_jitted_functions.py @@ -2,8 +2,11 @@ from numba import jit, njit, prange +################################ +# Time only functions +################################ @jit(nopython=True) -def jitrestrict(time_array, data_array, starts, ends): +def jitrestrict(time_array, starts, ends): n = len(time_array) m = len(starts) ix = np.zeros(n, dtype=np.bool_) @@ -37,16 +40,15 @@ def jitrestrict(time_array, data_array, starts, ends): if t == n: break - new_time_array = time_array[ix] - new_data_array = data_array[ix] - return (new_time_array, new_data_array) + return ix @jit(nopython=True) -def jittsrestrict(time_array, starts, ends): +def jittsrestrict_with_count(time_array, starts, ends): n = len(time_array) m = len(starts) ix = np.zeros(n, dtype=np.bool_) + count = np.zeros(m, dtype=np.int64) k = 0 t = 0 @@ -59,6 +61,7 @@ def jittsrestrict(time_array, starts, ends): while t < n: if time_array[t] >= starts[k]: # ix[t] = True + # count[k] += 1 # t += 1 break t += 1 @@ -70,6 +73,7 @@ def jittsrestrict(time_array, starts, ends): break else: ix[t] = True + count[k] += 1 t += 1 if k == m: @@ -78,101 +82,104 @@ def jittsrestrict(time_array, starts, ends): break new_time_array = time_array[ix] - return new_time_array + return new_time_array, count @jit(nopython=True) -def jitrestrict2(time_array, starts, ends): - n = len(time_array) - m = len(starts) - ix = np.zeros(n, dtype=np.bool_) +def jitvaluefrom(time_array, time_target_array, starts, ends): + time_array, count = jittsrestrict_with_count(time_array, starts, ends) + time_target_array, count_target = jittsrestrict_with_count( + time_target_array, starts, ends + ) - k = 0 - t = 0 + m = starts.shape[0] + n = time_array.shape[0] + d = time_target_array.shape[0] - while ends[k] < time_array[t]: - k += 1 + idx = np.zeros(n, dtype=np.int_) - while k < m: - # Outside - while t < n: - if time_array[t] >= starts[k]: - # ix[t] = True - # t += 1 - break - t += 1 + if n > 0 and d > 0: + for k in range(m): + if count[k] > 0 and count_target[k] > 0: + t = np.sum(count[0:k]) + i = np.sum(count_target[0:k]) + maxt = t + count[k] + maxi = i + count_target[k] + while t < maxt: + interval = abs(time_array[t] - time_target_array[i]) + idx[t] = i - # Inside - while t < n: - if time_array[t] > ends[k]: - k += 1 - break - else: - ix[t] = True - t += 1 + i += 1 + while i < maxi: + new_interval = abs(time_array[t] - time_target_array[i]) + if new_interval > interval: + break + else: + idx[t] = i + interval = new_interval + i += 1 + i -= 1 + t += 1 - if k == m: - break - if t == n: - break + return (time_array, idx) - # new_time_array = time_array[ix] - return ix +@jit(nopython=True) +def jitcount(time_array, starts, ends, bin_size): + time_array, countin = jittsrestrict_with_count(time_array, starts, ends) -def restrict(time_array, starts, ends): - return time_array[jitrestrict2(time_array, starts, ends)] + m = starts.shape[0] + nb_bins = np.zeros(m, dtype=np.int32) + for k in range(m): + if (ends[k] - starts[k]) > bin_size: + nb_bins[k] = int(np.ceil((ends[k] + bin_size - starts[k]) / bin_size)) + else: + nb_bins[k] = 1 -@jit(nopython=True) -def jitrestrict_with_count(time_array, data_array, starts, ends): - n = len(time_array) - m = len(starts) - ix = np.zeros(n, dtype=np.bool_) - count = np.zeros(m, dtype=np.int64) + nb = np.sum(nb_bins) + bins = np.zeros(nb, dtype=np.float64) + cnt = np.zeros(nb, dtype=np.int64) k = 0 t = 0 - - while ends[k] < time_array[t]: - k += 1 + b = 0 while k < m: - # Outside - while t < n: - if time_array[t] >= starts[k]: - # ix[t] = True - # count[k] += 1 - # t += 1 - break - t += 1 + maxb = b + nb_bins[k] + maxt = t + countin[k] + lbound = starts[k] - # Inside - while t < n: - if time_array[t] > ends[k]: - k += 1 + while b < maxb: + xpos = lbound + bin_size / 2 + if xpos > ends[k]: break else: - ix[t] = True - count[k] += 1 - t += 1 + bins[b] = xpos + rbound = np.round(lbound + bin_size, 9) + while t < maxt: + if time_array[t] < rbound: # similar to numpy hisrogram + cnt[b] += 1 + t += 1 + else: + break - if k == m: - break - if t == n: - break + lbound += bin_size + b += 1 + t = maxt + k += 1 - new_time_array = time_array[ix] - new_data_array = data_array[ix] - return new_time_array, new_data_array, count + new_time_array = bins[0:b] + new_data_array = cnt[0:b] + + return (new_time_array, new_data_array) @jit(nopython=True) -def jittsrestrict_with_count(time_array, starts, ends): +def jitin_interval(time_array, starts, ends): n = len(time_array) m = len(starts) - ix = np.zeros(n, dtype=np.bool_) - count = np.zeros(m, dtype=np.int64) + data = np.ones(n, dtype=np.float64) * np.nan k = 0 t = 0 @@ -184,20 +191,20 @@ def jittsrestrict_with_count(time_array, starts, ends): # Outside while t < n: if time_array[t] >= starts[k]: - # ix[t] = True - # count[k] += 1 + # data[t] = k # t += 1 break + # data[t] = np.nan t += 1 # Inside while t < n: if time_array[t] > ends[k]: k += 1 + # data[t] = np.nan break else: - ix[t] = True - count[k] += 1 + data[t] = k t += 1 if k == m: @@ -205,10 +212,37 @@ def jittsrestrict_with_count(time_array, starts, ends): if t == n: break - new_time_array = time_array[ix] - return new_time_array, count + return data + + +@jit(nopython=True) +def jitremove_nan(time_array, index_nan): + n = len(time_array) + ix_start = np.zeros(n, dtype=np.bool_) + ix_end = np.zeros(n, dtype=np.bool_) + + if not index_nan[0]: # First start + ix_start[0] = True + + t = 1 + while t < n: + if index_nan[t - 1] and not index_nan[t]: # start + ix_start[t] = True + if not index_nan[t - 1] and index_nan[t]: # end + ix_end[t - 1] = True + t += 1 + + if not index_nan[-1]: # Last stop + ix_end[-1] = True + + starts = time_array[ix_start] + ends = time_array[ix_end] + return (starts, ends) +################################ +# Time Data functions +################################ @jit(nopython=True) def jitthreshold(time_array, data_array, starts, ends, thr, method="above"): n = time_array.shape[0] @@ -284,206 +318,63 @@ def jitthreshold(time_array, data_array, starts, ends, thr, method="above"): @jit(nopython=True) -def jitvaluefrom(time_array, time_target_array, data_target_array, starts, ends): - time_array, count = jittsrestrict_with_count(time_array, starts, ends) - time_target_array, data_target_array, count_target = jitrestrict_with_count( - time_target_array, data_target_array, starts, ends - ) +def jitrestrict_with_count(time_array, data_array, starts, ends): + n = len(time_array) + m = len(starts) + ix = np.zeros(n, dtype=np.bool_) + count = np.zeros(m, dtype=np.int64) - m = starts.shape[0] - n = time_array.shape[0] - d = time_target_array.shape[0] + k = 0 + t = 0 - new_data_array = np.zeros(n, dtype=data_target_array.dtype) + while ends[k] < time_array[t]: + k += 1 - if n > 0 and d > 0: - for k in range(m): - if count[k] > 0 and count_target[k] > 0: - t = np.sum(count[0:k]) - i = np.sum(count_target[0:k]) - maxt = t + count[k] - maxi = i + count_target[k] - while t < maxt: - interval = abs(time_array[t] - time_target_array[i]) - new_data_array[t] = data_target_array[i] - i += 1 - while i < maxi: - new_interval = abs(time_array[t] - time_target_array[i]) - if new_interval > interval: - break - else: - new_data_array[t] = data_target_array[i] - interval = new_interval - i += 1 - i -= 1 - t += 1 + while k < m: + # Outside + while t < n: + if time_array[t] >= starts[k]: + # ix[t] = True + # count[k] += 1 + # t += 1 + break + t += 1 + + # Inside + while t < n: + if time_array[t] > ends[k]: + k += 1 + break + else: + ix[t] = True + count[k] += 1 + t += 1 + + if k == m: + break + if t == n: + break - return (time_array, new_data_array, starts, ends) + new_time_array = time_array[ix] + new_data_array = data_array[ix] + return new_time_array, new_data_array, count @jit(nopython=True) -def jitvaluefromtensor(time_array, time_target_array, data_target_array, starts, ends): - time_array, _, count = jitrestrict_with_count( - time_array, np.zeros(time_array.shape[0]), starts, ends - ) - time_target_array, data_target_array, count_target = jitrestrict_with_count( - time_target_array, data_target_array, starts, ends +def jitbin_array(time_array, data_array, starts, ends, bin_size): + time_array, data_array, countin = jitrestrict_with_count( + time_array, data_array, starts, ends ) m = starts.shape[0] - n = time_array.shape[0] - d = time_target_array.shape[0] + f = data_array.shape[1:] - new_data_array = np.zeros( - (n, *data_target_array.shape[1:]), dtype=data_target_array.dtype - ) - - if n > 0 and d > 0: - for k in range(m): - if count[k] > 0 and count_target[k] > 0: - t = np.sum(count[0:k]) - i = np.sum(count_target[0:k]) - maxt = t + count[k] - maxi = i + count_target[k] - while t < maxt: - interval = abs(time_array[t] - time_target_array[i]) - new_data_array[t] = data_target_array[i] - i += 1 - while i < maxi: - new_interval = abs(time_array[t] - time_target_array[i]) - if new_interval > interval: - break - else: - new_data_array[t] = data_target_array[i] - interval = new_interval - i += 1 - i -= 1 - t += 1 - - return (time_array, new_data_array, starts, ends) - - -@jit(nopython=True) -def jitcount(time_array, starts, ends, bin_size): - time_array, countin = jittsrestrict_with_count(time_array, starts, ends) - - m = starts.shape[0] - - nb_bins = np.zeros(m, dtype=np.int32) - for k in range(m): - if (ends[k] - starts[k]) > bin_size: - nb_bins[k] = int(np.ceil((ends[k] + bin_size - starts[k]) / bin_size)) - else: - nb_bins[k] = 1 - - nb = np.sum(nb_bins) - bins = np.zeros(nb, dtype=np.float64) - cnt = np.zeros(nb, dtype=np.int64) - - k = 0 - t = 0 - b = 0 - - while k < m: - maxb = b + nb_bins[k] - maxt = t + countin[k] - lbound = starts[k] - - while b < maxb: - xpos = lbound + bin_size / 2 - if xpos > ends[k]: - break - else: - bins[b] = xpos - rbound = np.round(lbound + bin_size, 9) - while t < maxt: - if time_array[t] < rbound: # similar to numpy hisrogram - cnt[b] += 1 - t += 1 - else: - break - - lbound += bin_size - b += 1 - t = maxt - k += 1 - - new_time_array = bins[0:b] - new_data_array = cnt[0:b] - - return (new_time_array, new_data_array) - - -@jit(nopython=True) -def jitbin(time_array, data_array, starts, ends, bin_size): - time_array, data_array, countin = jitrestrict_with_count( - time_array, data_array, starts, ends - ) - - m = starts.shape[0] - - nb_bins = np.zeros(m, dtype=np.int32) - for k in range(m): - if (ends[k] - starts[k]) > bin_size: - nb_bins[k] = int(np.ceil((ends[k] + bin_size - starts[k]) / bin_size)) - else: - nb_bins[k] = 1 - - nb = np.sum(nb_bins) - bins = np.zeros(nb, dtype=np.float64) - cnt = np.zeros(nb, dtype=np.float64) - average = np.zeros(nb, dtype=np.float64) - - k = 0 - t = 0 - b = 0 - - while k < m: - maxb = b + nb_bins[k] - maxt = t + countin[k] - lbound = starts[k] - - while b < maxb: - xpos = lbound + bin_size / 2 - if xpos > ends[k]: - break - else: - bins[b] = xpos - rbound = np.round(lbound + bin_size, 9) - while t < maxt: - if time_array[t] < rbound: # similar to numpy hisrogram - cnt[b] += 1.0 - average[b] += data_array[t] - t += 1 - else: - break - - lbound += bin_size - b += 1 - t = maxt - k += 1 - - new_time_array = bins[0:b] - new_data_array = average[0:b] / cnt[0:b] - - return (new_time_array, new_data_array) - - -@jit(nopython=True) -def jitbin_array(time_array, data_array, starts, ends, bin_size): - time_array, data_array, countin = jitrestrict_with_count( - time_array, data_array, starts, ends - ) - - m = starts.shape[0] - f = data_array.shape[1:] - - nb_bins = np.zeros(m, dtype=np.int32) - for k in range(m): - if (ends[k] - starts[k]) > bin_size: - nb_bins[k] = int(np.ceil((ends[k] + bin_size - starts[k]) / bin_size)) - else: - nb_bins[k] = 1 + nb_bins = np.zeros(m, dtype=np.int32) + for k in range(m): + if (ends[k] - starts[k]) > bin_size: + nb_bins[k] = int(np.ceil((ends[k] + bin_size - starts[k]) / bin_size)) + else: + nb_bins[k] = 1 nb = np.sum(nb_bins) bins = np.zeros(nb, dtype=np.float64) @@ -521,295 +412,9 @@ def jitbin_array(time_array, data_array, starts, ends, bin_size): new_time_array = bins[0:b] - new_data_array = average[0:b] / cnt[0:b] - - return (new_time_array, new_data_array) - - -@jit(nopython=True) -def jitintersect(start1, end1, start2, end2): - m = start1.shape[0] - n = start2.shape[0] - - i = 0 - j = 0 - - newstart = np.zeros(m + n, dtype=np.float64) - newend = np.zeros(m + n, dtype=np.float64) - ct = 0 - - while i < m: - while j < n: - if end2[j] > start1[i]: - break - j += 1 - - if j == n: - break - - if start2[j] < end1[i]: - newstart[ct] = max(start1[i], start2[j]) - newend[ct] = min(end1[i], end2[j]) - ct += 1 - if end2[j] < end1[i]: - j += 1 - else: - i += 1 - else: - i += 1 - - newstart = newstart[0:ct] - newend = newend[0:ct] - - return (newstart, newend) - - -@jit(nopython=True) -def jitunion(start1, end1, start2, end2): - m = start1.shape[0] - n = start2.shape[0] - - i = 0 - j = 0 - - newstart = np.zeros(m + n, dtype=np.float64) - newend = np.zeros(m + n, dtype=np.float64) - ct = 0 - - while i < m: - while j < n: - if end2[j] > start1[i]: - break - newstart[ct] = start2[j] - newend[ct] = end2[j] - ct += 1 - j += 1 - - if j == n: - break - - # overlap - if start2[j] < end1[i]: - newstart[ct] = min(start1[i], start2[j]) - - while i < m and j < n: - newend[ct] = max(end1[i], end2[j]) - - if end1[i] < end2[j]: - i += 1 - else: - j += 1 - - if i == m: - j += 1 - ct += 1 - break - - if j == n: - i += 1 - ct += 1 - break - - if end2[j] < start1[i]: - j += 1 - ct += 1 - break - elif end1[i] < start2[j]: - i += 1 - ct += 1 - break - - else: - newstart[ct] = start1[i] - newend[ct] = end1[i] - ct += 1 - i += 1 - - while i < m: - newstart[ct] = start1[i] - newend[ct] = end1[i] - ct += 1 - i += 1 - - while j < n: - newstart[ct] = start2[j] - newend[ct] = end2[j] - ct += 1 - j += 1 - - newstart = newstart[0:ct] - newend = newend[0:ct] - - return (newstart, newend) - - -@jit(nopython=True) -def jitdiff(start1, end1, start2, end2): - m = start1.shape[0] - n = start2.shape[0] - - i = 0 - j = 0 - - newstart = np.zeros(m + n, dtype=np.float64) - newend = np.zeros(m + n, dtype=np.float64) - ct = 0 - - while i < m: - while j < n: - if end2[j] > start1[i]: - break - j += 1 - - if j == n: - break - - # overlap - if start2[j] < end1[i]: - if start2[j] < start1[i] and end1[i] < end2[j]: - i += 1 - - else: - if start2[j] > start1[i]: - newstart[ct] = start1[i] - newend[ct] = start2[j] - ct += 1 - j += 1 - - else: - newstart[ct] = end2[j] - newend[ct] = end1[i] - j += 1 - - while j < n: - if start2[j] < end1[i]: - newstart[ct] = end2[j - 1] - newend[ct] = start2[j] - ct += 1 - j += 1 - else: - break - - if end2[j - 1] < end1[i]: - newstart[ct] = end2[j - 1] - newend[ct] = end1[i] - ct += 1 - else: - j -= 1 - i += 1 - - else: - newstart[ct] = start1[i] - newend[ct] = end1[i] - ct += 1 - i += 1 - - while i < m: - newstart[ct] = start1[i] - newend[ct] = end1[i] - ct += 1 - i += 1 - - newstart = newstart[0:ct] - newend = newend[0:ct] - - return (newstart, newend) - - -@jit(nopython=True) -def jitunion_isets(starts, ends): - idx = np.argsort(starts) - starts = starts[idx] - ends = ends[idx] - - n = starts.shape[0] - new_start = np.zeros(n, dtype=np.float64) - new_end = np.zeros(n, dtype=np.float64) - - ct = 0 - new_start[ct] = starts[0] - e = ends[0] - i = 1 - while i < n: - if starts[i] > e: - new_end[ct] = e - ct += 1 - new_start[ct] = starts[i] - e = ends[i] - else: - e = max(e, ends[i]) - i += 1 - - new_end[ct] = e - ct += 1 - new_start = new_start[0:ct] - new_end = new_end[0:ct] - return (new_start, new_end) - - -@jit(nopython=True) -def jitin_interval(time_array, starts, ends): - n = len(time_array) - m = len(starts) - data = np.ones(n, dtype=np.float64) * np.nan - - k = 0 - t = 0 - - while ends[k] < time_array[t]: - k += 1 - - while k < m: - # Outside - while t < n: - if time_array[t] >= starts[k]: - # data[t] = k - # t += 1 - break - # data[t] = np.nan - t += 1 - - # Inside - while t < n: - if time_array[t] > ends[k]: - k += 1 - # data[t] = np.nan - break - else: - data[t] = k - t += 1 - - if k == m: - break - if t == n: - break - - return data - - -@jit(nopython=True) -def jitremove_nan(time_array, index_nan): - n = len(time_array) - ix_start = np.zeros(n, dtype=np.bool_) - ix_end = np.zeros(n, dtype=np.bool_) - - if not index_nan[0]: # First start - ix_start[0] = True - - t = 1 - while t < n: - if index_nan[t - 1] and not index_nan[t]: # start - ix_start[t] = True - if not index_nan[t - 1] and index_nan[t]: # end - ix_end[t - 1] = True - t += 1 - - if not index_nan[-1]: # Last stop - ix_end[-1] = True + new_data_array = average[0:b] / cnt[0:b] - starts = time_array[ix_start] - ends = time_array[ix_end] - return (starts, ends) + return (new_time_array, new_data_array) @jit(nopython=True) @@ -1037,3 +642,302 @@ def jitperievent_trigger_average( new_data_array[:, n] /= total[n] return new_data_array + + +################################ +# IntervalSet functions +################################ +@jit(nopython=True) +def jitintersect(start1, end1, start2, end2): + m = start1.shape[0] + n = start2.shape[0] + + i = 0 + j = 0 + + newstart = np.zeros(m + n, dtype=np.float64) + newend = np.zeros(m + n, dtype=np.float64) + ct = 0 + + while i < m: + while j < n: + if end2[j] > start1[i]: + break + j += 1 + + if j == n: + break + + if start2[j] < end1[i]: + newstart[ct] = max(start1[i], start2[j]) + newend[ct] = min(end1[i], end2[j]) + ct += 1 + if end2[j] < end1[i]: + j += 1 + else: + i += 1 + else: + i += 1 + + newstart = newstart[0:ct] + newend = newend[0:ct] + + return (newstart, newend) + + +@jit(nopython=True) +def jitunion(start1, end1, start2, end2): + m = start1.shape[0] + n = start2.shape[0] + + i = 0 + j = 0 + + newstart = np.zeros(m + n, dtype=np.float64) + newend = np.zeros(m + n, dtype=np.float64) + ct = 0 + + while i < m: + while j < n: + if end2[j] > start1[i]: + break + newstart[ct] = start2[j] + newend[ct] = end2[j] + ct += 1 + j += 1 + + if j == n: + break + + # overlap + if start2[j] < end1[i]: + newstart[ct] = min(start1[i], start2[j]) + + while i < m and j < n: + newend[ct] = max(end1[i], end2[j]) + + if end1[i] < end2[j]: + i += 1 + else: + j += 1 + + if i == m: + j += 1 + ct += 1 + break + + if j == n: + i += 1 + ct += 1 + break + + if end2[j] < start1[i]: + j += 1 + ct += 1 + break + elif end1[i] < start2[j]: + i += 1 + ct += 1 + break + + else: + newstart[ct] = start1[i] + newend[ct] = end1[i] + ct += 1 + i += 1 + + while i < m: + newstart[ct] = start1[i] + newend[ct] = end1[i] + ct += 1 + i += 1 + + while j < n: + newstart[ct] = start2[j] + newend[ct] = end2[j] + ct += 1 + j += 1 + + newstart = newstart[0:ct] + newend = newend[0:ct] + + return (newstart, newend) + + +@jit(nopython=True) +def jitdiff(start1, end1, start2, end2): + m = start1.shape[0] + n = start2.shape[0] + + i = 0 + j = 0 + + newstart = np.zeros(m + n, dtype=np.float64) + newend = np.zeros(m + n, dtype=np.float64) + ct = 0 + + while i < m: + while j < n: + if end2[j] > start1[i]: + break + j += 1 + + if j == n: + break + + # overlap + if start2[j] < end1[i]: + if start2[j] < start1[i] and end1[i] < end2[j]: + i += 1 + + else: + if start2[j] > start1[i]: + newstart[ct] = start1[i] + newend[ct] = start2[j] + ct += 1 + j += 1 + + else: + newstart[ct] = end2[j] + newend[ct] = end1[i] + j += 1 + + while j < n: + if start2[j] < end1[i]: + newstart[ct] = end2[j - 1] + newend[ct] = start2[j] + ct += 1 + j += 1 + else: + break + + if end2[j - 1] < end1[i]: + newstart[ct] = end2[j - 1] + newend[ct] = end1[i] + ct += 1 + else: + j -= 1 + i += 1 + + else: + newstart[ct] = start1[i] + newend[ct] = end1[i] + ct += 1 + i += 1 + + while i < m: + newstart[ct] = start1[i] + newend[ct] = end1[i] + ct += 1 + i += 1 + + newstart = newstart[0:ct] + newend = newend[0:ct] + + return (newstart, newend) + + +@jit(nopython=True) +def jitunion_isets(starts, ends): + idx = np.argsort(starts) + starts = starts[idx] + ends = ends[idx] + + n = starts.shape[0] + new_start = np.zeros(n, dtype=np.float64) + new_end = np.zeros(n, dtype=np.float64) + + ct = 0 + new_start[ct] = starts[0] + e = ends[0] + i = 1 + while i < n: + if starts[i] > e: + new_end[ct] = e + ct += 1 + new_start[ct] = starts[i] + e = ends[i] + else: + e = max(e, ends[i]) + i += 1 + + new_end[ct] = e + ct += 1 + new_start = new_start[0:ct] + new_end = new_end[0:ct] + return (new_start, new_end) + + +@jit(nopython=True) +def _jitfix_iset(start, end): + """ + 0 - > "Some starts and ends are equal. Removing 1 microsecond!", + 1 - > "Some ends precede the relative start. Dropping them!", + 2 - > "Some starts precede the previous end. Joining them!", + 3 - > "Some epochs have no duration" + + Parameters + ---------- + start : numpy.ndarray + Description + end : numpy.ndarray + Description + + Returns + ------- + TYPE + Description + """ + to_warn = np.zeros(4, dtype=np.bool_) + m = start.shape[0] + data = np.zeros((m, 2), dtype=np.float64) + i = 0 + ct = 0 + + while i < m: + newstart = start[i] + newend = end[i] + + while i < m: + if end[i] == start[i]: + to_warn[3] = True + i += 1 + else: + newstart = start[i] + newend = end[i] + break + + while i < m: + if end[i] < start[i]: + to_warn[1] = True + i += 1 + else: + newstart = start[i] + newend = end[i] + break + + if i >= m: + break + + while i < m - 1: + if start[i + 1] < end[i]: + to_warn[2] = True + i += 1 + newend = max(end[i - 1], end[i]) + else: + break + + if i < m - 1: + if newend == start[i + 1]: + to_warn[0] = True + newend -= 1.0e-6 + + data[ct, 0] = newstart + data[ct, 1] = newend + + ct += 1 + i += 1 + + data = data[0:ct] + + return (data, to_warn) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index b18fadee..8b826a51 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -174,7 +174,7 @@ def value_from(self, data, ep=None): >>> tsd = nap.Tsd(t=np.arange(0,1000), d=np.random.rand(1000), time_units='s') >>> ep = nap.IntervalSet(start = 0, end = 500, time_units = 's') - The variable ts is a time series object containing only nan. + The variable ts is a timestamp object. The tsd object containing the values, for example the tracking data, and the epoch to restrict the operation. >>> newts = ts.value_from(tsd, ep) @@ -192,11 +192,11 @@ def value_from(self, data, ep=None): starts = ep.start ends = ep.end - t, d, ns, ne = _value_from( + t, d = _value_from( time_array, time_target_array, data_target_array, starts, ends ) - time_support = IntervalSet(start=ns, end=ne) + time_support = IntervalSet(start=starts, end=ends) kwargs = {} if hasattr(data, "columns"): @@ -338,23 +338,22 @@ def restrict(self, iset): assert isinstance(iset, IntervalSet), "Argument should be IntervalSet" time_array = self.index.values - data_array = None starts = iset.start ends = iset.end - if hasattr(self, "values"): - data_array = self.values - - out = _restrict(time_array, data_array, starts, ends) + idx = _restrict(time_array, starts, ends) kwargs = {} if hasattr(self, "columns"): kwargs["columns"] = self.columns if hasattr(self, "values"): - return self.__class__(t=out[0], d=out[1], time_support=iset, **kwargs) + data_array = self.values + return self.__class__( + t=time_array[idx], d=data_array[idx], time_support=iset, **kwargs + ) else: - return self.__class__(t=out, time_support=iset) + return self.__class__(t=time_array[idx], time_support=iset) def copy(self): """Copy the data, index and time support""" diff --git a/pynapple/core/core_function.py b/pynapple/core/core_function.py index 2b511d5a..d82880b3 100644 --- a/pynapple/core/core_function.py +++ b/pynapple/core/core_function.py @@ -11,22 +11,69 @@ import numpy as np from scipy import signal -from ._jitted_functions import ( - jitbin, +from ._jitted_functions import ( jitbin_array, jitcount, jitremove_nan, jitrestrict, jitthreshold, - jittsrestrict, jittsrestrict_with_count, jitvaluefrom, - jitvaluefromtensor, pjitconvolve, ) from .utils import get_backend +def _restrict(time_array, starts, ends): + return jitrestrict(time_array, starts, ends) + + +def _count(time_array, starts, ends, bin_size=None): + if isinstance(bin_size, (float, int)): + return jitcount(time_array, starts, ends, bin_size) + else: + _, d = jittsrestrict_with_count(time_array, starts, ends) + t = starts + (ends - starts) / 2 + return t, d + + +def _value_from(time_array, time_target_array, data_target_array, starts, ends): + t, idx = jitvaluefrom(time_array, time_target_array, starts, ends) + return t, data_target_array[idx] + + +def _dropna(time_array, data_array, starts, ends, update_time_support, ndim): + index_nan = np.asarray(np.any(np.isnan(data_array), axis=tuple(range(1, ndim)))) + if np.all(index_nan): # In case it's only NaNs + if update_time_support: + starts = None + ends = None + return ( + np.array([]), + np.empty(tuple([0] + [d for d in data_array.shape[1:]])), + starts, + ends, + ) + elif np.any(index_nan): + if update_time_support: + starts, ends = jitremove_nan(time_array, index_nan) + + to_fix = starts == ends + if np.any(to_fix): + ends[to_fix] += 1e-6 # adding 1 millisecond in case of a single point + + return (time_array[~index_nan], data_array[~index_nan], starts, ends) + else: + return (time_array[~index_nan], data_array[~index_nan], starts, ends) + else: + return (time_array, data_array, starts, ends) + + +#################################### +# Can call pynajax +#################################### + + def _convolve(time_array, data_array, starts, ends, array, trim="both"): if get_backend() == "jax": from pynajax.jax_core_convolve import convolve @@ -65,107 +112,22 @@ def _convolve(time_array, data_array, starts, ends, array, trim="both"): return new_data_array -def _restrict(time_array, data_array, starts, ends): - if get_backend() == "jax": - from pynajax.jax_core_restrict import restrict - - return restrict(time_array, data_array, starts, ends) - else: - if data_array is not None: - return jitrestrict(time_array, data_array, starts, ends) - else: - return jittsrestrict(time_array, starts, ends) - - -def _count(time_array, starts, ends, bin_size=None): - if get_backend() == "jax": - from pynajax.jax_core_count import count - - return count(time_array, starts, ends, bin_size) - else: - if isinstance(bin_size, (float, int)): - return jitcount(time_array, starts, ends, bin_size) - else: - _, d = jittsrestrict_with_count(time_array, starts, ends) - t = starts + (ends - starts) / 2 - return t, d - - -def _value_from(time_array, time_target_array, data_target_array, starts, ends): - if get_backend() == "jax": - from pynajax.jax_core_value_from import value_from - - return value_from( - time_array, time_target_array, data_target_array, starts, ends - ) - else: - if data_target_array.ndim == 1: - t, d, ns, ne = jitvaluefrom( - time_array, time_target_array, data_target_array, starts, ends - ) - else: - t, d, ns, ne = jitvaluefromtensor( - time_array, time_target_array, data_target_array, starts, ends - ) - return t, d, ns, ne - - def _bin_average(time_array, data_array, starts, ends, bin_size): if get_backend() == "jax": from pynajax.jax_core_bin_average import bin_average - return bin_average(time_array, data_array, starts, ends, bin_size) else: - if data_array.ndim > 1: - t, d = jitbin_array(time_array, data_array, starts, ends, bin_size) - else: - t, d = jitbin(time_array, data_array, starts, ends, bin_size) - - return t, d + return jitbin_array(time_array, data_array, starts, ends, bin_size) def _threshold(time_array, data_array, starts, ends, thr, method): if get_backend() == "jax": from pynajax.jax_core_threshold import threshold - return threshold(time_array, data_array, starts, ends, thr, method) else: return jitthreshold(time_array, data_array, starts, ends, thr, method) - -def _dropna(time_array, data_array, starts, ends, update_time_support, ndim): - if get_backend() == "jax": - from pynajax.jax_core_dropna import dropna - - return dropna(time_array, data_array, starts, ends, update_time_support) - else: - index_nan = np.any(np.isnan(data_array), axis=tuple(range(1, ndim))) - if np.all(index_nan): # In case it's only NaNs - if update_time_support: - starts = None - ends = None - return ( - np.array([]), - np.empty(tuple([0] + [d for d in data_array.shape[1:]])), - starts, - ends, - ) - elif np.any(index_nan): - if update_time_support: - starts, ends = jitremove_nan(time_array, index_nan) - - to_fix = starts == ends - if np.any(to_fix): - ends[ - to_fix - ] += 1e-6 # adding 1 millisecond in case of a single point - - return (time_array[~index_nan], data_array[~index_nan], starts, ends) - else: - return (time_array[~index_nan], data_array[~index_nan], starts, ends) - else: - return (time_array, data_array, starts, ends) - +# def perievent # def _sta(): diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index c7beb19a..121f854c 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -49,14 +49,19 @@ from numpy.lib.mixins import NDArrayOperatorsMixin from tabulate import tabulate -from ._jitted_functions import jitdiff, jitin_interval, jitintersect, jitunion +from ._jitted_functions import ( + jitdiff, + jitin_interval, + jitintersect, + jitunion, + _jitfix_iset, +) from .config import nap_config from .time_index import TsIndex from .utils import ( _get_terminal_size, _IntervalSetSliceHelper, - _jitfix_iset, - convert_to_numpy, + convert_to_numpy_array, is_array_like, ) @@ -134,7 +139,7 @@ def __init__(self, start, end=None, time_units="s", **kwargs): elif isinstance(data, np.ndarray): args[arg] = np.ravel(data) elif is_array_like(data): - args[arg] = cast_to_numpy(data, arg) + args[arg] = convert_to_numpy_array(data, arg) else: raise RuntimeError( "Unknown format for {}. Accepted formats are numpy.ndarray, list, tuple or any array-like objects.".format( diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 6a594579..041c46ad 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -88,7 +88,9 @@ def __init__(self, t, d, time_units="s", time_support=None): if isinstance(time_support, IntervalSet) and len(self.index): starts = time_support.start ends = time_support.end - t, d = _restrict(self.index.values, self.values, starts, ends) + idx = _restrict(self.index.values, starts, ends) + t = self.index.values[idx] + d = self.values[idx] self.index = TsIndex(t) self.values = d @@ -1429,8 +1431,8 @@ def __init__(self, t, time_units="s", time_support=None): if isinstance(time_support, IntervalSet) and len(self.index): starts = time_support.start ends = time_support.end - t = _restrict(self.index.values, None, starts, ends) - self.index = TsIndex(t) + idx = _restrict(self.index.values, starts, ends) + self.index = TsIndex(self.index.values[idx]) self.rate = self.index.shape[0] / np.sum( time_support.values[:, 1] - time_support.values[:, 0] ) diff --git a/pynapple/core/utils.py b/pynapple/core/utils.py index cb4e4e9c..3d5578ab 100644 --- a/pynapple/core/utils.py +++ b/pynapple/core/utils.py @@ -8,28 +8,9 @@ from numbers import Number import numpy as np -from numba import jit from .config import nap_config -# def not_implemented_in_pynajax(func, which_in, which_out, *args, **kwargs): - -# if nap_config.backend == "jax": -# import jax -# import jax.numpy as jnp - -# # def wrapper(*args, **kwargs): -# arguments, struct = jax.tree_util.tree_flatten((args, kwargs)) -# arguments[which_in] = jax.tree_map(np.asarray, arguments[which_in]) -# args, kwargs = jax.tree_util.tree_unflatten(struct, arguments) -# out = func(*args, **kwargs) -# out = list(out) -# out[which_out] = jax.tree_map(jnp.asarray, out[which_out]) -# return tuple(out) -# else: -# # def wrapper(*args, **kwargs): -# return func(*args, **kwargs) - def convert_to_numpy_array(array, array_name): """Convert any array like object to numpy ndarray. @@ -399,81 +380,6 @@ def _concatenate_tsd(func, *args, **kwargs): return output -@jit(nopython=True) -def _jitfix_iset(start, end): - """ - 0 - > "Some starts and ends are equal. Removing 1 microsecond!", - 1 - > "Some ends precede the relative start. Dropping them!", - 2 - > "Some starts precede the previous end. Joining them!", - 3 - > "Some epochs have no duration" - - Parameters - ---------- - start : numpy.ndarray - Description - end : numpy.ndarray - Description - - Returns - ------- - TYPE - Description - """ - to_warn = np.zeros(4, dtype=np.bool_) - m = start.shape[0] - data = np.zeros((m, 2), dtype=np.float64) - i = 0 - ct = 0 - - while i < m: - newstart = start[i] - newend = end[i] - - while i < m: - if end[i] == start[i]: - to_warn[3] = True - i += 1 - else: - newstart = start[i] - newend = end[i] - break - - while i < m: - if end[i] < start[i]: - to_warn[1] = True - i += 1 - else: - newstart = start[i] - newend = end[i] - break - - if i >= m: - break - - while i < m - 1: - if start[i + 1] < end[i]: - to_warn[2] = True - i += 1 - newend = max(end[i - 1], end[i]) - else: - break - - if i < m - 1: - if newend == start[i + 1]: - to_warn[0] = True - newend -= 1.0e-6 - - data[ct, 0] = newstart - data[ct, 1] = newend - - ct += 1 - i += 1 - - data = data[0:ct] - - return (data, to_warn) - - class _TsdFrameSliceHelper: def __init__(self, tsdframe): self.tsdframe = tsdframe diff --git a/tests/test_config.py b/tests/test_config.py index e63ffd8c..b668ba0b 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -59,7 +59,6 @@ def __len__(self): ################################## # Test for backend ################################## - def test_change_backend(): nap.nap_config.set_backend("numba") @@ -71,21 +70,20 @@ def test_change_backend(): # For local tests. # Should not be installed for github actions - try: - import pynajax + # try: + # import pynajax - nap.nap_config.set_backend("jax") - assert nap.core.utils.get_backend() == "jax" - assert nap.nap_config.backend == "jax" - - except ModuleNotFoundError: + # nap.nap_config.set_backend("jax") + # assert nap.core.utils.get_backend() == "jax" + # assert nap.nap_config.backend == "jax" - with warnings.catch_warnings(record=True) as w: - nap.nap_config.set_backend("jax") + # except ModuleNotFoundError: + # with warnings.catch_warnings(record=True) as w: + # nap.nap_config.set_backend("jax") - assert str(w[0].message) == 'Package pynajax is not found. Falling back to numba backend. To use the jax backend for pynapple, please install pynajax' - assert nap.core.utils.get_backend() == "numba" - assert nap.nap_config.backend == "numba" + # assert str(w[0].message) == 'Package pynajax is not found. Falling back to numba backend. To use the jax backend for pynapple, please install pynajax' + # assert nap.core.utils.get_backend() == "numba" + # assert nap.nap_config.backend == "numba" ################################## # Tests for warnings diff --git a/tests/test_interval_set.py b/tests/test_interval_set.py index 322c68ae..bd4b5631 100644 --- a/tests/test_interval_set.py +++ b/tests/test_interval_set.py @@ -345,7 +345,7 @@ def test_jitfix_iset(): starts = np.array([0, 10, 16]) ends = np.array([5, 15, 20]) - ep, to_warn = nap.core.utils._jitfix_iset(starts, ends) + ep, to_warn = nap.core._jitted_functions._jitfix_iset(starts, ends) np.testing.assert_array_almost_equal(starts, ep[:,0]) np.testing.assert_array_almost_equal(ends, ep[:,1]) np.testing.assert_array_almost_equal(to_warn, np.zeros(4)) @@ -354,7 +354,7 @@ def test_jitfix_iset_error0(): start = np.around(np.array([0, 10, 15], dtype=np.float64), 9) end = np.around(np.array([10, 15, 20], dtype=np.float64), 9) - ep, to_warn = nap.core.utils._jitfix_iset(start, end) + ep, to_warn = nap.core._jitted_functions._jitfix_iset(start, end) end[1:] -= 1e-6 @@ -373,7 +373,7 @@ def test_jitfix_iset_error1(): start = np.around(np.array([0, 15, 16], dtype=np.float64), 9) end = np.around(np.array([5, 10, 20], dtype=np.float64), 9) - ep, to_warn = nap.core.utils._jitfix_iset(start, end) + ep, to_warn = nap.core._jitted_functions._jitfix_iset(start, end) np.testing.assert_array_almost_equal(start[[0,2]], ep[:,0]) np.testing.assert_array_almost_equal(end[[0,2]], ep[:,1]) @@ -390,7 +390,7 @@ def test_jitfix_iset_error2(): start = np.around(np.array([0, 10, 16], dtype=np.float64), 9) end = np.around(np.array([11, 15, 20], dtype=np.float64), 9) - ep, to_warn = nap.core.utils._jitfix_iset(start, end) + ep, to_warn = nap.core._jitted_functions._jitfix_iset(start, end) np.testing.assert_array_almost_equal(start[[0,2]], ep[:,0]) np.testing.assert_array_almost_equal(end[[1,2]], ep[:,1]) @@ -407,7 +407,7 @@ def test_jitfix_iset_error3(): start = np.around(np.array([0, 15, 16], dtype=np.float64), 9) end = np.around(np.array([5, 15, 20], dtype=np.float64), 9) - ep, to_warn = nap.core.utils._jitfix_iset(start, end) + ep, to_warn = nap.core._jitted_functions._jitfix_iset(start, end) np.testing.assert_array_almost_equal(start[[0,2]], ep[:,0]) np.testing.assert_array_almost_equal(end[[0,2]], ep[:,1]) @@ -423,7 +423,7 @@ def test_jitfix_iset_random(): start = np.sort(np.random.uniform(0, 1000, 100)) end = np.sort(np.random.uniform(0, 1000, 100)) - ep, to_warn = nap.core.utils._jitfix_iset(start, end) + ep, to_warn = nap.core._jitted_functions._jitfix_iset(start, end) if len(ep): assert np.all(ep[:,1] - ep[:,0] > 0) diff --git a/tests/test_jitted.py b/tests/test_jitted.py index f7a1bb09..3cf5ee51 100644 --- a/tests/test_jitted.py +++ b/tests/test_jitted.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-12-02 17:17:03 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-04-01 17:37:20 +# @Last Modified time: 2024-04-18 17:23:42 """Tests of jitted core functions for `pynapple` package.""" @@ -75,8 +75,8 @@ def test_jitrestrict(): ep, ts, tsd, tsdframe = get_example_dataset() tsd2 = restrict(ep, tsd) - t, d= nap.core._jitted_functions.jitrestrict(tsd.index, tsd.values, ep.start, ep.end) - tsd3 = pd.Series(index=t, data=d) + ix = nap.core._jitted_functions.jitrestrict(tsd.index, ep.start, ep.end) + tsd3 = pd.Series(index=tsd.index[ix], data=tsd.values[ix]) pd.testing.assert_series_equal(tsd2, tsd3) def test_jittsrestrict(): @@ -84,8 +84,8 @@ def test_jittsrestrict(): ep, ts, tsd, tsdframe = get_example_dataset() ts2 = restrict(ep, ts) - t = nap.core._jitted_functions.jittsrestrict(ts.index, ep.start, ep.end) - ts3 = pd.Series(index=t, dtype="object") + idx = nap.core._jitted_functions.jitrestrict(ts.index, ep.start, ep.end) + ts3 = pd.Series(index=ts.index[idx], dtype="object") pd.testing.assert_series_equal(ts2, ts3) def test_jitrestrict_with_count(): @@ -170,8 +170,8 @@ def test_jitvalue_from(): for i in range(10): ep, ts, tsd, tsdframe = get_example_dataset() - t, d, s, e = nap.core._jitted_functions.jitvaluefrom(ts.index, tsd.index, tsd.values, ep.start, ep.end) - tsd3 = pd.Series(index=t, data=d) + t, idx = nap.core._jitted_functions.jitvaluefrom(ts.index, tsd.index, ep.start, ep.end) + tsd3 = pd.Series(index=t, data=tsd.values[idx]) tsd2 = [] for j in ep.index: @@ -219,7 +219,7 @@ def test_jitbin(): starts = ep.start ends = ep.end bin_size = 1.0 - t, d = nap.core._jitted_functions.jitbin(time_array, data_array, starts, ends, bin_size) + t, d = nap.core._jitted_functions.jitbin_array(time_array, data_array, starts, ends, bin_size) # tsd3 = nap.Tsd(t=t, d=d, time_support = ep) tsd3 = pd.Series(index=t, data=d) tsd3 = tsd3.fillna(0.0) diff --git a/tox.ini b/tox.ini index 8fecfc8a..d3e36cee 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,6 @@ [tox] isolated_build = True -envlist = py37,py38,py39,py310 +envlist = py310,py311 requires = tox-conda [testenv] @@ -12,14 +12,11 @@ extras = dev commands = black --check pynapple isort --check pynapple --profile black - flake8 pynapple --max-complexity 10 - #pytest tests/ + flake8 pynapple --max-complexity 10 coverage run --source=pynapple --branch -m pytest tests/ coverage report -m [gh-actions] python = - 3.8: py38 - 3.9: py39 3.10: py310 3.11: py311 \ No newline at end of file From fb9877cb15b9bd6a5ddbb210991034968ecf8ebc Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Fri, 19 Apr 2024 18:06:22 -0400 Subject: [PATCH 18/37] Working version of dev with minimal jitted functions --- pynapple/core/_jitted_functions.py | 79 ++++++------------------------ pynapple/core/core_function.py | 38 ++++++++++---- pynapple/core/interval_set.py | 2 +- pynapple/core/time_series.py | 2 +- pynapple/core/ts_group.py | 41 +++++++--------- tests/test_jitted.py | 43 +++------------- 6 files changed, 72 insertions(+), 133 deletions(-) diff --git a/pynapple/core/_jitted_functions.py b/pynapple/core/_jitted_functions.py index c855236b..b5caf04a 100644 --- a/pynapple/core/_jitted_functions.py +++ b/pynapple/core/_jitted_functions.py @@ -44,7 +44,7 @@ def jitrestrict(time_array, starts, ends): @jit(nopython=True) -def jittsrestrict_with_count(time_array, starts, ends): +def jitrestrict_with_count(time_array, starts, ends): n = len(time_array) m = len(starts) ix = np.zeros(n, dtype=np.bool_) @@ -81,22 +81,16 @@ def jittsrestrict_with_count(time_array, starts, ends): if t == n: break - new_time_array = time_array[ix] - return new_time_array, count + return ix, count @jit(nopython=True) -def jitvaluefrom(time_array, time_target_array, starts, ends): - time_array, count = jittsrestrict_with_count(time_array, starts, ends) - time_target_array, count_target = jittsrestrict_with_count( - time_target_array, starts, ends - ) - +def jitvaluefrom(time_array, time_target_array, count, count_target, starts, ends): m = starts.shape[0] n = time_array.shape[0] d = time_target_array.shape[0] - idx = np.zeros(n, dtype=np.int_) + idx = np.full(n, np.nan) if n > 0 and d > 0: for k in range(m): @@ -107,7 +101,7 @@ def jitvaluefrom(time_array, time_target_array, starts, ends): maxi = i + count_target[k] while t < maxt: interval = abs(time_array[t] - time_target_array[i]) - idx[t] = i + idx[t] = float(i) i += 1 while i < maxi: @@ -115,18 +109,19 @@ def jitvaluefrom(time_array, time_target_array, starts, ends): if new_interval > interval: break else: - idx[t] = i + idx[t] = float(i) interval = new_interval i += 1 i -= 1 t += 1 - return (time_array, idx) + return idx @jit(nopython=True) def jitcount(time_array, starts, ends, bin_size): - time_array, countin = jittsrestrict_with_count(time_array, starts, ends) + idx, countin = jitrestrict_with_count(time_array, starts, ends) + time_array = time_array[idx] m = starts.shape[0] @@ -317,54 +312,11 @@ def jitthreshold(time_array, data_array, starts, ends, thr, method="above"): return (new_time_array, new_data_array, new_starts, new_ends) -@jit(nopython=True) -def jitrestrict_with_count(time_array, data_array, starts, ends): - n = len(time_array) - m = len(starts) - ix = np.zeros(n, dtype=np.bool_) - count = np.zeros(m, dtype=np.int64) - - k = 0 - t = 0 - - while ends[k] < time_array[t]: - k += 1 - - while k < m: - # Outside - while t < n: - if time_array[t] >= starts[k]: - # ix[t] = True - # count[k] += 1 - # t += 1 - break - t += 1 - - # Inside - while t < n: - if time_array[t] > ends[k]: - k += 1 - break - else: - ix[t] = True - count[k] += 1 - t += 1 - - if k == m: - break - if t == n: - break - - new_time_array = time_array[ix] - new_data_array = data_array[ix] - return new_time_array, new_data_array, count - - @jit(nopython=True) def jitbin_array(time_array, data_array, starts, ends, bin_size): - time_array, data_array, countin = jitrestrict_with_count( - time_array, data_array, starts, ends - ) + idx, countin = jitrestrict_with_count(time_array, starts, ends) + time_array = time_array[idx] + data_array = data_array[idx] m = starts.shape[0] f = data_array.shape[1:] @@ -560,9 +512,10 @@ def jitperievent_trigger_average( N = count_array.shape[1] N_epochs = len(starts) - time_target_array, data_target_array, count = jitrestrict_with_count( - time_target_array, data_target_array, starts, ends - ) + idx, count = jitrestrict_with_count(time_target_array, starts, ends) + time_target_array = time_target_array[idx] + data_target_array = data_target_array[idx] + max_count = np.cumsum(count) new_data_array = np.full( diff --git a/pynapple/core/core_function.py b/pynapple/core/core_function.py index d82880b3..1483a186 100644 --- a/pynapple/core/core_function.py +++ b/pynapple/core/core_function.py @@ -11,13 +11,13 @@ import numpy as np from scipy import signal -from ._jitted_functions import ( +from ._jitted_functions import ( jitbin_array, jitcount, jitremove_nan, jitrestrict, + jitrestrict_with_count, jitthreshold, - jittsrestrict_with_count, jitvaluefrom, pjitconvolve, ) @@ -32,14 +32,33 @@ def _count(time_array, starts, ends, bin_size=None): if isinstance(bin_size, (float, int)): return jitcount(time_array, starts, ends, bin_size) else: - _, d = jittsrestrict_with_count(time_array, starts, ends) + _, d = jitrestrict_with_count(time_array, starts, ends) t = starts + (ends - starts) / 2 return t, d def _value_from(time_array, time_target_array, data_target_array, starts, ends): - t, idx = jitvaluefrom(time_array, time_target_array, starts, ends) - return t, data_target_array[idx] + idx_t, count = jitrestrict_with_count(time_array, starts, ends) + idx_target, count_target = jitrestrict_with_count(time_target_array, starts, ends) + idx = jitvaluefrom( + time_array[idx_t], + time_target_array[idx_target], + count, + count_target, + starts, + ends, + ) + + new_time_array = time_array[idx_t] + + new_data_array = np.zeros( + (len(new_time_array), *data_target_array.shape[1:]), + dtype=data_target_array.dtype, + ) + idx2 = ~np.isnan(idx) + new_data_array[idx2] = data_target_array[idx_target][idx[idx2].astype(int)] + + return new_time_array, new_data_array def _dropna(time_array, data_array, starts, ends, update_time_support, ndim): @@ -115,6 +134,7 @@ def _convolve(time_array, data_array, starts, ends, array, trim="both"): def _bin_average(time_array, data_array, starts, ends, bin_size): if get_backend() == "jax": from pynajax.jax_core_bin_average import bin_average + return bin_average(time_array, data_array, starts, ends, bin_size) else: return jitbin_array(time_array, data_array, starts, ends, bin_size) @@ -123,13 +143,11 @@ def _bin_average(time_array, data_array, starts, ends, bin_size): def _threshold(time_array, data_array, starts, ends, thr, method): if get_backend() == "jax": from pynajax.jax_core_threshold import threshold + return threshold(time_array, data_array, starts, ends, thr, method) else: return jitthreshold(time_array, data_array, starts, ends, thr, method) -# def perievent - -# def _sta(): - -# def _interp +def _event_trigger_average(): + pass diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 121f854c..3f65f802 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -50,11 +50,11 @@ from tabulate import tabulate from ._jitted_functions import ( + _jitfix_iset, jitdiff, jitin_interval, jitintersect, jitunion, - _jitfix_iset, ) from .config import nap_config from .time_index import TsIndex diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 041c46ad..c134e6ae 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -1431,7 +1431,7 @@ def __init__(self, t, time_units="s", time_support=None): if isinstance(time_support, IntervalSet) and len(self.index): starts = time_support.start ends = time_support.end - idx = _restrict(self.index.values, starts, ends) + idx = _restrict(self.index.values, starts, ends) self.index = TsIndex(self.index.values[idx]) self.rate = self.index.shape[0] / np.sum( time_support.values[:, 1] - time_support.values[:, 0] diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 8435e13c..f1e4571f 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -14,13 +14,9 @@ import pandas as pd from tabulate import tabulate -from ._jitted_functions import ( - jitcount, - jittsrestrict_with_count, - jitunion, - jitunion_isets, -) +from ._jitted_functions import jitunion, jitunion_isets from .base_class import Base +from .core_function import _count from .interval_set import IntervalSet from .time_index import TsIndex from .time_series import BaseTsd, Ts, Tsd, TsdFrame, is_array_like @@ -691,28 +687,29 @@ def count(self, *args, **kwargs): ends = ep.end if isinstance(bin_size, (float, int)): + bin_size = float(bin_size) bin_size = TsIndex.format_timestamps(np.array([bin_size]), time_units)[0] - time_index, _ = jitcount(np.array([]), starts, ends, bin_size) - n = len(self.index) - count = np.zeros((time_index.shape[0], n), dtype=np.int64) - for i in range(n): - count[:, i] = jitcount( - self.data[self.index[i]].index, starts, ends, bin_size - )[1] + # Call it on first element to pre-allocate the array + if len(self) >= 1: + time_index, d = _count( + self.data[self.index[0]].index.values, starts, ends, bin_size + ) - else: - time_index = starts + (ends - starts) / 2 - n = len(self.index) - count = np.zeros((time_index.shape[0], n), dtype=np.int64) + count = np.zeros((len(time_index), len(self.index)), dtype=np.int64) + count[:, 0] = d - for i in range(n): - count[:, i] = jittsrestrict_with_count( - self.data[self.index[i]].index, starts, ends + for i in range(1, len(self.index)): + count[:, i] = _count( + self.data[self.index[i]].index.values, starts, ends, bin_size )[1] - toreturn = TsdFrame(t=time_index, d=count, time_support=ep, columns=self.index) - return toreturn + return TsdFrame(t=time_index, d=count, time_support=ep, columns=self.index) + else: + time_index, _ = _count(np.array([]), starts, ends, bin_size) + return TsdFrame( + t=time_index, d=np.empty((len(time_index), 0)), time_support=ep + ) def to_tsd(self, *args): """ diff --git a/tests/test_jitted.py b/tests/test_jitted.py index 3cf5ee51..693e774c 100644 --- a/tests/test_jitted.py +++ b/tests/test_jitted.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-12-02 17:17:03 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-04-18 17:23:42 +# @Last Modified time: 2024-04-19 17:54:34 """Tests of jitted core functions for `pynapple` package.""" @@ -79,22 +79,13 @@ def test_jitrestrict(): tsd3 = pd.Series(index=tsd.index[ix], data=tsd.values[ix]) pd.testing.assert_series_equal(tsd2, tsd3) -def test_jittsrestrict(): - for i in range(100): - ep, ts, tsd, tsdframe = get_example_dataset() - - ts2 = restrict(ep, ts) - idx = nap.core._jitted_functions.jitrestrict(ts.index, ep.start, ep.end) - ts3 = pd.Series(index=ts.index[idx], dtype="object") - pd.testing.assert_series_equal(ts2, ts3) - def test_jitrestrict_with_count(): for i in range(100): ep, ts, tsd, tsdframe = get_example_dataset() tsd2 = restrict(ep, tsd) - t, d, count = nap.core._jitted_functions.jitrestrict_with_count(tsd.index, tsd.values, ep.start, ep.end) - tsd3 = pd.Series(index=t, data=d) + ix, count = nap.core._jitted_functions.jitrestrict_with_count(tsd.index, ep.start, ep.end) + tsd3 = pd.Series(index=tsd.index[ix], data=tsd.values[ix]) pd.testing.assert_series_equal(tsd2, tsd3) bins = ep.values.ravel() @@ -109,27 +100,6 @@ def test_jitrestrict_with_count(): np.testing.assert_array_equal(count, count2) -def test_jittsrestrict_with_count(): - for i in range(100): - ep, ts, tsd, tsdframe = get_example_dataset() - - ts2 = restrict(ep, ts) - t, count = nap.core._jitted_functions.jittsrestrict_with_count(ts.index, ep.start, ep.end) - ts3 = pd.Series(index=t, dtype="object") - pd.testing.assert_series_equal(ts2, ts3) - - bins = ep.values.ravel() - ix = np.array(pd.cut(ts.index, bins, labels=np.arange(len(bins) - 1, dtype=np.float64))) - ix2 = np.array(pd.cut(ts.index,bins,labels=np.arange(len(bins) - 1, dtype=np.float64),right=False,)) - ix3 = np.vstack((ix, ix2)).T - ix3[np.floor(ix3 / 2) * 2 != ix3] = np.NaN - ix3 = np.floor(ix3 / 2) - ix3[np.isnan(ix3[:, 0]), 0] = ix3[np.isnan(ix3[:, 0]), 1] - ix = ix3[:,0] - count2 = np.array([np.sum(ix==j) for j in range(len(ep))]) - - np.testing.assert_array_equal(count, count2) - def test_jitthreshold(): for i in range(100): ep, ts, tsd, tsdframe = get_example_dataset() @@ -167,11 +137,12 @@ def test_jitthreshold(): # new_tsd = restrict(new_ep, tsd) def test_jitvalue_from(): - for i in range(10): + for i in range(100): ep, ts, tsd, tsdframe = get_example_dataset() - t, idx = nap.core._jitted_functions.jitvaluefrom(ts.index, tsd.index, ep.start, ep.end) - tsd3 = pd.Series(index=t, data=tsd.values[idx]) + t, d = nap.core.core_function._value_from(ts.t, tsd.t, tsd.d, ep.start, ep.end) + + tsd3 = pd.Series(index=t, data=d) tsd2 = [] for j in ep.index: From 40820327cf95048d2064c5db38707977e0bfec59 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Mon, 22 Apr 2024 23:14:18 -0400 Subject: [PATCH 19/37] Adding _process_functions.py --- .../{core_function.py => _core_functions.py} | 4 - pynapple/core/_jitted_functions.py | 100 -------- pynapple/core/base_class.py | 2 +- pynapple/core/time_series.py | 19 +- pynapple/core/ts_group.py | 2 +- pynapple/process/_process_functions.py | 233 ++++++++++++++++++ pynapple/process/correlograms.py | 72 +----- pynapple/process/perievent.py | 42 ++-- tests/test_jitted.py | 4 +- tests/test_spike_trigger_average.py | 15 +- 10 files changed, 272 insertions(+), 221 deletions(-) rename pynapple/core/{core_function.py => _core_functions.py} (99%) create mode 100644 pynapple/process/_process_functions.py diff --git a/pynapple/core/core_function.py b/pynapple/core/_core_functions.py similarity index 99% rename from pynapple/core/core_function.py rename to pynapple/core/_core_functions.py index 1483a186..82097f11 100644 --- a/pynapple/core/core_function.py +++ b/pynapple/core/_core_functions.py @@ -147,7 +147,3 @@ def _threshold(time_array, data_array, starts, ends, thr, method): return threshold(time_array, data_array, starts, ends, thr, method) else: return jitthreshold(time_array, data_array, starts, ends, thr, method) - - -def _event_trigger_average(): - pass diff --git a/pynapple/core/_jitted_functions.py b/pynapple/core/_jitted_functions.py index b5caf04a..b211b187 100644 --- a/pynapple/core/_jitted_functions.py +++ b/pynapple/core/_jitted_functions.py @@ -497,106 +497,6 @@ def jitcontinuous_perievent( return new_data_array -@jit(nopython=True) -def jitperievent_trigger_average( - time_array, - count_array, - time_target_array, - data_target_array, - starts, - ends, - windows, - binsize, -): - T = time_array.shape[0] - N = count_array.shape[1] - N_epochs = len(starts) - - idx, count = jitrestrict_with_count(time_target_array, starts, ends) - time_target_array = time_target_array[idx] - data_target_array = data_target_array[idx] - - max_count = np.cumsum(count) - - new_data_array = np.full( - (int(windows.sum()) + 1, count_array.shape[1], *data_target_array.shape[1:]), - 0.0, - ) - - t = 0 # count events - - hankel_array = np.zeros((new_data_array.shape[0], *data_target_array.shape[1:])) - - for k in range(N_epochs): - if count[k] > 0: - t_start = t - maxi = max_count[k] - i = maxi - count[k] - - while t < T: - lbound = time_array[t] - rbound = np.round(lbound + binsize, 9) - - if time_target_array[i] < rbound: - i_start = i - i_stop = i - - while i_stop < maxi: - if time_target_array[i_stop] < rbound: - i_stop += 1 - else: - break - - while i_start < i_stop - 1: - if time_target_array[i_start] < lbound: - i_start += 1 - else: - break - v = np.sum(data_target_array[i_start:i_stop], 0) / float( - i_stop - i_start - ) - - checknan = np.sum(v) - if not np.isnan(checknan): - hankel_array[-1] = v - - if t - t_start >= windows[1]: - for n in range(N): - new_data_array[:, n] += ( - hankel_array * count_array[t - windows[1], n] - ) - - # hankel_array = np.roll(hankel_array, -1, axis=0) - hankel_array[0:-1] = hankel_array[1:] - hankel_array[-1] = 0.0 - - t += 1 - - i = i_start - - if t == T or time_array[t] > ends[k]: - if t - t_start > windows[1]: - for j in range(windows[1]): - for n in range(N): - new_data_array[:, n] += ( - hankel_array * count_array[t - windows[1] + j, n] - ) - - # hankel_array = np.roll(hankel_array, -1, axis=0) - hankel_array[0:-1] = hankel_array[1:] - hankel_array[-1] = 0.0 - - hankel_array *= 0.0 - break - - total = np.sum(count_array, 0) - for n in range(N): - if total[n] > 0.0: - new_data_array[:, n] /= total[n] - - return new_data_array - - ################################ # IntervalSet functions ################################ diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 8b826a51..da8c91ce 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -8,7 +8,7 @@ import numpy as np -from .core_function import _count, _restrict, _value_from +from ._core_functions import _count, _restrict, _value_from from .interval_set import IntervalSet from .time_index import TsIndex from .utils import convert_to_numpy_array diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index c134e6ae..d9d25054 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -27,8 +27,8 @@ from scipy import signal from tabulate import tabulate +from ._core_functions import _bin_average, _convolve, _dropna, _restrict, _threshold from .base_class import Base -from .core_function import _bin_average, _convolve, _dropna, _restrict, _threshold from .interval_set import IntervalSet from .time_index import TsIndex from .utils import ( @@ -492,13 +492,20 @@ def convolve(self, array, ep=None, trim="both"): "right", ], "Unknow argument. trim should be 'both', 'left' or 'right'." - if ep is None: - ep = self.time_support - time_array = self.index.values data_array = self.values - starts = ep.start - ends = ep.end + + if ep is None: + ep = self.time_support + starts = ep.start + ends = ep.end + else: + assert isinstance(ep, IntervalSet) + starts = ep.start + ends = ep.end + idx = _restrict(time_array, starts, ends) + time_array = time_array[idx] + data_array = data_array[idx] new_data_array = _convolve(time_array, data_array, starts, ends, array, trim) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index f1e4571f..e35a6937 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -14,9 +14,9 @@ import pandas as pd from tabulate import tabulate +from ._core_functions import _count from ._jitted_functions import jitunion, jitunion_isets from .base_class import Base -from .core_function import _count from .interval_set import IntervalSet from .time_index import TsIndex from .time_series import BaseTsd, Ts, Tsd, TsdFrame, is_array_like diff --git a/pynapple/process/_process_functions.py b/pynapple/process/_process_functions.py new file mode 100644 index 00000000..8f8e48e5 --- /dev/null +++ b/pynapple/process/_process_functions.py @@ -0,0 +1,233 @@ +""" + This module holds some process function of pynapple that can be + called with numba or pynajax as backend + + If pynajax is installed and `nap.nap_config.backend` is set + to `jax`, the module will call the functions within pynajax. + Otherwise the module will call the functions within `_jitted_functions.py`. + +""" + +import numpy as np +from numba import jit + +from .. import core as nap + + +@jit(nopython=True) +def _cross_correlogram(t1, t2, binsize, windowsize): + """ + Performs the discrete cross-correlogram of two time series. + The units should be in s for all arguments. + Return the firing rate of the series t2 relative to the timings of t1. + See compute_crosscorrelogram, compute_autocorrelogram and compute_eventcorrelogram + for wrappers of this function. + + Parameters + ---------- + t1 : numpy.ndarray + The timestamps of the reference time series (in seconds) + t2 : numpy.ndarray + The timestamps of the target time series (in seconds) + binsize : float + The bin size (in seconds) + windowsize : float + The window size (in seconds) + + Returns + ------- + numpy.ndarray + The cross-correlogram + numpy.ndarray + Center of the bins (in s) + + """ + # nbins = ((windowsize//binsize)*2) + + nt1 = len(t1) + nt2 = len(t2) + + nbins = int((windowsize * 2) // binsize) + if np.floor(nbins / 2) * 2 == nbins: + nbins = nbins + 1 + + w = (nbins / 2) * binsize + C = np.zeros(nbins) + i2 = 0 + + for i1 in range(nt1): + lbound = t1[i1] - w + while i2 < nt2 and t2[i2] < lbound: + i2 = i2 + 1 + while i2 > 0 and t2[i2 - 1] > lbound: + i2 = i2 - 1 + + rbound = lbound + leftb = i2 + for j in range(nbins): + k = 0 + rbound = rbound + binsize + while leftb < nt2 and t2[leftb] < rbound: + leftb = leftb + 1 + k = k + 1 + + C[j] += k + + C = C / (nt1 * binsize) + + m = -w + binsize / 2 + B = np.zeros(nbins) + for j in range(nbins): + B[j] = m + j * binsize + + return C, B + + +@jit(nopython=True) +def _jitperievent_trigger_average( + time_array, + count_array, + time_target_array, + data_target_array, + starts, + ends, + windows, + binsize, +): + T = time_array.shape[0] + N = count_array.shape[1] + N_epochs = len(starts) + + idx, count = nap._jitted_functions.jitrestrict_with_count(time_target_array, starts, ends) + time_target_array = time_target_array[idx] + data_target_array = data_target_array[idx] + + max_count = np.cumsum(count) + + new_data_array = np.full( + (int(windows.sum()) + 1, count_array.shape[1], *data_target_array.shape[1:]), + 0.0, + ) + + t = 0 # count events + + hankel_array = np.zeros((new_data_array.shape[0], *data_target_array.shape[1:])) + + for k in range(N_epochs): + if count[k] > 0: + t_start = t + maxi = max_count[k] + i = maxi - count[k] + + while t < T: + lbound = time_array[t] + rbound = np.round(lbound + binsize, 9) + + if time_target_array[i] < rbound: + i_start = i + i_stop = i + + while i_stop < maxi: + if time_target_array[i_stop] < rbound: + i_stop += 1 + else: + break + + while i_start < i_stop - 1: + if time_target_array[i_start] < lbound: + i_start += 1 + else: + break + v = np.sum(data_target_array[i_start:i_stop], 0) / float( + i_stop - i_start + ) + + checknan = np.sum(v) + if not np.isnan(checknan): + hankel_array[-1] = v + + if t - t_start >= windows[1]: + for n in range(N): + new_data_array[:, n] += ( + hankel_array * count_array[t - windows[1], n] + ) + + # hankel_array = np.roll(hankel_array, -1, axis=0) + hankel_array[0:-1] = hankel_array[1:] + hankel_array[-1] = 0.0 + + t += 1 + + i = i_start + + if t == T or time_array[t] > ends[k]: + if t - t_start > windows[1]: + for j in range(windows[1]): + for n in range(N): + new_data_array[:, n] += ( + hankel_array * count_array[t - windows[1] + j, n] + ) + + # hankel_array = np.roll(hankel_array, -1, axis=0) + hankel_array[0:-1] = hankel_array[1:] + hankel_array[-1] = 0.0 + + hankel_array *= 0.0 + break + + total = np.sum(count_array, 0) + for n in range(N): + if total[n] > 0.0: + new_data_array[:, n] /= total[n] + + return new_data_array + + +def _perievent_trigger_average( + time_array, + count_array, + time_target_array, + data_target_array, + starts, + ends, + windows, + binsize, +): + if nap.utils.get_backend() == "jax": + from pynajax.jax_process_eta import event_trigger_average + + return event_trigger_average( + time_array, + count_array, + time_target_array, + np.expand_dims(data_target_array, -1), + starts, + ends, + windows, + binsize, + ) + + else: + if data_target_array.ndim == 1: + eta = _jitperievent_trigger_average( + time_array, + count_array, + time_target_array, + np.expand_dims(data_target_array, -1), + starts, + ends, + windows, + binsize, + ) + return np.squeeze(eta, -1) + else: + return _jitperievent_trigger_average( + time_array, + count_array, + time_target_array, + data_target_array, + starts, + ends, + windows, + binsize, + ) diff --git a/pynapple/process/correlograms.py b/pynapple/process/correlograms.py index e49bd293..c19387ac 100644 --- a/pynapple/process/correlograms.py +++ b/pynapple/process/correlograms.py @@ -2,88 +2,20 @@ # @Author: gviejo # @Date: 2022-01-02 11:39:55 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-04-04 10:46:03 +# @Last Modified time: 2024-04-22 17:23:27 from itertools import combinations, product import numpy as np import pandas as pd -from numba import jit from .. import core as nap - +from ._process_functions import _cross_correlogram ######################################################### # CORRELATION ######################################################### -@jit(nopython=True) -def _cross_correlogram(t1, t2, binsize, windowsize): - """ - Performs the discrete cross-correlogram of two time series. - The units should be in s for all arguments. - Return the firing rate of the series t2 relative to the timings of t1. - See compute_crosscorrelogram, compute_autocorrelogram and compute_eventcorrelogram - for wrappers of this function. - - Parameters - ---------- - t1 : numpy.ndarray - The timestamps of the reference time series (in seconds) - t2 : numpy.ndarray - The timestamps of the target time series (in seconds) - binsize : float - The bin size (in seconds) - windowsize : float - The window size (in seconds) - - Returns - ------- - numpy.ndarray - The cross-correlogram - numpy.ndarray - Center of the bins (in s) - - """ - # nbins = ((windowsize//binsize)*2) - - nt1 = len(t1) - nt2 = len(t2) - - nbins = int((windowsize * 2) // binsize) - if np.floor(nbins / 2) * 2 == nbins: - nbins = nbins + 1 - - w = (nbins / 2) * binsize - C = np.zeros(nbins) - i2 = 0 - - for i1 in range(nt1): - lbound = t1[i1] - w - while i2 < nt2 and t2[i2] < lbound: - i2 = i2 + 1 - while i2 > 0 and t2[i2 - 1] > lbound: - i2 = i2 - 1 - - rbound = lbound - leftb = i2 - for j in range(nbins): - k = 0 - rbound = rbound + binsize - while leftb < nt2 and t2[leftb] < rbound: - leftb = leftb + 1 - k = k + 1 - - C[j] += k - - C = C / (nt1 * binsize) - - m = -w + binsize / 2 - B = np.zeros(nbins) - for j in range(nbins): - B[j] = m + j * binsize - - return C, B def compute_autocorrelogram( diff --git a/pynapple/process/perievent.py b/pynapple/process/perievent.py index ac68a7ee..af034526 100644 --- a/pynapple/process/perievent.py +++ b/pynapple/process/perievent.py @@ -1,12 +1,11 @@ -# -*- coding: utf-8 -*- -# @Author: gviejo -# @Date: 2022-01-30 22:59:00 -# @Last Modified by: gviejo -# @Last Modified time: 2024-02-20 22:27:23 +"""Perievent functions + +""" import numpy as np from .. import core as nap +from ._process_functions import _perievent_trigger_average def _align_tsd(tsd, tref, window, time_support): @@ -288,29 +287,16 @@ def compute_event_trigger_average( time_target_array = feature.index.values data_target_array = feature.values - if data_target_array.ndim == 1: - eta = nap._jitted_functions.jitperievent_trigger_average( - time_array, - count_array, - time_target_array, - np.expand_dims(data_target_array, -1), - starts, - ends, - windows, - binsize, - ) - eta = np.squeeze(eta, -1) - else: - eta = nap._jitted_functions.jitperievent_trigger_average( - time_array, - count_array, - time_target_array, - data_target_array, - starts, - ends, - windows, - binsize, - ) + eta = _perievent_trigger_average( + time_array, + count_array, + time_target_array, + data_target_array, + starts, + ends, + windows, + binsize, + ) if eta.ndim == 2: return nap.TsdFrame(t=time_idx, d=eta, columns=group.index) diff --git a/tests/test_jitted.py b/tests/test_jitted.py index 693e774c..6d609849 100644 --- a/tests/test_jitted.py +++ b/tests/test_jitted.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-12-02 17:17:03 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-04-19 17:54:34 +# @Last Modified time: 2024-04-22 17:30:35 """Tests of jitted core functions for `pynapple` package.""" @@ -140,7 +140,7 @@ def test_jitvalue_from(): for i in range(100): ep, ts, tsd, tsdframe = get_example_dataset() - t, d = nap.core.core_function._value_from(ts.t, tsd.t, tsd.d, ep.start, ep.end) + t, d = nap.core._core_functions._value_from(ts.t, tsd.t, tsd.d, ep.start, ep.end) tsd3 = pd.Series(index=t, data=d) diff --git a/tests/test_spike_trigger_average.py b/tests/test_spike_trigger_average.py index 56ee193a..3c657905 100644 --- a/tests/test_spike_trigger_average.py +++ b/tests/test_spike_trigger_average.py @@ -1,8 +1,3 @@ -# -*- coding: utf-8 -*- -# @Author: gviejo -# @Date: 2022-08-29 17:27:02 -# @Last Modified by: gviejo -# @Last Modified time: 2024-02-20 22:45:51 #!/usr/bin/env python """Tests of spike trigger average for `pynapple` package.""" @@ -188,15 +183,17 @@ def test_compute_spike_trigger_average_raise_error(): def test_compute_spike_trigger_average_time_unit(): ep = nap.IntervalSet(0, 100) - feature = pd.Series(index=np.arange(0, 101, 0.01), data=np.zeros(int(101 / 0.01))) + t = np.arange(0, 101, 0.01) + d = np.zeros(int(101 / 0.01)) t1 = np.arange(1, 100) - feature.loc[t1] = 1.0 + for i in range(len(t1)): + d[t==t1[i]] = 1.0 + feature = nap.Tsd(t=t, d=d, time_support=ep) + spikes = nap.TsGroup( {0: nap.Ts(t1), 1: nap.Ts(t1 - 0.1), 2: nap.Ts(t1 + 0.2)}, time_support=ep ) - feature = nap.Tsd(feature, time_support=ep) - sta = nap.compute_event_trigger_average(spikes, feature, 0.2, (0.6, 0.6), ep) output = np.zeros((7, 3)) From fcdf0e24d52352d472256f0d3dcef461f73df3b7 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Mon, 22 Apr 2024 23:15:34 -0400 Subject: [PATCH 20/37] linting --- pynapple/process/_process_functions.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pynapple/process/_process_functions.py b/pynapple/process/_process_functions.py index 8f8e48e5..9e1e245c 100644 --- a/pynapple/process/_process_functions.py +++ b/pynapple/process/_process_functions.py @@ -98,7 +98,9 @@ def _jitperievent_trigger_average( N = count_array.shape[1] N_epochs = len(starts) - idx, count = nap._jitted_functions.jitrestrict_with_count(time_target_array, starts, ends) + idx, count = nap._jitted_functions.jitrestrict_with_count( + time_target_array, starts, ends + ) time_target_array = time_target_array[idx] data_target_array = data_target_array[idx] From e0457ee4872cba23f0c8b1d7702f5fe53704c828 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 23 Apr 2024 11:14:47 -0400 Subject: [PATCH 21/37] changing perievent --- pynapple/core/_jitted_functions.py | 99 ------------- pynapple/process/_process_functions.py | 183 +++++++++++++++---------- pynapple/process/correlograms.py | 77 ++++++++++- pynapple/process/perievent.py | 4 +- 4 files changed, 186 insertions(+), 177 deletions(-) diff --git a/pynapple/core/_jitted_functions.py b/pynapple/core/_jitted_functions.py index b211b187..efec30d8 100644 --- a/pynapple/core/_jitted_functions.py +++ b/pynapple/core/_jitted_functions.py @@ -398,105 +398,6 @@ def pjitconvolve(data_array, array, trim="both"): return new_data_array -@njit(parallel=True) -def jitcontinuous_perievent( - time_array, data_array, time_target_array, starts, ends, windowsize -): - N_samples = len(time_array) - N_target = len(time_target_array) - N_epochs = len(starts) - count = np.zeros((N_epochs, 2), dtype=np.int64) - start_t = np.zeros((N_epochs, 2), dtype=np.int64) - - k = 0 # Epochs - t = 0 # Samples - i = 0 # Target - - while ends[k] < time_array[t] and ends[k] < time_target_array[i]: - k += 1 - - while k < N_epochs: - # Outside - while t < N_samples: - if time_array[t] >= starts[k]: - break - t += 1 - - while i < N_target: - if time_target_array[i] >= starts[k]: - break - i += 1 - - if time_array[t] <= ends[k]: - start_t[k, 0] = t - - if time_target_array[i] <= ends[k]: - start_t[k, 1] = i - - # Inside - while t < N_samples: - if time_array[t] > ends[k]: - break - else: - count[k, 0] += 1 - t += 1 - - while i < N_target: - if time_target_array[i] > ends[k]: - break - else: - count[k, 1] += 1 - i += 1 - - k += 1 - - if k == N_epochs: - break - if t == N_samples: - break - if i == N_target: - break - - new_data_array = np.full( - (np.sum(windowsize) + 1, np.sum(count[:, 1]), *data_array.shape[1:]), np.nan - ) - - if np.any((count[:, 0] * count[:, 1]) > 0): - for k in prange(N_epochs): - if count[k, 0] > 0 and count[k, 1] > 0: - t = start_t[k, 0] - i = start_t[k, 1] - maxt = t + count[k, 0] - maxi = i + count[k, 1] - cnt_i = np.sum(count[0:k, 1]) - - while i < maxi: - interval = abs(time_array[t] - time_target_array[i]) - t_pos = t - t += 1 - while t < maxt: - new_interval = abs(time_array[t] - time_target_array[i]) - if new_interval > interval: - break - else: - interval = new_interval - t_pos = t - t += 1 - - left = np.minimum(windowsize[0], t_pos - start_t[k, 0]) - right = np.minimum(windowsize[1], maxt - t_pos - 1) - center = windowsize[0] + 1 - new_data_array[center - left - 1 : center + right, cnt_i] = ( - data_array[t_pos - left : t_pos + right + 1] - ) - - t -= 1 - i += 1 - cnt_i += 1 - - return new_data_array - - ################################ # IntervalSet functions ################################ diff --git a/pynapple/process/_process_functions.py b/pynapple/process/_process_functions.py index 9e1e245c..6a143144 100644 --- a/pynapple/process/_process_functions.py +++ b/pynapple/process/_process_functions.py @@ -9,78 +9,108 @@ """ import numpy as np -from numba import jit +from numba import jit, njit, prange from .. import core as nap -@jit(nopython=True) -def _cross_correlogram(t1, t2, binsize, windowsize): - """ - Performs the discrete cross-correlogram of two time series. - The units should be in s for all arguments. - Return the firing rate of the series t2 relative to the timings of t1. - See compute_crosscorrelogram, compute_autocorrelogram and compute_eventcorrelogram - for wrappers of this function. - - Parameters - ---------- - t1 : numpy.ndarray - The timestamps of the reference time series (in seconds) - t2 : numpy.ndarray - The timestamps of the target time series (in seconds) - binsize : float - The bin size (in seconds) - windowsize : float - The window size (in seconds) - - Returns - ------- - numpy.ndarray - The cross-correlogram - numpy.ndarray - Center of the bins (in s) - - """ - # nbins = ((windowsize//binsize)*2) - - nt1 = len(t1) - nt2 = len(t2) - - nbins = int((windowsize * 2) // binsize) - if np.floor(nbins / 2) * 2 == nbins: - nbins = nbins + 1 - - w = (nbins / 2) * binsize - C = np.zeros(nbins) - i2 = 0 - - for i1 in range(nt1): - lbound = t1[i1] - w - while i2 < nt2 and t2[i2] < lbound: - i2 = i2 + 1 - while i2 > 0 and t2[i2 - 1] > lbound: - i2 = i2 - 1 - - rbound = lbound - leftb = i2 - for j in range(nbins): - k = 0 - rbound = rbound + binsize - while leftb < nt2 and t2[leftb] < rbound: - leftb = leftb + 1 - k = k + 1 - - C[j] += k - - C = C / (nt1 * binsize) - - m = -w + binsize / 2 - B = np.zeros(nbins) - for j in range(nbins): - B[j] = m + j * binsize - - return C, B +@njit(parallel=True) +def _jitcontinuous_perievent( + time_array, data_array, time_target_array, starts, ends, windowsize +): + N_samples = len(time_array) + N_target = len(time_target_array) + N_epochs = len(starts) + count = np.zeros((N_epochs, 2), dtype=np.int64) + start_t = np.zeros((N_epochs, 2), dtype=np.int64) + + k = 0 # Epochs + t = 0 # Samples + i = 0 # Target + + while ends[k] < time_array[t] and ends[k] < time_target_array[i]: + k += 1 + + while k < N_epochs: + # Outside + while t < N_samples: + if time_array[t] >= starts[k]: + break + t += 1 + + while i < N_target: + if time_target_array[i] >= starts[k]: + break + i += 1 + + if time_array[t] <= ends[k]: + start_t[k, 0] = t + + if time_target_array[i] <= ends[k]: + start_t[k, 1] = i + + # Inside + while t < N_samples: + if time_array[t] > ends[k]: + break + else: + count[k, 0] += 1 + t += 1 + + while i < N_target: + if time_target_array[i] > ends[k]: + break + else: + count[k, 1] += 1 + i += 1 + + k += 1 + + if k == N_epochs: + break + if t == N_samples: + break + if i == N_target: + break + + new_data_array = np.full( + (np.sum(windowsize) + 1, np.sum(count[:, 1]), *data_array.shape[1:]), np.nan + ) + + if np.any((count[:, 0] * count[:, 1]) > 0): + for k in prange(N_epochs): + if count[k, 0] > 0 and count[k, 1] > 0: + t = start_t[k, 0] + i = start_t[k, 1] + maxt = t + count[k, 0] + maxi = i + count[k, 1] + cnt_i = np.sum(count[0:k, 1]) + + while i < maxi: + interval = abs(time_array[t] - time_target_array[i]) + t_pos = t + t += 1 + while t < maxt: + new_interval = abs(time_array[t] - time_target_array[i]) + if new_interval > interval: + break + else: + interval = new_interval + t_pos = t + t += 1 + + left = np.minimum(windowsize[0], t_pos - start_t[k, 0]) + right = np.minimum(windowsize[1], maxt - t_pos - 1) + center = windowsize[0] + 1 + new_data_array[center - left - 1 : center + right, cnt_i] = ( + data_array[t_pos - left : t_pos + right + 1] + ) + + t -= 1 + i += 1 + cnt_i += 1 + + return new_data_array @jit(nopython=True) @@ -196,7 +226,7 @@ def _perievent_trigger_average( binsize, ): if nap.utils.get_backend() == "jax": - from pynajax.jax_process_eta import event_trigger_average + from pynajax.jax_process_perievent import event_trigger_average return event_trigger_average( time_array, @@ -233,3 +263,18 @@ def _perievent_trigger_average( windows, binsize, ) + + +def _perievent_continuous( + time_array, data_array, time_target_array, starts, ends, windowsize +): + if nap.utils.get_backend() == "jax": + from pynajax.jax_process_perievent import perievent_continuous + + return perievent_continuous( + time_array, data_array, time_target_array, starts, ends, windowsize + ) + else: + return _jitcontinuous_perievent( + time_array, data_array, time_target_array, starts, ends, windowsize + ) diff --git a/pynapple/process/correlograms.py b/pynapple/process/correlograms.py index c19387ac..d83329c2 100644 --- a/pynapple/process/correlograms.py +++ b/pynapple/process/correlograms.py @@ -1,9 +1,4 @@ -# -*- coding: utf-8 -*- -# @Author: gviejo -# @Date: 2022-01-02 11:39:55 -# @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-04-22 17:23:27 - +"""Cross-correlograms """ from itertools import combinations, product @@ -11,11 +6,79 @@ import pandas as pd from .. import core as nap -from ._process_functions import _cross_correlogram + +from numba import jit ######################################################### # CORRELATION ######################################################### +@jit(nopython=True) +def _cross_correlogram(t1, t2, binsize, windowsize): + """ + Performs the discrete cross-correlogram of two time series. + The units should be in s for all arguments. + Return the firing rate of the series t2 relative to the timings of t1. + See compute_crosscorrelogram, compute_autocorrelogram and compute_eventcorrelogram + for wrappers of this function. + + Parameters + ---------- + t1 : numpy.ndarray + The timestamps of the reference time series (in seconds) + t2 : numpy.ndarray + The timestamps of the target time series (in seconds) + binsize : float + The bin size (in seconds) + windowsize : float + The window size (in seconds) + + Returns + ------- + numpy.ndarray + The cross-correlogram + numpy.ndarray + Center of the bins (in s) + + """ + # nbins = ((windowsize//binsize)*2) + + nt1 = len(t1) + nt2 = len(t2) + + nbins = int((windowsize * 2) // binsize) + if np.floor(nbins / 2) * 2 == nbins: + nbins = nbins + 1 + + w = (nbins / 2) * binsize + C = np.zeros(nbins) + i2 = 0 + + for i1 in range(nt1): + lbound = t1[i1] - w + while i2 < nt2 and t2[i2] < lbound: + i2 = i2 + 1 + while i2 > 0 and t2[i2 - 1] > lbound: + i2 = i2 - 1 + + rbound = lbound + leftb = i2 + for j in range(nbins): + k = 0 + rbound = rbound + binsize + while leftb < nt2 and t2[leftb] < rbound: + leftb = leftb + 1 + k = k + 1 + + C[j] += k + + C = C / (nt1 * binsize) + + m = -w + binsize / 2 + B = np.zeros(nbins) + for j in range(nbins): + B[j] = m + j * binsize + + return C, B def compute_autocorrelogram( diff --git a/pynapple/process/perievent.py b/pynapple/process/perievent.py index af034526..d807e9c9 100644 --- a/pynapple/process/perievent.py +++ b/pynapple/process/perievent.py @@ -5,7 +5,7 @@ import numpy as np from .. import core as nap -from ._process_functions import _perievent_trigger_average +from ._process_functions import _perievent_trigger_average, _perievent_continuous def _align_tsd(tsd, tref, window, time_support): @@ -178,7 +178,7 @@ def compute_perievent_continuous(data, tref, minmax, ep=None, time_unit="s"): time_idx = np.hstack((idx1, np.zeros(1), idx2)) windowsize = np.array([idx1.shape[0], idx2.shape[0]]) - new_data_array = nap._jitted_functions.jitcontinuous_perievent( + new_data_array = _perievent_continuous( time_array, data_array, time_target_array, starts, ends, windowsize ) From 667f348bf5a844256c52bf71e8d33911a3c84331 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 23 Apr 2024 11:16:58 -0400 Subject: [PATCH 22/37] Linting --- pynapple/process/correlograms.py | 2 +- pynapple/process/perievent.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pynapple/process/correlograms.py b/pynapple/process/correlograms.py index d83329c2..41140c18 100644 --- a/pynapple/process/correlograms.py +++ b/pynapple/process/correlograms.py @@ -4,10 +4,10 @@ import numpy as np import pandas as pd +from numba import jit from .. import core as nap -from numba import jit ######################################################### # CORRELATION diff --git a/pynapple/process/perievent.py b/pynapple/process/perievent.py index d807e9c9..65ac1136 100644 --- a/pynapple/process/perievent.py +++ b/pynapple/process/perievent.py @@ -5,7 +5,7 @@ import numpy as np from .. import core as nap -from ._process_functions import _perievent_trigger_average, _perievent_continuous +from ._process_functions import _perievent_continuous, _perievent_trigger_average def _align_tsd(tsd, tref, window, time_support): From 01e5e284adf7e07a3d899c00f96748aeca849896 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 23 Apr 2024 18:23:07 -0400 Subject: [PATCH 23/37] Update --- pynapple/process/_process_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pynapple/process/_process_functions.py b/pynapple/process/_process_functions.py index 6a143144..cb83da45 100644 --- a/pynapple/process/_process_functions.py +++ b/pynapple/process/_process_functions.py @@ -226,7 +226,7 @@ def _perievent_trigger_average( binsize, ): if nap.utils.get_backend() == "jax": - from pynajax.jax_process_perievent import event_trigger_average + from pynajax.jax_process_perievent import event_trigger_average, fill_forward return event_trigger_average( time_array, From ca97c64c92e6f590cf7bbd762deb02d4b4097f8b Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 24 Apr 2024 12:40:06 -0400 Subject: [PATCH 24/37] Update --- pynapple/core/time_series.py | 4 ++-- pynapple/process/_process_functions.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index d9d25054..d24ee589 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -1241,7 +1241,7 @@ def threshold(self, thr, method="above"): thr : float The threshold value method : str, optional - The threshold method (above/below/aboveequal/belowequal) + The threshold method ("above"[default], "below", "aboveequal", "belowequal") Returns ------- @@ -1251,7 +1251,7 @@ def threshold(self, thr, method="above"): Raises ------ ValueError - Raise an error if method is not 'below' or 'above' + Raise an error if method is unknown. RuntimeError Raise an error if thr is too high/low and no epochs is found. diff --git a/pynapple/process/_process_functions.py b/pynapple/process/_process_functions.py index cb83da45..489979d0 100644 --- a/pynapple/process/_process_functions.py +++ b/pynapple/process/_process_functions.py @@ -224,9 +224,10 @@ def _perievent_trigger_average( ends, windows, binsize, + batch_size=64, ): if nap.utils.get_backend() == "jax": - from pynajax.jax_process_perievent import event_trigger_average, fill_forward + from pynajax.jax_process_perievent import event_trigger_average return event_trigger_average( time_array, @@ -237,6 +238,7 @@ def _perievent_trigger_average( ends, windows, binsize, + batch_size, ) else: From 4c4a14cdc20454b40b6f0622ad6aa4d42dd39b94 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Mon, 29 Apr 2024 11:20:28 -0400 Subject: [PATCH 25/37] Pasing tests for pynajax --- pynapple/core/_core_functions.py | 4 +- pynapple/core/_jitted_functions.py | 2 +- pynapple/core/time_series.py | 12 +- pynapple/process/_process_functions.py | 26 +-- pynapple/process/perievent.py | 12 +- pynapple/process/tuning_curves.py | 96 +---------- test_fastplotlib.py | 68 -------- tests/test_abstract_tsd.py | 211 ++++++++++++------------- tests/test_jitted.py | 7 +- tests/test_spike_trigger_average.py | 58 ++++++- tests/test_time_series.py | 62 +++++--- 11 files changed, 228 insertions(+), 330 deletions(-) delete mode 100644 test_fastplotlib.py diff --git a/pynapple/core/_core_functions.py b/pynapple/core/_core_functions.py index 82097f11..b969e787 100644 --- a/pynapple/core/_core_functions.py +++ b/pynapple/core/_core_functions.py @@ -97,7 +97,7 @@ def _convolve(time_array, data_array, starts, ends, array, trim="both"): if get_backend() == "jax": from pynajax.jax_core_convolve import convolve - return convolve(time_array, data_array, starts, ends, array) + return convolve(time_array, data_array, starts, ends, array, trim) else: if data_array.ndim == 1: new_data_array = np.zeros(data_array.shape) @@ -112,7 +112,7 @@ def _convolve(time_array, data_array, starts, ends, array, trim="both"): elif trim == "right": cut = (0, t) else: - cut = ((1 - k % 2) + (k - 1) // 2, t + k - 1 - ((k - 1) // 2)) + cut = ((k - 1) // 2, t + k - 1 - ((k - 1) // 2) - (1 - k % 2)) # scipy is actually faster for Tsd new_data_array[idx_s:idx_e] = signal.convolve( data_array[idx_s:idx_e], array diff --git a/pynapple/core/_jitted_functions.py b/pynapple/core/_jitted_functions.py index efec30d8..a30d143a 100644 --- a/pynapple/core/_jitted_functions.py +++ b/pynapple/core/_jitted_functions.py @@ -384,7 +384,7 @@ def pjitconvolve(data_array, array, trim="both"): new_data_array = np.zeros(data_array.shape) if trim == "both": - cut = ((1 - k % 2) + (k - 1) // 2, t + k - 1 - ((k - 1) // 2)) + cut = ((k - 1) // 2, t + k - 1 - ((k - 1) // 2) - (1 - k % 2)) elif trim == "left": cut = (k - 1, t + k - 1) elif trim == "right": diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index d24ee589..4edd7828 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -165,7 +165,7 @@ def __array_ufunc__(self, ufunc, method, *args, **kwargs): else: out = ufunc(*new_args, **kwargs) - if isinstance(out, np.ndarray): + if isinstance(out, np.ndarray) or is_array_like(out): if out.shape[0] == self.index.shape[0]: kwargs = {} if hasattr(self, "columns"): @@ -181,6 +181,14 @@ def __array_ufunc__(self, ufunc, method, *args, **kwargs): return NotImplemented def __array_function__(self, func, types, args, kwargs): + # print("In __array_function__") + # print(" func = ", func) + # print(" types = ", types) + # print(" args = ", args) + # for inp in args: + # print(type(inp)) + # print(" kwargs = ", kwargs) + if func in [ np.sort, np.lexsort, @@ -208,7 +216,7 @@ def __array_function__(self, func, types, args, kwargs): out = func._implementation(*new_args, **kwargs) - if isinstance(out, np.ndarray): + if isinstance(out, np.ndarray) or is_array_like(out): # # if dims increased in any case, we can't return safely a time series # if out.ndim > self.ndim: # return out diff --git a/pynapple/process/_process_functions.py b/pynapple/process/_process_functions.py index 489979d0..2b403057 100644 --- a/pynapple/process/_process_functions.py +++ b/pynapple/process/_process_functions.py @@ -216,10 +216,10 @@ def _jitperievent_trigger_average( def _perievent_trigger_average( - time_array, - count_array, time_target_array, - data_target_array, + count_array, + time_array, + data_array, starts, ends, windows, @@ -230,10 +230,10 @@ def _perievent_trigger_average( from pynajax.jax_process_perievent import event_trigger_average return event_trigger_average( - time_array, - count_array, time_target_array, - np.expand_dims(data_target_array, -1), + count_array, + time_array, + np.expand_dims(data_array, -1), starts, ends, windows, @@ -242,12 +242,12 @@ def _perievent_trigger_average( ) else: - if data_target_array.ndim == 1: + if data_array.ndim == 1: eta = _jitperievent_trigger_average( - time_array, - count_array, time_target_array, - np.expand_dims(data_target_array, -1), + count_array, + time_array, + np.expand_dims(data_array, -1), starts, ends, windows, @@ -256,10 +256,10 @@ def _perievent_trigger_average( return np.squeeze(eta, -1) else: return _jitperievent_trigger_average( - time_array, - count_array, time_target_array, - data_target_array, + count_array, + time_array, + data_array, starts, ends, windows, diff --git a/pynapple/process/perievent.py b/pynapple/process/perievent.py index 65ac1136..f6210c6a 100644 --- a/pynapple/process/perievent.py +++ b/pynapple/process/perievent.py @@ -279,19 +279,19 @@ def compute_event_trigger_average( # Bin the spike train count = group.count(binsize, ep) - time_array = np.round(count.index.values - (binsize / 2), 9) + time_target_array = np.round(count.index.values - (binsize / 2), 9) count_array = count.values starts = ep.start ends = ep.end - time_target_array = feature.index.values - data_target_array = feature.values + time_array = feature.index.values + data_array = feature.values eta = _perievent_trigger_average( - time_array, - count_array, time_target_array, - data_target_array, + count_array, + time_array, + data_array, starts, ends, windows, diff --git a/pynapple/process/tuning_curves.py b/pynapple/process/tuning_curves.py index d3383a73..af9153af 100644 --- a/pynapple/process/tuning_curves.py +++ b/pynapple/process/tuning_curves.py @@ -541,7 +541,7 @@ def compute_2d_tuning_curves_continuous( for k, tmp in idxs.groupby(cols): if (0 <= k[0] < nb_bins[0]) and (0 <= k[1] < nb_bins[1]): - tc_np[:, k[0], k[1]] = np.mean(tsdframe[tmp.index].values, 0) + tc_np[:, k[0], k[1]] = np.mean(tsdframe[tmp.index.values].values, 0) tc_np[np.isnan(tc_np)] = 0.0 @@ -550,97 +550,3 @@ def compute_2d_tuning_curves_continuous( tc = {c: tc_np[i] for i, c in enumerate(tsdframe.columns)} return tc, xy - - -# def compute_1d_poisson_glm( -# group, feature, binsize, windowsize, ep, time_units="s", niter=100, tolerance=1e-5 -# ): -# """ -# Poisson GLM - -# Warning : this function is still experimental! - -# Parameters -# ---------- -# group : TsGroup -# Spike trains -# feature : Tsd -# The regressors -# binsize : float -# Bin size -# windowsize : Float -# The window for offsetting the regressors -# ep : IntervalSet, optional -# On which epoch to perfom the GLM -# time_units : str, optional -# Time units of binsize and windowsize -# niter : int, optional -# Number of iteration for fitting the GLM -# tolerance : float, optional -# Tolerance for stopping the IRLS - -# Returns -# ------- -# tuple -# regressors : TsdFrame\n -# offset : pandas.Series\n -# prediction : TsdFrame\n - -# Raises -# ------ -# RuntimeError -# if group is not a TsGroup - -# """ -# if type(group) is nap.TsGroup: -# newgroup = group.restrict(ep) -# else: -# raise RuntimeError("Unknown format for group") - -# binsize = nap.TsIndex.format_timestamps(binsize, time_units)[0] -# windowsize = nap.TsIndex.format_timestamps(windowsize, time_units)[0] - -# # Bin the spike train -# count = newgroup.count(binsize) - -# # Downsample the feature to binsize -# tidx = [] -# dfeat = [] -# for i in ep.index: -# bins = np.arange(ep.start[i], ep.end[i] + binsize, binsize) -# idx = np.digitize(feature.index.values, bins) - 1 -# tmp = feature.groupby(idx).mean() -# tidx.append(bins[0:-1] + np.diff(bins) / 2) -# dfeat.append(tmp) -# dfeat = nap.Tsd(t=np.hstack(tidx), d=np.hstack(dfeat), time_support=ep) - -# # Build the Hankel matrix -# nt = np.abs(windowsize // binsize).astype("int") + 1 -# X = hankel( -# np.hstack((np.zeros(nt - 1), dfeat.values))[: -nt + 1], dfeat.values[-nt:] -# ) -# X = np.hstack((np.ones((len(dfeat), 1)), X)) - -# # Fitting GLM for each neuron -# regressors = [] -# for i, n in enumerate(group.keys()): -# print("Fitting Poisson GLM for unit %i" % n) -# b = nap.jitted_functions.jit_poisson_IRLS( -# X, count[n].values, niter=niter, tolerance=tolerance -# ) -# regressors.append(b) - -# regressors = np.array(regressors).T -# offset = regressors[0] -# regressors = regressors[1:] -# regressors = nap.TsdFrame( -# t=np.arange(-nt + 1, 1) * binsize, d=regressors, columns=list(group.keys()) -# ) -# offset = pd.Series(index=group.keys(), data=offset) - -# prediction = nap.TsdFrame( -# t=dfeat.index.values, -# d=np.exp(np.dot(X[:, 1:], regressors.values) + offset.values) * binsize, -# ) - -# return (regressors, offset, prediction) diff --git a/test_fastplotlib.py b/test_fastplotlib.py deleted file mode 100644 index 36dae64b..00000000 --- a/test_fastplotlib.py +++ /dev/null @@ -1,68 +0,0 @@ -# -*- coding: utf-8 -*- -# @Author: Guillaume Viejo -# @Date: 2023-10-31 18:34:19 -# @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-10-31 19:24:48 - -import fastplotlib as fpl -import pynapple as nap -import numpy as np -import sys, os -sys.path.append(os.path.expanduser("~/fastplotlib-sfn2023")) -from _video import LazyVideo -from pathlib import Path -from ipywidgets import HBox - -behavior_path = Path('/mnt/home/gviejo/fastplotlib-sfn2023/sample_data/M238Slc17a7_Chr2/20170824') - -paths_side = sorted(behavior_path.glob("*side_v*.avi")) -paths_front = sorted(behavior_path.glob("*front_v*.avi")) - - -class Concat: - def __init__(self, files): - self.files = files - self.videos = [LazyVideo(p) for p in self.files] - self._nframes_per_video = [v.shape[0] for v in self.videos] - self._cumsum = np.cumsum(self._nframes_per_video) - self.nframes = sum(self._nframes_per_video) - self.shape = (self.nframes, self.videos[0].shape[1], self.videos[0].shape[2]) - self.ndim = 3 - - self.dtype = self.videos[0].dtype - - def __len__(self) -> int: - return self.nframes - - def _get_vid_ix_sub_ix(self, key): - vid_ix = np.searchsorted(self._cumsum, key) - if vid_ix != 0: - sub_ix = key - self._cumsum[vid_ix - 1] - else: - sub_ix = key - - return vid_ix, sub_ix - - def __getitem__(self, key)-> np.ndarray: - if isinstance(key, slice): - start, stop = key.start, key.stop - vid_ix, sub_ix0 = self._get_vid_ix_sub_ix(start) - vid_ix, sub_ix1 = self._get_vid_ix_sub_ix(stop) - return self.videos[vid_ix][sub_ix0:sub_ix1] - elif isinstance(key, int): - vid_ix, sub_ix0 = self._get_vid_ix_sub_ix(key) - return self.videos[vid_ix][sub_ix0] - - - -concat = Concat(paths_side) - -# print(concat.videos) - -t = np.linspace(0, concat.nframes / 500, concat.nframes) - -tsd_video = nap.TsdTensor(t, concat) - -v = LazyVideo(concat.files[0]) - -tsd = nap.TsdTensor(t=np.arange(0, len(v)), d=v) \ No newline at end of file diff --git a/tests/test_abstract_tsd.py b/tests/test_abstract_tsd.py index a54f3f8b..fdde8c0e 100644 --- a/tests/test_abstract_tsd.py +++ b/tests/test_abstract_tsd.py @@ -1,9 +1,3 @@ -# -*- coding: utf-8 -*- -# @Author: Guillaume Viejo -# @Date: 2023-09-25 11:53:30 -# @Last Modified by: gviejo -# @Last Modified time: 2024-02-26 13:06:55 - import pynapple as nap import numpy as np import pandas as pd @@ -15,137 +9,140 @@ class MyClass(BaseTsd): - def __getitem__(self, key): - return key + def __getitem__(self, key): + return key - def __setitem__(self, key, value): - pass + def __setitem__(self, key, value): + pass - def __str__(self): - return "In str" + def __str__(self): + return "In str" - def __repr__(self): - return "In repr" + def __repr__(self): + return "In repr" class MyClass2(Base): - def __getitem__(self, key): - return key + def __getitem__(self, key): + return key - def __setitem__(self, key, value): - pass + def __setitem__(self, key, value): + pass - def __str__(self): - return "In str" + def __str__(self): + return "In str" - def __repr__(self): - return "In repr" + def __repr__(self): + return "In repr" def test_create_atsd(): - a = MyClass(t=np.arange(10), d=np.arange(10)) - - assert hasattr(a, "rate") - assert hasattr(a, "index") - assert hasattr(a, "values") - assert hasattr(a, "time_support") - - assert np.isclose(a.rate, 10/9) - assert isinstance(a.index, nap.TsIndex) - assert isinstance(a.values, np.ndarray) - assert isinstance(a.time_support, nap.IntervalSet) - - assert hasattr(a, "t") - assert hasattr(a, "d") - assert hasattr(a, "start") - assert hasattr(a, "end") - assert hasattr(a, "__array__") - assert hasattr(a, "shape") - assert hasattr(a, "ndim") - assert hasattr(a, "size") - - np.testing.assert_array_equal(a.values, np.arange(10)) - np.testing.assert_array_equal(a.__array__(), np.arange(10)) - - assert len(a) == 10 - - assert a.__repr__() == "In repr" - assert a.__str__() == "In str" - - assert hasattr(a, "__getitem__") - assert hasattr(a, "__setitem__") - assert a[0] == 0 - - b = a.copy() - np.testing.assert_array_equal(a.values, b.values) - np.testing.assert_array_equal(a.index.values, b.index.values) + a = MyClass(t=np.arange(10), d=np.arange(10)) + + assert hasattr(a, "rate") + assert hasattr(a, "index") + assert hasattr(a, "values") + assert hasattr(a, "time_support") + + assert np.isclose(a.rate, 10/9) + assert isinstance(a.index, nap.TsIndex) + try: + assert isinstance(a.values, np.ndarray) + except: + assert nap.core.utils.is_array_like(a.values) + assert isinstance(a.time_support, nap.IntervalSet) + + assert hasattr(a, "t") + assert hasattr(a, "d") + assert hasattr(a, "start") + assert hasattr(a, "end") + assert hasattr(a, "__array__") + assert hasattr(a, "shape") + assert hasattr(a, "ndim") + assert hasattr(a, "size") + + np.testing.assert_array_equal(a.values, np.arange(10)) + np.testing.assert_array_equal(a.__array__(), np.arange(10)) + + assert len(a) == 10 + + assert a.__repr__() == "In repr" + assert a.__str__() == "In str" + + assert hasattr(a, "__getitem__") + assert hasattr(a, "__setitem__") + assert a[0] == 0 + + b = a.copy() + np.testing.assert_array_equal(a.values, b.values) + np.testing.assert_array_equal(a.index.values, b.index.values) def test_create_ats(): - a = MyClass2(t=np.arange(10)) + a = MyClass2(t=np.arange(10)) - assert hasattr(a, "rate") - assert hasattr(a, "index") - assert hasattr(a, "time_support") - assert hasattr(a, "shape") + assert hasattr(a, "rate") + assert hasattr(a, "index") + assert hasattr(a, "time_support") + assert hasattr(a, "shape") - assert np.isclose(a.rate, 10/9) - assert isinstance(a.index, nap.TsIndex) - assert isinstance(a.time_support, nap.IntervalSet) - assert a.shape == a.index.shape + assert np.isclose(a.rate, 10/9) + assert isinstance(a.index, nap.TsIndex) + assert isinstance(a.time_support, nap.IntervalSet) + assert a.shape == a.index.shape - assert hasattr(a, "t") - assert a[0] == 0 + assert hasattr(a, "t") + assert a[0] == 0 def test_create_ats_from_tsindex(): - a = MyClass2(t=TsIndex(np.arange(10))) + a = MyClass2(t=TsIndex(np.arange(10))) - assert hasattr(a, "rate") - assert hasattr(a, "index") - assert hasattr(a, "time_support") - assert hasattr(a, "shape") + assert hasattr(a, "rate") + assert hasattr(a, "index") + assert hasattr(a, "time_support") + assert hasattr(a, "shape") - assert np.isclose(a.rate, 10/9) - assert isinstance(a.index, nap.TsIndex) - assert isinstance(a.time_support, nap.IntervalSet) - assert a.shape == a.index.shape + assert np.isclose(a.rate, 10/9) + assert isinstance(a.index, nap.TsIndex) + assert isinstance(a.time_support, nap.IntervalSet) + assert a.shape == a.index.shape - assert hasattr(a, "t") + assert hasattr(a, "t") @pytest.mark.filterwarnings("ignore") def test_create_ats_from_number(): - a = MyClass2(t=1) + a = MyClass2(t=1) - assert hasattr(a, "rate") - assert hasattr(a, "index") - assert hasattr(a, "time_support") - assert hasattr(a, "shape") - + assert hasattr(a, "rate") + assert hasattr(a, "index") + assert hasattr(a, "time_support") + assert hasattr(a, "shape") + def test_methods(): - a = MyClass(t=[], d=[]) - - np.testing.assert_array_equal(a.times(), np.empty(0)) - np.testing.assert_array_equal(a.as_array(), np.empty(0)) - np.testing.assert_array_equal(a.data(), np.empty(0)) - np.testing.assert_array_equal(a.to_numpy(), np.empty(0)) - - assert a.start_time() is None - assert a.end_time() is None - - assert hasattr(a, "value_from") - assert hasattr(a, "count") - assert hasattr(a, "restrict") - assert hasattr(a, "as_array") - assert hasattr(a, "data") - assert hasattr(a, "to_numpy") - assert hasattr(a, "copy") - assert hasattr(a, "bin_average") - assert hasattr(a, "dropna") - assert hasattr(a, "convolve") - assert hasattr(a, "smooth") - assert hasattr(a, "interpolate") + a = MyClass(t=[], d=[]) + + np.testing.assert_array_equal(a.times(), np.empty(0)) + np.testing.assert_array_equal(a.as_array(), np.empty(0)) + np.testing.assert_array_equal(a.data(), np.empty(0)) + np.testing.assert_array_equal(a.to_numpy(), np.empty(0)) + + assert a.start_time() is None + assert a.end_time() is None + + assert hasattr(a, "value_from") + assert hasattr(a, "count") + assert hasattr(a, "restrict") + assert hasattr(a, "as_array") + assert hasattr(a, "data") + assert hasattr(a, "to_numpy") + assert hasattr(a, "copy") + assert hasattr(a, "bin_average") + assert hasattr(a, "dropna") + assert hasattr(a, "convolve") + assert hasattr(a, "smooth") + assert hasattr(a, "interpolate") diff --git a/tests/test_jitted.py b/tests/test_jitted.py index 6d609849..f97096d5 100644 --- a/tests/test_jitted.py +++ b/tests/test_jitted.py @@ -1,9 +1,3 @@ -# -*- coding: utf-8 -*- -# @Author: gviejo -# @Date: 2022-12-02 17:17:03 -# @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-04-22 17:30:35 - """Tests of jitted core functions for `pynapple` package.""" import pynapple as nap @@ -12,6 +6,7 @@ import pytest import warnings + def get_example_dataset(n=100): with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/tests/test_spike_trigger_average.py b/tests/test_spike_trigger_average.py index 3c657905..e97c9d33 100644 --- a/tests/test_spike_trigger_average.py +++ b/tests/test_spike_trigger_average.py @@ -11,15 +11,65 @@ def test_compute_spike_trigger_average_tsd(): ep = nap.IntervalSet(0, 100) - feature = nap.Tsd( - t=np.arange(0, 101, 0.01), d=np.zeros(int(101 / 0.01)), time_support=ep - ) + d = np.zeros(int(101 / 0.01)) t1 = np.arange(1, 100) x = np.arange(100, 10000, 100) - feature[x] = 1.0 + d[x] = 1.0 + feature = nap.Tsd( + t=np.arange(0, 101, 0.01), d=d, time_support=ep + ) spikes = nap.TsGroup( {0: nap.Ts(t1), 1: nap.Ts(t1 - 0.1), 2: nap.Ts(t1 + 0.2)}, time_support=ep ) + + # ################# + group = spikes + binsize = 0.2 + windowsize = (0.6, 0.6) + count = spikes.count(binsize, ep) + + start, end = windowsize + + idx1 = -np.arange(0, start + binsize, binsize)[::-1][:-1] + idx2 = np.arange(0, end + binsize, binsize)[1:] + time_idx = np.hstack((idx1, np.zeros(1), idx2)) + + eta = np.zeros((time_idx.shape[0], len(group), *feature.shape[1:])) + + windows = np.array([len(idx1), len(idx2)]) + + # Bin the spike train + count = group.count(binsize, ep) + + time_target_array = count.index.values + count_array = count.values + starts = ep.start + ends = ep.end + time_array = feature.index.values + data_array = feature.values + batch_size = 61 + + from pynajax.utils import _get_idxs, _get_slicing, _get_shifted_indices + from pynajax.jax_core_bin_average import bin_average + import jax.numpy as jnp + import jax + from matplotlib.pyplot import * + + time_array, data_array = bin_average(time_array, data_array, starts, ends, binsize) + + figure() + ax = subplot(211) + plot(spikes.to_tsd(), "o", ms = 5) + [axvline(t) for t in np.arange(0, 100, binsize)] + subplot(212, sharex = ax) + plot(time_target_array, count_array[:,0], label="count") + # plot(time_array, data_array.flatten(), label="feature") + legend() + show() + + # ################ + + sta = nap.compute_event_trigger_average(spikes, feature, 0.2, (0.6, 0.6), ep) diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 1c48d794..8b8e40b9 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -182,12 +182,14 @@ def test_raise_error_tsdtensor_init(): def test_index_error(): tsd = nap.Tsd(t=np.arange(100), d=np.arange(100)) - with pytest.raises(IndexError): - tsd[1000] = 0 - ts = nap.Ts(t=np.arange(100)) - with pytest.raises(IndexError): - ts[1000] + + if isinstance(tsd.d, np.ndarray): + with pytest.raises(IndexError): + tsd[1000] = 0 + + with pytest.raises(IndexError): + ts[1000] def test_find_support(): tsd = nap.Tsd(t=np.arange(100), d=np.arange(100)) @@ -203,7 +205,7 @@ def test_find_support(): def test_properties(): t = np.arange(100) - d = np.random.rand(100) + d = np.random.rand(100).astype(np.float32) # to match pynajax tsd = nap.Tsd(t=t, d = d) assert hasattr(tsd, "t") @@ -233,14 +235,17 @@ def __init__(self, t, d): def __getitem__(self, key): return self.values.__getitem__(key) - tsd = DummyTsd([], []) - assert np.isnan(tsd.rate) + tsd = DummyTsd([0, 1], [1, 2]) + assert isinstance(tsd.rate, float) assert isinstance(tsd.index, nap.TsIndex) - assert isinstance(tsd.values, np.ndarray) + try: + assert isinstance(tsd.values, np.ndarray) + except AssertionError: + assert nap.core.utils.is_array_like(tsd.values) # for pynajax assert isinstance(tsd.__repr__(), str) - with pytest.raises(IndexError): + with pytest.raises((IndexError, TypeError)): tsd['a'] @@ -379,9 +384,12 @@ def test_dropna(self, tsd): np.testing.assert_array_equal(tsd.index.values, new_tsd.index.values) np.testing.assert_array_equal(tsd.values, new_tsd.values) - tsd.values[tsd.values>0.9] = np.NaN + tmp = np.random.rand(*tsd.shape) + tmp[tmp>0.9] = np.NaN + tsd = tsd.__class__(t=tsd.t, d=tmp) + new_tsd = tsd.dropna() - assert not np.all(np.isnan(new_tsd)) + assert not np.all(np.isnan(new_tsd.values)) tokeep = np.array([~np.any(np.isnan(tsd[i])) for i in range(len(tsd))]) np.testing.assert_array_equal(tsd.index.values[tokeep], new_tsd.index.values) np.testing.assert_array_equal(tsd.values[tokeep], new_tsd.values) @@ -395,7 +403,7 @@ def test_dropna(self, tsd): np.testing.assert_array_equal(tsd.values[tokeep], new_tsd.values) np.testing.assert_array_equal(new_tsd.time_support, tsd.time_support) - tsd.values[:] = np.NaN + tsd = tsd.__class__(t=tsd.t, d=np.ones(tsd.shape)*np.NaN) new_tsd = tsd.dropna() assert len(new_tsd) == 0 assert len(new_tsd.time_support) == 0 @@ -407,7 +415,7 @@ def test_convolve(self, tsd): tmp = tsd.values.reshape(tsd.shape[0], -1) tmp2 = np.zeros_like(tmp) for i in range(tmp.shape[-1]): - tmp2[:,i] = np.convolve(tmp[:,i], array, mode='full')[5:-4] + tmp2[:,i] = np.convolve(tmp[:,i], array, mode='full')[4:-5] np.testing.assert_array_almost_equal( tmp2, tsd2.values.reshape(tsd2.shape[0], -1) @@ -424,18 +432,20 @@ def test_convolve(self, tsd): ep = nap.IntervalSet(start=[0, 60], end=[40,100]) tsd3 = tsd.convolve(array, ep) + tmp3 = [] for i in range(len(ep)): tmp2 = tsd.restrict(ep[i]).values - tmp2 = tmp2.reshape(tmp2.shape[0], -1) + tmp2 = np.array(tmp2.reshape(tmp2.shape[0], -1)) # for pynajax for j in range(tmp2.shape[-1]): - tmp2[:,j] = np.convolve(tmp2[:,j], array, mode='full')[5:-4] + tmp2[:,j] = np.convolve(tmp2[:,j], array, mode='full')[4:-5] + tmp3.append(tmp2) np.testing.assert_array_almost_equal( tmp2, tsd3.restrict(ep[i]).values.reshape(tmp2.shape[0], -1) ) # Trim - for trim, sl in zip(['left', 'both', 'right'], [slice(9,None),slice(5,-4),slice(None,-9)]): + for trim, sl in zip(['left', 'both', 'right'], [slice(9,None),slice(4,-5),slice(None,-9)]): tsd2 = tsd.convolve(array, trim=trim) tmp = tsd.values.reshape(tsd.shape[0], -1) tmp2 = np.zeros_like(tmp) @@ -462,7 +472,7 @@ def test_smooth(self, tsd): window = signal.windows.gaussian(M, std=std) window = window / window.sum() for i in range(tmp.shape[-1]): - tmp2[:,i] = np.convolve(tmp[:,i], window, mode='full')[M//2:1-M//2] + tmp2[:,i] = np.convolve(tmp[:,i], window, mode='full')[M//2-1:1-M//2-1] np.testing.assert_array_almost_equal( tmp2, tsd2.values.reshape(tsd2.shape[0], -1) @@ -483,7 +493,7 @@ def test_smooth(self, tsd): M = std*200 window = signal.windows.gaussian(M, std=std) for i in range(tmp.shape[-1]): - tmp2[:,i] = np.convolve(tmp[:,i], window, mode='full')[M//2:1-M//2] + tmp2[:,i] = np.convolve(tmp[:,i], window, mode='full')[M//2-1:1-M//2-1] np.testing.assert_array_almost_equal( tmp2, tsd2.values.reshape(tsd2.shape[0], -1) @@ -496,7 +506,7 @@ def test_smooth(self, tsd): M = int(tsd.rate * 10) window = signal.windows.gaussian(M, std=std) for i in range(tmp.shape[-1]): - tmp2[:,i] = np.convolve(tmp[:,i], window, mode='full')[M//2:1-M//2] + tmp2[:,i] = np.convolve(tmp[:,i], window, mode='full')[M//2-1:1-M//2-1] np.testing.assert_array_almost_equal( tmp2, tsd2.values.reshape(tsd2.shape[0], -1) @@ -633,13 +643,13 @@ def test_bin_average_with_ep(self, tsd): def test_threshold(self, tsd): thrs = tsd.threshold(0.5, "above") - assert len(thrs) == np.sum(tsd.values > 0.5) + assert len(thrs) == int(np.sum(tsd.values > 0.5)) thrs = tsd.threshold(0.5, "below") - assert len(thrs) == np.sum(tsd.values < 0.5) + assert len(thrs) == int(np.sum(tsd.values < 0.5)) thrs = tsd.threshold(0.5, "aboveequal") - assert len(thrs) == np.sum(tsd.values >= 0.5) + assert len(thrs) == int(np.sum(tsd.values >= 0.5)) thrs = tsd.threshold(0.5, "belowequal") - assert len(thrs) == np.sum(tsd.values <= 0.5) + assert len(thrs) == int(np.sum(tsd.values <= 0.5)) def test_threshold_time_support(self, tsd): thrs = tsd.threshold(0.5, "above") @@ -806,8 +816,8 @@ def test_vertical_slicing(self, tsdframe): def test_str_indexing(self, tsdframe): tsdframe = nap.TsdFrame(t=np.arange(100), d=np.random.rand(100, 3), time_units="s", columns=['a', 'b', 'c']) - np.testing.assert_array_almost_equal(tsdframe.values[:,0], tsdframe['a']) - np.testing.assert_array_almost_equal(tsdframe.values[:,[0,2]], tsdframe[['a', 'c']]) + np.testing.assert_array_almost_equal(tsdframe.values[:,0], tsdframe['a'].values) + np.testing.assert_array_almost_equal(tsdframe.values[:,[0,2]], tsdframe[['a', 'c']].values) with pytest.raises(Exception): tsdframe['d'] From ad4ef3567085e2a68fac1b9fbdbb32660861f3d3 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Mon, 29 Apr 2024 16:41:33 -0400 Subject: [PATCH 26/37] Updating docs --- docs/examples/tutorial_pynapple_io.py | 2 +- .../examples/tutorial_pynapple_quick_start.py | 2 +- docs/external.md | 17 ++++ docs/pynajax.md | 65 +++++++++++++ draft_pynapple_fastplotlib.py | 2 +- mkdocs.yml | 2 + pynapple/process/_process_functions.py | 2 +- tests/test_spike_trigger_average.py | 95 +++++-------------- 8 files changed, 113 insertions(+), 74 deletions(-) create mode 100644 docs/external.md create mode 100644 docs/pynajax.md diff --git a/docs/examples/tutorial_pynapple_io.py b/docs/examples/tutorial_pynapple_io.py index 8d02e4c9..17f4dc72 100644 --- a/docs/examples/tutorial_pynapple_io.py +++ b/docs/examples/tutorial_pynapple_io.py @@ -50,7 +50,7 @@ # The object project behaves like a nested dictionnary. It is then easy to loop and navigate through a hierarchy of folders when doing analyses. In this case, we are gonna take only the session A2929-200711. -session = project["sub-A2929"]["ses-A2929-200711"] +session = project["sub-A2929"]["A2929-200711"] print(session) diff --git a/docs/examples/tutorial_pynapple_quick_start.py b/docs/examples/tutorial_pynapple_quick_start.py index 28492373..2d2ecb4f 100644 --- a/docs/examples/tutorial_pynapple_quick_start.py +++ b/docs/examples/tutorial_pynapple_quick_start.py @@ -54,7 +54,7 @@ # The object `data` is a [`Folder`](https://pynapple-org.github.io/pynapple/io.folder/) object that allows easy navigation and interaction with a dataset. # In this case, we want to load the NWB file in the folder `/pynapplenwb`. Data are always lazy loaded. No time series is loaded until it's actually called. # When calling the NWB file, the object `nwb` is an interface to the NWB file. All the data inside the NWB file that are compatible with one of the pynapple objects are shown with their corresponding keys. -nwb = data["sub-A2929"]["ses-A2929-200711"]["pynapplenwb"]["A2929-200711"] +nwb = data["sub-A2929"]["A2929-200711"]["pynapplenwb"]["A2929-200711"] print(nwb) diff --git a/docs/external.md b/docs/external.md new file mode 100644 index 00000000..3e1665e0 --- /dev/null +++ b/docs/external.md @@ -0,0 +1,17 @@ +Pynapple has been designed as a lightweight package for representing time series and epochs in system neuroscience. +As such, it can function as a foundational element for other analysis packages handling time series data. Here we keep track of external projects that uses pynapple. + + +## NEMOS + +![image](https://raw.githubusercontent.com/flatironinstitute/nemos/main/docs/assets/glm_features_scheme.svg) + +[NeMOs](https://nemos.readthedocs.io/en/stable/) is a statistical modeling framework optimized for systems neuroscience and powered by JAX. It streamlines the process of defining and selecting models, through a collection of easy-to-use methods for feature design. + +The core of nemos includes GPU-accelerated, well-tested implementations of standard statistical models, currently focusing on the Generalized Linear Model (GLM). + +Check out this [page](https://nemos.readthedocs.io/en/stable/generated/neural_modeling/) for many examples of neural modelling using nemos and pynapple. + +!!! note + Nemos is build on top of [jax](https://jax.readthedocs.io/en/latest/index.html), a library for high-performance numerical computing. + To ensure full compatibility with nemos, consider installing [pynajax](https://github.com/pynapple-org/pynajax), a pynapple backend for jax. \ No newline at end of file diff --git a/docs/pynajax.md b/docs/pynajax.md new file mode 100644 index 00000000..e664f8e4 --- /dev/null +++ b/docs/pynajax.md @@ -0,0 +1,65 @@ +### Motivation + +Multiple python packages exist for high-performance computing. Internally, pynapple makes extensive use of [numba](https://numba.pydata.org/) for accelerating some functions. Numba is a stable package that provide speed gains with minimal installation issues when running on CPUs. + +Another high-performance toolbox for numerical analysis is +[jax](https://jax.readthedocs.io/en/latest/index.html). In addition to accelerating python code on CPUs, GPUs, and TPUs, it provides a special representation of arrays using the [jax Array object](https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html). Unfortunately, jax Array is incompatible with Numba. To solve this issue, we developped [pynajax](https://github.com/pynapple-org/pynajax). + +Pynajax is an accelerated backend for pynapple built on top on jax. It offers a fast acceleration for some pynapple functions using CPU or GPU. Here is a minimal example on how to use pynajax: + +``` bash +$ pip install pynajax +``` + + + +``` python +import pynapple as nap +import numpy as np + +# Changed the backend from 'numba' to 'jax' +nap.nap_config.set_backend("jax") + +# This will convert the numpy array to a jax Array. +tsd = nap.Tsd(t=np.arange(100), d=np.random.randn(100)) + +# This will run on GPU or CPU depending on the jax installation +tsd.convolve(np.ones(11)) +``` + +This documentation page keeps tracks of the list of pynapple functions that can be jax-accelerated as well as their performances compared to pure numba. + +### Installation issues + +To get the best of the pynajax backend, jax needs to use the GPU. + +While installing pynajax will install all the dependencies necessary to use jax, it does not guarantee +the use of the GPU. + +To check if jax is using the GPU, you can run the following python commands : + +- no GPU found : + + ```python + >>> import jax + >>> print(jax.devices()) + [CpuDevice(id=0)] + ``` + +- GPU found : + + ```python + >>> import jax + >>> print(jax.devices()) + [cuda(id=0)] + ``` + +Support for installing with GPU can be found in the [jax documentation](https://jax.readthedocs.io/en/latest/installation.html) + + +### Typical use-case + + +In addition to providing high performance numerical computing, jax can be used a the backbone of a large scale machine learning model. Thus, pynajax can offer full compatibility between pynapple's time series representation and computational neuroscience models constructed using jax. + +An example of a python package using both pynapple and jax is [NeMOs](https://nemos.readthedocs.io/en/stable/). \ No newline at end of file diff --git a/draft_pynapple_fastplotlib.py b/draft_pynapple_fastplotlib.py index da4c4538..54a65bb9 100644 --- a/draft_pynapple_fastplotlib.py +++ b/draft_pynapple_fastplotlib.py @@ -35,7 +35,7 @@ # mkdocs_gallery_thumbnail_path = '../_static/fastplotlib_demo.png' #nwb = nap.load_file("/Users/gviejo/pynapple/Mouse32-220101.nwb") -nwb = nap.load_file("your/path/to/MyProject/sub-A2929/ses-A2929-200711/pynapplenwb/A2929-200711.nwb") +nwb = nap.load_file("your/path/to/MyProject/sub-A2929/A2929-200711/pynapplenwb/A2929-200711.nwb") units = nwb['units']#.getby_category("location")['adn'] diff --git a/mkdocs.yml b/mkdocs.yml index dd380a7f..2660731c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -36,6 +36,8 @@ plugins: nav: - Overview: index.md - Usage: generated/gallery + - External projects: external.md + - Pynajax - GPU acceleration: pynajax.md - Modules : reference/ - Contributing: CONTRIBUTING.md - Authors: AUTHORS.md diff --git a/pynapple/process/_process_functions.py b/pynapple/process/_process_functions.py index 2b403057..862103f3 100644 --- a/pynapple/process/_process_functions.py +++ b/pynapple/process/_process_functions.py @@ -233,7 +233,7 @@ def _perievent_trigger_average( time_target_array, count_array, time_array, - np.expand_dims(data_array, -1), + data_array, starts, ends, windows, diff --git a/tests/test_spike_trigger_average.py b/tests/test_spike_trigger_average.py index e97c9d33..ae8c5dfb 100644 --- a/tests/test_spike_trigger_average.py +++ b/tests/test_spike_trigger_average.py @@ -22,55 +22,6 @@ def test_compute_spike_trigger_average_tsd(): {0: nap.Ts(t1), 1: nap.Ts(t1 - 0.1), 2: nap.Ts(t1 + 0.2)}, time_support=ep ) - # ################# - group = spikes - binsize = 0.2 - windowsize = (0.6, 0.6) - count = spikes.count(binsize, ep) - - start, end = windowsize - - idx1 = -np.arange(0, start + binsize, binsize)[::-1][:-1] - idx2 = np.arange(0, end + binsize, binsize)[1:] - time_idx = np.hstack((idx1, np.zeros(1), idx2)) - - eta = np.zeros((time_idx.shape[0], len(group), *feature.shape[1:])) - - windows = np.array([len(idx1), len(idx2)]) - - # Bin the spike train - count = group.count(binsize, ep) - - time_target_array = count.index.values - count_array = count.values - starts = ep.start - ends = ep.end - time_array = feature.index.values - data_array = feature.values - batch_size = 61 - - from pynajax.utils import _get_idxs, _get_slicing, _get_shifted_indices - from pynajax.jax_core_bin_average import bin_average - import jax.numpy as jnp - import jax - from matplotlib.pyplot import * - - time_array, data_array = bin_average(time_array, data_array, starts, ends, binsize) - - figure() - ax = subplot(211) - plot(spikes.to_tsd(), "o", ms = 5) - [axvline(t) for t in np.arange(0, 100, binsize)] - subplot(212, sharex = ax) - plot(time_target_array, count_array[:,0], label="count") - # plot(time_array, data_array.flatten(), label="feature") - legend() - show() - - # ################ - - - sta = nap.compute_event_trigger_average(spikes, feature, 0.2, (0.6, 0.6), ep) output = np.zeros((7, 3)) @@ -80,16 +31,17 @@ def test_compute_spike_trigger_average_tsd(): assert isinstance(sta, nap.TsdFrame) assert sta.shape == output.shape - np.testing.assert_array_almost_equal(sta, output) + np.testing.assert_array_almost_equal(sta.values, output) def test_compute_spike_trigger_average_tsdframe(): ep = nap.IntervalSet(0, 100) + d = np.zeros((int(101 / 0.01),1)) + x = np.arange(100, 10000, 100) + d[x] = 1.0 feature = nap.TsdFrame( - t=np.arange(0, 101, 0.01), d=np.zeros((int(101 / 0.01),1)), time_support=ep + t=np.arange(0, 101, 0.01), d=d, time_support=ep ) - t1 = np.arange(1, 100) - x = np.arange(100, 10000, 100) - feature[x] = 1.0 + t1 = np.arange(1, 100) spikes = nap.TsGroup( {0: nap.Ts(t1), 1: nap.Ts(t1 - 0.1), 2: nap.Ts(t1 + 0.2)}, time_support=ep ) @@ -103,16 +55,17 @@ def test_compute_spike_trigger_average_tsdframe(): assert isinstance(sta, nap.TsdTensor) assert sta.shape == (*output.shape, 1) - np.testing.assert_array_almost_equal(sta, np.expand_dims(output, 2)) + np.testing.assert_array_almost_equal(sta.values, np.expand_dims(output, 2)) def test_compute_spike_trigger_average_tsdtensor(): ep = nap.IntervalSet(0, 100) + d=np.zeros((int(101 / 0.01),1,1)) + x = np.arange(100, 10000, 100) + d[x] = 1.0 feature = nap.TsdTensor( - t=np.arange(0, 101, 0.01), d=np.zeros((int(101 / 0.01),1,1)), time_support=ep + t=np.arange(0, 101, 0.01), d=d, time_support=ep ) - t1 = np.arange(1, 100) - x = np.arange(100, 10000, 100) - feature[x] = 1.0 + t1 = np.arange(1, 100) spikes = nap.TsGroup( {0: nap.Ts(t1), 1: nap.Ts(t1 - 0.1), 2: nap.Ts(t1 + 0.2)}, time_support=ep ) @@ -126,7 +79,7 @@ def test_compute_spike_trigger_average_tsdtensor(): assert isinstance(sta, nap.TsdTensor) assert sta.shape == output.shape - np.testing.assert_array_almost_equal(sta, output) + np.testing.assert_array_almost_equal(sta.values, output) def test_compute_spike_trigger_average_random_feature(): ep = nap.IntervalSet(0, 100) @@ -163,18 +116,19 @@ def test_compute_spike_trigger_average_random_feature(): def test_compute_spike_trigger_average_add_nan(): ep = nap.IntervalSet(0, 110) + d=np.zeros(int(110 / 0.01)) + x = np.arange(100, 10000, 100) + d[x] = 1.0 + d[-1001:] = np.nan feature = nap.Tsd( - t=np.arange(0, 110, 0.01), d=np.zeros(int(110 / 0.01)), time_support=ep + t=np.arange(0, 110, 0.01), d=d, time_support=ep ) t1 = np.arange(1, 100) - x = np.arange(100, 10000, 100) - feature[x] = 1.0 + spikes = nap.TsGroup( {0: nap.Ts(t1), 1: nap.Ts(t1 - 0.1), 2: nap.Ts(t1 + 0.2)}, time_support=ep ) - feature[-1001:] = np.nan - sta = nap.compute_event_trigger_average(spikes, feature, 0.2, (0.6, 0.6), ep) output = np.zeros((7, 3)) @@ -184,16 +138,17 @@ def test_compute_spike_trigger_average_add_nan(): assert isinstance(sta, nap.TsdFrame) assert sta.shape == output.shape - np.testing.assert_array_almost_equal(sta, output) + np.testing.assert_array_almost_equal(sta.values, output) def test_compute_spike_trigger_average_raise_error(): ep = nap.IntervalSet(0, 101) + d=np.zeros(int(101 / 0.01)) + x = np.arange(100, 10000, 100)+1 + d[x] = 1.0 feature = nap.Tsd( - t=np.arange(0, 101, 0.01), d=np.zeros(int(101 / 0.01)), time_support=ep + t=np.arange(0, 101, 0.01), d=d , time_support=ep ) t1 = np.arange(1, 101) + 0.01 - x = np.arange(100, 10000, 100)+1 - feature[x] = 1.0 spikes = nap.TsGroup( {0: nap.Ts(t1), 1: nap.Ts(t1 - 0.1), 2: nap.Ts(t1 + 0.2)}, time_support=ep ) @@ -281,7 +236,7 @@ def test_compute_spike_trigger_average_no_windows(): assert isinstance(sta, nap.TsdFrame) assert sta.shape == output.shape - np.testing.assert_array_almost_equal(sta, output) + np.testing.assert_array_almost_equal(sta.values, output) def test_compute_spike_trigger_average_multiple_epochs(): From d86965c5f99f4f7bba1793c635ef401253994ccc Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 30 Apr 2024 20:49:08 -0400 Subject: [PATCH 27/37] testing fastplotlib --- draft_pynapple_fastplotlib.py | 88 ++++++++++++++++++++++++++--------- pyproject.toml | 5 ++ 2 files changed, 70 insertions(+), 23 deletions(-) diff --git a/draft_pynapple_fastplotlib.py b/draft_pynapple_fastplotlib.py index 54a65bb9..03c88db3 100644 --- a/draft_pynapple_fastplotlib.py +++ b/draft_pynapple_fastplotlib.py @@ -15,58 +15,100 @@ """ # %% -# !!! warning -# This tutorial uses seaborn and matplotlib for displaying the figure -# -# You can install all with `pip install matplotlib seaborn tqdm` -# -# mkdocs_gallery_thumbnail_number = 1 -# -# Now, import the necessary libraries: - -# %qui qt +# %gui qt import pynapple as nap import numpy as np import fastplotlib as fpl -import imageio.v3 as iio import sys # mkdocs_gallery_thumbnail_path = '../_static/fastplotlib_demo.png' -#nwb = nap.load_file("/Users/gviejo/pynapple/Mouse32-220101.nwb") -nwb = nap.load_file("your/path/to/MyProject/sub-A2929/A2929-200711/pynapplenwb/A2929-200711.nwb") +def get_memory_map(filepath, nChannels, frequency=20000): + n_channels = int(nChannels) + f = open(filepath, 'rb') + startoffile = f.seek(0, 0) + endoffile = f.seek(0, 2) + bytes_size = 2 + n_samples = int((endoffile-startoffile)/n_channels/bytes_size) + duration = n_samples/frequency + interval = 1/frequency + f.close() + fp = np.memmap(filepath, np.int16, 'r', shape = (n_samples, n_channels)) + timestep = np.arange(0, n_samples)/frequency + + return fp, timestep + + +#### LFP +data_array, time_array = get_memory_map("your/path/to/MyProject/sub-A2929/A2929-200711/A2929-200711.dat", 16) +lfp = nap.TsdFrame(t=time_array, d=data_array) + +lfp2 = lfp.get(0, 20)[:,14] +lfp2 = np.vstack((lfp2.t, lfp2.d)).T +#### NWB +nwb = nap.load_file("your/path/to/MyProject/sub-A2929/A2929-200711/pynapplenwb/A2929-200711.nwb") units = nwb['units']#.getby_category("location")['adn'] +tmp = units.to_tsd().get(0, 20) +tmp = np.vstack((tmp.index.values, tmp.values)).T + + + +fig = fpl.Figure(shape=(2,1)) +fig[0,0].add_line(data=lfp2, thickness=1, cmap="autumn") +fig[1,0].add_scatter(tmp) +fig.show() + + + + + +# grid_plot = fpl.GridPlot(shape=(2, 1), controller_ids="sync", names = ['lfp', 'wavelet']) +# grid_plot['lfp'].add_line(lfp.t, lfp[:,14].d) + + + + + + + -tmp = units.to_tsd() + + + + + + + + + +sys.exit() + +################################################################################################# + + +nwb = nap.load_file("your/path/to/MyProject/sub-A2929/A2929-200711/pynapplenwb/A2929-200711.nwb") +units = nwb['units']#.getby_category("location")['adn'] +tmp = units.to_tsd() tmp = np.vstack((tmp.index.values, tmp.values)).T # Example 1 fplot = fpl.Plot() - fplot.add_scatter(tmp) - fplot.graphics[0].cmap = "jet" - fplot.graphics[0].cmap.values = tmp[:, 1] - fplot.show(maintain_aspect=False) # Example 2 names = [['raster'], ['position']] - grid_plot = fpl.GridPlot(shape=(2, 1), controller_ids="sync", names = names) - grid_plot['raster'].add_scatter(tmp) - grid_plot['position'].add_line(np.vstack((nwb['ry'].t, nwb['ry'].d)).T) - grid_plot.show(maintain_aspect=False) - grid_plot['raster'].auto_scale(maintain_aspect=False) diff --git a/pyproject.toml b/pyproject.toml index 0bd2d464..8a2db4fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,9 @@ repository = "https://github.com/pynapple-org/pynapple" # Optional dependencies ########################################################################## [project.optional-dependencies] +jax = [ + "pynajax" +] dev = [ "black>=24.2.0", # Code formatter "isort", # Import sorter @@ -70,6 +73,8 @@ dandi = [ "aiohttp", ] + + [tool.black] target-version = ['py36', 'py37', 'py38', 'py39'] skip-string-normalization = false From bf866f0a56fb748e52b755c34d24adcc59220ca9 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Mon, 6 May 2024 17:01:51 -0400 Subject: [PATCH 28/37] changing tests for pynajax --- draft_pynapple_fastplotlib.py | 37 ++++++++++++++++++++++++++++++++--- pynapple/core/utils.py | 2 ++ tests/test_non_numpy_array.py | 24 ++++++++++++++--------- 3 files changed, 51 insertions(+), 12 deletions(-) diff --git a/draft_pynapple_fastplotlib.py b/draft_pynapple_fastplotlib.py index 03c88db3..a3dd423d 100644 --- a/draft_pynapple_fastplotlib.py +++ b/draft_pynapple_fastplotlib.py @@ -55,11 +55,11 @@ def get_memory_map(filepath, nChannels, frequency=20000): -fig = fpl.Figure(shape=(2,1)) +fig = fpl.Figure(canvas="glfw", shape=(2,1)) fig[0,0].add_line(data=lfp2, thickness=1, cmap="autumn") fig[1,0].add_scatter(tmp) -fig.show() - +fig.show(maintain_aspect=False) +# fpl.run() @@ -68,6 +68,37 @@ def get_memory_map(filepath, nChannels, frequency=20000): # grid_plot['lfp'].add_line(lfp.t, lfp[:,14].d) +import numpy as np +import fastplotlib as fpl + +fig = fpl.Figure(canvas="glfw")#, shape=(2,1), controller_ids="sync") +fig[0,0].add_line(data=np.random.randn(1000)) +fig.show(maintain_aspect=False) + +fig2 = fpl.Figure(canvas="glfw", controllers=fig.controllers)#, shape=(2,1), controller_ids="sync") +fig2[0,0].add_line(data=np.random.randn(1000)*1000) +fig2.show(maintain_aspect=False) + + + +# Not sure about this : +fig[1,0].controller.controls["mouse1"] = "pan", "drag", (1.0, 0.0) + +fig[1,0].controller.controls.pop("mouse2") +fig[1,0].controller.controls.pop("mouse4") +fig[1,0].controller.controls.pop("wheel") + +import pygfx + +controller = pygfx.PanZoomController() +controller.controls.pop("mouse1") +controller.add_camera(fig[0, 0].camera) +controller.register_events(fig[0, 0].viewport) + +controller2 = pygfx.PanZoomController() +controller2.add_camera(fig[1, 0].camera) +controller2.controls.pop("mouse1") +controller2.register_events(fig[1, 0].viewport) diff --git a/pynapple/core/utils.py b/pynapple/core/utils.py index 3d5578ab..9339e5aa 100644 --- a/pynapple/core/utils.py +++ b/pynapple/core/utils.py @@ -78,6 +78,8 @@ def convert_to_jax_array(array, array_name): return array elif isinstance(array, np.ndarray): return cast_to_jax(array, array_name) + elif is_array_like(array): + return cast_to_jax(array, array_name) else: raise RuntimeError( "Unknown format for {}. Accepted formats are numpy.ndarray, list, tuple or any array-like objects.".format( diff --git a/tests/test_non_numpy_array.py b/tests/test_non_numpy_array.py index 9f6d8c62..7553e661 100644 --- a/tests/test_non_numpy_array.py +++ b/tests/test_non_numpy_array.py @@ -99,7 +99,8 @@ def test_tsd_type_d(self, time, data, expectation): """Verify that the data attribute 'd' of a Tsd object is stored as a numpy.ndarray.""" with expectation: ts = nap.Tsd(t=time, d=data) - assert isinstance(ts.d, np.ndarray) + if nap.nap_config.backend == "numba": + assert isinstance(ts.d, np.ndarray) @pytest.mark.filterwarnings("ignore") @pytest.mark.parametrize( @@ -132,8 +133,9 @@ def test_tsd_type_t(self, time, data, expectation): ) def test_tsd_warn(self, data, expectation): """Check for warnings when the data attribute 'd' is automatically converted to numpy.ndarray.""" - with expectation: - nap.Tsd(t=np.array(data), d=data) + if nap.nap_config.backend == "numba": + with expectation: + nap.Tsd(t=np.array(data), d=data) class TestTsdFrameArray: @@ -169,7 +171,8 @@ def test_tsdframe_type(self, time, data, expectation): """Verify that the data attribute 'd' of a TsdFrame object is stored as a numpy.ndarray.""" with expectation: ts = nap.TsdFrame(t=time, d=data) - assert isinstance(ts.d, np.ndarray) + if nap.nap_config.backend == "numba": + assert isinstance(ts.d, np.ndarray) @pytest.mark.filterwarnings("ignore") @pytest.mark.parametrize( @@ -202,8 +205,9 @@ def test_tsdframe_type_t(self, time, data, expectation): ) def test_tsdframe_warn(self, data, expectation): """Check for warnings when the data attribute 'd' is automatically converted to numpy.ndarray.""" - with expectation: - nap.TsdFrame(t=np.array(data), d=data) + if nap.nap_config.backend == "numba": + with expectation: + nap.TsdFrame(t=np.array(data), d=data) class TestTsdTensorArray: @@ -245,7 +249,8 @@ def test_tsdtensor_type_d(self, time, data, expectation): """Verify that the data attribute 'd' of a TsdTensor object is stored as a numpy.ndarray.""" with expectation: ts = nap.TsdTensor(t=time, d=data) - assert isinstance(ts.d, np.ndarray) + if nap.nap_config.backend == "numba": + assert isinstance(ts.d, np.ndarray) @pytest.mark.filterwarnings("ignore") @pytest.mark.parametrize( @@ -278,6 +283,7 @@ def test_tsdtensor_type_t(self, time, data, expectation): ) def test_tsdtensor_warn(self, data, expectation): """Check for warnings when the data attribute 'd' is automatically converted to numpy.ndarray.""" - with expectation: - nap.TsdTensor(t=np.ravel(np.array(data)), d=data) + if nap.nap_config.backend == "numba": + with expectation: + nap.TsdTensor(t=np.ravel(np.array(data)), d=data) From ef0151c5fc834be2cc20eab69146468058644140 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 7 May 2024 15:27:18 -0400 Subject: [PATCH 29/37] CHanged perivent continuous --- pynapple/process/_process_functions.py | 128 +++++++++++-------------- tests/test_perievent.py | 30 +++++- 2 files changed, 84 insertions(+), 74 deletions(-) diff --git a/pynapple/process/_process_functions.py b/pynapple/process/_process_functions.py index 862103f3..3798882f 100644 --- a/pynapple/process/_process_functions.py +++ b/pynapple/process/_process_functions.py @@ -9,82 +9,40 @@ """ import numpy as np -from numba import jit, njit, prange +from numba import jit from .. import core as nap -@njit(parallel=True) -def _jitcontinuous_perievent( - time_array, data_array, time_target_array, starts, ends, windowsize -): - N_samples = len(time_array) - N_target = len(time_target_array) +@jit(nopython=True) +def _jitcontinuous_perievent(time_array, time_target_array, starts, ends, windowsize): N_epochs = len(starts) count = np.zeros((N_epochs, 2), dtype=np.int64) - start_t = np.zeros((N_epochs, 2), dtype=np.int64) - - k = 0 # Epochs - t = 0 # Samples - i = 0 # Target - - while ends[k] < time_array[t] and ends[k] < time_target_array[i]: - k += 1 - - while k < N_epochs: - # Outside - while t < N_samples: - if time_array[t] >= starts[k]: - break - t += 1 - - while i < N_target: - if time_target_array[i] >= starts[k]: - break - i += 1 - - if time_array[t] <= ends[k]: - start_t[k, 0] = t - - if time_target_array[i] <= ends[k]: - start_t[k, 1] = i - - # Inside - while t < N_samples: - if time_array[t] > ends[k]: - break - else: - count[k, 0] += 1 - t += 1 - - while i < N_target: - if time_target_array[i] > ends[k]: - break - else: - count[k, 1] += 1 - i += 1 - - k += 1 - - if k == N_epochs: - break - if t == N_samples: - break - if i == N_target: - break - new_data_array = np.full( - (np.sum(windowsize) + 1, np.sum(count[:, 1]), *data_array.shape[1:]), np.nan + idx, count[:, 1] = nap._jitted_functions.jitrestrict_with_count( + time_target_array, starts, ends ) + time_target_array = time_target_array[idx] + + idx, count[:, 0] = nap._jitted_functions.jitrestrict_with_count( + time_array, starts, ends + ) + time_array = time_array[idx] + + N_target = len(time_target_array) + + slice_idx = np.zeros((N_target, 2), dtype=np.int64) + start_w = np.zeros(N_target, dtype=np.int64) if np.any((count[:, 0] * count[:, 1]) > 0): - for k in prange(N_epochs): + for k in range(N_epochs): if count[k, 0] > 0 and count[k, 1] > 0: - t = start_t[k, 0] - i = start_t[k, 1] + t = np.sum(count[0:k, 0]) + i = np.sum(count[0:k, 1]) maxt = t + count[k, 0] maxi = i + count[k, 1] - cnt_i = np.sum(count[0:k, 1]) + + start_t = t while i < maxi: interval = abs(time_array[t] - time_target_array[i]) @@ -99,18 +57,17 @@ def _jitcontinuous_perievent( t_pos = t t += 1 - left = np.minimum(windowsize[0], t_pos - start_t[k, 0]) + left = np.minimum(windowsize[0], t_pos - start_t) right = np.minimum(windowsize[1], maxt - t_pos - 1) - center = windowsize[0] + 1 - new_data_array[center - left - 1 : center + right, cnt_i] = ( - data_array[t_pos - left : t_pos + right + 1] - ) + # center = windowsize[0] + 1 + + slice_idx[i] = (t_pos - left, t_pos + right + 1) + start_w[i] = windowsize[0] - left t -= 1 i += 1 - cnt_i += 1 - return new_data_array + return idx, slice_idx, np.sum(count[:, 1]), start_w @jit(nopython=True) @@ -270,13 +227,38 @@ def _perievent_trigger_average( def _perievent_continuous( time_array, data_array, time_target_array, starts, ends, windowsize ): + + idx, slice_idx, N_target, w_starts = _jitcontinuous_perievent( + time_array, time_target_array, starts, ends, windowsize + ) + + data_array = data_array[idx] + if nap.utils.get_backend() == "jax": from pynajax.jax_process_perievent import perievent_continuous return perievent_continuous( - time_array, data_array, time_target_array, starts, ends, windowsize + data_array, np.sum(windowsize) + 1, N_target, slice_idx, w_starts ) else: - return _jitcontinuous_perievent( - time_array, data_array, time_target_array, starts, ends, windowsize + new_data_array = np.full( + (np.sum(windowsize) + 1, N_target, *data_array.shape[1:]), np.nan ) + + w_sizes = slice_idx[:, 1] - slice_idx[:, 0] # Different sizes + + all_w_sizes = np.unique(w_sizes) + all_w_start = np.unique(w_starts) + + for w_size in all_w_sizes: + for w_start in all_w_start: + col_idx = w_sizes == w_size + new_idx = np.zeros((w_size, np.sum(col_idx)), dtype=int) + for i, tmp in enumerate(slice_idx[col_idx]): + new_idx[:, i] = np.arange(tmp[0], tmp[1]) + + new_data_array[w_start : w_start + w_size, col_idx] = data_array[ + new_idx + ] + + return new_data_array diff --git a/tests/test_perievent.py b/tests/test_perievent.py index 2c8f14b7..761e99cd 100644 --- a/tests/test_perievent.py +++ b/tests/test_perievent.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-03-30 11:16:53 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-12-12 17:45:38 +# @Last Modified time: 2024-05-07 15:22:24 #!/usr/bin/env python """Tests of perievent for `pynapple` package.""" @@ -109,7 +109,21 @@ def test_compute_perievent_continuous(): tsd = nap.Tsd(t=np.arange(100), d=np.arange(100)) tref = nap.Ts(t=np.array([20, 60])) minmax=(-5, 10) + + # time_array = tsd.t + # data_array = tsd.d + # time_target_array = tref.t + # starts = tsd.time_support.start + # ends = tsd.time_support.end + # window = np.abs(minmax) + # binsize = time_array[1] - time_array[0] + # idx1 = -np.arange(0, window[0] + binsize, binsize)[::-1][:-1] + # idx2 = np.arange(0, window[1] + binsize, binsize)[1:] + # time_idx = np.hstack((idx1, np.zeros(1), idx2)) + # windowsize = np.array([idx1.shape[0], idx2.shape[0]]) + pe = nap.compute_perievent_continuous(tsd, tref, minmax=minmax) + assert isinstance(pe, nap.TsdFrame) assert pe.shape[1] == len(tref) np.testing.assert_array_almost_equal(pe.index.values, np.arange(minmax[0], minmax[-1]+1)) @@ -182,6 +196,20 @@ def test_compute_perievent_continuous_with_ep(): np.testing.assert_array_almost_equal(pe.values, tmp) tref = ep.starts + + # time_array = tsd.t + # data_array = tsd.d + # time_target_array = tref.t + # starts = ep.start + # ends = ep.end + # window = np.abs(minmax) + # binsize = time_array[1] - time_array[0] + # idx1 = -np.arange(0, window[0] + binsize, binsize)[::-1][:-1] + # idx2 = np.arange(0, window[1] + binsize, binsize)[1:] + # time_idx = np.hstack((idx1, np.zeros(1), idx2)) + # windowsize = np.array([idx1.shape[0], idx2.shape[0]]) + + pe = nap.compute_perievent_continuous(tsd, tref, minmax=minmax, ep=ep) tmp = np.array([np.arange(t, t+minmax[1]+1) for t in tref.restrict(ep).t]).T np.testing.assert_array_almost_equal(pe.values[abs(minmax[0]):], tmp) From 1bff1c8672acf4df084d3dd834fb30d892998cea Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 7 May 2024 15:51:20 -0400 Subject: [PATCH 30/37] linting --- pynapple/process/_process_functions.py | 2 +- tox.ini | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pynapple/process/_process_functions.py b/pynapple/process/_process_functions.py index 3798882f..0713d26c 100644 --- a/pynapple/process/_process_functions.py +++ b/pynapple/process/_process_functions.py @@ -28,7 +28,7 @@ def _jitcontinuous_perievent(time_array, time_target_array, starts, ends, window time_array, starts, ends ) time_array = time_array[idx] - + N_target = len(time_target_array) slice_idx = np.zeros((N_target, 2), dtype=np.int64) diff --git a/tox.ini b/tox.ini index d3e36cee..e0a8a5ee 100644 --- a/tox.ini +++ b/tox.ini @@ -12,7 +12,7 @@ extras = dev commands = black --check pynapple isort --check pynapple --profile black - flake8 pynapple --max-complexity 10 + flake8 pynapple --max-complexity 10 coverage run --source=pynapple --branch -m pytest tests/ coverage report -m From 8847a2796bfe3849da8450840162bcd0d413455f Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 7 May 2024 15:53:44 -0400 Subject: [PATCH 31/37] Update docs/pynajax.md Co-authored-by: Edoardo Balzani --- docs/pynajax.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/pynajax.md b/docs/pynajax.md index e664f8e4..6673806c 100644 --- a/docs/pynajax.md +++ b/docs/pynajax.md @@ -54,7 +54,7 @@ To check if jax is using the GPU, you can run the following python commands : [cuda(id=0)] ``` -Support for installing with GPU can be found in the [jax documentation](https://jax.readthedocs.io/en/latest/installation.html) +Support for installing `JAX` for GPU users can be found in the [jax documentation](https://jax.readthedocs.io/en/latest/installation.html) ### Typical use-case From bc51b404a902a43fc04b48ef7990e337b8203d8c Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 7 May 2024 16:50:03 -0400 Subject: [PATCH 32/37] Update docs/pynajax.md Co-authored-by: Edoardo Balzani --- docs/pynajax.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/pynajax.md b/docs/pynajax.md index 6673806c..52634d6e 100644 --- a/docs/pynajax.md +++ b/docs/pynajax.md @@ -60,6 +60,6 @@ Support for installing `JAX` for GPU users can be found in the [jax documentatio ### Typical use-case -In addition to providing high performance numerical computing, jax can be used a the backbone of a large scale machine learning model. Thus, pynajax can offer full compatibility between pynapple's time series representation and computational neuroscience models constructed using jax. +In addition to providing high performance numerical computing, jax can be used as a the backbone for a large scale machine learning model. Thus, pynajax can offer full compatibility between pynapple's time series representation and computational neuroscience models constructed using jax. An example of a python package using both pynapple and jax is [NeMOs](https://nemos.readthedocs.io/en/stable/). \ No newline at end of file From c2f72b30a299ca4827dbdac6ff659ff1f6de8367 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 7 May 2024 16:56:13 -0400 Subject: [PATCH 33/37] Update pynapple/core/time_series.py Co-authored-by: Edoardo Balzani --- pynapple/core/time_series.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 4edd7828..f845e5a0 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -181,14 +181,6 @@ def __array_ufunc__(self, ufunc, method, *args, **kwargs): return NotImplemented def __array_function__(self, func, types, args, kwargs): - # print("In __array_function__") - # print(" func = ", func) - # print(" types = ", types) - # print(" args = ", args) - # for inp in args: - # print(type(inp)) - # print(" kwargs = ", kwargs) - if func in [ np.sort, np.lexsort, From 6e28b6168cffa8cbf88ff0e7e0039945dad6a138 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 7 May 2024 16:56:33 -0400 Subject: [PATCH 34/37] Update pynapple/core/time_series.py Co-authored-by: Edoardo Balzani --- pynapple/core/time_series.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index f845e5a0..f7c43f87 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -291,7 +291,7 @@ def value_from(self, data, ep=None): >>> newts = ts.value_from(tsd, ep) - newts is the same size as ts restrict to ep. + newts has the same size of ts restrict to ep. >>> print(len(ts.restrict(ep)), len(newts)) 52 52 From 7282d0d809709b2f61a2864d1b11265930b6c057 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 7 May 2024 17:03:43 -0400 Subject: [PATCH 35/37] Update pynapple/core/time_series.py Co-authored-by: Edoardo Balzani --- pynapple/core/time_series.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index f7c43f87..6cc446d5 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -471,7 +471,8 @@ def convolve(self, array, ep=None, trim="both"): Parameters ---------- array : array-like - + One dimensional input array-like. + ep : None, optional The epochs to apply the convolution trim : str, optional From 95d078a6037d4ce6a9f925af4549122255cdf927 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 7 May 2024 17:13:15 -0400 Subject: [PATCH 36/37] Update --- pynapple/core/time_series.py | 14 +++------ pynapple/core/utils.py | 50 +++++++------------------------ tests/test_correlograms.py | 6 ---- tests/test_numpy_compatibility.py | 8 ----- 4 files changed, 14 insertions(+), 64 deletions(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 6cc446d5..c7751348 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -36,9 +36,7 @@ _get_terminal_size, _split_tsd, _TsdFrameSliceHelper, - convert_to_jax_array, - convert_to_numpy_array, - get_backend, + convert_to_array, is_array_like, ) @@ -73,11 +71,7 @@ class BaseTsd(Base, NDArrayOperatorsMixin, abc.ABC): def __init__(self, t, d, time_units="s", time_support=None): super().__init__(t, time_units, time_support) - # Check if jax backend - if get_backend() == "jax": - self.values = convert_to_jax_array(d, "d") - else: - self.values = convert_to_numpy_array(d, "d") + self.values = convert_to_array(d, "d") assert len(self.index) == len( self.values @@ -455,7 +449,7 @@ def dropna(self, update_time_support=True): if hasattr(self, "columns"): kwargs["columns"] = self.columns - return self.__class__(t=t, d=d, time_support=ep) + return self.__class__(t=t, d=d, time_support=ep, **kwargs) def convolve(self, array, ep=None, trim="both"): """Return the discrete linear convolution of the time series with a one dimensional sequence. @@ -472,7 +466,7 @@ def convolve(self, array, ep=None, trim="both"): ---------- array : array-like One dimensional input array-like. - + ep : None, optional The epochs to apply the convolution trim : str, optional diff --git a/pynapple/core/utils.py b/pynapple/core/utils.py index 9339e5aa..c44073ab 100644 --- a/pynapple/core/utils.py +++ b/pynapple/core/utils.py @@ -12,6 +12,16 @@ from .config import nap_config +def convert_to_array(array, array_name): + # Check if jax backend + if get_backend() == "jax": + from pynajax.utils import convert_to_jax_array + + return convert_to_jax_array(array, array_name) + else: + return convert_to_numpy_array(array, array_name) + + def convert_to_numpy_array(array, array_name): """Convert any array like object to numpy ndarray. @@ -48,46 +58,6 @@ def convert_to_numpy_array(array, array_name): ) -def convert_to_jax_array(array, array_name): - """Convert any array like object to jax Array. - - Parameters - ---------- - array : ArrayLike - - array_name : str - Array name if RuntimeError is raised or object is casted to numpy - - Returns - ------- - jax.Array - Jax array object - - Raises - ------ - RuntimeError - If input can't be converted to jax array - """ - import jax.numpy as jnp - - if isinstance(array, Number): - return jnp.array([array]) - elif isinstance(array, (list, tuple)): - return jnp.array(array) - elif isinstance(array, jnp.ndarray): - return array - elif isinstance(array, np.ndarray): - return cast_to_jax(array, array_name) - elif is_array_like(array): - return cast_to_jax(array, array_name) - else: - raise RuntimeError( - "Unknown format for {}. Accepted formats are numpy.ndarray, list, tuple or any array-like objects.".format( - array_name - ) - ) - - def get_backend(): """ Return the current backend of pynapple. Possible backends are diff --git a/tests/test_correlograms.py b/tests/test_correlograms.py index 9e31e45d..185c3522 100644 --- a/tests/test_correlograms.py +++ b/tests/test_correlograms.py @@ -1,9 +1,3 @@ -# -*- coding: utf-8 -*- -# @Author: gviejo -# @Date: 2022-03-30 11:16:22 -# @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-04-04 10:46:31 -#!/usr/bin/env python """Tests of correlograms for `pynapple` package.""" diff --git a/tests/test_numpy_compatibility.py b/tests/test_numpy_compatibility.py index cafc42b0..aecb5677 100644 --- a/tests/test_numpy_compatibility.py +++ b/tests/test_numpy_compatibility.py @@ -1,11 +1,3 @@ -# -*- coding: utf-8 -*- -# @Author: Guillaume Viejo -# @Date: 2023-09-18 18:11:24 -# @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-04-04 11:34:35 - - - import pynapple as nap import numpy as np import pytest From f71610d7881075107082a0c4271a6170af0949ec Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 7 May 2024 17:28:29 -0400 Subject: [PATCH 37/37] Final update --- pynapple/core/time_series.py | 2 +- pynapple/core/utils.py | 43 +++--------------------------------- 2 files changed, 4 insertions(+), 41 deletions(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index c7751348..4ff91aa2 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -36,7 +36,7 @@ _get_terminal_size, _split_tsd, _TsdFrameSliceHelper, - convert_to_array, + convert_to_array, is_array_like, ) diff --git a/pynapple/core/utils.py b/pynapple/core/utils.py index c44073ab..b6ea1e24 100644 --- a/pynapple/core/utils.py +++ b/pynapple/core/utils.py @@ -17,7 +17,9 @@ def convert_to_array(array, array_name): if get_backend() == "jax": from pynajax.utils import convert_to_jax_array - return convert_to_jax_array(array, array_name) + return convert_to_jax_array( + array, array_name, nap_config.suppress_conversion_warnings + ) else: return convert_to_numpy_array(array, array_name) @@ -188,45 +190,6 @@ def cast_to_numpy(array, array_name): return np.asarray(array) -def cast_to_jax(array, array_name): - """ - Convert an input array-like object to a jax Array. - - - Parameters - ---------- - array : array_like - The input object to convert. This can be any object that `np.asarray` is capable of - converting to a jax array, such as lists, tuples, and other array-like objects. - array_name : str - The name of the variable that we are converting, printed in the warning message. - - Returns - ------- - ndarray - A jax Array representation of the input `values`. If `values` is already a jax - Array, it is returned unchanged. Otherwise, a new jax Array is created and returned. - - Warnings - -------- - A warning is issued if the input `values` is not already a jax Array, indicating - that a conversion has taken place and showing the original type of the input. - - """ - import jax.numpy as jnp - - if ( - not isinstance(array, jnp.ndarray) - and not nap_config.suppress_conversion_warnings - ): - original_type = type(array).__name__ - warnings.warn( - f"Converting '{array_name}' to jax.ndarray. The provided array was of type '{original_type}'.", - UserWarning, - ) - return jnp.asarray(array) - - def _check_time_equals(time_arrays): """ Check if a list of time arrays are all equal.