Skip to content

Commit 9f17b8e

Browse files
author
julienboussard
authored
Add files via upload
1 parent 809a462 commit 9f17b8e

File tree

1 file changed

+156
-47
lines changed

1 file changed

+156
-47
lines changed

scripts/spline_registration.py

Lines changed: 156 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
1+
#!/usr/bin/env python
2+
# coding: utf-8
3+
4+
# In[15]:
5+
6+
17
from pathlib import Path
28
import numpy as np
39
import h5py
410
from tqdm.auto import tqdm, trange
511
import scipy.io
612
import hdbscan
13+
import os
14+
15+
from dredge import dredge_ap, motion_util as mu
16+
17+
18+
import spikeinterface.core as sc
19+
import spikeinterface.full as si
720

821
from spike_psvae import cluster_viz, cluster_utils
922
from spike_psvae.ibme import register_nonrigid, fast_raster
@@ -13,20 +26,35 @@
1326
from sklearn.preprocessing import PolynomialFeatures, SplineTransformer
1427
from sklearn.pipeline import make_pipeline
1528

29+
from dartsort.config import TemplateConfig, MatchingConfig, ClusteringConfig, SubtractionConfig, FeaturizationConfig
30+
from dartsort import cluster
31+
32+
33+
# In[40]:
1634

1735

1836
# Input
19-
fname_sub_h5 = "subtraction.h5"
20-
sampling_rate=30000
21-
rec_len_sec = 300 # Duration of the recording in seconds
22-
output_dir = "spline_drift_results"
23-
Path(output_dir).mkdir(exist_ok=True)
37+
data_dir = Path("UHD_DATA/ZYE_0021___2021-05-01___1")
38+
fname_sub_h5 = data_dir / "subtraction.h5"
39+
raw_data_name = data_dir / "standardized.bin"
40+
dtype_preprocessed = "float32"
41+
sampling_rate = 30000
42+
n_channels = 384
43+
44+
rec_len_sec = int(os.path.getsize(raw_data_name)/4/n_channels/sampling_rate)
45+
46+
output_dir = data_dir / "spline_drift_results"
47+
os.makedirs(Path(output_dir), exist_ok=True)
2448
savefigs=True
2549
if savefigs:
2650
import matplotlib.pyplot as plt
2751
import matplotlib.cm as cm
2852
vir = cm.get_cmap('viridis')
2953

54+
55+
# In[9]:
56+
57+
3058
#parameters of output displacement
3159
disp_resolution=30000
3260

@@ -39,82 +67,134 @@
3967

4068
sub_h5 = Path(fname_sub_h5)
4169

70+
clustering_config_uhd = ClusteringConfig(
71+
cluster_strategy="density_peaks",
72+
sigma_regional=25,
73+
chunk_size_s=300,
74+
cluster_selection_epsilon=1,
75+
min_cluster_size = 25,
76+
min_samples = 25,
77+
recursive=False,
78+
remove_big_units=True,
79+
zstd_big_units=50.0,
80+
)
81+
4282
with h5py.File(sub_h5, "r+") as h5:
43-
geom = h5["geom"][:]
44-
localization_results = np.array(h5["localizations"][:]) #f.get('localizations').value
45-
maxptps = np.array(h5["maxptps"][:])
46-
spike_index = np.array(h5["spike_index"][:])
47-
z_reg = np.array(h5["z_reg"][:])
48-
dispmap = np.array(h5["dispmap"][:])
83+
localization_results = np.array(h5["point_source_localizations"][:])
84+
maxptps = np.array(h5["denoised_ptp_amplitudes"][:])
85+
times_samples = np.array(h5["times_samples"][:])
86+
times_seconds = np.array(h5["times_seconds"][:])
87+
geom = np.array(h5["geom"][:])
88+
channels = np.array(h5["channels"][:])
89+
90+
91+
recording = sc.read_binary(
92+
raw_data_name,
93+
sampling_rate,
94+
dtype_preprocessed,
95+
num_channels=n_channels,
96+
is_filtered=True,
97+
)
98+
99+
recording.set_dummy_probe_from_locations(
100+
geom, shape_params=dict(radius=10)
101+
)
102+
103+
104+
# In[ ]:
105+
106+
49107

50-
if savefigs:
51-
idx = maxptps>threshold_ptp_rigid_reg
52-
ptp_col = np.log(maxptps[idx]+1)
53-
ptp_col[ptp_col>3.5]=3.5
54-
ptp_col -= ptp_col.min()
55-
ptp_col /= ptp_col.max()
56-
color_array = vir(ptp_col)
57-
plt.figure(figsize=(20, 10))
58-
plt.scatter(spike_index[idx, 0], z_reg[idx], s=1, color = color_array)
59-
plt.savefig(Path(output_dir) / "initial_detection_localization.png")
60-
plt.close()
108+
109+
110+
# In[12]:
111+
112+
113+
z = localization_results[:, 2]
114+
wh = (z > geom[:,1].min() - 100) & (z < geom[:,1].max() + 100)
115+
a = maxptps[wh]
116+
z = z[wh]
117+
t = times_seconds[wh]
118+
119+
me, extra = dredge_ap.register(a, z, t, max_disp_um=100, win_scale_um=300, win_step_um=200, rigid=False, mincorr=0.6)
120+
121+
122+
# In[13]:
123+
124+
125+
z_reg = me.correct_s(times_seconds, localization_results[:, 2])
126+
displacement_rigid = me.displacement
61127

62128
# Rigid reg
63129
idx = np.flatnonzero(maxptps>threshold_ptp_rigid_reg)
64130

65-
raster, dd, tt = fast_raster(maxptps[idx], z_reg[idx], spike_index[idx, 0]/sampling_rate)
66-
D, C = calc_corr_decent(raster, disp = 100)
67-
displacement_rigid = psolvecorr(D, C)
68-
69131
if savefigs:
70132
ptp_col = np.log(maxptps[idx]+1)
71133
ptp_col[ptp_col>3.5]=3.5
72134
ptp_col -= ptp_col.min()
73135
ptp_col /= ptp_col.max()
74136
color_array = vir(ptp_col)
75137
plt.figure(figsize=(20, 10))
76-
plt.scatter(spike_index[idx, 0], z_reg[idx], s=1, color = color_array)
138+
plt.scatter(times_seconds[idx], z_reg[idx], s=1, color = color_array)
77139
plt.plot(displacement_rigid+geom.max()/2, c='red')
78140
plt.savefig(Path(output_dir) / "initial_detection_localization.png")
79141
plt.close()
80142

81143
idx = np.flatnonzero(maxptps>threshold_ptp_spline)
144+
145+
146+
147+
# In[20]:
148+
149+
150+
sorting = cluster(
151+
sub_h5,
152+
recording,
153+
clustering_config=clustering_config_uhd,
154+
motion_est=me)
155+
156+
157+
# In[25]:
82158

83-
clusterer, cluster_centers, spike_index_cluster, x, z, maxptps_cluster, original_spike_ids = cluster_utils.cluster_spikes(
84-
localization_results[idx, 0],
85-
z_reg[idx]-displacement_rigid[spike_index[idx, 0]//sampling_rate],
86-
maxptps[idx], spike_index[idx], triage_quantile=100, do_copy_spikes=False, split_big=False, do_remove_dups=False)
87159

88160
if savefigs:
89-
cluster_viz.array_scatter(clusterer.labels_,
90-
geom, x, z,
91-
maxptps_cluster)
161+
cluster_viz.array_scatter(sorting.labels,
162+
geom, localization_results[:, 0], z_reg,
163+
maxptps, zlim=(-50, 332))
92164
plt.savefig(Path(output_dir) / "clustering_high_ptp_units.png")
93165
plt.close()
94166

95-
z_centered = z[clusterer.labels_>-1].copy()
96-
for k in range(clusterer.labels_.max()+1):
97-
idx_k = np.flatnonzero(clusterer.labels_[clusterer.labels_>-1] == k)
167+
168+
# In[24]:
169+
170+
171+
z_centered = localization_results[sorting.labels>-1, 2].copy()
172+
for k in range(sorting.labels.max()+1):
173+
idx_k = np.flatnonzero(sorting.labels[sorting.labels>-1] == k)
98174
z_centered[idx_k] -= z_centered[idx_k].mean()
99175
z_centered_std = z_centered.std()
100176

101177
values = z_centered[np.abs(z_centered)<z_centered_std*std_bound]
102-
idx_times = np.flatnonzero(clusterer.labels_>-1)
103-
times = spike_index_cluster[idx_times, 0][np.abs(z_centered)<z_centered_std*std_bound]/sampling_rate
178+
idx_times = np.flatnonzero(sorting.labels>-1)
179+
times = times_seconds[idx_times][np.abs(z_centered)<z_centered_std*std_bound]
104180

105181
transformer = SplineTransformer(
106182
degree=spline_degree,
107183
n_knots=int(rec_len_sec*2.5), #can find this automatically - 2 per period (of ~1.25sec)
108184
)
185+
109186
model = make_pipeline(transformer, Ridge(alpha=1e-3))
110187
model.fit(times.reshape(-1, 1), values.reshape(-1, 1))
111188

112-
if savefigs:
113-
plt.figure(figsize=(20, 5))
114-
plt.scatter(times, values, s=1, c='blue')
115-
plt.plot(times, model.predict(times.reshape(-1, 1))[:, 0], c='red')
116-
plt.savefig(Path(output_dir) / "spline_fit.png")
117-
plt.close()
189+
190+
# In[37]:
191+
192+
193+
rec_len_sec
194+
195+
196+
# In[ ]:
197+
118198

119199
print("Inference")
120200
spline_displacement = np.zeros(rec_len_sec*sampling_rate)
@@ -123,6 +203,10 @@
123203
idx_batch = np.arange(batch_id*batch_size_spline*sampling_rate, (batch_id+1)*batch_size_spline*sampling_rate, )
124204
spline_displacement[idx_batch] = model.predict(idx_batch.reshape(-1, 1)/sampling_rate)[:, 0]
125205

206+
207+
# In[ ]:
208+
209+
126210
if savefigs:
127211
idx = np.flatnonzero(maxptps>threshold_ptp_rigid_reg)
128212
ptp_col = np.log(maxptps[idx]+1)
@@ -131,9 +215,34 @@
131215
ptp_col /= ptp_col.max()
132216
color_array = vir(ptp_col)
133217
plt.figure(figsize=(20, 10))
134-
plt.scatter(spike_index[idx, 0], z_reg[idx]-displacement_rigid[spike_index[idx, 0]//30000] - spline_displacement[spike_index[idx, 0].astype('int')], s=1, color = color_array)
218+
plt.scatter(times_seconds[idx], z_reg[idx] - spline_displacement[times_samples[idx].astype('int')], s=1, color = color_array)
135219
plt.savefig(Path(output_dir) / "final_registered_raster.png")
136220
plt.close()
137221

138-
np.save("low_freq_disp_map.npy", dispmap + displacement_rigid[None, :])
222+
# np.save("low_freq_disp_map.npy", dispmap + displacement_rigid[None, :])
139223
np.save("high_freq_correction.npy", spline_displacement)
224+
225+
226+
# In[29]:
227+
228+
229+
spline_displacement.shape
230+
231+
232+
# In[30]:
233+
234+
235+
times_samples.max()
236+
237+
238+
# In[31]:
239+
240+
241+
4040*30_000/9000000
242+
243+
244+
# In[ ]:
245+
246+
247+
248+

0 commit comments

Comments
 (0)