Skip to content

Commit

Permalink
Merge pull request #266 from pynapple-org/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
gviejo authored May 7, 2024
2 parents eab8f5e + f71610d commit a51825c
Show file tree
Hide file tree
Showing 35 changed files with 1,499 additions and 1,398 deletions.
1 change: 1 addition & 0 deletions docs/HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ In 2021, Guillaume and other trainees in Adrien's lab decided to fork from neuro
- Fixed `TsGroup` saving method.
- `__getattr__` of `BaseTsd` allow numpy functions to be attached as attributes of Tsd objects
- Added `get` method for `TsGroup`
- Tsds can be concatenate vertically if time indexes matches.


0.6.1 (2024-03-03)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/tutorial_pynapple_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
# The object project behaves like a nested dictionnary. It is then easy to loop and navigate through a hierarchy of folders when doing analyses. In this case, we are gonna take only the session A2929-200711.


session = project["sub-A2929"]["ses-A2929-200711"]
session = project["sub-A2929"]["A2929-200711"]

print(session)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/tutorial_pynapple_quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
# The object `data` is a [`Folder`](https://pynapple-org.github.io/pynapple/io.folder/) object that allows easy navigation and interaction with a dataset.
# In this case, we want to load the NWB file in the folder `/pynapplenwb`. Data are always lazy loaded. No time series is loaded until it's actually called.
# When calling the NWB file, the object `nwb` is an interface to the NWB file. All the data inside the NWB file that are compatible with one of the pynapple objects are shown with their corresponding keys.
nwb = data["sub-A2929"]["ses-A2929-200711"]["pynapplenwb"]["A2929-200711"]
nwb = data["sub-A2929"]["A2929-200711"]["pynapplenwb"]["A2929-200711"]
print(nwb)


Expand Down
17 changes: 17 additions & 0 deletions docs/external.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
Pynapple has been designed as a lightweight package for representing time series and epochs in system neuroscience.
As such, it can function as a foundational element for other analysis packages handling time series data. Here we keep track of external projects that uses pynapple.


## NEMOS

![image](https://raw.githubusercontent.com/flatironinstitute/nemos/main/docs/assets/glm_features_scheme.svg)

[NeMOs](https://nemos.readthedocs.io/en/stable/) is a statistical modeling framework optimized for systems neuroscience and powered by JAX. It streamlines the process of defining and selecting models, through a collection of easy-to-use methods for feature design.

The core of nemos includes GPU-accelerated, well-tested implementations of standard statistical models, currently focusing on the Generalized Linear Model (GLM).

Check out this [page](https://nemos.readthedocs.io/en/stable/generated/neural_modeling/) for many examples of neural modelling using nemos and pynapple.

!!! note
Nemos is build on top of [jax](https://jax.readthedocs.io/en/latest/index.html), a library for high-performance numerical computing.
To ensure full compatibility with nemos, consider installing [pynajax](https://github.com/pynapple-org/pynajax), a pynapple backend for jax.
65 changes: 65 additions & 0 deletions docs/pynajax.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
### Motivation

Multiple python packages exist for high-performance computing. Internally, pynapple makes extensive use of [numba](https://numba.pydata.org/) for accelerating some functions. Numba is a stable package that provide speed gains with minimal installation issues when running on CPUs.

Another high-performance toolbox for numerical analysis is
[jax](https://jax.readthedocs.io/en/latest/index.html). In addition to accelerating python code on CPUs, GPUs, and TPUs, it provides a special representation of arrays using the [jax Array object](https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html). Unfortunately, jax Array is incompatible with Numba. To solve this issue, we developped [pynajax](https://github.com/pynapple-org/pynajax).

Pynajax is an accelerated backend for pynapple built on top on jax. It offers a fast acceleration for some pynapple functions using CPU or GPU. Here is a minimal example on how to use pynajax:

``` bash
$ pip install pynajax
```



``` python
import pynapple as nap
import numpy as np

# Changed the backend from 'numba' to 'jax'
nap.nap_config.set_backend("jax")

# This will convert the numpy array to a jax Array.
tsd = nap.Tsd(t=np.arange(100), d=np.random.randn(100))

# This will run on GPU or CPU depending on the jax installation
tsd.convolve(np.ones(11))
```

This documentation page keeps tracks of the list of pynapple functions that can be jax-accelerated as well as their performances compared to pure numba.

### Installation issues

To get the best of the pynajax backend, jax needs to use the GPU.

While installing pynajax will install all the dependencies necessary to use jax, it does not guarantee
the use of the GPU.

To check if jax is using the GPU, you can run the following python commands :

- no GPU found :

```python
>>> import jax
>>> print(jax.devices())
[CpuDevice(id=0)]
```

- GPU found :

```python
>>> import jax
>>> print(jax.devices())
[cuda(id=0)]
```

Support for installing `JAX` for GPU users can be found in the [jax documentation](https://jax.readthedocs.io/en/latest/installation.html)


### Typical use-case


In addition to providing high performance numerical computing, jax can be used as a the backbone for a large scale machine learning model. Thus, pynajax can offer full compatibility between pynapple's time series representation and computational neuroscience models constructed using jax.

An example of a python package using both pynapple and jax is [NeMOs](https://nemos.readthedocs.io/en/stable/).
119 changes: 96 additions & 23 deletions draft_pynapple_fastplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,58 +15,131 @@
"""
# %%
# !!! warning
# This tutorial uses seaborn and matplotlib for displaying the figure
#
# You can install all with `pip install matplotlib seaborn tqdm`
#
# mkdocs_gallery_thumbnail_number = 1
#
# Now, import the necessary libraries:

# %qui qt
# %gui qt

import pynapple as nap
import numpy as np
import fastplotlib as fpl

import imageio.v3 as iio
import sys
# mkdocs_gallery_thumbnail_path = '../_static/fastplotlib_demo.png'

#nwb = nap.load_file("/Users/gviejo/pynapple/Mouse32-220101.nwb")
nwb = nap.load_file("your/path/to/MyProject/sub-A2929/ses-A2929-200711/pynapplenwb/A2929-200711.nwb")
def get_memory_map(filepath, nChannels, frequency=20000):
n_channels = int(nChannels)
f = open(filepath, 'rb')
startoffile = f.seek(0, 0)
endoffile = f.seek(0, 2)
bytes_size = 2
n_samples = int((endoffile-startoffile)/n_channels/bytes_size)
duration = n_samples/frequency
interval = 1/frequency
f.close()
fp = np.memmap(filepath, np.int16, 'r', shape = (n_samples, n_channels))
timestep = np.arange(0, n_samples)/frequency

return fp, timestep


#### LFP
data_array, time_array = get_memory_map("your/path/to/MyProject/sub-A2929/A2929-200711/A2929-200711.dat", 16)
lfp = nap.TsdFrame(t=time_array, d=data_array)

lfp2 = lfp.get(0, 20)[:,14]
lfp2 = np.vstack((lfp2.t, lfp2.d)).T

#### NWB
nwb = nap.load_file("your/path/to/MyProject/sub-A2929/A2929-200711/pynapplenwb/A2929-200711.nwb")
units = nwb['units']#.getby_category("location")['adn']
tmp = units.to_tsd().get(0, 20)
tmp = np.vstack((tmp.index.values, tmp.values)).T



fig = fpl.Figure(canvas="glfw", shape=(2,1))
fig[0,0].add_line(data=lfp2, thickness=1, cmap="autumn")
fig[1,0].add_scatter(tmp)
fig.show(maintain_aspect=False)
# fpl.run()




# grid_plot = fpl.GridPlot(shape=(2, 1), controller_ids="sync", names = ['lfp', 'wavelet'])
# grid_plot['lfp'].add_line(lfp.t, lfp[:,14].d)


import numpy as np
import fastplotlib as fpl

fig = fpl.Figure(canvas="glfw")#, shape=(2,1), controller_ids="sync")
fig[0,0].add_line(data=np.random.randn(1000))
fig.show(maintain_aspect=False)

fig2 = fpl.Figure(canvas="glfw", controllers=fig.controllers)#, shape=(2,1), controller_ids="sync")
fig2[0,0].add_line(data=np.random.randn(1000)*1000)
fig2.show(maintain_aspect=False)



# Not sure about this :
fig[1,0].controller.controls["mouse1"] = "pan", "drag", (1.0, 0.0)

fig[1,0].controller.controls.pop("mouse2")
fig[1,0].controller.controls.pop("mouse4")
fig[1,0].controller.controls.pop("wheel")

import pygfx

controller = pygfx.PanZoomController()
controller.controls.pop("mouse1")
controller.add_camera(fig[0, 0].camera)
controller.register_events(fig[0, 0].viewport)

controller2 = pygfx.PanZoomController()
controller2.add_camera(fig[1, 0].camera)
controller2.controls.pop("mouse1")
controller2.register_events(fig[1, 0].viewport)


tmp = units.to_tsd()














sys.exit()

#################################################################################################


nwb = nap.load_file("your/path/to/MyProject/sub-A2929/A2929-200711/pynapplenwb/A2929-200711.nwb")
units = nwb['units']#.getby_category("location")['adn']
tmp = units.to_tsd()
tmp = np.vstack((tmp.index.values, tmp.values)).T

# Example 1

fplot = fpl.Plot()

fplot.add_scatter(tmp)

fplot.graphics[0].cmap = "jet"

fplot.graphics[0].cmap.values = tmp[:, 1]

fplot.show(maintain_aspect=False)

# Example 2

names = [['raster'], ['position']]

grid_plot = fpl.GridPlot(shape=(2, 1), controller_ids="sync", names = names)

grid_plot['raster'].add_scatter(tmp)

grid_plot['position'].add_line(np.vstack((nwb['ry'].t, nwb['ry'].d)).T)

grid_plot.show(maintain_aspect=False)

grid_plot['raster'].auto_scale(maintain_aspect=False)


Expand Down
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ plugins:
nav:
- Overview: index.md
- Usage: generated/gallery
- External projects: external.md
- Pynajax - GPU acceleration: pynajax.md
- Modules : reference/
- Contributing: CONTRIBUTING.md
- Authors: AUTHORS.md
Expand Down
11 changes: 10 additions & 1 deletion pynapple/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
__version__ = "0.6.4"
from .core import IntervalSet, Ts, Tsd, TsdFrame, TsdTensor, TsGroup, TsIndex, config
from .core import (
IntervalSet,
Ts,
Tsd,
TsdFrame,
TsdTensor,
TsGroup,
TsIndex,
nap_config,
)
from .io import *
from .process import *
2 changes: 1 addition & 1 deletion pynapple/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import config
from .config import nap_config
from .interval_set import IntervalSet
from .time_index import TsIndex
from .time_series import Ts, Tsd, TsdFrame, TsdTensor
Expand Down
Loading

0 comments on commit a51825c

Please sign in to comment.