Skip to content

Commit aca7ec9

Browse files
adding support for spike_times data in gui (#27)
1 parent aa5fe2d commit aca7ec9

File tree

7 files changed

+133
-15
lines changed

7 files changed

+133
-15
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,18 @@ Most of the time you will input to `Rastermap().fit` a matrix of neurons by time
198198
* **itrain** : array, shape (n_features,) (optional, default None)
199199
fit embedding on timepoints itrain only
200200

201+
If you have a `spike_times.npy` and `spike_clusters.npy`, create your time-binned data
202+
matrix with, where the bin size `st_bin` is in milliseconds (assuming your spike times are in seconds):
203+
204+
```
205+
from rastermap import io
206+
207+
# bin spike times into neurons by time matrix
208+
data = io.load_spike_times("spike_times.npy", "spike_clusters.npy", st_bin=100)
209+
```
210+
211+
You can also load these matrices into the GUI with the `File > Load spike_times...` option.
212+
201213
# Settings
202214

203215
These are inputs to the `Rastermap` class initialization, the settings are sorted in order of importance

rastermap/__main__.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import argparse
77
import os
88
from rastermap import Rastermap
9-
from rastermap.io import load_activity
9+
from rastermap.io import load_activity, load_spike_times
1010

1111
try:
1212
from rastermap.gui import gui
@@ -24,15 +24,23 @@
2424
if __name__ == "__main__":
2525
parser = argparse.ArgumentParser(description="spikes")
2626
parser.add_argument("--S", default=[], type=str, help="spiking matrix")
27+
parser.add_argument("--spike_times", default=[], type=str, help="spike_times.npy")
28+
parser.add_argument("--spike_clusters", default=[], type=str, help="spike_clusters.npy")
29+
parser.add_argument("--st_bin", default=100, type=float, help="bin size in milliseconds for spike times")
2730
parser.add_argument("--proc", default=[], type=str,
2831
help="processed data file 'embedding.npy'")
2932
parser.add_argument("--ops", default=[], type=str, help="options file 'ops.npy'")
3033
parser.add_argument("--iscell", default=[], type=str,
3134
help="which cells to select for processing")
3235
args = parser.parse_args()
3336

34-
if len(args.ops) > 0 and len(args.S) > 0:
35-
X, Usv, Vsv, xy = load_activity(args.S)
37+
if len(args.ops) > 0 and (len(args.S) > 0 or
38+
(len(args.spike_times) > 0 and len(args.spike_clusters) > 0)):
39+
if len(args.S) > 0:
40+
X, Usv, Vsv, xy = load_activity(args.S)
41+
else:
42+
Usv, Vsv, xy = None, None, None
43+
X = load_spike_times(args.spike_times, args.spike_clusters, args.st_bin)
3644
ops = np.load(args.ops, allow_pickle=True).item()
3745
if len(args.iscell) > 0:
3846
iscell = np.load(args.iscell)
@@ -62,14 +70,16 @@
6270
model.fit(data=X, Usv=Usv, Vsv=Vsv)
6371

6472
proc = {
65-
"filename": args.S,
66-
"save_path": os.path.split(args.S)[0],
73+
"filename": args.S if len(args.S) > 0 else args.spike_times,
74+
"filename_cluid": args.spike_clusters if args.spike_clusters else None,
75+
"st_bin": args.st_bin if args.spike_clusters else None,
76+
"save_path": os.path.split(args.S)[0] if args.S else os.path.split(args.spike_times)[0],
6777
"isort": model.isort,
6878
"embedding": model.embedding,
6979
"user_clusters": None,
7080
"ops": ops,
7181
}
72-
basename, fname = os.path.split(args.S)
82+
basename, fname = os.path.split(args.S) if args.S else os.path.split(args.spike_times)
7383
fname = os.path.splitext(fname)[0]
7484
try:
7585
np.save(os.path.join(basename, f"{fname}_embedding.npy"), proc)

rastermap/gui/gui.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def __init__(self, filename=None, proc=False):
178178
# Default variables
179179
self.tpos = -0.5
180180
self.tsize = 1
181+
self.from_spike_times = False
181182
self.reset_variables()
182183

183184
self.init_time_roi()

rastermap/gui/io.py

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
import os
55
import numpy as np
66
from qtpy import QtGui, QtCore, QtWidgets
7-
from qtpy.QtWidgets import QFileDialog, QInputDialog, QMainWindow, QApplication, QWidget, QScrollBar, QSlider, QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, QLineEdit, QMessageBox, QGroupBox
7+
from qtpy.QtWidgets import QFileDialog, QDialog, QInputDialog, QMainWindow, QApplication, QWidget, QScrollBar, QSlider, QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, QLineEdit, QMessageBox, QGroupBox
88
import pyqtgraph as pg
99
from scipy.stats import zscore
10+
from scipy.sparse import csr_array
1011
import scipy.io as sio
1112
from . import guiparts
12-
from ..io import _load_iscell, _load_stat, load_activity
13+
from ..io import _load_iscell, _load_stat, load_activity, load_spike_times
1314

1415
def _load_activity_gui(parent, X, Usv, Vsv, xy):
1516
parent.reset_variables()
@@ -54,6 +55,66 @@ def _load_activity_gui(parent, X, Usv, Vsv, xy):
5455
parent.sorting = np.arange(0, parent.n_samples).astype(np.int64)
5556
_load_sp(parent)
5657

58+
59+
class SpikeTimeLoad(QDialog):
60+
def __init__(self, parent=None):
61+
super().__init__()
62+
self.parent = parent
63+
self.layout = QGridLayout(self)
64+
self.files, self.file_labels = [], []
65+
lbl = ["spike_times", "spike_clusters"]
66+
for j in range(2):
67+
self.files.append(QPushButton(f"Choose {lbl[j]} file", self))
68+
self.files[-1].clicked.connect(lambda state, idx=j: self.get_file(idx))
69+
self.layout.addWidget(self.files[-1], j, 0, 1, 1)
70+
self.file_labels.append(QLabel("", self))
71+
self.layout.addWidget(self.file_labels[-1], j, 1, 1, 1)
72+
73+
# time bin size lineedit
74+
self.time_bin = QLineEdit("100", self)
75+
self.layout.addWidget(self.time_bin, 2, 1, 1, 1)
76+
self.layout.addWidget(QLabel("Time bin size (in millisec.)", self), 2, 0, 1, 1)
77+
78+
# Submit button
79+
self.submit_button = QPushButton("Submit", self)
80+
self.submit_button.clicked.connect(self.submit)
81+
self.layout.addWidget(self.submit_button, 3, 0, 1, 2)
82+
83+
self.setWindowTitle("spike time input")
84+
self.show()
85+
86+
def get_file(self, j):
87+
name = QFileDialog.getOpenFileName(self, "Open *.npy or *.mat",
88+
filter="*.npy *.mat")
89+
self.file_labels[j].setText(name[0])
90+
91+
def submit(self):
92+
if not self.file_labels[0].text() or not self.file_labels[1].text():
93+
QMessageBox.critical(self, 'Error', 'Both files are required!')
94+
else:
95+
fname = self.file_labels[0].text()
96+
fname_cluid = self.file_labels[1].text()
97+
st_bin = float(self.time_bin.text())
98+
st_bin = min(5000, max(5, st_bin))
99+
print(f"setting bin size to {st_bin} for visualization")
100+
_load_spike_times(self.parent, fname, fname_cluid, st_bin)
101+
102+
def _load_spike_times(parent, fname, fname_cluid, st_bin):
103+
spks = load_spike_times(fname, fname_cluid, st_bin)
104+
parent.fname = fname
105+
parent.fname_cluid = fname_cluid
106+
parent.st_bin = st_bin
107+
parent.from_spike_times = True
108+
_load_activity_gui(parent, spks, None, None, None)
109+
110+
def load_st_clu(parent, name=None):
111+
""" load spike times of neurons (*.npy or *.mat) """
112+
if name is None:
113+
st = SpikeTimeLoad(parent)
114+
st.exec_()
115+
116+
117+
57118
def load_mat(parent, name=None):
58119
""" load data matrix of neurons by time (*.npy or *.mat)
59120
@@ -361,13 +422,27 @@ def load_proc(parent, name=None):
361422
else:
362423
print(f"ERROR: {parent.proc['filename']} not found")
363424
return
425+
426+
if parent.proc["filename_cluid"]:
427+
if os.path.exists(parent.proc["filename_cluid"]):
428+
parent.fname_cluid = parent.proc["filename_cluid"]
429+
elif os.path.exists(os.path.join(foldername, filename)):
430+
parent.fname_cluid = os.path.join(foldername, filename)
431+
else:
432+
print(f"ERROR: {parent.proc['filename_cluid']} not found")
433+
return
434+
parent.st_bin = parent.proc["st_bin"]
435+
364436

365437
isort = parent.proc["isort"]
366438
y = parent.proc["embedding"]
367439
ops = parent.proc["ops"]
368440
user_clusters = parent.proc.get("user_clusters", None)
369-
370-
X, Usv, Vsv, xy = load_activity(parent.fname)
441+
if parent.proc["filename_cluid"]:
442+
Usv, Vsv, xy = None, None, None
443+
X = load_spike_times(parent.fname, parent.fname_cluid, parent.st_bin)
444+
else:
445+
X, Usv, Vsv, xy = load_activity(parent.fname)
371446
_load_activity_gui(parent, X, Usv, Vsv, xy)
372447

373448
except Exception as e:

rastermap/gui/menus.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,32 @@ def mainmenu(parent):
1515

1616
file_menu = main_menu.addMenu("&File")
1717

18-
loadMat = QAction("&Load data matrix", parent)
18+
loadMat = QAction("&Load data matrix (neurons by time)", parent)
1919
loadMat.setShortcut("Ctrl+L")
2020
loadMat.triggered.connect(lambda: io.load_mat(parent, name=None))
2121
parent.addAction(loadMat)
2222
file_menu.addAction(loadMat)
2323

24-
parent.loadXY = QAction("&Load xy(z) positions of neurons", parent)
24+
loadSt = QAction("Load spike_times and spike_&Clusters", parent)
25+
loadSt.setShortcut("Ctrl+C")
26+
loadSt.triggered.connect(lambda: io.load_st_clu(parent, name=None))
27+
parent.addAction(loadSt)
28+
file_menu.addAction(loadSt)
29+
30+
parent.loadXY = QAction("Load &XY(z) positions of neurons", parent)
2531
parent.loadXY.setShortcut("Ctrl+X")
2632
parent.loadXY.triggered.connect(lambda: io.load_neuron_pos(parent))
2733
parent.addAction(parent.loadXY)
2834
file_menu.addAction(parent.loadXY)
2935

3036
# load Z-stack
31-
parent.loadProc = QAction("&Load z-stack (mean images)", parent)
37+
parent.loadProc = QAction("Load &Z-stack (mean images)", parent)
3238
parent.loadProc.setShortcut("Ctrl+Z")
3339
parent.loadProc.triggered.connect(lambda: io.load_zstack(parent, name=None))
3440
parent.addAction(parent.loadProc)
3541
file_menu.addAction(parent.loadProc)
3642

37-
parent.loadNd = QAction("Load &n-d variable (times or cont.)", parent)
43+
parent.loadNd = QAction("Load &N-d variable (times or cont.)", parent)
3844
parent.loadNd.setShortcut("Ctrl+N")
3945
parent.loadNd.triggered.connect(lambda: io.get_behav_data(parent))
4046
parent.loadNd.setEnabled(False)

rastermap/gui/run.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,10 @@ def run_RMAP(self, parent):
8282
ops_path = os.path.join(os.getcwd(), "rmap_ops.npy")
8383
np.save(ops_path, self.ops)
8484
print("Running rastermap with command:")
85-
cmd = f"-u -W ignore -m rastermap --ops {ops_path} --S {parent.fname}"
85+
if parent.from_spike_times:
86+
cmd = f"-u -W ignore -m rastermap --ops {ops_path} --spike_times {parent.fname} --spike_clusters {parent.fname_cluid} --st_bin {parent.st_bin}"
87+
else:
88+
cmd = f"-u -W ignore -m rastermap --ops {ops_path} --S {parent.fname}"
8689
if parent.file_iscell is not None:
8790
cmd += f" --iscell {parent.file_iscell}"
8891
print("python " + cmd)

rastermap/io.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import scipy.io as sio
77
from scipy.stats import zscore
8+
from scipy.sparse import csr_array
89

910
def _load_dict(dat, keys):
1011
X, Usv, Vsv, xpos, ypos, xy = None, None, None, None, None, None
@@ -156,6 +157,16 @@ def load_activity(filename):
156157

157158
return X, Usv, Vsv, xy
158159

160+
def load_spike_times(fname, fname_cluid, st_bin=100):
161+
print("Loading " + fname)
162+
st = np.load(fname).squeeze()
163+
clu = np.load(fname_cluid).squeeze()
164+
if len(st) != len(clu):
165+
raise ValueError("spike times and clusters must have same length")
166+
spks = csr_array((np.ones(len(st), "uint8"),
167+
(clu, np.floor(st / st_bin * 1000).astype("int"))))
168+
spks = spks.todense().astype("float32")
169+
return spks
159170

160171
def _cell_center(voxel_mask):
161172
x = np.median(np.array([v[0] for v in voxel_mask]))

0 commit comments

Comments
 (0)