From 25b179c68d26bc4d56b3abf762c4ea5a199dcd2b Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Mon, 29 Jan 2024 11:18:04 -0500 Subject: [PATCH 1/7] solving Issue #220 --- pynapple/process/tuning_curves.py | 77 ++++++++++++++++++++----------- tests/test_decoding.py | 4 +- tests/test_tuning_curves.py | 12 ++--- 3 files changed, 58 insertions(+), 35 deletions(-) diff --git a/pynapple/process/tuning_curves.py b/pynapple/process/tuning_curves.py index 3311c119..4e1647d2 100644 --- a/pynapple/process/tuning_curves.py +++ b/pynapple/process/tuning_curves.py @@ -4,7 +4,7 @@ # @Author: gviejo # @Date: 2022-01-02 23:33:42 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-01-26 15:28:51 +# @Last Modified time: 2024-01-29 11:10:07 import warnings @@ -145,16 +145,16 @@ def compute_1d_tuning_curves(group, feature, nb_bins, ep=None, minmax=None): return tuning_curves -def compute_2d_tuning_curves(group, feature, nb_bins, ep=None, minmax=None): +def compute_2d_tuning_curves(group, features, nb_bins, ep=None, minmax=None): """ - Computes 2-dimensional tuning curves relative to a 2d feature + Computes 2-dimensional tuning curves relative to a 2d features Parameters ---------- group : TsGroup The group of Ts/Tsd for which the tuning curves will be computed - feature : TsdFrame - The 2d feature (i.e. 2 columns features). + features : TsdFrame + The 2d features (i.e. 2 columns features). nb_bins : int Number of bins in the tuning curves ep : IntervalSet, optional @@ -175,32 +175,32 @@ def compute_2d_tuning_curves(group, feature, nb_bins, ep=None, minmax=None): Raises ------ RuntimeError - If group is not a TsGroup object or if feature is not 2 columns only. + If group is not a TsGroup object or if features is not 2 columns only. """ assert isinstance(group, nap.TsGroup), "group should be a TsGroup." assert isinstance( - feature, nap.TsdFrame - ), "feature should be a TsdFrame with 2 columns" - if isinstance(feature, nap.TsdFrame): - assert feature.shape[1] == 2, "feature should have 2 columns only." + features, nap.TsdFrame + ), "features should be a TsdFrame with 2 columns" + if isinstance(features, nap.TsdFrame): + assert features.shape[1] == 2, "features should have 2 columns only." assert isinstance(nb_bins, int) if ep is None: - ep = feature.time_support + ep = features.time_support else: assert isinstance(ep, nap.IntervalSet), "ep should be an IntervalSet" - feature = feature.restrict(ep) + features = features.restrict(ep) - cols = list(feature.columns) + cols = list(features.columns) groups_value = {} binsxy = {} for i, c in enumerate(cols): - groups_value[c] = group.value_from(feature.loc[c], ep) + groups_value[c] = group.value_from(features.loc[c], ep) if minmax is None: bins = np.linspace( - np.min(feature.loc[c]), np.max(feature.loc[c]), nb_bins + 1 + np.min(features.loc[c]), np.max(features.loc[c]), nb_bins + 1 ) else: assert isinstance(minmax, tuple), "minmax should be a tuple of 4 elements" @@ -208,8 +208,8 @@ def compute_2d_tuning_curves(group, feature, nb_bins, ep=None, minmax=None): binsxy[c] = bins occupancy, _, _ = np.histogram2d( - feature.loc[cols[0]].values.flatten(), - feature.loc[cols[1]].values.flatten(), + features.loc[cols[0]].values.flatten(), + features.loc[cols[1]].values.flatten(), [binsxy[cols[0]], binsxy[cols[1]]], ) @@ -222,7 +222,7 @@ def compute_2d_tuning_curves(group, feature, nb_bins, ep=None, minmax=None): ) count = count / occupancy # count[np.isnan(count)] = 0.0 - tc[n] = count * feature.rate + tc[n] = count * features.rate xy = [binsxy[c][0:-1] + np.diff(binsxy[c]) / 2 for c in binsxy.keys()] @@ -241,8 +241,8 @@ def compute_1d_mutual_info(tc, feature, ep=None, minmax=None, bitssec=False): ---------- tc : pandas.DataFrame or numpy.ndarray Tuning curves in columns - feature : Tsd - The feature that was used to compute the tuning curves + feature : Tsd (or TsdFrame with 1 column only) + The 1-dimensional target feature (e.g. head-direction) ep : IntervalSet, optional The epoch over which the tuning curves were computed If None, the epoch is the time support of the feature. @@ -265,6 +265,14 @@ def compute_1d_mutual_info(tc, feature, ep=None, minmax=None, bitssec=False): fx = np.atleast_2d(tc) columns = np.arange(tc.shape[1]) + assert isinstance( + feature, (nap.Tsd, nap.TsdFrame) + ), "feature should be a Tsd (or TsdFrame with 1 column only)" + if isinstance(feature, nap.TsdFrame): + assert ( + feature.shape[1] == 1 + ), "feature should be a Tsd (or TsdFrame with 1 column only)" + nb_bins = tc.shape[0] + 1 if minmax is None: bins = np.linspace(np.min(feature), np.max(feature), nb_bins) @@ -332,6 +340,12 @@ def compute_2d_mutual_info(tc, features, ep=None, minmax=None, bitssec=False): fx = tc idx = np.arange(len(tc)) + assert isinstance( + features, nap.TsdFrame + ), "features should be a TsdFrame with 2 columns" + if isinstance(features, nap.TsdFrame): + assert features.shape[1] == 2, "features should have 2 columns only." + nb_bins = (fx.shape[1] + 1, fx.shape[2] + 1) cols = features.columns @@ -388,8 +402,8 @@ def compute_1d_tuning_curves_continuous( tsdframe : Tsd or TsdFrame Input data (e.g. continus calcium data where each column is the calcium activity of one neuron) - feature : Tsd - The feature (one column) + feature : Tsd (or TsdFrame with 1 column only) + The 1-dimensional target feature (e.g. head-direction) nb_bins : int Number of bins in the tuning curves ep : IntervalSet, optional @@ -413,6 +427,15 @@ def compute_1d_tuning_curves_continuous( if not isinstance(tsdframe, (nap.Tsd, nap.TsdFrame)): raise RuntimeError("Unknown format for tsdframe.") + assert isinstance( + feature, (nap.Tsd, nap.TsdFrame) + ), "feature should be a Tsd (or TsdFrame with 1 column only)" + if isinstance(feature, nap.TsdFrame): + assert ( + feature.shape[1] == 1 + ), "feature should be a Tsd (or TsdFrame with 1 column only)" + feature = np.squeeze(feature) + if isinstance(ep, nap.IntervalSet): feature = feature.restrict(ep) tsdframe = tsdframe.restrict(ep) @@ -474,8 +497,11 @@ def compute_2d_tuning_curves_continuous( if not isinstance(tsdframe, (nap.Tsd, nap.TsdFrame)): raise RuntimeError("Unknown format for tsdframe.") - if not isinstance(features, nap.TsdFrame): - raise RuntimeError("Unknown format for features.") + assert isinstance( + features, nap.TsdFrame + ), "features should be a TsdFrame with 2 columns" + if isinstance(features, nap.TsdFrame): + assert features.shape[1] == 2, "features should have 2 columns only." if isinstance(ep, nap.IntervalSet): features = features.restrict(ep) @@ -483,9 +509,6 @@ def compute_2d_tuning_curves_continuous( else: tsdframe = tsdframe.restrict(features.time_support) - if features.shape[1] != 2: - raise RuntimeError("features input is not 2 columns.") - if isinstance(nb_bins, int): nb_bins = (nb_bins, nb_bins) elif len(nb_bins) != 2: diff --git a/tests/test_decoding.py b/tests/test_decoding.py index 988342ea..f181badf 100644 --- a/tests/test_decoding.py +++ b/tests/test_decoding.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-03-30 11:16:39 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-09-18 14:30:10 +# @Last Modified time: 2024-01-29 11:15:41 #!/usr/bin/env python """Tests of decoding for `pynapple` package.""" @@ -112,7 +112,7 @@ def get_testing_set_2d(): ) tc, xy = nap.compute_2d_tuning_curves( - group=group, feature=features, nb_bins=2, minmax=(-0.5, 1.5, -0.5, 1.5) + group=group, features=features, nb_bins=2, minmax=(-0.5, 1.5, -0.5, 1.5) ) ep = nap.IntervalSet(start=0, end=100) return features, group, tc, ep, tuple(xy) diff --git a/tests/test_tuning_curves.py b/tests/test_tuning_curves.py index 730e5486..e93c9d97 100644 --- a/tests/test_tuning_curves.py +++ b/tests/test_tuning_curves.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-03-30 11:16:30 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-01-26 15:23:20 +# @Last Modified time: 2024-01-29 11:05:11 """Tests of tuning curves for `pynapple` package.""" @@ -120,7 +120,7 @@ def test_compute_2d_tuning_curves_error(): features = nap.TsdFrame(t=np.arange(100), d=np.random.rand(100, 3)) with pytest.raises(AssertionError) as e_info: nap.compute_2d_tuning_curves(tsgroup, features, 10) - assert str(e_info.value) == "feature should have 2 columns only." + assert str(e_info.value) == "features should have 2 columns only." def test_compute_2d_tuning_curves_with_ep(): tsgroup = nap.TsGroup( @@ -300,16 +300,16 @@ def test_compute_2d_tuning_curves_continuous_error(): nap.compute_2d_tuning_curves_continuous([1,2,3], features, 2) assert str(e_info.value) == "Unknown format for tsdframe." - with pytest.raises(RuntimeError) as e_info: + with pytest.raises(AssertionError) as e_info: nap.compute_2d_tuning_curves_continuous(tsdframe, [1,2,3], 2) - assert str(e_info.value) == "Unknown format for features." + assert str(e_info.value) == "features should be a TsdFrame with 2 columns" features = nap.TsdFrame( t=np.arange(100), d=np.tile(np.array([[0, 0, 1, 1], [0, 1, 0, 1], [0,0,0,0]]), 25).T ) - with pytest.raises(RuntimeError) as e_info: + with pytest.raises(AssertionError) as e_info: nap.compute_2d_tuning_curves_continuous(tsdframe, features, 2) - assert str(e_info.value) == "features input is not 2 columns." + assert str(e_info.value) == "features should have 2 columns only." @pytest.mark.filterwarnings("ignore") From 3df6037faf4fa595b99a507f867db592a328e682 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Mon, 29 Jan 2024 11:28:08 -0500 Subject: [PATCH 2/7] Fixing #219 --- pynapple/core/ts_group.py | 10 +++++----- tests/test_ts_group.py | 10 +++++++--- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 7f1dfaad..7b9427cc 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-01-28 15:10:48 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-09-21 15:57:41 +# @Last Modified time: 2024-01-29 11:23:07 import os @@ -254,7 +254,7 @@ def set_info(self, *args, **kwargs): *args pandas.Dataframe or list of pandas.DataFrame **kwargs - Can be either pandas.Series or numpy.ndarray + Can be either pandas.Series, numpy.ndarray, list or tuple Raises ------ @@ -286,7 +286,7 @@ def set_info(self, *args, **kwargs): 1 2 pfc 2 4 ca1 - To add metadata with a pd.Series or numpy.ndarray: + To add metadata with a pd.Series, numpy.ndarray, list or tuple: >>> hd = pd.Series(index = [0,1,2], data = [0,1,1]) >>> tsgroup.set_info(hd=hd) @@ -314,9 +314,9 @@ def set_info(self, *args, **kwargs): self._metadata[k] = v else: raise RuntimeError("Index are not equals") - elif isinstance(v, np.ndarray): + elif isinstance(v, (np.ndarray, list, tuple)): if len(self._metadata) == len(v): - self._metadata[k] = v + self._metadata[k] = np.asarray(v) else: raise RuntimeError("Array is not the same length.") return diff --git a/tests/test_ts_group.py b/tests/test_ts_group.py index 05effe7e..4ec2c0fa 100644 --- a/tests/test_ts_group.py +++ b/tests/test_ts_group.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-03-30 11:14:41 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-09-18 17:05:05 +# @Last Modified time: 2024-01-29 11:26:22 """Tests of ts group for `pynapple` package.""" @@ -89,11 +89,15 @@ def test_add_metainfo(self, group): df_info = pd.DataFrame(index=[0, 1, 2], data=[0, 0, 0], columns=["df"]) sr_info = pd.Series(index=[0, 1, 2], data=[1, 1, 1], name="sr") ar_info = np.ones(3) * 3 - tsgroup.set_info(df_info, sr=sr_info, ar=ar_info) - assert tsgroup._metadata.shape == (3, 4) + lt_info = [3,4,5] + tu_info = (6,8,3) + tsgroup.set_info(df_info, sr=sr_info, ar=ar_info, lt=lt_info, tu=tu_info) + assert tsgroup._metadata.shape == (3, 6) pd.testing.assert_series_equal(tsgroup._metadata["df"], df_info["df"]) pd.testing.assert_series_equal(tsgroup._metadata["sr"], sr_info) np.testing.assert_array_almost_equal(tsgroup._metadata["ar"].values, ar_info) + np.testing.assert_array_almost_equal(tsgroup._metadata["lt"].values, lt_info) + np.testing.assert_array_almost_equal(tsgroup._metadata["tu"].values, tu_info) def test_add_metainfo_raise_error(self, group): tsgroup = nap.TsGroup(group) From a3a2ba6dde1e92e489c23a2011953ca62afe52fd Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Mon, 29 Jan 2024 12:50:38 -0500 Subject: [PATCH 3/7] Changing jitted_functions to _jitted_functions --- docs/gen_ref_pages.py | 94 +-- pynapple/core/interval_set.py | 6 +- pynapple/core/jitted_functions.py | 1084 ----------------------------- pynapple/core/time_series.py | 4 +- pynapple/core/ts_group.py | 4 +- pynapple/process/perievent.py | 8 +- pyproject.toml | 2 +- tests/test_jitted.py | 34 +- 8 files changed, 82 insertions(+), 1154 deletions(-) delete mode 100644 pynapple/core/jitted_functions.py diff --git a/docs/gen_ref_pages.py b/docs/gen_ref_pages.py index 9ce064fc..278a6d1d 100644 --- a/docs/gen_ref_pages.py +++ b/docs/gen_ref_pages.py @@ -12,56 +12,68 @@ io_orders = ['interface_nwb', 'interface_npz', 'folder', 'misc'] + deprecated +ignored = ['_jitted_functions'] for path in sorted(Path("pynapple").rglob("*.py")): - print(path) module_path = path.relative_to("pynapple").with_suffix("") - doc_path = path.relative_to("pynapple").with_suffix(".md") - full_doc_path = Path("reference", doc_path) - - parts = tuple(module_path.parts) - - if parts[-1] == "__init__": - parts = parts[:-1] - doc_path = doc_path.with_name("index.md") - full_doc_path = full_doc_path.with_name("index.md") - elif parts[-1] == "__main__": - continue - - if len(parts): - nav[parts] = doc_path.as_posix() - # if the md file name is `module.md`, generate documentation from docstrings - if full_doc_path.name != 'index.md': - # sys.exit() - with mkdocs_gen_files.open(full_doc_path, "w") as fd: - ident = "pynapple."+".".join(parts) - fd.write(f"::: {ident}") - # if the md file name is `index.md`, add the list of modules with hyperlinks - else: - this_module_path = Path("pynapple") / path.parent.name - module_index = "" - - module_order = sorted(this_module_path.rglob("*.py")) - module_order = [m.name.replace('.py', '') for m in module_order] + + if module_path.name not in ignored: + + doc_path = path.relative_to("pynapple").with_suffix(".md") + full_doc_path = Path("reference", doc_path) + + # print(module_path, "\t", doc_path, "\t", full_doc_path) + + parts = tuple(module_path.parts) + + if parts[-1] == "__init__": + parts = parts[:-1] + doc_path = doc_path.with_name("index.md") + full_doc_path = full_doc_path.with_name("index.md") + elif parts[-1] == "__main__": + continue + + # print(parts, doc_path) + # if str(doc_path) == "core/_jitted_functions.md": + # sys.exit() + + if len(parts): + nav[parts] = doc_path.as_posix() + + # if the md file name is `module.md`, generate documentation from docstrings + if full_doc_path.name != 'index.md': + + with mkdocs_gen_files.open(full_doc_path, "w") as fd: + ident = "pynapple."+".".join(parts) + fd.write(f"::: {ident}") - if "io" in this_module_path.name: - module_order = io_orders + else: + this_module_path = Path("pynapple") / path.parent.name + module_index = "" + + module_order = sorted(this_module_path.rglob("*.py")) + module_order = [m.name.replace('.py', '') for m in module_order] + + if "io" in this_module_path.name: + module_order = io_orders - for m in module_order: - if "__init__" in m: - continue + for m in module_order: + if "__init__" in m: + continue + if m[0] == "_": + continue - module_name = m - if m in deprecated: - module_name += " (deprecated)" - module_index += f"* [{module_name}]" \ - "("+m+".md)\n" + module_name = m + if m in deprecated: + module_name += " (deprecated)" + module_index += f"* [{module_name}]" \ + "("+m+".md)\n" - with mkdocs_gen_files.open(full_doc_path, "w") as fd: - fd.write(module_index) + with mkdocs_gen_files.open(full_doc_path, "w") as fd: + fd.write(module_index) - mkdocs_gen_files.set_edit_path(full_doc_path, path) + mkdocs_gen_files.set_edit_path(full_doc_path, path) with mkdocs_gen_files.open("reference/index.md", "w") as nav_file: nav_file.writelines(nav.build_literate_nav()) \ No newline at end of file diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 2b4a1a00..62de64ff 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # @Date: 2022-01-25 21:50:48 -# @Last Modified by: gviejo -# @Last Modified time: 2023-10-15 16:18:42 +# @Last Modified by: Guillaume Viejo +# @Last Modified time: 2024-01-29 12:16:38 """ """ @@ -14,7 +14,7 @@ import pandas as pd from numba import jit -from .jitted_functions import jitdiff, jitin_interval, jitintersect, jitunion +from ._jitted_functions import jitdiff, jitin_interval, jitintersect, jitunion from .time_index import TsIndex all_warnings = np.array( diff --git a/pynapple/core/jitted_functions.py b/pynapple/core/jitted_functions.py deleted file mode 100644 index 8679adfa..00000000 --- a/pynapple/core/jitted_functions.py +++ /dev/null @@ -1,1084 +0,0 @@ -# -*- coding: utf-8 -*- -# @Author: guillaume -# @Date: 2022-10-31 16:44:31 -# @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-01-25 16:43:34 -import numpy as np -from numba import jit, njit, prange - - -@jit(nopython=True) -def jitrestrict(time_array, data_array, starts, ends): - n = len(time_array) - m = len(starts) - ix = np.zeros(n, dtype=np.bool_) - - k = 0 - t = 0 - - while ends[k] < time_array[t]: - k += 1 - - while k < m: - # Outside - while t < n: - if time_array[t] >= starts[k]: - # ix[t] = True - # t += 1 - break - t += 1 - - # Inside - while t < n: - if time_array[t] > ends[k]: - k += 1 - break - else: - ix[t] = True - t += 1 - - if k == m: - break - if t == n: - break - - new_time_array = time_array[ix] - new_data_array = data_array[ix] - return (new_time_array, new_data_array) - - -@jit(nopython=True) -def jittsrestrict(time_array, starts, ends): - n = len(time_array) - m = len(starts) - ix = np.zeros(n, dtype=np.bool_) - - k = 0 - t = 0 - - while ends[k] < time_array[t]: - k += 1 - - while k < m: - # Outside - while t < n: - if time_array[t] >= starts[k]: - # ix[t] = True - # t += 1 - break - t += 1 - - # Inside - while t < n: - if time_array[t] > ends[k]: - k += 1 - break - else: - ix[t] = True - t += 1 - - if k == m: - break - if t == n: - break - - new_time_array = time_array[ix] - return new_time_array - - -@jit(nopython=True) -def jitrestrict_with_count(time_array, data_array, starts, ends): - n = len(time_array) - m = len(starts) - ix = np.zeros(n, dtype=np.bool_) - count = np.zeros(m, dtype=np.int64) - - k = 0 - t = 0 - - while ends[k] < time_array[t]: - k += 1 - - while k < m: - # Outside - while t < n: - if time_array[t] >= starts[k]: - # ix[t] = True - # count[k] += 1 - # t += 1 - break - t += 1 - - # Inside - while t < n: - if time_array[t] > ends[k]: - k += 1 - break - else: - ix[t] = True - count[k] += 1 - t += 1 - - if k == m: - break - if t == n: - break - - new_time_array = time_array[ix] - new_data_array = data_array[ix] - return new_time_array, new_data_array, count - - -@jit(nopython=True) -def jittsrestrict_with_count(time_array, starts, ends): - n = len(time_array) - m = len(starts) - ix = np.zeros(n, dtype=np.bool_) - count = np.zeros(m, dtype=np.int64) - - k = 0 - t = 0 - - while ends[k] < time_array[t]: - k += 1 - - while k < m: - # Outside - while t < n: - if time_array[t] >= starts[k]: - # ix[t] = True - # count[k] += 1 - # t += 1 - break - t += 1 - - # Inside - while t < n: - if time_array[t] > ends[k]: - k += 1 - break - else: - ix[t] = True - count[k] += 1 - t += 1 - - if k == m: - break - if t == n: - break - - new_time_array = time_array[ix] - return new_time_array, count - - -@jit(nopython=True) -def jitthreshold(time_array, data_array, starts, ends, thr, method="above"): - n = time_array.shape[0] - - if method == "above": - ix = data_array > thr - elif method == "below": - ix = data_array < thr - elif method == "aboveequal": - ix = data_array >= thr - elif method == "belowequal": - ix = data_array <= thr - - k = 0 - t = 0 - - ix_start = np.zeros(n, dtype=np.bool_) - ix_end = np.zeros(n, dtype=np.bool_) - new_start = np.zeros(n, dtype=np.float64) - new_end = np.zeros(n, dtype=np.float64) - - while time_array[t] < starts[k]: - k += 1 - - if ix[t]: - ix_start[t] = 1 - new_start[t] = time_array[t] - - t += 1 - - while t < n - 1: - # transition - if time_array[t] > ends[k]: - k += 1 - if ix[t - 1]: - ix_end[t - 1] = 1 - new_end[t - 1] = time_array[t - 1] - if ix[t]: - ix_start[t] = 1 - new_start[t] = time_array[t] - - else: - if not ix[t - 1] and ix[t]: - ix_start[t] = 1 - new_start[t] = time_array[t] - (time_array[t] - time_array[t - 1]) / 2 - - if ix[t - 1] and not ix[t]: - ix_end[t] = 1 - new_end[t] = time_array[t] - (time_array[t] - time_array[t - 1]) / 2 - - t += 1 - - if ix[t] and ix[t - 1]: - ix_end[t] = 1 - new_end[t] = time_array[t] - - if ix[t] and not ix[t - 1]: - ix_start[t] = 1 - ix_end[t] = 1 - new_start[t] = time_array[t] - (time_array[t] - time_array[t - 1]) / 2 - new_end[t] = time_array[t] - - elif ix[t - 1] and not ix[t]: - ix_end[t] = 1 - new_end[t] = time_array[t] - (time_array[t] - time_array[t - 1]) / 2 - - new_time_array = time_array[ix] - new_data_array = data_array[ix] - new_starts = new_start[ix_start] - new_ends = new_end[ix_end] - - return (new_time_array, new_data_array, new_starts, new_ends) - - -@jit(nopython=True) -def jitvaluefrom(time_array, time_target_array, data_target_array, starts, ends): - time_array, count = jittsrestrict_with_count(time_array, starts, ends) - time_target_array, data_target_array, count_target = jitrestrict_with_count( - time_target_array, data_target_array, starts, ends - ) - - m = starts.shape[0] - n = time_array.shape[0] - d = time_target_array.shape[0] - - new_data_array = np.zeros(n, dtype=data_target_array.dtype) - - if n > 0 and d > 0: - for k in range(m): - if count[k] > 0 and count_target[k] > 0: - t = np.sum(count[0:k]) - i = np.sum(count_target[0:k]) - maxt = t + count[k] - maxi = i + count_target[k] - while t < maxt: - interval = abs(time_array[t] - time_target_array[i]) - new_data_array[t] = data_target_array[i] - i += 1 - while i < maxi: - new_interval = abs(time_array[t] - time_target_array[i]) - if new_interval > interval: - break - else: - new_data_array[t] = data_target_array[i] - interval = new_interval - i += 1 - i -= 1 - t += 1 - - return (time_array, new_data_array, starts, ends) - - -@jit(nopython=True) -def jitvaluefromtensor(time_array, time_target_array, data_target_array, starts, ends): - time_array, _, count = jitrestrict_with_count( - time_array, np.zeros(time_array.shape[0]), starts, ends - ) - time_target_array, data_target_array, count_target = jitrestrict_with_count( - time_target_array, data_target_array, starts, ends - ) - - m = starts.shape[0] - n = time_array.shape[0] - d = time_target_array.shape[0] - - new_data_array = np.zeros( - (n, *data_target_array.shape[1:]), dtype=data_target_array.dtype - ) - - if n > 0 and d > 0: - for k in range(m): - if count[k] > 0 and count_target[k] > 0: - t = np.sum(count[0:k]) - i = np.sum(count_target[0:k]) - maxt = t + count[k] - maxi = i + count_target[k] - while t < maxt: - interval = abs(time_array[t] - time_target_array[i]) - new_data_array[t] = data_target_array[i] - i += 1 - while i < maxi: - new_interval = abs(time_array[t] - time_target_array[i]) - if new_interval > interval: - break - else: - new_data_array[t] = data_target_array[i] - interval = new_interval - i += 1 - i -= 1 - t += 1 - - return (time_array, new_data_array, starts, ends) - - -@jit(nopython=True) -def jitcount(time_array, starts, ends, bin_size): - time_array, countin = jittsrestrict_with_count(time_array, starts, ends) - - m = starts.shape[0] - - nb_bins = np.zeros(m, dtype=np.int32) - for k in range(m): - if (ends[k] - starts[k]) > bin_size: - nb_bins[k] = int(np.ceil((ends[k] + bin_size - starts[k]) / bin_size)) - else: - nb_bins[k] = 1 - - nb = np.sum(nb_bins) - bins = np.zeros(nb, dtype=np.float64) - cnt = np.zeros(nb, dtype=np.int64) - - k = 0 - t = 0 - b = 0 - - while k < m: - maxb = b + nb_bins[k] - maxt = t + countin[k] - lbound = starts[k] - - while b < maxb: - xpos = lbound + bin_size / 2 - if xpos > ends[k]: - break - else: - bins[b] = xpos - rbound = np.round(lbound + bin_size, 9) - while t < maxt: - if time_array[t] < rbound: # similar to numpy hisrogram - cnt[b] += 1 - t += 1 - else: - break - - lbound += bin_size - b += 1 - t = maxt - k += 1 - - new_time_array = bins[0:b] - new_data_array = cnt[0:b] - - return (new_time_array, new_data_array) - - -@jit(nopython=True) -def jitbin(time_array, data_array, starts, ends, bin_size): - time_array, data_array, countin = jitrestrict_with_count( - time_array, data_array, starts, ends - ) - - m = starts.shape[0] - - nb_bins = np.zeros(m, dtype=np.int32) - for k in range(m): - if (ends[k] - starts[k]) > bin_size: - nb_bins[k] = int(np.ceil((ends[k] + bin_size - starts[k]) / bin_size)) - else: - nb_bins[k] = 1 - - nb = np.sum(nb_bins) - bins = np.zeros(nb, dtype=np.float64) - cnt = np.zeros(nb, dtype=np.float64) - average = np.zeros(nb, dtype=np.float64) - - k = 0 - t = 0 - b = 0 - - while k < m: - maxb = b + nb_bins[k] - maxt = t + countin[k] - lbound = starts[k] - - while b < maxb: - xpos = lbound + bin_size / 2 - if xpos > ends[k]: - break - else: - bins[b] = xpos - rbound = np.round(lbound + bin_size, 9) - while t < maxt: - if time_array[t] < rbound: # similar to numpy hisrogram - cnt[b] += 1.0 - average[b] += data_array[t] - t += 1 - else: - break - - lbound += bin_size - b += 1 - t = maxt - k += 1 - - new_time_array = bins[0:b] - new_data_array = average[0:b] / cnt[0:b] - - return (new_time_array, new_data_array) - - -@jit(nopython=True) -def jitbin_array(time_array, data_array, starts, ends, bin_size): - time_array, data_array, countin = jitrestrict_with_count( - time_array, data_array, starts, ends - ) - - m = starts.shape[0] - f = data_array.shape[1:] - - nb_bins = np.zeros(m, dtype=np.int32) - for k in range(m): - if (ends[k] - starts[k]) > bin_size: - nb_bins[k] = int(np.ceil((ends[k] + bin_size - starts[k]) / bin_size)) - else: - nb_bins[k] = 1 - - nb = np.sum(nb_bins) - bins = np.zeros(nb, dtype=np.float64) - cnt = np.zeros((nb, *f), dtype=np.float64) - average = np.zeros((nb, *f), dtype=np.float64) - - k = 0 - t = 0 - b = 0 - - while k < m: - maxb = b + nb_bins[k] - maxt = t + countin[k] - lbound = starts[k] - - while b < maxb: - xpos = lbound + bin_size / 2 - if xpos > ends[k]: - break - else: - bins[b] = xpos - rbound = np.round(lbound + bin_size, 9) - while t < maxt: - if time_array[t] < rbound: # similar to numpy hisrogram - cnt[b] += 1.0 - average[b] += data_array[t] - t += 1 - else: - break - - lbound += bin_size - b += 1 - t = maxt - k += 1 - - new_time_array = bins[0:b] - - new_data_array = average[0:b] / cnt[0:b] - - return (new_time_array, new_data_array) - - -@jit(nopython=True) -def jitintersect(start1, end1, start2, end2): - m = start1.shape[0] - n = start2.shape[0] - - i = 0 - j = 0 - - newstart = np.zeros(m + n, dtype=np.float64) - newend = np.zeros(m + n, dtype=np.float64) - ct = 0 - - while i < m: - while j < n: - if end2[j] > start1[i]: - break - j += 1 - - if j == n: - break - - if start2[j] < end1[i]: - newstart[ct] = max(start1[i], start2[j]) - newend[ct] = min(end1[i], end2[j]) - ct += 1 - if end2[j] < end1[i]: - j += 1 - else: - i += 1 - else: - i += 1 - - newstart = newstart[0:ct] - newend = newend[0:ct] - - return (newstart, newend) - - -@jit(nopython=True) -def jitunion(start1, end1, start2, end2): - m = start1.shape[0] - n = start2.shape[0] - - i = 0 - j = 0 - - newstart = np.zeros(m + n, dtype=np.float64) - newend = np.zeros(m + n, dtype=np.float64) - ct = 0 - - while i < m: - while j < n: - if end2[j] > start1[i]: - break - newstart[ct] = start2[j] - newend[ct] = end2[j] - ct += 1 - j += 1 - - if j == n: - break - - # overlap - if start2[j] < end1[i]: - newstart[ct] = min(start1[i], start2[j]) - - while i < m and j < n: - newend[ct] = max(end1[i], end2[j]) - - if end1[i] < end2[j]: - i += 1 - else: - j += 1 - - if i == m: - j += 1 - ct += 1 - break - - if j == n: - i += 1 - ct += 1 - break - - if end2[j] < start1[i]: - j += 1 - ct += 1 - break - elif end1[i] < start2[j]: - i += 1 - ct += 1 - break - - else: - newstart[ct] = start1[i] - newend[ct] = end1[i] - ct += 1 - i += 1 - - while i < m: - newstart[ct] = start1[i] - newend[ct] = end1[i] - ct += 1 - i += 1 - - while j < n: - newstart[ct] = start2[j] - newend[ct] = end2[j] - ct += 1 - j += 1 - - newstart = newstart[0:ct] - newend = newend[0:ct] - - return (newstart, newend) - - -@jit(nopython=True) -def jitdiff(start1, end1, start2, end2): - m = start1.shape[0] - n = start2.shape[0] - - i = 0 - j = 0 - - newstart = np.zeros(m + n, dtype=np.float64) - newend = np.zeros(m + n, dtype=np.float64) - ct = 0 - - while i < m: - while j < n: - if end2[j] > start1[i]: - break - j += 1 - - if j == n: - break - - # overlap - if start2[j] < end1[i]: - if start2[j] < start1[i] and end1[i] < end2[j]: - i += 1 - - else: - if start2[j] > start1[i]: - newstart[ct] = start1[i] - newend[ct] = start2[j] - ct += 1 - j += 1 - - else: - newstart[ct] = end2[j] - newend[ct] = end1[i] - j += 1 - - while j < n: - if start2[j] < end1[i]: - newstart[ct] = end2[j - 1] - newend[ct] = start2[j] - ct += 1 - j += 1 - else: - break - - if end2[j - 1] < end1[i]: - newstart[ct] = end2[j - 1] - newend[ct] = end1[i] - ct += 1 - else: - j -= 1 - i += 1 - - else: - newstart[ct] = start1[i] - newend[ct] = end1[i] - ct += 1 - i += 1 - - while i < m: - newstart[ct] = start1[i] - newend[ct] = end1[i] - ct += 1 - i += 1 - - newstart = newstart[0:ct] - newend = newend[0:ct] - - return (newstart, newend) - - -@jit(nopython=True) -def jitunion_isets(starts, ends): - idx = np.argsort(starts) - starts = starts[idx] - ends = ends[idx] - - n = starts.shape[0] - new_start = np.zeros(n, dtype=np.float64) - new_end = np.zeros(n, dtype=np.float64) - - ct = 0 - new_start[ct] = starts[0] - e = ends[0] - i = 1 - while i < n: - if starts[i] > e: - new_end[ct] = e - ct += 1 - new_start[ct] = starts[i] - e = ends[i] - else: - e = max(e, ends[i]) - i += 1 - - new_end[ct] = e - ct += 1 - new_start = new_start[0:ct] - new_end = new_end[0:ct] - return (new_start, new_end) - - -@jit(nopython=True) -def jitin_interval(time_array, starts, ends): - n = len(time_array) - m = len(starts) - data = np.ones(n, dtype=np.float64) * np.nan - - k = 0 - t = 0 - - while ends[k] < time_array[t]: - k += 1 - - while k < m: - # Outside - while t < n: - if time_array[t] >= starts[k]: - # data[t] = k - # t += 1 - break - # data[t] = np.nan - t += 1 - - # Inside - while t < n: - if time_array[t] > ends[k]: - k += 1 - # data[t] = np.nan - break - else: - data[t] = k - t += 1 - - if k == m: - break - if t == n: - break - - return data - - -@jit(nopython=True) -def jitremove_nan(time_array, index_nan): - n = len(time_array) - ix_start = np.zeros(n, dtype=np.bool_) - ix_end = np.zeros(n, dtype=np.bool_) - - if not index_nan[0]: # First start - ix_start[0] = True - - t = 1 - while t < n: - if index_nan[t - 1] and not index_nan[t]: # start - ix_start[t] = True - if not index_nan[t - 1] and index_nan[t]: # end - ix_end[t - 1] = True - t += 1 - - if not index_nan[-1]: # Last stop - ix_end[-1] = True - - starts = time_array[ix_start] - ends = time_array[ix_end] - return (starts, ends) - - -@jit(nopython=True) -def jitconvolve(d, a): - return np.convolve(d, a) - - -@njit(parallel=True) -def pjitconvolve(data_array, array, trim="both"): - shape = data_array.shape - t = shape[0] - k = array.shape[0] - - data_array = data_array.reshape(t, -1) - new_data_array = np.zeros(data_array.shape) - - if trim == "both": - cut = ((1 - k % 2) + (k - 1) // 2, t + k - 1 - ((k - 1) // 2)) - elif trim == "left": - cut = (k - 1, t + k - 1) - elif trim == "right": - cut = (0, t) - - for i in prange(data_array.shape[1]): - new_data_array[:, i] = jitconvolve(data_array[:, i], array)[cut[0] : cut[1]] - - new_data_array = new_data_array.reshape(shape) - - return new_data_array - - -@njit(parallel=True) -def jitcontinuous_perievent( - time_array, data_array, time_target_array, starts, ends, windowsize -): - N_samples = len(time_array) - N_target = len(time_target_array) - N_epochs = len(starts) - count = np.zeros((N_epochs, 2), dtype=np.int64) - start_t = np.zeros((N_epochs, 2), dtype=np.int64) - - k = 0 # Epochs - t = 0 # Samples - i = 0 # Target - - while ends[k] < time_array[t] and ends[k] < time_target_array[i]: - k += 1 - - while k < N_epochs: - # Outside - while t < N_samples: - if time_array[t] >= starts[k]: - break - t += 1 - - while i < N_target: - if time_target_array[i] >= starts[k]: - break - i += 1 - - if time_array[t] <= ends[k]: - start_t[k, 0] = t - - if time_target_array[i] <= ends[k]: - start_t[k, 1] = i - - # Inside - while t < N_samples: - if time_array[t] > ends[k]: - break - else: - count[k, 0] += 1 - t += 1 - - while i < N_target: - if time_target_array[i] > ends[k]: - break - else: - count[k, 1] += 1 - i += 1 - - k += 1 - - if k == N_epochs: - break - if t == N_samples: - break - if i == N_target: - break - - new_data_array = np.full( - (np.sum(windowsize) + 1, np.sum(count[:, 1]), *data_array.shape[1:]), np.nan - ) - - if np.any((count[:, 0] * count[:, 1]) > 0): - for k in prange(N_epochs): - if count[k, 0] > 0 and count[k, 1] > 0: - t = start_t[k, 0] - i = start_t[k, 1] - maxt = t + count[k, 0] - maxi = i + count[k, 1] - cnt_i = np.sum(count[0:k, 1]) - - while i < maxi: - interval = abs(time_array[t] - time_target_array[i]) - t_pos = t - t += 1 - while t < maxt: - new_interval = abs(time_array[t] - time_target_array[i]) - if new_interval > interval: - break - else: - interval = new_interval - t_pos = t - t += 1 - - left = np.minimum(windowsize[0], t_pos - start_t[k, 0]) - right = np.minimum(windowsize[1], maxt - t_pos - 1) - center = windowsize[0] + 1 - new_data_array[center - left - 1 : center + right, cnt_i] = ( - data_array[t_pos - left : t_pos + right + 1] - ) - - t -= 1 - i += 1 - cnt_i += 1 - - return new_data_array - - -@jit(nopython=True) -def jitperievent_trigger_average( - time_array, - count_array, - time_target_array, - data_target_array, - starts, - ends, - windows, - binsize, -): - T = time_array.shape[0] - N = count_array.shape[1] - N_epochs = len(starts) - - time_target_array, data_target_array, count = jitrestrict_with_count( - time_target_array, data_target_array, starts, ends - ) - max_count = np.cumsum(count) - - new_data_array = np.full( - (int(windows.sum()) + 1, count_array.shape[1], *data_target_array.shape[1:]), - 0.0, - ) - - t = 0 # count events - - hankel_array = np.zeros((new_data_array.shape[0], *data_target_array.shape[1:])) - - for k in range(N_epochs): - if count[k] > 0: - t_start = t - maxi = max_count[k] - i = maxi - count[k] - - while t < T: - lbound = time_array[t] - rbound = np.round(lbound + binsize, 9) - - if time_target_array[i] < rbound: - i_start = i - i_stop = i - - while i_stop < maxi: - if time_target_array[i_stop] < rbound: - i_stop += 1 - else: - break - - while i_start < i_stop - 1: - if time_target_array[i_start] < lbound: - i_start += 1 - else: - break - v = np.sum(data_target_array[i_start:i_stop], 0) / float( - i_stop - i_start - ) - - checknan = np.sum(v) - if not np.isnan(checknan): - hankel_array[-1] = v - - if t - t_start >= windows[1]: - for n in range(N): - new_data_array[:, n] += ( - hankel_array * count_array[t - windows[1], n] - ) - - # hankel_array = np.roll(hankel_array, -1, axis=0) - hankel_array[0:-1] = hankel_array[1:] - hankel_array[-1] = 0.0 - - t += 1 - - i = i_start - - if t == T or time_array[t] > ends[k]: - if t - t_start > windows[1]: - for j in range(windows[1]): - for n in range(N): - new_data_array[:, n] += ( - hankel_array * count_array[t - windows[1] + j, n] - ) - - # hankel_array = np.roll(hankel_array, -1, axis=0) - hankel_array[0:-1] = hankel_array[1:] - hankel_array[-1] = 0.0 - - hankel_array *= 0.0 - break - - total = np.sum(count_array, 0) - for n in range(N): - if total[n] > 0.0: - new_data_array[:, n] /= total[n] - - return new_data_array - - -# @jit(nopython=True) -# def jit_poisson_IRLS(X, y, niter=100, tolerance=1e-5): -# y = y.astype(np.float64) -# X = X.astype(np.float64) -# n, d = X.shape -# W = np.ones(n) -# iXtWX = np.linalg.inv(np.dot(X.T * W, X)) -# XtWY = np.dot(X.T * W, y) -# B = np.dot(iXtWX, XtWY) - -# for _ in range(niter): -# B_ = B -# L = np.exp(X.dot(B)) # Link function -# Z = L.reshape((-1, 1)) * X # partial derivatives -# delta = np.dot(np.linalg.inv(np.dot(Z.T * W, Z)), np.dot(Z.T * W, y)) -# B = B + delta -# tol = np.sum(np.abs((B - B_) / B_)) -# if tol < tolerance: -# return B -# return B - - -# @jit(nopython=True) -# def jitfind_gaps(time_array, starts, ends, min_gap): -# """ -# Jitted version of find_gap - -# Parameters -# ---------- -# time_array : numpy.ndarray -# Description -# data_array : numpy.ndarray -# Description -# starts : numpy.ndarray -# Description -# ends : numpy.ndarray -# Description - -# Returns -# ------- -# TYPE -# Description -# """ -# n = len(time_array) -# m = len(starts) - -# new_start = np.zeros(n+m, dtype=np.float64) -# new_end = np.zeros(n+m, dtype=np.float64) - -# k = 0 -# t = 0 -# i = 0 - -# while k ends[k]: -# break - -# if (time_array[t] - start) > min_gap: -# new_start[i] = start+1e-6 -# new_end[i] = time_array[t]-1e-6 -# start = time_array[t] -# t += 1 -# i += 1 - -# else: -# start = time_array[t] -# t += 1 - - -# k += 1 - - -# new_start = new_start[0:i] -# new_end = new_end[0:i] - -# return new_start, new_end diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index fc4ccad0..e7602cb2 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-01-27 18:33:31 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-01-08 16:09:01 +# @Last Modified time: 2024-01-29 12:16:45 """ @@ -36,7 +36,7 @@ from tabulate import tabulate from .interval_set import IntervalSet -from .jitted_functions import ( +from ._jitted_functions import ( jitbin, jitbin_array, jitcount, diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 7b9427cc..5e6dc76b 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-01-28 15:10:48 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-01-29 11:23:07 +# @Last Modified time: 2024-01-29 12:16:24 import os @@ -14,7 +14,7 @@ from tabulate import tabulate from .interval_set import IntervalSet -from .jitted_functions import ( +from ._jitted_functions import ( jitcount, jittsrestrict_with_count, jitunion, diff --git a/pynapple/process/perievent.py b/pynapple/process/perievent.py index ac5e5b1e..d418512b 100644 --- a/pynapple/process/perievent.py +++ b/pynapple/process/perievent.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-01-30 22:59:00 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-01-26 15:52:19 +# @Last Modified time: 2024-01-29 12:47:56 import numpy as np @@ -179,7 +179,7 @@ def compute_perievent_continuous(data, tref, minmax, ep=None, time_unit="s"): time_idx = np.hstack((idx1, np.zeros(1), idx2)) windowsize = np.array([idx1.shape[0], idx2.shape[0]]) - new_data_array = nap.jitted_functions.jitcontinuous_perievent( + new_data_array = nap._jitted_functions.jitcontinuous_perievent( time_array, data_array, time_target_array, starts, ends, windowsize ) @@ -289,7 +289,7 @@ def compute_event_trigger_average( data_target_array = feature.values if data_target_array.ndim == 1: - eta = nap.jitted_functions.jitperievent_trigger_average( + eta = nap._jitted_functions.jitperievent_trigger_average( time_array, count_array, time_target_array, @@ -301,7 +301,7 @@ def compute_event_trigger_average( ) eta = np.squeeze(eta, -1) else: - eta = nap.jitted_functions.jitperievent_trigger_average( + eta = nap._jitted_functions.jitperievent_trigger_average( time_array, count_array, time_target_array, diff --git a/pyproject.toml b/pyproject.toml index 7cde031d..8c370cda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,5 +105,5 @@ omit = [ "pynapple/io/neurosuite.py", "pynapple/io/phy.py", "pynapple/io/suite2p.py", - "pynapple/core/jitted_functions.py" # coverage doesn't work with numba + "pynapple/core/_jitted_functions.py" # coverage doesn't work with numba ] \ No newline at end of file diff --git a/tests/test_jitted.py b/tests/test_jitted.py index 8f1bd91f..8370637b 100644 --- a/tests/test_jitted.py +++ b/tests/test_jitted.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-12-02 17:17:03 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-09-18 13:59:43 +# @Last Modified time: 2024-01-29 12:37:44 """Tests of jitted core functions for `pynapple` package.""" @@ -75,7 +75,7 @@ def test_jitrestrict(): ep, ts, tsd, tsdframe = get_example_dataset() tsd2 = restrict(ep, tsd) - t, d= nap.core.jitted_functions.jitrestrict(tsd.index, tsd.values, ep['start'].values, ep['end'].values) + t, d= nap.core._jitted_functions.jitrestrict(tsd.index, tsd.values, ep['start'].values, ep['end'].values) tsd3 = pd.Series(index=t, data=d) pd.testing.assert_series_equal(tsd2, tsd3) @@ -84,7 +84,7 @@ def test_jittsrestrict(): ep, ts, tsd, tsdframe = get_example_dataset() ts2 = restrict(ep, ts) - t = nap.core.jitted_functions.jittsrestrict(ts.index, ep['start'].values, ep['end'].values) + t = nap.core._jitted_functions.jittsrestrict(ts.index, ep['start'].values, ep['end'].values) ts3 = pd.Series(index=t, dtype="object") pd.testing.assert_series_equal(ts2, ts3) @@ -93,7 +93,7 @@ def test_jitrestrict_with_count(): ep, ts, tsd, tsdframe = get_example_dataset() tsd2 = restrict(ep, tsd) - t, d, count = nap.core.jitted_functions.jitrestrict_with_count(tsd.index, tsd.values, ep['start'].values, ep['end'].values) + t, d, count = nap.core._jitted_functions.jitrestrict_with_count(tsd.index, tsd.values, ep['start'].values, ep['end'].values) tsd3 = pd.Series(index=t, data=d) pd.testing.assert_series_equal(tsd2, tsd3) @@ -114,7 +114,7 @@ def test_jittsrestrict_with_count(): ep, ts, tsd, tsdframe = get_example_dataset() ts2 = restrict(ep, ts) - t, count = nap.core.jitted_functions.jittsrestrict_with_count(ts.index, ep['start'].values, ep['end'].values) + t, count = nap.core._jitted_functions.jittsrestrict_with_count(ts.index, ep['start'].values, ep['end'].values) ts3 = pd.Series(index=t, dtype="object") pd.testing.assert_series_equal(ts2, ts3) @@ -136,25 +136,25 @@ def test_jitthreshold(): thr = np.random.rand() - t, d, s, e = nap.core.jitted_functions.jitthreshold(tsd.index, tsd.values, ep['start'].values, ep['end'].values, thr) + t, d, s, e = nap.core._jitted_functions.jitthreshold(tsd.index, tsd.values, ep['start'].values, ep['end'].values, thr) assert len(t) == np.sum(tsd.values > thr) assert len(d) == np.sum(tsd.values > thr) np.testing.assert_array_equal(d, tsd.values[tsd.values > thr]) - t, d, s, e = nap.core.jitted_functions.jitthreshold(tsd.index, tsd.values, ep['start'].values, ep['end'].values, thr, "below") + t, d, s, e = nap.core._jitted_functions.jitthreshold(tsd.index, tsd.values, ep['start'].values, ep['end'].values, thr, "below") assert len(t) == np.sum(tsd.values < thr) assert len(d) == np.sum(tsd.values < thr) np.testing.assert_array_equal(d, tsd.values[tsd.values < thr]) - t, d, s, e = nap.core.jitted_functions.jitthreshold(tsd.index, tsd.values, ep['start'].values, ep['end'].values, thr, "aboveequal") + t, d, s, e = nap.core._jitted_functions.jitthreshold(tsd.index, tsd.values, ep['start'].values, ep['end'].values, thr, "aboveequal") assert len(t) == np.sum(tsd.values >= thr) assert len(d) == np.sum(tsd.values >= thr) np.testing.assert_array_equal(d, tsd.values[tsd.values >= thr]) - t, d, s, e = nap.core.jitted_functions.jitthreshold(tsd.index, tsd.values, ep['start'].values, ep['end'].values, thr, "belowequal") + t, d, s, e = nap.core._jitted_functions.jitthreshold(tsd.index, tsd.values, ep['start'].values, ep['end'].values, thr, "belowequal") assert len(t) == np.sum(tsd.values <= thr) assert len(d) == np.sum(tsd.values <= thr) @@ -170,7 +170,7 @@ def test_jitvalue_from(): for i in range(10): ep, ts, tsd, tsdframe = get_example_dataset() - t, d, s, e = nap.core.jitted_functions.jitvaluefrom(ts.index, tsd.index, tsd.values, ep['start'].values, ep['end'].values) + t, d, s, e = nap.core._jitted_functions.jitvaluefrom(ts.index, tsd.index, tsd.values, ep['start'].values, ep['end'].values) tsd3 = pd.Series(index=t, data=d) tsd2 = [] @@ -191,7 +191,7 @@ def test_jitcount(): starts = ep['start'].values ends = ep['end'].values bin_size = 1.0 - t, d = nap.core.jitted_functions.jitcount(time_array, starts, ends, bin_size) + t, d = nap.core._jitted_functions.jitcount(time_array, starts, ends, bin_size) tsd3 = nap.Tsd(t=t, d=d, time_support = ep) tsd2 = [] @@ -219,7 +219,7 @@ def test_jitbin(): starts = ep['start'].values ends = ep['end'].values bin_size = 1.0 - t, d = nap.core.jitted_functions.jitbin(time_array, data_array, starts, ends, bin_size) + t, d = nap.core._jitted_functions.jitbin(time_array, data_array, starts, ends, bin_size) # tsd3 = nap.Tsd(t=t, d=d, time_support = ep) tsd3 = pd.Series(index=t, data=d) tsd3 = tsd3.fillna(0.0) @@ -256,7 +256,7 @@ def test_jitbin_array(): starts = ep['start'].values ends = ep['end'].values bin_size = 1.0 - t, d = nap.core.jitted_functions.jitbin_array(time_array, data_array, starts, ends, bin_size) + t, d = nap.core._jitted_functions.jitbin_array(time_array, data_array, starts, ends, bin_size) tsd3 = pd.DataFrame(index=t, data=d) tsd3 = tsd3.fillna(0.0) # tsd3 = nap.TsdFrame(tsd3, time_support = ep) @@ -288,7 +288,7 @@ def test_jitintersect(): for i in range(10): ep1, ep2 = get_example_isets() - s, e = nap.core.jitted_functions.jitintersect(ep1.start.values, ep1.end.values, ep2.start.values, ep2.end.values) + s, e = nap.core._jitted_functions.jitintersect(ep1.start.values, ep1.end.values, ep2.start.values, ep2.end.values) ep3 = nap.IntervalSet(s, e) @@ -320,7 +320,7 @@ def test_jitunion(): for i in range(10): ep1, ep2 = get_example_isets() - s, e = nap.core.jitted_functions.jitunion(ep1.start.values, ep1.end.values, ep2.start.values, ep2.end.values) + s, e = nap.core._jitted_functions.jitunion(ep1.start.values, ep1.end.values, ep2.start.values, ep2.end.values) ep3 = nap.IntervalSet(s, e) @@ -353,7 +353,7 @@ def test_jitdiff(): for i in range(10): ep1, ep2 = get_example_isets() - s, e = nap.core.jitted_functions.jitdiff(ep1.start.values, ep1.end.values, ep2.start.values, ep2.end.values) + s, e = nap.core._jitted_functions.jitdiff(ep1.start.values, ep1.end.values, ep2.start.values, ep2.end.values) ep3 = nap.IntervalSet(s, e) i_sets = (ep1, ep2) @@ -426,7 +426,7 @@ def test_jitin_interval(): for i in range(10): ep, ts, tsd, tsdframe = get_example_dataset() - inep = nap.core.jitted_functions.jitin_interval(tsd.index, ep['start'].values, ep['end'].values) + inep = nap.core._jitted_functions.jitin_interval(tsd.index, ep['start'].values, ep['end'].values) inep[np.isnan(inep)] = -1 bins = ep.values.ravel() From d09216c17b04693fc3330f86a974fb2a348e735d Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Mon, 29 Jan 2024 14:20:26 -0500 Subject: [PATCH 4/7] Fixing black --- pynapple/core/time_series.py | 2 +- pynapple/core/ts_group.py | 2 +- pynapple/process/tuning_curves.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index e7602cb2..ca9614ea 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -35,7 +35,6 @@ from scipy import signal from tabulate import tabulate -from .interval_set import IntervalSet from ._jitted_functions import ( jitbin, jitbin_array, @@ -49,6 +48,7 @@ jitvaluefromtensor, pjitconvolve, ) +from .interval_set import IntervalSet from .time_index import TsIndex diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 5e6dc76b..0a3a930a 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -13,13 +13,13 @@ import pandas as pd from tabulate import tabulate -from .interval_set import IntervalSet from ._jitted_functions import ( jitcount, jittsrestrict_with_count, jitunion, jitunion_isets, ) +from .interval_set import IntervalSet # from .time_units import format_timestamps from .time_index import TsIndex diff --git a/pynapple/process/tuning_curves.py b/pynapple/process/tuning_curves.py index 4e1647d2..edbde3cc 100644 --- a/pynapple/process/tuning_curves.py +++ b/pynapple/process/tuning_curves.py @@ -271,7 +271,7 @@ def compute_1d_mutual_info(tc, feature, ep=None, minmax=None, bitssec=False): if isinstance(feature, nap.TsdFrame): assert ( feature.shape[1] == 1 - ), "feature should be a Tsd (or TsdFrame with 1 column only)" + ), "feature should be a Tsd (or TsdFrame with 1 column only)" nb_bins = tc.shape[0] + 1 if minmax is None: @@ -344,7 +344,7 @@ def compute_2d_mutual_info(tc, features, ep=None, minmax=None, bitssec=False): features, nap.TsdFrame ), "features should be a TsdFrame with 2 columns" if isinstance(features, nap.TsdFrame): - assert features.shape[1] == 2, "features should have 2 columns only." + assert features.shape[1] == 2, "features should have 2 columns only." nb_bins = (fx.shape[1] + 1, fx.shape[2] + 1) @@ -433,7 +433,7 @@ def compute_1d_tuning_curves_continuous( if isinstance(feature, nap.TsdFrame): assert ( feature.shape[1] == 1 - ), "feature should be a Tsd (or TsdFrame with 1 column only)" + ), "feature should be a Tsd (or TsdFrame with 1 column only)" feature = np.squeeze(feature) if isinstance(ep, nap.IntervalSet): From fa12815c2a44abbab346b1c5c6902026e16e59bb Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Mon, 29 Jan 2024 15:02:55 -0500 Subject: [PATCH 5/7] Bumping to 0.5.1 --- docs/HISTORY.md | 7 +++++++ pynapple/__init__.py | 2 +- pynapple/core/__init__.py | 7 ++++--- pynapple/core/time_series.py | 2 +- pynapple/process/__init__.py | 31 ++++++++++++++++++++++++++----- pyproject.toml | 2 +- setup.py | 4 ++-- tests/test_jitted.py | 4 ++-- 8 files changed, 44 insertions(+), 15 deletions(-) diff --git a/docs/HISTORY.md b/docs/HISTORY.md index 9e5961d8..97d05bfd 100644 --- a/docs/HISTORY.md +++ b/docs/HISTORY.md @@ -9,6 +9,13 @@ In 2018, Francesco started neuroseries, a Python package built on Pandas. It was In 2021, Guillaume and other trainees in Adrien's lab decided to fork from neuroseries and started *pynapple*. The core of pynapple is largely built upon neuroseries. Some of the original changes to TSToolbox made by Luke were included in this package, especially the *time_support* property of all ts/tsd objects. +0.5.1 (2024-01-29) +------------------ + +- Implementing `event_trigger_average` for all dimensions. +- Hiding jitted functions from users. + + 0.5.0 (2023-12-12) ------------------ diff --git a/pynapple/__init__.py b/pynapple/__init__.py index e26a04cd..0c57730a 100644 --- a/pynapple/__init__.py +++ b/pynapple/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.5.0" +__version__ = "0.5.1" from .core import * from .io import * from .process import * diff --git a/pynapple/core/__init__.py b/pynapple/core/__init__.py index b44f0ae3..0fe00c5e 100644 --- a/pynapple/core/__init__.py +++ b/pynapple/core/__init__.py @@ -1,3 +1,4 @@ -from .interval_set import * -from .time_series import * -from .ts_group import * +from .interval_set import IntervalSet +from .time_index import TsIndex +from .time_series import Ts, Tsd, TsdFrame, TsdTensor +from .ts_group import TsGroup diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index ca9614ea..3c0b2e5d 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-01-27 18:33:31 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-01-29 12:16:45 +# @Last Modified time: 2024-01-29 14:36:05 """ diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index c92798a6..2e1af412 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -1,5 +1,26 @@ -from .correlograms import * -from .decoding import * -from .perievent import * -from .randomize import * -from .tuning_curves import * +from .correlograms import ( + compute_autocorrelogram, + compute_crosscorrelogram, + compute_eventcorrelogram, +) +from .decoding import decode_1d, decode_2d +from .perievent import ( + compute_event_trigger_average, + compute_perievent, + compute_perievent_continuous, +) +from .randomize import ( + jitter_timestamps, + resample_timestamps, + shift_timestamps, + shuffle_ts_intervals, +) +from .tuning_curves import ( + compute_1d_mutual_info, + compute_1d_tuning_curves, + compute_1d_tuning_curves_continuous, + compute_2d_mutual_info, + compute_2d_tuning_curves, + compute_2d_tuning_curves_continuous, + compute_discrete_tuning_curves, +) diff --git a/pyproject.toml b/pyproject.toml index 8c370cda..4e0e9bc5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pynapple" -version = "0.5.0" +version = "0.5.1" description = "PYthon Neural Analysis Package Pour Laboratoires d’Excellence" readme = "README.md" authors = [{ name = "Guillaume Viejo", email = "guillaume.viejo@gmail.com" }] diff --git a/setup.py b/setup.py index 5489cd00..700b4858 100644 --- a/setup.py +++ b/setup.py @@ -59,8 +59,8 @@ test_suite='tests', tests_require=test_requirements, url='https://github.com/pynapple-org/pynapple', - version='v0.5.0', + version='v0.5.1', zip_safe=False, long_description_content_type='text/markdown', - download_url='https://github.com/pynapple-org/pynapple/archive/refs/tags/v0.4.1.tar.gz' + download_url='https://github.com/pynapple-org/pynapple/archive/refs/tags/v0.5.1.tar.gz' ) diff --git a/tests/test_jitted.py b/tests/test_jitted.py index 8370637b..e06b76e7 100644 --- a/tests/test_jitted.py +++ b/tests/test_jitted.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-12-02 17:17:03 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-01-29 12:37:44 +# @Last Modified time: 2024-01-29 14:48:58 """Tests of jitted core functions for `pynapple` package.""" @@ -395,7 +395,7 @@ def test_jitunion_isets(): i_sets = [ep1, ep2, ep3, ep4] - ep6 = nap.core.union_intervals(i_sets) + ep6 = nap.core.ts_group.union_intervals(i_sets) time = np.hstack( From ed6ee818023f27362f698258d4a78feeb0ea3d00 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Mon, 29 Jan 2024 15:09:19 -0500 Subject: [PATCH 6/7] Fix correlogram --- tests/test_correlograms.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_correlograms.py b/tests/test_correlograms.py index 5c37e3b0..b8294ac2 100644 --- a/tests/test_correlograms.py +++ b/tests/test_correlograms.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-03-30 11:16:22 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-08-07 18:55:13 +# @Last Modified time: 2024-01-29 15:06:58 #!/usr/bin/env python """Tests of correlograms for `pynapple` package.""" @@ -17,19 +17,19 @@ def test_cross_correlogram(): t1 = np.array([0]) t2 = np.array([1]) - cc, bincenter = nap.cross_correlogram(t1, t2, 1, 100) + cc, bincenter = nap.process.correlograms.cross_correlogram(t1, t2, 1, 100) np.testing.assert_approx_equal(cc[101], 1.0) - cc, bincenter = nap.cross_correlogram(t2, t1, 1, 100) + cc, bincenter = nap.process.correlograms.cross_correlogram(t2, t1, 1, 100) np.testing.assert_approx_equal(cc[99], 1.0) t1 = np.array([0]) t2 = np.array([100]) - cc, bincenter = nap.cross_correlogram(t1, t2, 1, 100) + cc, bincenter = nap.process.correlograms.cross_correlogram(t1, t2, 1, 100) np.testing.assert_approx_equal(cc[200], 1.0) t1 = np.array([0, 10]) - cc, bincenter = nap.cross_correlogram(t1, t1, 1, 100) + cc, bincenter = nap.process.correlograms.cross_correlogram(t1, t1, 1, 100) np.testing.assert_approx_equal(cc[100], 1.0) np.testing.assert_approx_equal(cc[90], 0.5) np.testing.assert_approx_equal(cc[110], 0.5) @@ -38,7 +38,7 @@ def test_cross_correlogram(): for t in [100, 200, 1000]: np.testing.assert_array_almost_equal( - nap.cross_correlogram(np.arange(0, t), np.arange(0, t), 1, t)[0], + nap.process.correlograms.cross_correlogram(np.arange(0, t), np.arange(0, t), 1, t)[0], np.hstack( (np.arange(0, 1, 1 / t), np.ones(1), np.arange(0, 1, 1 / t)[::-1]) ), From 8ddff7de1e0152104b67573aff056573cca2b65d Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Mon, 29 Jan 2024 15:20:21 -0500 Subject: [PATCH 7/7] Adding _jitted_functions --- pynapple/core/_jitted_functions.py | 1084 ++++++++++++++++++++++++++++ 1 file changed, 1084 insertions(+) create mode 100644 pynapple/core/_jitted_functions.py diff --git a/pynapple/core/_jitted_functions.py b/pynapple/core/_jitted_functions.py new file mode 100644 index 00000000..8679adfa --- /dev/null +++ b/pynapple/core/_jitted_functions.py @@ -0,0 +1,1084 @@ +# -*- coding: utf-8 -*- +# @Author: guillaume +# @Date: 2022-10-31 16:44:31 +# @Last Modified by: Guillaume Viejo +# @Last Modified time: 2024-01-25 16:43:34 +import numpy as np +from numba import jit, njit, prange + + +@jit(nopython=True) +def jitrestrict(time_array, data_array, starts, ends): + n = len(time_array) + m = len(starts) + ix = np.zeros(n, dtype=np.bool_) + + k = 0 + t = 0 + + while ends[k] < time_array[t]: + k += 1 + + while k < m: + # Outside + while t < n: + if time_array[t] >= starts[k]: + # ix[t] = True + # t += 1 + break + t += 1 + + # Inside + while t < n: + if time_array[t] > ends[k]: + k += 1 + break + else: + ix[t] = True + t += 1 + + if k == m: + break + if t == n: + break + + new_time_array = time_array[ix] + new_data_array = data_array[ix] + return (new_time_array, new_data_array) + + +@jit(nopython=True) +def jittsrestrict(time_array, starts, ends): + n = len(time_array) + m = len(starts) + ix = np.zeros(n, dtype=np.bool_) + + k = 0 + t = 0 + + while ends[k] < time_array[t]: + k += 1 + + while k < m: + # Outside + while t < n: + if time_array[t] >= starts[k]: + # ix[t] = True + # t += 1 + break + t += 1 + + # Inside + while t < n: + if time_array[t] > ends[k]: + k += 1 + break + else: + ix[t] = True + t += 1 + + if k == m: + break + if t == n: + break + + new_time_array = time_array[ix] + return new_time_array + + +@jit(nopython=True) +def jitrestrict_with_count(time_array, data_array, starts, ends): + n = len(time_array) + m = len(starts) + ix = np.zeros(n, dtype=np.bool_) + count = np.zeros(m, dtype=np.int64) + + k = 0 + t = 0 + + while ends[k] < time_array[t]: + k += 1 + + while k < m: + # Outside + while t < n: + if time_array[t] >= starts[k]: + # ix[t] = True + # count[k] += 1 + # t += 1 + break + t += 1 + + # Inside + while t < n: + if time_array[t] > ends[k]: + k += 1 + break + else: + ix[t] = True + count[k] += 1 + t += 1 + + if k == m: + break + if t == n: + break + + new_time_array = time_array[ix] + new_data_array = data_array[ix] + return new_time_array, new_data_array, count + + +@jit(nopython=True) +def jittsrestrict_with_count(time_array, starts, ends): + n = len(time_array) + m = len(starts) + ix = np.zeros(n, dtype=np.bool_) + count = np.zeros(m, dtype=np.int64) + + k = 0 + t = 0 + + while ends[k] < time_array[t]: + k += 1 + + while k < m: + # Outside + while t < n: + if time_array[t] >= starts[k]: + # ix[t] = True + # count[k] += 1 + # t += 1 + break + t += 1 + + # Inside + while t < n: + if time_array[t] > ends[k]: + k += 1 + break + else: + ix[t] = True + count[k] += 1 + t += 1 + + if k == m: + break + if t == n: + break + + new_time_array = time_array[ix] + return new_time_array, count + + +@jit(nopython=True) +def jitthreshold(time_array, data_array, starts, ends, thr, method="above"): + n = time_array.shape[0] + + if method == "above": + ix = data_array > thr + elif method == "below": + ix = data_array < thr + elif method == "aboveequal": + ix = data_array >= thr + elif method == "belowequal": + ix = data_array <= thr + + k = 0 + t = 0 + + ix_start = np.zeros(n, dtype=np.bool_) + ix_end = np.zeros(n, dtype=np.bool_) + new_start = np.zeros(n, dtype=np.float64) + new_end = np.zeros(n, dtype=np.float64) + + while time_array[t] < starts[k]: + k += 1 + + if ix[t]: + ix_start[t] = 1 + new_start[t] = time_array[t] + + t += 1 + + while t < n - 1: + # transition + if time_array[t] > ends[k]: + k += 1 + if ix[t - 1]: + ix_end[t - 1] = 1 + new_end[t - 1] = time_array[t - 1] + if ix[t]: + ix_start[t] = 1 + new_start[t] = time_array[t] + + else: + if not ix[t - 1] and ix[t]: + ix_start[t] = 1 + new_start[t] = time_array[t] - (time_array[t] - time_array[t - 1]) / 2 + + if ix[t - 1] and not ix[t]: + ix_end[t] = 1 + new_end[t] = time_array[t] - (time_array[t] - time_array[t - 1]) / 2 + + t += 1 + + if ix[t] and ix[t - 1]: + ix_end[t] = 1 + new_end[t] = time_array[t] + + if ix[t] and not ix[t - 1]: + ix_start[t] = 1 + ix_end[t] = 1 + new_start[t] = time_array[t] - (time_array[t] - time_array[t - 1]) / 2 + new_end[t] = time_array[t] + + elif ix[t - 1] and not ix[t]: + ix_end[t] = 1 + new_end[t] = time_array[t] - (time_array[t] - time_array[t - 1]) / 2 + + new_time_array = time_array[ix] + new_data_array = data_array[ix] + new_starts = new_start[ix_start] + new_ends = new_end[ix_end] + + return (new_time_array, new_data_array, new_starts, new_ends) + + +@jit(nopython=True) +def jitvaluefrom(time_array, time_target_array, data_target_array, starts, ends): + time_array, count = jittsrestrict_with_count(time_array, starts, ends) + time_target_array, data_target_array, count_target = jitrestrict_with_count( + time_target_array, data_target_array, starts, ends + ) + + m = starts.shape[0] + n = time_array.shape[0] + d = time_target_array.shape[0] + + new_data_array = np.zeros(n, dtype=data_target_array.dtype) + + if n > 0 and d > 0: + for k in range(m): + if count[k] > 0 and count_target[k] > 0: + t = np.sum(count[0:k]) + i = np.sum(count_target[0:k]) + maxt = t + count[k] + maxi = i + count_target[k] + while t < maxt: + interval = abs(time_array[t] - time_target_array[i]) + new_data_array[t] = data_target_array[i] + i += 1 + while i < maxi: + new_interval = abs(time_array[t] - time_target_array[i]) + if new_interval > interval: + break + else: + new_data_array[t] = data_target_array[i] + interval = new_interval + i += 1 + i -= 1 + t += 1 + + return (time_array, new_data_array, starts, ends) + + +@jit(nopython=True) +def jitvaluefromtensor(time_array, time_target_array, data_target_array, starts, ends): + time_array, _, count = jitrestrict_with_count( + time_array, np.zeros(time_array.shape[0]), starts, ends + ) + time_target_array, data_target_array, count_target = jitrestrict_with_count( + time_target_array, data_target_array, starts, ends + ) + + m = starts.shape[0] + n = time_array.shape[0] + d = time_target_array.shape[0] + + new_data_array = np.zeros( + (n, *data_target_array.shape[1:]), dtype=data_target_array.dtype + ) + + if n > 0 and d > 0: + for k in range(m): + if count[k] > 0 and count_target[k] > 0: + t = np.sum(count[0:k]) + i = np.sum(count_target[0:k]) + maxt = t + count[k] + maxi = i + count_target[k] + while t < maxt: + interval = abs(time_array[t] - time_target_array[i]) + new_data_array[t] = data_target_array[i] + i += 1 + while i < maxi: + new_interval = abs(time_array[t] - time_target_array[i]) + if new_interval > interval: + break + else: + new_data_array[t] = data_target_array[i] + interval = new_interval + i += 1 + i -= 1 + t += 1 + + return (time_array, new_data_array, starts, ends) + + +@jit(nopython=True) +def jitcount(time_array, starts, ends, bin_size): + time_array, countin = jittsrestrict_with_count(time_array, starts, ends) + + m = starts.shape[0] + + nb_bins = np.zeros(m, dtype=np.int32) + for k in range(m): + if (ends[k] - starts[k]) > bin_size: + nb_bins[k] = int(np.ceil((ends[k] + bin_size - starts[k]) / bin_size)) + else: + nb_bins[k] = 1 + + nb = np.sum(nb_bins) + bins = np.zeros(nb, dtype=np.float64) + cnt = np.zeros(nb, dtype=np.int64) + + k = 0 + t = 0 + b = 0 + + while k < m: + maxb = b + nb_bins[k] + maxt = t + countin[k] + lbound = starts[k] + + while b < maxb: + xpos = lbound + bin_size / 2 + if xpos > ends[k]: + break + else: + bins[b] = xpos + rbound = np.round(lbound + bin_size, 9) + while t < maxt: + if time_array[t] < rbound: # similar to numpy hisrogram + cnt[b] += 1 + t += 1 + else: + break + + lbound += bin_size + b += 1 + t = maxt + k += 1 + + new_time_array = bins[0:b] + new_data_array = cnt[0:b] + + return (new_time_array, new_data_array) + + +@jit(nopython=True) +def jitbin(time_array, data_array, starts, ends, bin_size): + time_array, data_array, countin = jitrestrict_with_count( + time_array, data_array, starts, ends + ) + + m = starts.shape[0] + + nb_bins = np.zeros(m, dtype=np.int32) + for k in range(m): + if (ends[k] - starts[k]) > bin_size: + nb_bins[k] = int(np.ceil((ends[k] + bin_size - starts[k]) / bin_size)) + else: + nb_bins[k] = 1 + + nb = np.sum(nb_bins) + bins = np.zeros(nb, dtype=np.float64) + cnt = np.zeros(nb, dtype=np.float64) + average = np.zeros(nb, dtype=np.float64) + + k = 0 + t = 0 + b = 0 + + while k < m: + maxb = b + nb_bins[k] + maxt = t + countin[k] + lbound = starts[k] + + while b < maxb: + xpos = lbound + bin_size / 2 + if xpos > ends[k]: + break + else: + bins[b] = xpos + rbound = np.round(lbound + bin_size, 9) + while t < maxt: + if time_array[t] < rbound: # similar to numpy hisrogram + cnt[b] += 1.0 + average[b] += data_array[t] + t += 1 + else: + break + + lbound += bin_size + b += 1 + t = maxt + k += 1 + + new_time_array = bins[0:b] + new_data_array = average[0:b] / cnt[0:b] + + return (new_time_array, new_data_array) + + +@jit(nopython=True) +def jitbin_array(time_array, data_array, starts, ends, bin_size): + time_array, data_array, countin = jitrestrict_with_count( + time_array, data_array, starts, ends + ) + + m = starts.shape[0] + f = data_array.shape[1:] + + nb_bins = np.zeros(m, dtype=np.int32) + for k in range(m): + if (ends[k] - starts[k]) > bin_size: + nb_bins[k] = int(np.ceil((ends[k] + bin_size - starts[k]) / bin_size)) + else: + nb_bins[k] = 1 + + nb = np.sum(nb_bins) + bins = np.zeros(nb, dtype=np.float64) + cnt = np.zeros((nb, *f), dtype=np.float64) + average = np.zeros((nb, *f), dtype=np.float64) + + k = 0 + t = 0 + b = 0 + + while k < m: + maxb = b + nb_bins[k] + maxt = t + countin[k] + lbound = starts[k] + + while b < maxb: + xpos = lbound + bin_size / 2 + if xpos > ends[k]: + break + else: + bins[b] = xpos + rbound = np.round(lbound + bin_size, 9) + while t < maxt: + if time_array[t] < rbound: # similar to numpy hisrogram + cnt[b] += 1.0 + average[b] += data_array[t] + t += 1 + else: + break + + lbound += bin_size + b += 1 + t = maxt + k += 1 + + new_time_array = bins[0:b] + + new_data_array = average[0:b] / cnt[0:b] + + return (new_time_array, new_data_array) + + +@jit(nopython=True) +def jitintersect(start1, end1, start2, end2): + m = start1.shape[0] + n = start2.shape[0] + + i = 0 + j = 0 + + newstart = np.zeros(m + n, dtype=np.float64) + newend = np.zeros(m + n, dtype=np.float64) + ct = 0 + + while i < m: + while j < n: + if end2[j] > start1[i]: + break + j += 1 + + if j == n: + break + + if start2[j] < end1[i]: + newstart[ct] = max(start1[i], start2[j]) + newend[ct] = min(end1[i], end2[j]) + ct += 1 + if end2[j] < end1[i]: + j += 1 + else: + i += 1 + else: + i += 1 + + newstart = newstart[0:ct] + newend = newend[0:ct] + + return (newstart, newend) + + +@jit(nopython=True) +def jitunion(start1, end1, start2, end2): + m = start1.shape[0] + n = start2.shape[0] + + i = 0 + j = 0 + + newstart = np.zeros(m + n, dtype=np.float64) + newend = np.zeros(m + n, dtype=np.float64) + ct = 0 + + while i < m: + while j < n: + if end2[j] > start1[i]: + break + newstart[ct] = start2[j] + newend[ct] = end2[j] + ct += 1 + j += 1 + + if j == n: + break + + # overlap + if start2[j] < end1[i]: + newstart[ct] = min(start1[i], start2[j]) + + while i < m and j < n: + newend[ct] = max(end1[i], end2[j]) + + if end1[i] < end2[j]: + i += 1 + else: + j += 1 + + if i == m: + j += 1 + ct += 1 + break + + if j == n: + i += 1 + ct += 1 + break + + if end2[j] < start1[i]: + j += 1 + ct += 1 + break + elif end1[i] < start2[j]: + i += 1 + ct += 1 + break + + else: + newstart[ct] = start1[i] + newend[ct] = end1[i] + ct += 1 + i += 1 + + while i < m: + newstart[ct] = start1[i] + newend[ct] = end1[i] + ct += 1 + i += 1 + + while j < n: + newstart[ct] = start2[j] + newend[ct] = end2[j] + ct += 1 + j += 1 + + newstart = newstart[0:ct] + newend = newend[0:ct] + + return (newstart, newend) + + +@jit(nopython=True) +def jitdiff(start1, end1, start2, end2): + m = start1.shape[0] + n = start2.shape[0] + + i = 0 + j = 0 + + newstart = np.zeros(m + n, dtype=np.float64) + newend = np.zeros(m + n, dtype=np.float64) + ct = 0 + + while i < m: + while j < n: + if end2[j] > start1[i]: + break + j += 1 + + if j == n: + break + + # overlap + if start2[j] < end1[i]: + if start2[j] < start1[i] and end1[i] < end2[j]: + i += 1 + + else: + if start2[j] > start1[i]: + newstart[ct] = start1[i] + newend[ct] = start2[j] + ct += 1 + j += 1 + + else: + newstart[ct] = end2[j] + newend[ct] = end1[i] + j += 1 + + while j < n: + if start2[j] < end1[i]: + newstart[ct] = end2[j - 1] + newend[ct] = start2[j] + ct += 1 + j += 1 + else: + break + + if end2[j - 1] < end1[i]: + newstart[ct] = end2[j - 1] + newend[ct] = end1[i] + ct += 1 + else: + j -= 1 + i += 1 + + else: + newstart[ct] = start1[i] + newend[ct] = end1[i] + ct += 1 + i += 1 + + while i < m: + newstart[ct] = start1[i] + newend[ct] = end1[i] + ct += 1 + i += 1 + + newstart = newstart[0:ct] + newend = newend[0:ct] + + return (newstart, newend) + + +@jit(nopython=True) +def jitunion_isets(starts, ends): + idx = np.argsort(starts) + starts = starts[idx] + ends = ends[idx] + + n = starts.shape[0] + new_start = np.zeros(n, dtype=np.float64) + new_end = np.zeros(n, dtype=np.float64) + + ct = 0 + new_start[ct] = starts[0] + e = ends[0] + i = 1 + while i < n: + if starts[i] > e: + new_end[ct] = e + ct += 1 + new_start[ct] = starts[i] + e = ends[i] + else: + e = max(e, ends[i]) + i += 1 + + new_end[ct] = e + ct += 1 + new_start = new_start[0:ct] + new_end = new_end[0:ct] + return (new_start, new_end) + + +@jit(nopython=True) +def jitin_interval(time_array, starts, ends): + n = len(time_array) + m = len(starts) + data = np.ones(n, dtype=np.float64) * np.nan + + k = 0 + t = 0 + + while ends[k] < time_array[t]: + k += 1 + + while k < m: + # Outside + while t < n: + if time_array[t] >= starts[k]: + # data[t] = k + # t += 1 + break + # data[t] = np.nan + t += 1 + + # Inside + while t < n: + if time_array[t] > ends[k]: + k += 1 + # data[t] = np.nan + break + else: + data[t] = k + t += 1 + + if k == m: + break + if t == n: + break + + return data + + +@jit(nopython=True) +def jitremove_nan(time_array, index_nan): + n = len(time_array) + ix_start = np.zeros(n, dtype=np.bool_) + ix_end = np.zeros(n, dtype=np.bool_) + + if not index_nan[0]: # First start + ix_start[0] = True + + t = 1 + while t < n: + if index_nan[t - 1] and not index_nan[t]: # start + ix_start[t] = True + if not index_nan[t - 1] and index_nan[t]: # end + ix_end[t - 1] = True + t += 1 + + if not index_nan[-1]: # Last stop + ix_end[-1] = True + + starts = time_array[ix_start] + ends = time_array[ix_end] + return (starts, ends) + + +@jit(nopython=True) +def jitconvolve(d, a): + return np.convolve(d, a) + + +@njit(parallel=True) +def pjitconvolve(data_array, array, trim="both"): + shape = data_array.shape + t = shape[0] + k = array.shape[0] + + data_array = data_array.reshape(t, -1) + new_data_array = np.zeros(data_array.shape) + + if trim == "both": + cut = ((1 - k % 2) + (k - 1) // 2, t + k - 1 - ((k - 1) // 2)) + elif trim == "left": + cut = (k - 1, t + k - 1) + elif trim == "right": + cut = (0, t) + + for i in prange(data_array.shape[1]): + new_data_array[:, i] = jitconvolve(data_array[:, i], array)[cut[0] : cut[1]] + + new_data_array = new_data_array.reshape(shape) + + return new_data_array + + +@njit(parallel=True) +def jitcontinuous_perievent( + time_array, data_array, time_target_array, starts, ends, windowsize +): + N_samples = len(time_array) + N_target = len(time_target_array) + N_epochs = len(starts) + count = np.zeros((N_epochs, 2), dtype=np.int64) + start_t = np.zeros((N_epochs, 2), dtype=np.int64) + + k = 0 # Epochs + t = 0 # Samples + i = 0 # Target + + while ends[k] < time_array[t] and ends[k] < time_target_array[i]: + k += 1 + + while k < N_epochs: + # Outside + while t < N_samples: + if time_array[t] >= starts[k]: + break + t += 1 + + while i < N_target: + if time_target_array[i] >= starts[k]: + break + i += 1 + + if time_array[t] <= ends[k]: + start_t[k, 0] = t + + if time_target_array[i] <= ends[k]: + start_t[k, 1] = i + + # Inside + while t < N_samples: + if time_array[t] > ends[k]: + break + else: + count[k, 0] += 1 + t += 1 + + while i < N_target: + if time_target_array[i] > ends[k]: + break + else: + count[k, 1] += 1 + i += 1 + + k += 1 + + if k == N_epochs: + break + if t == N_samples: + break + if i == N_target: + break + + new_data_array = np.full( + (np.sum(windowsize) + 1, np.sum(count[:, 1]), *data_array.shape[1:]), np.nan + ) + + if np.any((count[:, 0] * count[:, 1]) > 0): + for k in prange(N_epochs): + if count[k, 0] > 0 and count[k, 1] > 0: + t = start_t[k, 0] + i = start_t[k, 1] + maxt = t + count[k, 0] + maxi = i + count[k, 1] + cnt_i = np.sum(count[0:k, 1]) + + while i < maxi: + interval = abs(time_array[t] - time_target_array[i]) + t_pos = t + t += 1 + while t < maxt: + new_interval = abs(time_array[t] - time_target_array[i]) + if new_interval > interval: + break + else: + interval = new_interval + t_pos = t + t += 1 + + left = np.minimum(windowsize[0], t_pos - start_t[k, 0]) + right = np.minimum(windowsize[1], maxt - t_pos - 1) + center = windowsize[0] + 1 + new_data_array[center - left - 1 : center + right, cnt_i] = ( + data_array[t_pos - left : t_pos + right + 1] + ) + + t -= 1 + i += 1 + cnt_i += 1 + + return new_data_array + + +@jit(nopython=True) +def jitperievent_trigger_average( + time_array, + count_array, + time_target_array, + data_target_array, + starts, + ends, + windows, + binsize, +): + T = time_array.shape[0] + N = count_array.shape[1] + N_epochs = len(starts) + + time_target_array, data_target_array, count = jitrestrict_with_count( + time_target_array, data_target_array, starts, ends + ) + max_count = np.cumsum(count) + + new_data_array = np.full( + (int(windows.sum()) + 1, count_array.shape[1], *data_target_array.shape[1:]), + 0.0, + ) + + t = 0 # count events + + hankel_array = np.zeros((new_data_array.shape[0], *data_target_array.shape[1:])) + + for k in range(N_epochs): + if count[k] > 0: + t_start = t + maxi = max_count[k] + i = maxi - count[k] + + while t < T: + lbound = time_array[t] + rbound = np.round(lbound + binsize, 9) + + if time_target_array[i] < rbound: + i_start = i + i_stop = i + + while i_stop < maxi: + if time_target_array[i_stop] < rbound: + i_stop += 1 + else: + break + + while i_start < i_stop - 1: + if time_target_array[i_start] < lbound: + i_start += 1 + else: + break + v = np.sum(data_target_array[i_start:i_stop], 0) / float( + i_stop - i_start + ) + + checknan = np.sum(v) + if not np.isnan(checknan): + hankel_array[-1] = v + + if t - t_start >= windows[1]: + for n in range(N): + new_data_array[:, n] += ( + hankel_array * count_array[t - windows[1], n] + ) + + # hankel_array = np.roll(hankel_array, -1, axis=0) + hankel_array[0:-1] = hankel_array[1:] + hankel_array[-1] = 0.0 + + t += 1 + + i = i_start + + if t == T or time_array[t] > ends[k]: + if t - t_start > windows[1]: + for j in range(windows[1]): + for n in range(N): + new_data_array[:, n] += ( + hankel_array * count_array[t - windows[1] + j, n] + ) + + # hankel_array = np.roll(hankel_array, -1, axis=0) + hankel_array[0:-1] = hankel_array[1:] + hankel_array[-1] = 0.0 + + hankel_array *= 0.0 + break + + total = np.sum(count_array, 0) + for n in range(N): + if total[n] > 0.0: + new_data_array[:, n] /= total[n] + + return new_data_array + + +# @jit(nopython=True) +# def jit_poisson_IRLS(X, y, niter=100, tolerance=1e-5): +# y = y.astype(np.float64) +# X = X.astype(np.float64) +# n, d = X.shape +# W = np.ones(n) +# iXtWX = np.linalg.inv(np.dot(X.T * W, X)) +# XtWY = np.dot(X.T * W, y) +# B = np.dot(iXtWX, XtWY) + +# for _ in range(niter): +# B_ = B +# L = np.exp(X.dot(B)) # Link function +# Z = L.reshape((-1, 1)) * X # partial derivatives +# delta = np.dot(np.linalg.inv(np.dot(Z.T * W, Z)), np.dot(Z.T * W, y)) +# B = B + delta +# tol = np.sum(np.abs((B - B_) / B_)) +# if tol < tolerance: +# return B +# return B + + +# @jit(nopython=True) +# def jitfind_gaps(time_array, starts, ends, min_gap): +# """ +# Jitted version of find_gap + +# Parameters +# ---------- +# time_array : numpy.ndarray +# Description +# data_array : numpy.ndarray +# Description +# starts : numpy.ndarray +# Description +# ends : numpy.ndarray +# Description + +# Returns +# ------- +# TYPE +# Description +# """ +# n = len(time_array) +# m = len(starts) + +# new_start = np.zeros(n+m, dtype=np.float64) +# new_end = np.zeros(n+m, dtype=np.float64) + +# k = 0 +# t = 0 +# i = 0 + +# while k ends[k]: +# break + +# if (time_array[t] - start) > min_gap: +# new_start[i] = start+1e-6 +# new_end[i] = time_array[t]-1e-6 +# start = time_array[t] +# t += 1 +# i += 1 + +# else: +# start = time_array[t] +# t += 1 + + +# k += 1 + + +# new_start = new_start[0:i] +# new_end = new_end[0:i] + +# return new_start, new_end