diff --git a/scripts/spline_registration.py b/scripts/spline_registration.py index fb6cad2a..c706f091 100644 --- a/scripts/spline_registration.py +++ b/scripts/spline_registration.py @@ -1,9 +1,3 @@ -#!/usr/bin/env python -# coding: utf-8 - -# In[15]: - - from pathlib import Path import numpy as np import h5py @@ -29,10 +23,6 @@ from dartsort.config import TemplateConfig, MatchingConfig, ClusteringConfig, SubtractionConfig, FeaturizationConfig from dartsort import cluster - -# In[40]: - - # Input data_dir = Path("UHD_DATA/ZYE_0021___2021-05-01___1") fname_sub_h5 = data_dir / "subtraction.h5" @@ -50,11 +40,7 @@ import matplotlib.pyplot as plt import matplotlib.cm as cm vir = cm.get_cmap('viridis') - - -# In[9]: - - + #parameters of output displacement disp_resolution=30000 @@ -100,16 +86,6 @@ geom, shape_params=dict(radius=10) ) - -# In[ ]: - - - - - -# In[12]: - - z = localization_results[:, 2] wh = (z > geom[:,1].min() - 100) & (z < geom[:,1].max() + 100) a = maxptps[wh] @@ -118,10 +94,6 @@ me, extra = dredge_ap.register(a, z, t, max_disp_um=100, win_scale_um=300, win_step_um=200, rigid=False, mincorr=0.6) - -# In[13]: - - z_reg = me.correct_s(times_seconds, localization_results[:, 2]) displacement_rigid = me.displacement @@ -141,22 +113,13 @@ plt.close() idx = np.flatnonzero(maxptps>threshold_ptp_spline) - - - -# In[20]: - - + sorting = cluster( sub_h5, recording, clustering_config=clustering_config_uhd, motion_est=me) - -# In[25]: - - if savefigs: cluster_viz.array_scatter(sorting.labels, geom, localization_results[:, 0], z_reg, @@ -164,10 +127,6 @@ plt.savefig(Path(output_dir) / "clustering_high_ptp_units.png") plt.close() - -# In[24]: - - z_centered = localization_results[sorting.labels>-1, 2].copy() for k in range(sorting.labels.max()+1): idx_k = np.flatnonzero(sorting.labels[sorting.labels>-1] == k) @@ -186,16 +145,6 @@ model = make_pipeline(transformer, Ridge(alpha=1e-3)) model.fit(times.reshape(-1, 1), values.reshape(-1, 1)) - -# In[37]: - - -rec_len_sec - - -# In[ ]: - - print("Inference") spline_displacement = np.zeros(rec_len_sec*sampling_rate) n_batches = rec_len_sec//batch_size_spline @@ -203,10 +152,7 @@ idx_batch = np.arange(batch_id*batch_size_spline*sampling_rate, (batch_id+1)*batch_size_spline*sampling_rate, ) spline_displacement[idx_batch] = model.predict(idx_batch.reshape(-1, 1)/sampling_rate)[:, 0] - -# In[ ]: - - +z_reg_spline = z_reg - spline_displacement[times_samples.astype('int')] if savefigs: idx = np.flatnonzero(maxptps>threshold_ptp_rigid_reg) ptp_col = np.log(maxptps[idx]+1) @@ -215,34 +161,11 @@ ptp_col /= ptp_col.max() color_array = vir(ptp_col) plt.figure(figsize=(20, 10)) - plt.scatter(times_seconds[idx], z_reg[idx] - spline_displacement[times_samples[idx].astype('int')], s=1, color = color_array) + plt.scatter(times_seconds[idx], z_reg_spline[idx], s=1, color = color_array) plt.savefig(Path(output_dir) / "final_registered_raster.png") plt.close() -# np.save("low_freq_disp_map.npy", dispmap + displacement_rigid[None, :]) -np.save("high_freq_correction.npy", spline_displacement) - - -# In[29]: - - -spline_displacement.shape - - -# In[30]: - - -times_samples.max() - - -# In[31]: - - -4040*30_000/9000000 - - -# In[ ]: - - +np.save("high_freq_correction.npy", spline_displacement) +np.save("spline_registered_z.npy", z_reg_spline)