|
| 1 | +#!/usr/bin/env python |
| 2 | +# coding: utf-8 |
| 3 | + |
| 4 | +# In[15]: |
| 5 | + |
| 6 | + |
1 | 7 | from pathlib import Path
|
2 | 8 | import numpy as np
|
3 | 9 | import h5py
|
4 | 10 | from tqdm.auto import tqdm, trange
|
5 | 11 | import scipy.io
|
6 | 12 | import hdbscan
|
| 13 | +import os |
| 14 | + |
| 15 | +from dredge import dredge_ap, motion_util as mu |
| 16 | + |
| 17 | + |
| 18 | +import spikeinterface.core as sc |
| 19 | +import spikeinterface.full as si |
7 | 20 |
|
8 | 21 | from spike_psvae import cluster_viz, cluster_utils
|
9 | 22 | from spike_psvae.ibme import register_nonrigid, fast_raster
|
|
13 | 26 | from sklearn.preprocessing import PolynomialFeatures, SplineTransformer
|
14 | 27 | from sklearn.pipeline import make_pipeline
|
15 | 28 |
|
| 29 | +from dartsort.config import TemplateConfig, MatchingConfig, ClusteringConfig, SubtractionConfig, FeaturizationConfig |
| 30 | +from dartsort import cluster |
| 31 | + |
| 32 | + |
| 33 | +# In[40]: |
16 | 34 |
|
17 | 35 |
|
18 | 36 | # Input
|
19 |
| -fname_sub_h5 = "subtraction.h5" |
20 |
| -sampling_rate=30000 |
21 |
| -rec_len_sec = 300 # Duration of the recording in seconds |
22 |
| -output_dir = "spline_drift_results" |
23 |
| -Path(output_dir).mkdir(exist_ok=True) |
| 37 | +data_dir = Path("UHD_DATA/ZYE_0021___2021-05-01___1") |
| 38 | +fname_sub_h5 = data_dir / "subtraction.h5" |
| 39 | +raw_data_name = data_dir / "standardized.bin" |
| 40 | +dtype_preprocessed = "float32" |
| 41 | +sampling_rate = 30000 |
| 42 | +n_channels = 384 |
| 43 | + |
| 44 | +rec_len_sec = int(os.path.getsize(raw_data_name)/4/n_channels/sampling_rate) |
| 45 | + |
| 46 | +output_dir = data_dir / "spline_drift_results" |
| 47 | +os.makedirs(Path(output_dir), exist_ok=True) |
24 | 48 | savefigs=True
|
25 | 49 | if savefigs:
|
26 | 50 | import matplotlib.pyplot as plt
|
27 | 51 | import matplotlib.cm as cm
|
28 | 52 | vir = cm.get_cmap('viridis')
|
29 | 53 |
|
| 54 | + |
| 55 | +# In[9]: |
| 56 | + |
| 57 | + |
30 | 58 | #parameters of output displacement
|
31 | 59 | disp_resolution=30000
|
32 | 60 |
|
|
39 | 67 |
|
40 | 68 | sub_h5 = Path(fname_sub_h5)
|
41 | 69 |
|
| 70 | +clustering_config_uhd = ClusteringConfig( |
| 71 | + cluster_strategy="density_peaks", |
| 72 | + sigma_regional=25, |
| 73 | + chunk_size_s=300, |
| 74 | + cluster_selection_epsilon=1, |
| 75 | + min_cluster_size = 25, |
| 76 | + min_samples = 25, |
| 77 | + recursive=False, |
| 78 | + remove_big_units=True, |
| 79 | + zstd_big_units=50.0, |
| 80 | +) |
| 81 | + |
42 | 82 | with h5py.File(sub_h5, "r+") as h5:
|
43 |
| - geom = h5["geom"][:] |
44 |
| - localization_results = np.array(h5["localizations"][:]) #f.get('localizations').value |
45 |
| - maxptps = np.array(h5["maxptps"][:]) |
46 |
| - spike_index = np.array(h5["spike_index"][:]) |
47 |
| - z_reg = np.array(h5["z_reg"][:]) |
48 |
| - dispmap = np.array(h5["dispmap"][:]) |
| 83 | + localization_results = np.array(h5["point_source_localizations"][:]) |
| 84 | + maxptps = np.array(h5["denoised_ptp_amplitudes"][:]) |
| 85 | + times_samples = np.array(h5["times_samples"][:]) |
| 86 | + times_seconds = np.array(h5["times_seconds"][:]) |
| 87 | + geom = np.array(h5["geom"][:]) |
| 88 | + channels = np.array(h5["channels"][:]) |
| 89 | + |
| 90 | + |
| 91 | +recording = sc.read_binary( |
| 92 | + raw_data_name, |
| 93 | + sampling_rate, |
| 94 | + dtype_preprocessed, |
| 95 | + num_channels=n_channels, |
| 96 | + is_filtered=True, |
| 97 | +) |
| 98 | + |
| 99 | +recording.set_dummy_probe_from_locations( |
| 100 | + geom, shape_params=dict(radius=10) |
| 101 | +) |
| 102 | + |
| 103 | + |
| 104 | +# In[ ]: |
| 105 | + |
| 106 | + |
49 | 107 |
|
50 |
| -if savefigs: |
51 |
| - idx = maxptps>threshold_ptp_rigid_reg |
52 |
| - ptp_col = np.log(maxptps[idx]+1) |
53 |
| - ptp_col[ptp_col>3.5]=3.5 |
54 |
| - ptp_col -= ptp_col.min() |
55 |
| - ptp_col /= ptp_col.max() |
56 |
| - color_array = vir(ptp_col) |
57 |
| - plt.figure(figsize=(20, 10)) |
58 |
| - plt.scatter(spike_index[idx, 0], z_reg[idx], s=1, color = color_array) |
59 |
| - plt.savefig(Path(output_dir) / "initial_detection_localization.png") |
60 |
| - plt.close() |
| 108 | + |
| 109 | + |
| 110 | +# In[12]: |
| 111 | + |
| 112 | + |
| 113 | +z = localization_results[:, 2] |
| 114 | +wh = (z > geom[:,1].min() - 100) & (z < geom[:,1].max() + 100) |
| 115 | +a = maxptps[wh] |
| 116 | +z = z[wh] |
| 117 | +t = times_seconds[wh] |
| 118 | + |
| 119 | +me, extra = dredge_ap.register(a, z, t, max_disp_um=100, win_scale_um=300, win_step_um=200, rigid=False, mincorr=0.6) |
| 120 | + |
| 121 | + |
| 122 | +# In[13]: |
| 123 | + |
| 124 | + |
| 125 | +z_reg = me.correct_s(times_seconds, localization_results[:, 2]) |
| 126 | +displacement_rigid = me.displacement |
61 | 127 |
|
62 | 128 | # Rigid reg
|
63 | 129 | idx = np.flatnonzero(maxptps>threshold_ptp_rigid_reg)
|
64 | 130 |
|
65 |
| -raster, dd, tt = fast_raster(maxptps[idx], z_reg[idx], spike_index[idx, 0]/sampling_rate) |
66 |
| -D, C = calc_corr_decent(raster, disp = 100) |
67 |
| -displacement_rigid = psolvecorr(D, C) |
68 |
| - |
69 | 131 | if savefigs:
|
70 | 132 | ptp_col = np.log(maxptps[idx]+1)
|
71 | 133 | ptp_col[ptp_col>3.5]=3.5
|
72 | 134 | ptp_col -= ptp_col.min()
|
73 | 135 | ptp_col /= ptp_col.max()
|
74 | 136 | color_array = vir(ptp_col)
|
75 | 137 | plt.figure(figsize=(20, 10))
|
76 |
| - plt.scatter(spike_index[idx, 0], z_reg[idx], s=1, color = color_array) |
| 138 | + plt.scatter(times_seconds[idx], z_reg[idx], s=1, color = color_array) |
77 | 139 | plt.plot(displacement_rigid+geom.max()/2, c='red')
|
78 | 140 | plt.savefig(Path(output_dir) / "initial_detection_localization.png")
|
79 | 141 | plt.close()
|
80 | 142 |
|
81 | 143 | idx = np.flatnonzero(maxptps>threshold_ptp_spline)
|
| 144 | + |
| 145 | + |
| 146 | + |
| 147 | +# In[20]: |
| 148 | + |
| 149 | + |
| 150 | +sorting = cluster( |
| 151 | + sub_h5, |
| 152 | + recording, |
| 153 | + clustering_config=clustering_config_uhd, |
| 154 | + motion_est=me) |
| 155 | + |
| 156 | + |
| 157 | +# In[25]: |
82 | 158 |
|
83 |
| -clusterer, cluster_centers, spike_index_cluster, x, z, maxptps_cluster, original_spike_ids = cluster_utils.cluster_spikes( |
84 |
| - localization_results[idx, 0], |
85 |
| - z_reg[idx]-displacement_rigid[spike_index[idx, 0]//sampling_rate], |
86 |
| - maxptps[idx], spike_index[idx], triage_quantile=100, do_copy_spikes=False, split_big=False, do_remove_dups=False) |
87 | 159 |
|
88 | 160 | if savefigs:
|
89 |
| - cluster_viz.array_scatter(clusterer.labels_, |
90 |
| - geom, x, z, |
91 |
| - maxptps_cluster) |
| 161 | + cluster_viz.array_scatter(sorting.labels, |
| 162 | + geom, localization_results[:, 0], z_reg, |
| 163 | + maxptps, zlim=(-50, 332)) |
92 | 164 | plt.savefig(Path(output_dir) / "clustering_high_ptp_units.png")
|
93 | 165 | plt.close()
|
94 | 166 |
|
95 |
| -z_centered = z[clusterer.labels_>-1].copy() |
96 |
| -for k in range(clusterer.labels_.max()+1): |
97 |
| - idx_k = np.flatnonzero(clusterer.labels_[clusterer.labels_>-1] == k) |
| 167 | + |
| 168 | +# In[24]: |
| 169 | + |
| 170 | + |
| 171 | +z_centered = localization_results[sorting.labels>-1, 2].copy() |
| 172 | +for k in range(sorting.labels.max()+1): |
| 173 | + idx_k = np.flatnonzero(sorting.labels[sorting.labels>-1] == k) |
98 | 174 | z_centered[idx_k] -= z_centered[idx_k].mean()
|
99 | 175 | z_centered_std = z_centered.std()
|
100 | 176 |
|
101 | 177 | values = z_centered[np.abs(z_centered)<z_centered_std*std_bound]
|
102 |
| -idx_times = np.flatnonzero(clusterer.labels_>-1) |
103 |
| -times = spike_index_cluster[idx_times, 0][np.abs(z_centered)<z_centered_std*std_bound]/sampling_rate |
| 178 | +idx_times = np.flatnonzero(sorting.labels>-1) |
| 179 | +times = times_seconds[idx_times][np.abs(z_centered)<z_centered_std*std_bound] |
104 | 180 |
|
105 | 181 | transformer = SplineTransformer(
|
106 | 182 | degree=spline_degree,
|
107 | 183 | n_knots=int(rec_len_sec*2.5), #can find this automatically - 2 per period (of ~1.25sec)
|
108 | 184 | )
|
| 185 | + |
109 | 186 | model = make_pipeline(transformer, Ridge(alpha=1e-3))
|
110 | 187 | model.fit(times.reshape(-1, 1), values.reshape(-1, 1))
|
111 | 188 |
|
112 |
| -if savefigs: |
113 |
| - plt.figure(figsize=(20, 5)) |
114 |
| - plt.scatter(times, values, s=1, c='blue') |
115 |
| - plt.plot(times, model.predict(times.reshape(-1, 1))[:, 0], c='red') |
116 |
| - plt.savefig(Path(output_dir) / "spline_fit.png") |
117 |
| - plt.close() |
| 189 | + |
| 190 | +# In[37]: |
| 191 | + |
| 192 | + |
| 193 | +rec_len_sec |
| 194 | + |
| 195 | + |
| 196 | +# In[ ]: |
| 197 | + |
118 | 198 |
|
119 | 199 | print("Inference")
|
120 | 200 | spline_displacement = np.zeros(rec_len_sec*sampling_rate)
|
|
123 | 203 | idx_batch = np.arange(batch_id*batch_size_spline*sampling_rate, (batch_id+1)*batch_size_spline*sampling_rate, )
|
124 | 204 | spline_displacement[idx_batch] = model.predict(idx_batch.reshape(-1, 1)/sampling_rate)[:, 0]
|
125 | 205 |
|
| 206 | + |
| 207 | +# In[ ]: |
| 208 | + |
| 209 | + |
126 | 210 | if savefigs:
|
127 | 211 | idx = np.flatnonzero(maxptps>threshold_ptp_rigid_reg)
|
128 | 212 | ptp_col = np.log(maxptps[idx]+1)
|
|
131 | 215 | ptp_col /= ptp_col.max()
|
132 | 216 | color_array = vir(ptp_col)
|
133 | 217 | plt.figure(figsize=(20, 10))
|
134 |
| - plt.scatter(spike_index[idx, 0], z_reg[idx]-displacement_rigid[spike_index[idx, 0]//30000] - spline_displacement[spike_index[idx, 0].astype('int')], s=1, color = color_array) |
| 218 | + plt.scatter(times_seconds[idx], z_reg[idx] - spline_displacement[times_samples[idx].astype('int')], s=1, color = color_array) |
135 | 219 | plt.savefig(Path(output_dir) / "final_registered_raster.png")
|
136 | 220 | plt.close()
|
137 | 221 |
|
138 |
| -np.save("low_freq_disp_map.npy", dispmap + displacement_rigid[None, :]) |
| 222 | +# np.save("low_freq_disp_map.npy", dispmap + displacement_rigid[None, :]) |
139 | 223 | np.save("high_freq_correction.npy", spline_displacement)
|
| 224 | + |
| 225 | + |
| 226 | +# In[29]: |
| 227 | + |
| 228 | + |
| 229 | +spline_displacement.shape |
| 230 | + |
| 231 | + |
| 232 | +# In[30]: |
| 233 | + |
| 234 | + |
| 235 | +times_samples.max() |
| 236 | + |
| 237 | + |
| 238 | +# In[31]: |
| 239 | + |
| 240 | + |
| 241 | +4040*30_000/9000000 |
| 242 | + |
| 243 | + |
| 244 | +# In[ ]: |
| 245 | + |
| 246 | + |
| 247 | + |
| 248 | + |
0 commit comments