Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #266

Merged
merged 49 commits into from
May 7, 2024
Merged

Dev #266

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
a2fe7c1
Working branch when passing jax for d
gviejo Feb 27, 2024
28b8b66
Adding not_implemented_in_pynajax
gviejo Feb 27, 2024
a2a0d76
adding test
gviejo Feb 27, 2024
7237a19
comparing
gviejo Feb 27, 2024
1168ab4
speed test with great results on CPU
BalzaniEdoardo Feb 28, 2024
0854886
massive speedup on convolve
BalzaniEdoardo Feb 28, 2024
8dcd108
maxi-speedup
BalzaniEdoardo Feb 28, 2024
1d79b6a
working version of convolve
gviejo Feb 28, 2024
f7eed3d
Update
gviejo Feb 29, 2024
cb01bc2
Update
gviejo Mar 5, 2024
d8b9850
Merge branch 'dev' into pynajax
gviejo Mar 6, 2024
55827f7
Working version with jax backend
gviejo Mar 7, 2024
76b8bf2
Adding count
gviejo Mar 8, 2024
bd0a46b
Merging main into pynajax
gviejo Apr 2, 2024
ba4a294
Merge branch 'main' into pynajax
gviejo Apr 2, 2024
e2271b2
Updating history
gviejo Apr 3, 2024
39debbb
FIxing conflict
gviejo Apr 3, 2024
24450a8
Merge pull request #240 from pynapple-org/pynajax
gviejo Apr 3, 2024
4f1bf12
working version with backend dispatch
gviejo Apr 4, 2024
d132b0c
Merge branch 'main' into dev
gviejo Apr 4, 2024
ba758df
Fixing conflicts
gviejo Apr 4, 2024
df52c53
Merge main
gviejo Apr 10, 2024
f8fc6a8
Merge branch 'main' into dev
gviejo Apr 10, 2024
bdc7fa2
Merge branch 'main' into dev
gviejo Apr 11, 2024
f238161
keeping t as numpy array
gviejo Apr 12, 2024
c6d8818
Merging main
gviejo Apr 17, 2024
e542d75
Updating
gviejo Apr 18, 2024
e88f371
Updating branch with 0.6.4
gviejo Apr 18, 2024
73da1c1
Updating jittedfucntions
gviejo Apr 18, 2024
fb9877c
Working version of dev with minimal jitted functions
gviejo Apr 19, 2024
4082032
Adding _process_functions.py
gviejo Apr 23, 2024
fcdf0e2
linting
gviejo Apr 23, 2024
e0457ee
changing perievent
gviejo Apr 23, 2024
667f348
Linting
gviejo Apr 23, 2024
01e5e28
Update
gviejo Apr 23, 2024
ca97c64
Update
gviejo Apr 24, 2024
4c4a14c
Pasing tests for pynajax
gviejo Apr 29, 2024
ad4ef35
Updating docs
gviejo Apr 29, 2024
d86965c
testing fastplotlib
gviejo May 1, 2024
bf866f0
changing tests for pynajax
gviejo May 6, 2024
ef0151c
CHanged perivent continuous
gviejo May 7, 2024
1bff1c8
linting
gviejo May 7, 2024
8847a27
Update docs/pynajax.md
gviejo May 7, 2024
bc51b40
Update docs/pynajax.md
gviejo May 7, 2024
c2f72b3
Update pynapple/core/time_series.py
gviejo May 7, 2024
6e28b61
Update pynapple/core/time_series.py
gviejo May 7, 2024
7282d0d
Update pynapple/core/time_series.py
gviejo May 7, 2024
95d078a
Update
gviejo May 7, 2024
f71610d
Final update
gviejo May 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ In 2021, Guillaume and other trainees in Adrien's lab decided to fork from neuro
- Fixed `TsGroup` saving method.
- `__getattr__` of `BaseTsd` allow numpy functions to be attached as attributes of Tsd objects
- Added `get` method for `TsGroup`
- Tsds can be concatenate vertically if time indexes matches.


0.6.1 (2024-03-03)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/tutorial_pynapple_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
# The object project behaves like a nested dictionnary. It is then easy to loop and navigate through a hierarchy of folders when doing analyses. In this case, we are gonna take only the session A2929-200711.


session = project["sub-A2929"]["ses-A2929-200711"]
session = project["sub-A2929"]["A2929-200711"]

print(session)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/tutorial_pynapple_quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
# The object `data` is a [`Folder`](https://pynapple-org.github.io/pynapple/io.folder/) object that allows easy navigation and interaction with a dataset.
# In this case, we want to load the NWB file in the folder `/pynapplenwb`. Data are always lazy loaded. No time series is loaded until it's actually called.
# When calling the NWB file, the object `nwb` is an interface to the NWB file. All the data inside the NWB file that are compatible with one of the pynapple objects are shown with their corresponding keys.
nwb = data["sub-A2929"]["ses-A2929-200711"]["pynapplenwb"]["A2929-200711"]
nwb = data["sub-A2929"]["A2929-200711"]["pynapplenwb"]["A2929-200711"]
print(nwb)


Expand Down
17 changes: 17 additions & 0 deletions docs/external.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
Pynapple has been designed as a lightweight package for representing time series and epochs in system neuroscience.
As such, it can function as a foundational element for other analysis packages handling time series data. Here we keep track of external projects that uses pynapple.


## NEMOS

![image](https://raw.githubusercontent.com/flatironinstitute/nemos/main/docs/assets/glm_features_scheme.svg)

[NeMOs](https://nemos.readthedocs.io/en/stable/) is a statistical modeling framework optimized for systems neuroscience and powered by JAX. It streamlines the process of defining and selecting models, through a collection of easy-to-use methods for feature design.

The core of nemos includes GPU-accelerated, well-tested implementations of standard statistical models, currently focusing on the Generalized Linear Model (GLM).

Check out this [page](https://nemos.readthedocs.io/en/stable/generated/neural_modeling/) for many examples of neural modelling using nemos and pynapple.

!!! note
Nemos is build on top of [jax](https://jax.readthedocs.io/en/latest/index.html), a library for high-performance numerical computing.
To ensure full compatibility with nemos, consider installing [pynajax](https://github.com/pynapple-org/pynajax), a pynapple backend for jax.
65 changes: 65 additions & 0 deletions docs/pynajax.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
### Motivation

Multiple python packages exist for high-performance computing. Internally, pynapple makes extensive use of [numba](https://numba.pydata.org/) for accelerating some functions. Numba is a stable package that provide speed gains with minimal installation issues when running on CPUs.

Another high-performance toolbox for numerical analysis is
[jax](https://jax.readthedocs.io/en/latest/index.html). In addition to accelerating python code on CPUs, GPUs, and TPUs, it provides a special representation of arrays using the [jax Array object](https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html). Unfortunately, jax Array is incompatible with Numba. To solve this issue, we developped [pynajax](https://github.com/pynapple-org/pynajax).

Pynajax is an accelerated backend for pynapple built on top on jax. It offers a fast acceleration for some pynapple functions using CPU or GPU. Here is a minimal example on how to use pynajax:

``` bash
$ pip install pynajax
```



``` python
import pynapple as nap
import numpy as np

# Changed the backend from 'numba' to 'jax'
nap.nap_config.set_backend("jax")

# This will convert the numpy array to a jax Array.
tsd = nap.Tsd(t=np.arange(100), d=np.random.randn(100))

# This will run on GPU or CPU depending on the jax installation
tsd.convolve(np.ones(11))
```

This documentation page keeps tracks of the list of pynapple functions that can be jax-accelerated as well as their performances compared to pure numba.

### Installation issues

To get the best of the pynajax backend, jax needs to use the GPU.

While installing pynajax will install all the dependencies necessary to use jax, it does not guarantee
the use of the GPU.

To check if jax is using the GPU, you can run the following python commands :

- no GPU found :

```python
>>> import jax
>>> print(jax.devices())
[CpuDevice(id=0)]
```

- GPU found :

```python
>>> import jax
>>> print(jax.devices())
[cuda(id=0)]
```

Support for installing `JAX` for GPU users can be found in the [jax documentation](https://jax.readthedocs.io/en/latest/installation.html)


### Typical use-case


In addition to providing high performance numerical computing, jax can be used as a the backbone for a large scale machine learning model. Thus, pynajax can offer full compatibility between pynapple's time series representation and computational neuroscience models constructed using jax.

An example of a python package using both pynapple and jax is [NeMOs](https://nemos.readthedocs.io/en/stable/).
119 changes: 96 additions & 23 deletions draft_pynapple_fastplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,58 +15,131 @@

"""
# %%
# !!! warning
# This tutorial uses seaborn and matplotlib for displaying the figure
#
# You can install all with `pip install matplotlib seaborn tqdm`
#
# mkdocs_gallery_thumbnail_number = 1
#
# Now, import the necessary libraries:

# %qui qt
# %gui qt

import pynapple as nap
import numpy as np
import fastplotlib as fpl

import imageio.v3 as iio
import sys
# mkdocs_gallery_thumbnail_path = '../_static/fastplotlib_demo.png'

#nwb = nap.load_file("/Users/gviejo/pynapple/Mouse32-220101.nwb")
nwb = nap.load_file("your/path/to/MyProject/sub-A2929/ses-A2929-200711/pynapplenwb/A2929-200711.nwb")
def get_memory_map(filepath, nChannels, frequency=20000):
n_channels = int(nChannels)
f = open(filepath, 'rb')
startoffile = f.seek(0, 0)
endoffile = f.seek(0, 2)
bytes_size = 2
n_samples = int((endoffile-startoffile)/n_channels/bytes_size)
duration = n_samples/frequency
interval = 1/frequency
f.close()
fp = np.memmap(filepath, np.int16, 'r', shape = (n_samples, n_channels))
timestep = np.arange(0, n_samples)/frequency

return fp, timestep


#### LFP
data_array, time_array = get_memory_map("your/path/to/MyProject/sub-A2929/A2929-200711/A2929-200711.dat", 16)
lfp = nap.TsdFrame(t=time_array, d=data_array)

lfp2 = lfp.get(0, 20)[:,14]
lfp2 = np.vstack((lfp2.t, lfp2.d)).T

#### NWB
nwb = nap.load_file("your/path/to/MyProject/sub-A2929/A2929-200711/pynapplenwb/A2929-200711.nwb")
units = nwb['units']#.getby_category("location")['adn']
tmp = units.to_tsd().get(0, 20)
tmp = np.vstack((tmp.index.values, tmp.values)).T



fig = fpl.Figure(canvas="glfw", shape=(2,1))
fig[0,0].add_line(data=lfp2, thickness=1, cmap="autumn")
fig[1,0].add_scatter(tmp)
fig.show(maintain_aspect=False)
# fpl.run()




# grid_plot = fpl.GridPlot(shape=(2, 1), controller_ids="sync", names = ['lfp', 'wavelet'])
# grid_plot['lfp'].add_line(lfp.t, lfp[:,14].d)


import numpy as np
import fastplotlib as fpl

fig = fpl.Figure(canvas="glfw")#, shape=(2,1), controller_ids="sync")
fig[0,0].add_line(data=np.random.randn(1000))
fig.show(maintain_aspect=False)

fig2 = fpl.Figure(canvas="glfw", controllers=fig.controllers)#, shape=(2,1), controller_ids="sync")
fig2[0,0].add_line(data=np.random.randn(1000)*1000)
fig2.show(maintain_aspect=False)



# Not sure about this :
fig[1,0].controller.controls["mouse1"] = "pan", "drag", (1.0, 0.0)

fig[1,0].controller.controls.pop("mouse2")
fig[1,0].controller.controls.pop("mouse4")
fig[1,0].controller.controls.pop("wheel")

import pygfx

controller = pygfx.PanZoomController()
controller.controls.pop("mouse1")
controller.add_camera(fig[0, 0].camera)
controller.register_events(fig[0, 0].viewport)

controller2 = pygfx.PanZoomController()
controller2.add_camera(fig[1, 0].camera)
controller2.controls.pop("mouse1")
controller2.register_events(fig[1, 0].viewport)


tmp = units.to_tsd()














sys.exit()

#################################################################################################


nwb = nap.load_file("your/path/to/MyProject/sub-A2929/A2929-200711/pynapplenwb/A2929-200711.nwb")
units = nwb['units']#.getby_category("location")['adn']
tmp = units.to_tsd()
tmp = np.vstack((tmp.index.values, tmp.values)).T

# Example 1

fplot = fpl.Plot()

fplot.add_scatter(tmp)

fplot.graphics[0].cmap = "jet"

fplot.graphics[0].cmap.values = tmp[:, 1]

fplot.show(maintain_aspect=False)

# Example 2

names = [['raster'], ['position']]

grid_plot = fpl.GridPlot(shape=(2, 1), controller_ids="sync", names = names)

grid_plot['raster'].add_scatter(tmp)

grid_plot['position'].add_line(np.vstack((nwb['ry'].t, nwb['ry'].d)).T)

grid_plot.show(maintain_aspect=False)

grid_plot['raster'].auto_scale(maintain_aspect=False)


Expand Down
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ plugins:
nav:
- Overview: index.md
- Usage: generated/gallery
- External projects: external.md
- Pynajax - GPU acceleration: pynajax.md
- Modules : reference/
- Contributing: CONTRIBUTING.md
- Authors: AUTHORS.md
Expand Down
11 changes: 10 additions & 1 deletion pynapple/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
__version__ = "0.6.4"
from .core import IntervalSet, Ts, Tsd, TsdFrame, TsdTensor, TsGroup, TsIndex, config
from .core import (
IntervalSet,
Ts,
Tsd,
TsdFrame,
TsdTensor,
TsGroup,
TsIndex,
nap_config,
)
from .io import *
from .process import *
2 changes: 1 addition & 1 deletion pynapple/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import config
from .config import nap_config
from .interval_set import IntervalSet
from .time_index import TsIndex
from .time_series import Ts, Tsd, TsdFrame, TsdTensor
Expand Down
Loading
Loading