Skip to content

Commit 288662b

Browse files
committed
Adding dropna
1 parent e081914 commit 288662b

File tree

5 files changed

+109
-42
lines changed

5 files changed

+109
-42
lines changed

pynapple/core/jitted_functions.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# -*- coding: utf-8 -*-
22
# @Author: guillaume
33
# @Date: 2022-10-31 16:44:31
4-
# @Last Modified by: gviejo
5-
# @Last Modified time: 2023-10-15 16:05:27
4+
# @Last Modified by: Guillaume Viejo
5+
# @Last Modified time: 2023-11-19 18:27:43
66
import numpy as np
77
from numba import jit
88

@@ -749,6 +749,31 @@ def jitin_interval(time_array, starts, ends):
749749
return data
750750

751751

752+
@jit(nopython=True)
753+
def jitremove_nan(time_array, index_nan):
754+
n = len(time_array)
755+
ix_start = np.zeros(n, dtype=np.bool_)
756+
ix_end = np.zeros(n, dtype=np.bool_)
757+
758+
if not index_nan[0]: # First start
759+
ix_start[0] = True
760+
761+
t = 1
762+
while t < n:
763+
if index_nan[t - 1] and not index_nan[t]: # start
764+
ix_start[t] = True
765+
if not index_nan[t - 1] and index_nan[t]: # end
766+
ix_end[t - 1] = True
767+
t += 1
768+
769+
if not index_nan[-1]: # Last stop
770+
ix_end[-1] = True
771+
772+
starts = time_array[ix_start]
773+
ends = time_array[ix_end]
774+
return (starts, ends)
775+
776+
752777
@jit(nopython=True)
753778
def jit_poisson_IRLS(X, y, niter=100, tolerance=1e-5):
754779
y = y.astype(np.float64)

pynapple/core/time_series.py

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# @Author: gviejo
33
# @Date: 2022-01-27 18:33:31
44
# @Last Modified by: Guillaume Viejo
5-
# @Last Modified time: 2023-11-17 16:09:35
5+
# @Last Modified time: 2023-11-19 18:59:08
66

77
"""
88
@@ -39,6 +39,7 @@
3939
jitbin,
4040
jitbin_array,
4141
jitcount,
42+
jitremove_nan,
4243
jitrestrict,
4344
jitthreshold,
4445
jittsrestrict,
@@ -787,43 +788,47 @@ def get(self, start, end=None, time_units="s"):
787788
idx_end = np.searchsorted(time_array, end, side="right")
788789
return self[idx_start:idx_end]
789790

790-
def dropna(self):
791-
nant = np.any(np.isnan(self.values), 1)
792-
if np.any(nant):
793-
starts = []
794-
ends = []
795-
n = 0
796-
if not nant[n]: # start is the time support
797-
starts.append(self.time_support.start.values[0])
798-
else:
799-
while n<len(self):
800-
if nant[n]:
801-
n+=1
802-
else:
803-
starts.append(self.index.values[n])
804-
break
791+
def dropna(self, update_time_support=True):
792+
"""Drop every rows containing NaNs. By default, the time support is updated to start and end around the time points that are non NaNs.
793+
To change this behavior, you can set update_time_support=False.
805794
806-
while n<len(self):
807-
if nant[n]:
808-
ends.append(nant[n-1])
809-
break
810-
else:
811-
n+=1
795+
Parameters
796+
----------
797+
update_time_support : bool, optional
812798
813-
while n<len(self):
814-
if not nant[n]:
815-
n+1
816-
else:
817-
starts.append(self.index.values[n])
818-
break
799+
Returns
800+
-------
801+
Tsd, TsdFrame or TsdTensor
802+
The time series without the NaNs
803+
"""
804+
index_nan = np.any(np.isnan(self.values), axis=tuple(range(1, self.ndim)))
805+
if np.all(index_nan): # In case it's only NaNs
806+
return self.__class__(
807+
t=np.array([]), d=np.empty(tuple([0] + [d for d in self.shape[1:]]))
808+
)
819809

820-
if not nant[-1]: # end is the time support
821-
ends.append(self.time_support.end.values[0])
810+
elif np.any(index_nan):
811+
if update_time_support:
812+
time_array = self.index.values
813+
starts, ends = jitremove_nan(time_array, index_nan)
822814

815+
to_fix = starts == ends
816+
if np.any(to_fix):
817+
ends[
818+
to_fix
819+
] += 1e-6 # adding 1 millisecond in case of a single point
823820

824-
else:
825-
return self
821+
ep = IntervalSet(starts, ends)
822+
823+
return self.__class__(
824+
t=time_array[~index_nan], d=self.values[~index_nan], time_support=ep
825+
)
826+
827+
else:
828+
return self[~index_nan]
826829

830+
else:
831+
return self
827832

828833

829834
class TsdTensor(NDArrayOperatorsMixin, _AbstractTsd):

pynapple/process/perievent.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# -*- coding: utf-8 -*-
22
# @Author: gviejo
33
# @Date: 2022-01-30 22:59:00
4-
# @Last Modified by: gviejo
5-
# @Last Modified time: 2023-11-16 11:34:48
4+
# @Last Modified by: Guillaume Viejo
5+
# @Last Modified time: 2023-11-19 19:13:24
66

77
import numpy as np
88
from scipy.linalg import hankel
@@ -168,6 +168,11 @@ def compute_event_trigger_average(
168168

169169
tmp = feature.bin_average(binsize, ep)
170170

171+
# Check for any NaNs in feature
172+
if np.any(np.isnan(tmp)):
173+
tmp = tmp.dropna()
174+
count = count.restrict(tmp.time_support)
175+
171176
# Build the Hankel matrix
172177
n_p = len(idx1)
173178
n_f = len(idx2)

tests/test_numpy_compatibility.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# -*- coding: utf-8 -*-
22
# @Author: Guillaume Viejo
33
# @Date: 2023-09-18 18:11:24
4-
# @Last Modified by: gviejo
5-
# @Last Modified time: 2023-11-08 18:14:12
4+
# @Last Modified by: Guillaume Viejo
5+
# @Last Modified time: 2023-11-19 16:55:26
66

77

88

@@ -17,7 +17,10 @@
1717

1818
tsd = nap.TsdTensor(t=np.arange(100), d=np.random.rand(100, 5, 3), time_units="s")
1919

20-
tsd = nap.TsdFrame(t=np.arange(100), d=np.random.randn(100, 6))
20+
# tsd = nap.TsdFrame(t=np.arange(100), d=np.random.randn(100, 6))
21+
22+
tsd.d[tsd.values>0.9] = np.NaN
23+
2124

2225
@pytest.mark.parametrize(
2326
"tsd",

tests/test_time_series.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# -*- coding: utf-8 -*-
22
# @Author: gviejo
33
# @Date: 2022-04-01 09:57:55
4-
# @Last Modified by: gviejo
5-
# @Last Modified time: 2023-11-08 18:46:52
4+
# @Last Modified by: Guillaume Viejo
5+
# @Last Modified time: 2023-11-19 18:48:57
66
#!/usr/bin/env python
77

88
"""Tests of time series for `pynapple` package."""
@@ -271,7 +271,7 @@ def __init__(self):
271271
@pytest.mark.parametrize(
272272
"tsd",
273273
[
274-
nap.Tsd(t=np.arange(100), d=np.arange(100), time_units="s"),
274+
nap.Tsd(t=np.arange(100), d=np.random.rand(100), time_units="s"),
275275
nap.TsdFrame(t=np.arange(100), d=np.random.rand(100, 5), time_units="s"),
276276
nap.TsdTensor(t=np.arange(100), d=np.random.rand(100, 5, 2), time_units="s"),
277277
nap.Ts(t=np.arange(100), time_units="s"),
@@ -393,6 +393,35 @@ def test_get_timepoint(self, tsd):
393393
np.testing.assert_array_equal(tsd.get(1), tsd[1])
394394
np.testing.assert_array_equal(tsd.get(1000), tsd[-1])
395395

396+
def test_dropna(self, tsd):
397+
if not isinstance(tsd, nap.Ts):
398+
399+
new_tsd = tsd.dropna()
400+
np.testing.assert_array_equal(tsd.index.values, new_tsd.index.values)
401+
np.testing.assert_array_equal(tsd.values, new_tsd.values)
402+
403+
tsd.values[tsd.values>0.9] = np.NaN
404+
new_tsd = tsd.dropna()
405+
assert not np.all(np.isnan(new_tsd))
406+
tokeep = np.array([~np.any(np.isnan(tsd[i])) for i in range(len(tsd))])
407+
np.testing.assert_array_equal(tsd.index.values[tokeep], new_tsd.index.values)
408+
np.testing.assert_array_equal(tsd.values[tokeep], new_tsd.values)
409+
410+
newtsd2 = tsd.restrict(new_tsd.time_support)
411+
np.testing.assert_array_equal(newtsd2.index.values, new_tsd.index.values)
412+
np.testing.assert_array_equal(newtsd2.values, new_tsd.values)
413+
414+
new_tsd = tsd.dropna(update_time_support=False)
415+
np.testing.assert_array_equal(tsd.index.values[tokeep], new_tsd.index.values)
416+
np.testing.assert_array_equal(tsd.values[tokeep], new_tsd.values)
417+
pd.testing.assert_frame_equal(new_tsd.time_support, tsd.time_support)
418+
419+
tsd.values[:] = np.NaN
420+
new_tsd = tsd.dropna()
421+
assert len(new_tsd) == 0
422+
assert len(new_tsd.time_support) == 0
423+
424+
396425
####################################################
397426
# Test for tsd
398427
####################################################

0 commit comments

Comments
 (0)