Skip to content

Commit

Permalink
Merge pull request #255 from pynapple-org/dev
Browse files Browse the repository at this point in the history
Fixing tsgroup save
  • Loading branch information
gviejo authored Apr 2, 2024
2 parents 9fabb06 + 7e7e3a2 commit 4039000
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 98 deletions.
2 changes: 1 addition & 1 deletion docs/examples/tutorial_pynapple_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@

tsdframe = nap.TsdFrame(t=np.arange(5), d=np.random.randn(5, 3))

print(np.concatenate(tsdframe, tsdframe), 1)
print(np.concatenate((tsdframe, tsdframe), 1))

# %%
# Spliting
Expand Down
14 changes: 4 additions & 10 deletions pynapple/core/interval_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def save(self, filename):
objects. For example, you determined some epochs for one session that you want to save
to avoid recomputing them.
You can load the object with numpy.load. Keys are 'start', 'end' and 'type'.
You can load the object with `nap.load_file`. Keys are 'start', 'end' and 'type'.
See the example below.
Parameters
Expand All @@ -577,16 +577,10 @@ def save(self, filename):
>>> ep = nap.IntervalSet(start=[0, 10, 20], end=[5, 12, 33])
>>> ep.save("my_ep.npz")
Here I can retrieve my data with numpy directly:
To load you file, you can use the `nap.load_file` function :
>>> file = np.load("my_ep.npz")
>>> print(list(file.keys()))
['start', 'end', 'type']
>>> print(file['start'])
[0. 10. 20.]
It is then easy to recreate the IntervalSet object.
>>> nap.IntervalSet(file['start'], file['end'])
>>> ep = nap.load_file("my_path/my_ep.npz")
>>> ep
start end
0 0.0 5.0
1 10.0 12.0
Expand Down
3 changes: 2 additions & 1 deletion pynapple/core/time_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
Similar to pandas.Index, `TsIndex` holds the timestamps associated with the data of a time series.
This class deals with conversion between different time units for all pynapple objects as well
as making sure that timestamps are property sorted before initializing any objects.
as making sure that timestamps are property sorted before initializing any objects.
- `us`: microseconds
- `ms`: milliseconds
- `s`: seconds (overall default)
Expand Down
64 changes: 15 additions & 49 deletions pynapple/core/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ def save(self, filename):
filtered them. You can save the filtered channels as a npz to avoid
reprocessing it.
You can load the object with numpy.load. Keys are 't', 'd', 'start', 'end', 'type'
You can load the object with `nap.load_file`. Keys are 't', 'd', 'start', 'end', 'type'
and 'columns' for columns names.
Parameters
Expand All @@ -673,21 +673,9 @@ def save(self, filename):
>>> tsdtensor = nap.TsdTensor(t=np.array([0., 1.]), d = np.zeros((2,3,4)))
>>> tsdtensor.save("my_path/my_tsdtensor.npz")
Here I can retrieve my data with numpy directly:
>>> file = np.load("my_path/my_tsdtensor.npz")
>>> print(list(file.keys()))
['t', 'd', 'start', 'end', ''type']
>>> print(file['t'])
[0. 1.]
It is then easy to recreate the TsdTensor object.
>>> time_support = nap.IntervalSet(file['start'], file['end'])
>>> nap.TsdTensor(t=file['t'], d=file['d'], time_support=time_support)
Time (s)
0.0 [[[0.0 ...]]]
1.0 [[[0.0 ...]]]
To load you file, you can use the `nap.load_file` function :
>>> tsdtensor = nap.load_file("my_path/my_tsdtensor.npz")
Raises
------
Expand Down Expand Up @@ -914,7 +902,7 @@ def save(self, filename):
filtered them. You can save the filtered channels as a npz to avoid
reprocessing it.
You can load the object with numpy.load. Keys are 't', 'd', 'start', 'end', 'type'
You can load the object with `nap.load_file`. Keys are 't', 'd', 'start', 'end', 'type'
and 'columns' for columns names.
Parameters
Expand All @@ -929,17 +917,10 @@ def save(self, filename):
>>> tsdframe = nap.TsdFrame(t=np.array([0., 1.]), d = np.array([[2, 3],[4,5]]), columns=['a', 'b'])
>>> tsdframe.save("my_path/my_tsdframe.npz")
Here I can retrieve my data with numpy directly:
>>> file = np.load("my_path/my_tsdframe.npz")
>>> print(list(file.keys()))
['t', 'd', 'start', 'end', 'columns', 'type']
>>> print(file['t'])
[0. 1.]
To load you file, you can use the `nap.load_file` function :
It is then easy to recreate the Tsd object.
>>> time_support = nap.IntervalSet(file['start'], file['end'])
>>> nap.TsdFrame(t=file['t'], d=file['d'], time_support=time_support, columns=file['columns'])
>>> tsdframe = nap.load_file("my_path/my_tsdframe.npz")
>>> tsdframe
a b
Time (s)
0.0 2 3
Expand Down Expand Up @@ -1236,7 +1217,7 @@ def save(self, filename):
filtered it. You can save the filtered channel as a npz to avoid
reprocessing it.
You can load the object with numpy.load. Keys are 't', 'd', 'start', 'end' and 'type'.
You can load the object with `nap.load_file`. Keys are 't', 'd', 'start', 'end' and 'type'.
See the example below.
Parameters
Expand All @@ -1251,17 +1232,10 @@ def save(self, filename):
>>> tsd = nap.Tsd(t=np.array([0., 1.]), d = np.array([2, 3]))
>>> tsd.save("my_path/my_tsd.npz")
Here I can retrieve my data with numpy directly:
To load you file, you can use the `nap.load_file` function :
>>> file = np.load("my_path/my_tsd.npz")
>>> print(list(file.keys()))
['t', 'd', 'start', 'end', 'type']
>>> print(file['t'])
[0. 1.]
It is then easy to recreate the Tsd object.
>>> time_support = nap.IntervalSet(file['start'], file['end'])
>>> nap.Tsd(t=file['t'], d=file['d'], time_support=time_support)
>>> tsd = nap.load_file("my_path/my_tsd.npz")
>>> tsd
Time (s)
0.0 2
1.0 3
Expand Down Expand Up @@ -1530,7 +1504,7 @@ def save(self, filename):
The main purpose of this function is to save small/medium sized timestamps
object.
You can load the object with numpy.load. Keys are 't', 'start' and 'end' and 'type'.
You can load the object with `nap.load_file`. Keys are 't', 'start' and 'end' and 'type'.
See the example below.
Parameters
Expand All @@ -1545,23 +1519,15 @@ def save(self, filename):
>>> ts = nap.Ts(t=np.array([0., 1., 1.5]))
>>> ts.save("my_path/my_ts.npz")
Here I can retrieve my data with numpy directly:
>>> file = np.load("my_path/my_ts.npz")
>>> print(list(file.keys()))
['t', 'start', 'end', 'type']
>>> print(file['t'])
[0. 1. 1.5]
To load you file, you can use the `nap.load_file` function :
It is then easy to recreate the Tsd object.
>>> time_support = nap.IntervalSet(file['start'], file['end'])
>>> nap.Ts(t=file['t'], time_support=time_support)
>>> ts = nap.load_file("my_path/my_ts.npz")
>>> ts
Time (s)
0.0
1.0
1.5
Raises
------
RuntimeError
Expand Down
55 changes: 24 additions & 31 deletions pynapple/core/ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,28 +888,33 @@ def save(self, filename):
and assigning to each the corresponding index. Typically, a TsGroup like
this :
TsGroup({
0 : Tsd(t=[0, 2, 4], d=[1, 2, 3])
1 : Tsd(t=[1, 5], d=[5, 6])
})
``` py
TsGroup({
0 : Tsd(t=[0, 2, 4], d=[1, 2, 3])
1 : Tsd(t=[1, 5], d=[5, 6])
})
```
will be saved as npz with the following keys:
{
't' : [0, 1, 2, 4, 5],
'd' : [1, 5, 2, 3, 5],
'index' : [0, 1, 0, 0, 1],
'start' : [0],
'end' : [5],
'type' : 'TsGroup'
}
``` py
{
't' : [0, 1, 2, 4, 5],
'd' : [1, 5, 2, 3, 5],
'index' : [0, 1, 0, 0, 1],
'start' : [0],
'end' : [5],
'keys' : [0, 1],
'type' : 'TsGroup'
}
```
Metadata are saved by columns with the column name as the npz key. To avoid
potential conflicts, make sure the columns name of the metadata are different
from ['t', 'd', 'start', 'end', 'index']
from ['t', 'd', 'start', 'end', 'index', 'keys']
You can load the object with numpy.load. Default keys are 't', 'd'(optional),
'start', 'end', 'index' and 'type'.
You can load the object with `nap.load_file`. Default keys are 't', 'd'(optional),
'start', 'end', 'index', 'keys' and 'type'.
See the example below.
Parameters
Expand All @@ -935,21 +940,9 @@ def save(self, filename):
6 0.4 1 left foot
>>> tsgroup.save("my_tsgroup.npz")
Here I can retrieve my data with numpy directly:
>>> file = np.load("my_tsgroup.npz")
>>> print(list(file.keys()))
['rate', 'group', 'location', 't', 'index', 'start', 'end', 'type']
>>> print(file['index'])
[0 6 0 0 6]
To get back to pynapple, you can use the `nap.load_file` function :
In the case where TsGroup is a set of Ts objects, it is very direct to
recreate the TsGroup by using the function to_tsgroup :
>>> time_support = nap.IntervalSet(file['start'], file['end'])
>>> tsd = nap.Tsd(t=file['t'], d=file['index'], time_support = time_support)
>>> tsgroup = tsd.to_tsgroup()
>>> tsgroup.set_info(group = file['group'], location = file['location'])
>>> tsgroup = nap.load_file("my_tsgroup.npz")
>>> tsgroup
Index rate group location
------- ------ ------- ----------
Expand Down Expand Up @@ -981,7 +974,7 @@ def save(self, filename):

dicttosave = {"type": np.array(["TsGroup"], dtype=np.str_)}
for k in self._metadata.columns:
if k not in ["t", "d", "start", "end", "index"]:
if k not in ["t", "d", "start", "end", "index", "keys"]:
tmp = self._metadata[k].values
if tmp.dtype == np.dtype("O"):
tmp = tmp.astype(np.str_)
Expand Down Expand Up @@ -1012,7 +1005,7 @@ def save(self, filename):
dicttosave["index"] = index
if not np.all(np.isnan(data)):
dicttosave["d"] = data[idx]

dicttosave["keys"] = np.array(self.keys())
dicttosave["start"] = self.time_support.start
dicttosave["end"] = self.time_support.end

Expand Down
35 changes: 29 additions & 6 deletions pynapple/io/interface_npz.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# @Author: Guillaume Viejo
# @Date: 2023-07-05 16:03:25
# @Last Modified by: Guillaume Viejo
# @Last Modified time: 2023-09-26 18:00:54
# @Last Modified time: 2024-04-02 14:32:25

"""
File classes help to validate and load pynapple objects or NWB files.
Expand Down Expand Up @@ -89,12 +89,35 @@ def load(self):
else:
time_support = nap.IntervalSet(self.file["start"], self.file["end"])
if self.type == "TsGroup":
tsd = nap.Tsd(
t=self.file["t"], d=self.file["index"], time_support=time_support
)
tsgroup = tsd.to_tsgroup()

times = self.file["t"]
index = self.file["index"]
has_data = False
if "d" in self.file.keys():
print("TODO")
data = self.file["data"]
has_data = True

if "keys" in self.file.keys():
keys = self.file["keys"]
else:
keys = np.unique(index)

group = {}
for k in keys:
if has_data:
group[k] = nap.Tsd(
t=times[index == k],
d=data[index == k],
time_support=time_support,
)
else:
group[k] = nap.Ts(
t=times[index == k], time_support=time_support
)

tsgroup = nap.TsGroup(
group, time_support=time_support, bypass_check=True
)

metainfo = {}
for k in set(self.file.keys()) - {
Expand Down

0 comments on commit 4039000

Please sign in to comment.