diff --git a/src/dartsort/localize/localize_torch.py b/src/dartsort/localize/localize_torch.py index d13ee78a..667032ae 100644 --- a/src/dartsort/localize/localize_torch.py +++ b/src/dartsort/localize/localize_torch.py @@ -19,6 +19,7 @@ def localize_amplitude_vectors( dtype=torch.double, y0=1.0, levenberg_marquardt_kwargs=None, + th_dipole_proj_dist=250.0, ): """Localize a bunch of amplitude vectors with torch @@ -59,7 +60,7 @@ def localize_amplitude_vectors( # maybe this will become a wrapper function if we want more models. # and, this is why we return a dict, different models will have different # parameters - assert model in ("com", "pointsource") + assert model in ("com", "pointsource", "dipole") n_spikes, c = amplitude_vectors.shape n_channels_tot = len(geom) if channel_index is None: @@ -68,7 +69,8 @@ def localize_amplitude_vectors( assert channel_index.shape == (n_channels_tot, c) assert main_channels.shape == (n_spikes,) # we'll return numpy if user sent numpy - is_numpy = not torch.is_tensor(amplitude_vectors) + is_numpy = not torch.is_tensor(amplitude_vectors) + # handle channel subsetting if radius is not None or n_channels_subset is not None: @@ -122,42 +124,89 @@ def localize_amplitude_vectors( # fixed constants in regularizers like the log barrier max_amplitudes = torch.max(amplitude_vectors, dim=1).values normalized_amp_vecs = amplitude_vectors / max_amplitudes[:, None] - + # -- torch optimize # initialize with center of mass locs = torch.column_stack((xcom, torch.full_like(xcom, y0), zcom)) - if levenberg_marquardt_kwargs is None: - levenberg_marquardt_kwargs = {} - locs, i = batched_levenberg_marquardt( - locs, - vmap_point_source_grad_and_mse, - vmap_point_source_hessian, - extra_args=(normalized_amp_vecs, in_probe_mask, local_geoms), - **levenberg_marquardt_kwargs, - ) - # finish: get alpha closed form - x, y0, z_rel = locs.T - y = F.softplus(y0) - alpha = vmap_point_source_find_alpha( - amplitude_vectors, in_probe_mask, x, y, z_rel, local_geoms - ) - z_abs = z_rel + geom[main_channels, 1] + + if model == "pointsource": - if is_numpy: - x = x.numpy(force=True) - y = y.numpy(force=True) - z_rel = z_rel.numpy(force=True) - z_abs = z_abs.numpy(force=True) - alpha = alpha.numpy(force=True) + if levenberg_marquardt_kwargs is None: + levenberg_marquardt_kwargs = {} + locs, i = batched_levenberg_marquardt( + locs, + vmap_point_source_grad_and_mse, + vmap_point_source_hessian, + extra_args=(normalized_amp_vecs, in_probe_mask, local_geoms), + **levenberg_marquardt_kwargs, + ) - return dict(x=x, y=y, z_rel=z_rel, z_abs=z_abs, alpha=alpha) + # finish: get alpha closed form + x, y0, z_rel = locs.T + y = F.softplus(y0) + alpha = vmap_point_source_find_alpha( + amplitude_vectors, in_probe_mask, x, y, z_rel, local_geoms + ) + z_abs = z_rel + geom[main_channels, 1] + + if is_numpy: + x = x.numpy(force=True) + y = y.numpy(force=True) + z_rel = z_rel.numpy(force=True) + z_abs = z_abs.numpy(force=True) + alpha = alpha.numpy(force=True) + return dict(x=x, y=y, z_rel=z_rel, z_abs=z_abs, alpha=alpha) + + if model == "dipole": + if levenberg_marquardt_kwargs is None: + levenberg_marquardt_kwargs = {} + locs, i = batched_levenberg_marquardt( + locs, + vmap_dipole_grad_and_mse, + vmap_dipole_hessian, + extra_args=(normalized_amp_vecs, local_geoms), + **levenberg_marquardt_kwargs, + ) + + x, y0, z_rel = locs.T + y = F.softplus(y0) + projected_dist = vmap_dipole_find_projection_distance( + normalized_amp_vecs, x, y, z_rel, local_geoms + ) + + # if projected_dist>th_dipole_proj_dist: return the loc values from pointsource + pointsource_spikes = torch.nonzero(projected_dist>th_dipole_proj_dist, as_tuple=True) + + locs_pointsource_spikes, i = batched_levenberg_marquardt( + locs[pointsource_spikes], + vmap_point_source_grad_and_mse, + vmap_point_source_hessian, + extra_args=(normalized_amp_vecs[pointsource_spikes], in_probe_mask, local_geoms[pointsource_spikes]), + **levenberg_marquardt_kwargs, + ) + x_pointsource_spikes, y0_pointsource_spikes, z_rel_pointsource_spikes = locs.T + y_pointsource_spikes = F.softplus(y0_pointsource_spikes) + + x[pointsource_spikes] = x_pointsource_spikes + y[pointsource_spikes] = y_pointsource_spikes + z_rel[pointsource_spikes] = z_rel_pointsource_spikes + + z_abs = z_rel + geom[main_channels, 1] + + if is_numpy: + x = x.numpy(force=True) + y = y.numpy(force=True) + z_rel = z_rel.numpy(force=True) + z_abs = z_abs.numpy(force=True) + alpha = alpha.numpy(force=True) -# -- point source model library functions + return dict(x=x, y=y, z_rel=z_rel, z_abs=z_abs, alpha=projected_dist) -def point_source_amplitude_at(x, y, z, alpha, local_geom): +# -- point source / dipole model library functions +def point_source_amplitude_at(x, y, z, local_geom): """Point source model predicted amplitude at local_geom given location""" dxs = torch.square(x - local_geom[:, 0]) dzs = torch.square(z - local_geom[:, 1]) @@ -174,6 +223,21 @@ def point_source_find_alpha(amp_vec, channel_mask, x, y, z, local_geoms): ) return alpha +def dipole_find_projection_distance(normalized_amp_vec, x, y, z, local_geom): + """We can solve for the brightness (alpha) of the source in closed form given x,y,z""" + + dxs = x - local_geom[:, 0] + dzs = z - local_geom[:, 1] + dys = y + duv = torch.tensor([dxs, dys, dzs]) + X = duv / torch.pow(torch.sum(torch.square(duv)), 3/2) + beta = torch.linalg.solve(torch.matmul(X.T, X), torch.matmul(X.T, normalized_amp_vec)) + beta /= torch.sqrt(torch.square(beta).sum()) + dipole_planar_direction = torch.sqrt(np.torch(beta[[0, 2]]).sum()) + closest_chan = torch.square(duv).sum(1).argmin() + min_duv = duv[closest_chan] + val_th = torch.sqrt(torch.square(min_duv).sum())/dipole_planar_direction + return val_th def point_source_mse( loc, amplitude_vector, channel_mask, local_geom, logbarrier=True @@ -213,8 +277,35 @@ def point_source_mse( # obj -= torch.log(1000.0 - torch.sqrt(torch.square(x) + torch.square(z))).sum() / 10000.0 return obj +def dipole_mse(loc, amplitude_vector, local_geom, logbarrier=True): + """Dipole model predicted amplitude at local_geom given location""" + + x, y0, z = loc + y = F.softplus(y0) + + dxs = x - local_geom[:, 0] + dzs = z - local_geom[:, 1] + dys = y + + duv = torch.tensor([dxs, dys, dzs]) + + X = duv / torch.pow(torch.sum(torch.square(duv)), 3/2) + + beta = torch.linalg.solve(torch.matmul(X.T, X), torch.matmul(X.T, (ptp / maxptp))) + qtq = torch.matmul(X, beta) + + obj = torch.square(ptp / maxptp - qtq).mean() + if logbarrier: + obj -= torch.log(10.0 * y) / 10000.0 + + return obj + # vmapped functions for use in the optimizer, and might be handy for users too vmap_point_source_grad_and_mse = vmap(grad_and_value(point_source_mse)) vmap_point_source_hessian = vmap(hessian(point_source_mse)) vmap_point_source_find_alpha = vmap(point_source_find_alpha) + +vmap_dipole_grad_and_mse = vmap(grad_and_value(dipole_mse)) +vmap_dipole_hessian = vmap(hessian(dipole_mse)) +vmap_dipole_find_projection_distance = vmap(dipole_find_projection_distance) diff --git a/src/dartsort/transform/localize.py b/src/dartsort/transform/localize.py index 29786441..083318fe 100644 --- a/src/dartsort/transform/localize.py +++ b/src/dartsort/transform/localize.py @@ -63,3 +63,65 @@ def transform(self, waveforms, max_channels=None): ] ) return localizations + +class DipoleLocalization(BaseWaveformFeaturizer): + """Order of output columns: x, y, z_abs, alpha""" + + default_name = "dipole_localizations" + shape = (4,) + dtype = torch.double + + def __init__( + self, + channel_index, + geom, + radius=None, + n_channels_subset=None, + logbarrier=True, + amplitude_kind="peak", + model="dipole", + name=None, + name_prefix="", + ): + assert amplitude_kind in ("peak", "ptp") + super().__init__( + geom=geom, + channel_index=channel_index, + name=name, + name_prefix=name_prefix, + ) + self.amplitude_kind = amplitude_kind + self.radius = radius + self.n_channels_subset = n_channels_subset + self.logbarrier = logbarrier + self.model = model + + def transform(self, waveforms, max_channels=None): + # get amplitude vectors + if self.amplitude_kind == "peak": + ampvecs = waveforms.abs().max(dim=1).values + elif self.amplitude_kind == "ptp": + ampvecs = ptp(waveforms, dim=1) + + with torch.enable_grad(): + loc_result = localize_amplitude_vectors( + ampvecs, + self.geom, + max_channels, + channel_index=self.channel_index, + radius=self.radius, + n_channels_subset=self.n_channels_subset, + logbarrier=self.logbarrier, + model=self.model, + dtype=self.dtype, + ) + + localizations = torch.column_stack( + [ + loc_result["x"], + loc_result["y"], + loc_result["z_abs"], + loc_result["alpha"], + ] + ) + return localizations diff --git a/src/spike_psvae/chunk_features.py b/src/spike_psvae/chunk_features.py index e7940148..a0defb89 100644 --- a/src/spike_psvae/chunk_features.py +++ b/src/spike_psvae/chunk_features.py @@ -392,6 +392,7 @@ def transform( else: if self.ptp_precision_decimals is not None: ptps = np.round(ptps, decimals=self.ptp_precision_decimals) + ( xs, ys, @@ -487,7 +488,7 @@ def raw_fit(self, wfs, max_channels): self.needs_fit = False self.dtype = self.tpca.components_.dtype - self.n_components = self.tpca.n_components + self.n_components = self.n_components self.components_ = self.tpca.components_ self.mean_ = self.tpca.mean_ if self.centered: # otherwise SVD diff --git a/src/spike_psvae/denoise.py b/src/spike_psvae/denoise.py index a42e992f..4ecf62cf 100644 --- a/src/spike_psvae/denoise.py +++ b/src/spike_psvae/denoise.py @@ -95,7 +95,9 @@ def phase_shift_and_hallucination_idx_preshift(waveforms_roll_denoise, waveforms which = slice(offset-10, offset+10) - d_s_corr = wfs_corr(waveforms_roll_denoise[:, which], waveforms_roll[:, which])#torch.sum(wfs_denoised[which]*chan_wfs[which], 1)/torch.sqrt(torch.sum(chan_wfs[which]*chan_wfs[which],1) * torch.sum(wfs_denoised[which]*wfs_denoised[which],1)) ## didn't use which at the beginning! check whether this changes the results + d_s_corr = wfs_corr(waveforms_roll_denoise[:, which], waveforms_roll[:, which]) + # torch.sum(wfs_denoised[which]*chan_wfs[which], 1)/torch.sqrt(torch.sum(chan_wfs[which]*chan_wfs[which],1) * torch.sum(wfs_denoised[which]*wfs_denoised[which],1)) + # didn't use which at the beginning! check whether this changes the results halu_idx = (ptp(waveforms_roll_denoise, 1) decr_ptp[parents_rel].max(): decr_ptp[c] *= decr_ptp[parents_rel].max() / decr_ptp[c] + # decreasing_ptps[i] = decr_ptp # apply decreasing ptps to the original waveforms rescale = (decreasing_ptps / orig_ptps)[:, None, :] + if is_torch: rescale = torch.as_tensor(rescale, device=waveforms.device) if in_place: diff --git a/src/spike_psvae/localize_index.py b/src/spike_psvae/localize_index.py index ea8708a9..b90f8d73 100644 --- a/src/spike_psvae/localize_index.py +++ b/src/spike_psvae/localize_index.py @@ -65,9 +65,9 @@ def ptp_at_dipole(x1, y1, z1, alpha, x2, y2, z2): ) - 1 / np.sqrt( - np.square(x2 - local_geom[:, 0]) - + np.square(z2 - local_geom[:, 1]) - + np.square(y2) + np.square(x2 + x1 - local_geom[:, 0]) + + np.square(z2 + z1 - local_geom[:, 1]) + + np.square(y2 + y1) ) ) return ptp_dipole_out @@ -107,18 +107,24 @@ def mse(loc): # - (np.log1p(10.0 * y) / 10000.0 if logbarrier else 0) # ) - def mse_dipole(x_in): - x1 = x_in[0] - y1 = x_in[1] - z1 = x_in[2] - x2 = x_in[3] - y2 = x_in[4] - z2 = x_in[5] - q = ptp_at_dipole(x1, y1, z1, 1.0, x2, y2, z2) - alpha = (q * ptp).sum() / (q * q).sum() - return np.square( - ptp - ptp_at_dipole(x1, y1, z1, alpha, x2, y2, z2) - ).mean() - (np.log1p(10.0 * y1) / 10000.0 if logbarrier else 0) + def mse_dipole(loc): + x, y, z = loc + # q = ptp_at(x, y, z, 1.0) + # alpha = (q * (ptp / maxptp - delta)).sum() / (q * q).sum() + duv = np.c_[ + x - local_geom[:, 0], + np.broadcast_to(y, ptp.shape), + z - local_geom[:, 1], + ] + X = duv / np.power(np.square(duv).sum(axis=1, keepdims=True), 3/2) + beta = np.linalg.solve(X.T @ X, X.T @ (ptp / maxptp)) + qtq = X @ beta + return ( + np.square(ptp / maxptp - qtq).mean() + # np.square(ptp / maxptp - delta - ptp_at(x, y, z, alpha)).mean() + # np.square(np.maximum(0, ptp / maxptp - ptp_at(x, y, z, alpha))).mean() + - np.log1p(10.0 * y) / 10000.0 + ) if model == "pointsource": result = minimize( @@ -146,24 +152,51 @@ def mse_dipole(x_in): result = minimize( mse_dipole, - x0=[xcom, Y0, zcom, xcom + 1, Y0 + 1, zcom + 1], + x0=[xcom, Y0, zcom], bounds=[ (local_geom[:, 0].min() - DX, local_geom[:, 0].max() + DX), (1e-4, 250), (-DZ, DZ), - (-100, 100), - (-100, 100), - (-100, 100), ], ) # print(result) - bx, by, bz_rel, bpx, bpy, bpz = result.x - - q = ptp_at_dipole(bx, by, bz_rel, 1.0, bpx, bpy, bpz) - - balpha = (q * ptp).sum() / (q * q).sum() - return bx, by, bz_rel, balpha + bx, by, bz_rel = result.x + + duv = np.c_[ + bx - local_geom[:, 0], + np.broadcast_to(by, ptp.shape), + bz_rel - local_geom[:, 1], + ] + X = duv / np.power(np.square(duv).sum(axis=1, keepdims=True), 3/2) + beta = np.linalg.solve(X.T @ X, X.T @ (ptp / maxptp)) + beta /= np.sqrt(np.square(beta).sum()) + dipole_planar_direction = np.sqrt(np.square(beta[[0, 2]]).sum()) + closest_chan = np.square(duv).sum(1).argmin() + min_duv = duv[closest_chan] + + val_th = np.sqrt(np.square(min_duv).sum())/dipole_planar_direction + + # reparameterized_dist = np.sqrt(np.square(min_duv[0]/beta[2]) + np.square(min_duv[2]/beta[0]) + # + np.square(min_duv[1]/beta[1])) + + if val_th<250: + return bx, by, bz_rel, val_th + else: + result = minimize( + mse, + x0=[xcom, Y0, zcom], + bounds=[ + (local_geom[:, 0].min() - DX, local_geom[:, 0].max() + DX), + (1e-4, 250), + (-DZ, DZ), + ], + ) + # print(result) + bx, by, bz_rel = result.x + q = ptp_at(bx, by, bz_rel, 1.0) + balpha = (ptp * q).sum() / np.square(q).sum() + return bx, by, bz_rel, val_th else: raise NameError("Wrong localization model") @@ -230,6 +263,5 @@ def localize_ptps_index( ys[n] = y z_rels[n] = z_rel alphas[n] = alpha - z_abss = z_rels + geom[maxchans, 1] return xs, ys, z_rels, z_abss, alphas diff --git a/src/spike_psvae/subtract.py b/src/spike_psvae/subtract.py index d1e3d62a..0cc936e9 100644 --- a/src/spike_psvae/subtract.py +++ b/src/spike_psvae/subtract.py @@ -668,11 +668,11 @@ def subtraction_binary( n_channels = geom.shape[0] recording = sc.read_binary( - standardized_bin, - sampling_rate, - n_channels, - binary_dtype, - time_axis=time_axis, + file_paths=standardized_bin, + sampling_frequency=sampling_rate, + num_channels=n_channels, + dtype=binary_dtype, + time_axis=0, is_filtered=True, ) @@ -1077,7 +1077,7 @@ def subtraction_batch( batch_data_folder / f"{prefix}{f.name}.npy", feat, ) - + denoised_wfs = full_denoising( cleaned_wfs, spike_index[:, 1],