Skip to content

Commit

Permalink
New spline reg script
Browse files Browse the repository at this point in the history
  • Loading branch information
julienboussard authored Feb 20, 2024
1 parent 9f17b8e commit d1094c1
Showing 1 changed file with 6 additions and 83 deletions.
89 changes: 6 additions & 83 deletions scripts/spline_registration.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
#!/usr/bin/env python
# coding: utf-8

# In[15]:


from pathlib import Path
import numpy as np
import h5py
Expand All @@ -29,10 +23,6 @@
from dartsort.config import TemplateConfig, MatchingConfig, ClusteringConfig, SubtractionConfig, FeaturizationConfig
from dartsort import cluster


# In[40]:


# Input
data_dir = Path("UHD_DATA/ZYE_0021___2021-05-01___1")
fname_sub_h5 = data_dir / "subtraction.h5"
Expand All @@ -50,11 +40,7 @@
import matplotlib.pyplot as plt
import matplotlib.cm as cm
vir = cm.get_cmap('viridis')


# In[9]:



#parameters of output displacement
disp_resolution=30000

Expand Down Expand Up @@ -100,16 +86,6 @@
geom, shape_params=dict(radius=10)
)


# In[ ]:





# In[12]:


z = localization_results[:, 2]
wh = (z > geom[:,1].min() - 100) & (z < geom[:,1].max() + 100)
a = maxptps[wh]
Expand All @@ -118,10 +94,6 @@

me, extra = dredge_ap.register(a, z, t, max_disp_um=100, win_scale_um=300, win_step_um=200, rigid=False, mincorr=0.6)


# In[13]:


z_reg = me.correct_s(times_seconds, localization_results[:, 2])
displacement_rigid = me.displacement

Expand All @@ -141,33 +113,20 @@
plt.close()

idx = np.flatnonzero(maxptps>threshold_ptp_spline)



# In[20]:



sorting = cluster(
sub_h5,
recording,
clustering_config=clustering_config_uhd,
motion_est=me)


# In[25]:


if savefigs:
cluster_viz.array_scatter(sorting.labels,
geom, localization_results[:, 0], z_reg,
maxptps, zlim=(-50, 332))
plt.savefig(Path(output_dir) / "clustering_high_ptp_units.png")
plt.close()


# In[24]:


z_centered = localization_results[sorting.labels>-1, 2].copy()
for k in range(sorting.labels.max()+1):
idx_k = np.flatnonzero(sorting.labels[sorting.labels>-1] == k)
Expand All @@ -186,27 +145,14 @@
model = make_pipeline(transformer, Ridge(alpha=1e-3))
model.fit(times.reshape(-1, 1), values.reshape(-1, 1))


# In[37]:


rec_len_sec


# In[ ]:


print("Inference")
spline_displacement = np.zeros(rec_len_sec*sampling_rate)
n_batches = rec_len_sec//batch_size_spline
for batch_id in tqdm(range(n_batches)):
idx_batch = np.arange(batch_id*batch_size_spline*sampling_rate, (batch_id+1)*batch_size_spline*sampling_rate, )
spline_displacement[idx_batch] = model.predict(idx_batch.reshape(-1, 1)/sampling_rate)[:, 0]


# In[ ]:


z_reg_spline = z_reg - spline_displacement[times_samples.astype('int')]
if savefigs:
idx = np.flatnonzero(maxptps>threshold_ptp_rigid_reg)
ptp_col = np.log(maxptps[idx]+1)
Expand All @@ -215,34 +161,11 @@
ptp_col /= ptp_col.max()
color_array = vir(ptp_col)
plt.figure(figsize=(20, 10))
plt.scatter(times_seconds[idx], z_reg[idx] - spline_displacement[times_samples[idx].astype('int')], s=1, color = color_array)
plt.scatter(times_seconds[idx], z_reg_spline[idx], s=1, color = color_array)
plt.savefig(Path(output_dir) / "final_registered_raster.png")
plt.close()

# np.save("low_freq_disp_map.npy", dispmap + displacement_rigid[None, :])
np.save("high_freq_correction.npy", spline_displacement)


# In[29]:


spline_displacement.shape


# In[30]:


times_samples.max()


# In[31]:


4040*30_000/9000000


# In[ ]:



np.save("high_freq_correction.npy", spline_displacement)
np.save("spline_registered_z.npy", z_reg_spline)

0 comments on commit d1094c1

Please sign in to comment.