Skip to content

Commit

Permalink
Merge branch 'main' of github.com:cwindolf/spike-psvae
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Nov 6, 2023
2 parents 443b61a + be3af49 commit de39125
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 66 deletions.
147 changes: 119 additions & 28 deletions src/dartsort/localize/localize_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def localize_amplitude_vectors(
dtype=torch.double,
y0=1.0,
levenberg_marquardt_kwargs=None,
th_dipole_proj_dist=250.0,
):
"""Localize a bunch of amplitude vectors with torch
Expand Down Expand Up @@ -59,7 +60,7 @@ def localize_amplitude_vectors(
# maybe this will become a wrapper function if we want more models.
# and, this is why we return a dict, different models will have different
# parameters
assert model in ("com", "pointsource")
assert model in ("com", "pointsource", "dipole")
n_spikes, c = amplitude_vectors.shape
n_channels_tot = len(geom)
if channel_index is None:
Expand All @@ -68,7 +69,8 @@ def localize_amplitude_vectors(
assert channel_index.shape == (n_channels_tot, c)
assert main_channels.shape == (n_spikes,)
# we'll return numpy if user sent numpy
is_numpy = not torch.is_tensor(amplitude_vectors)
is_numpy = not torch.is_tensor(amplitude_vectors)


# handle channel subsetting
if radius is not None or n_channels_subset is not None:
Expand Down Expand Up @@ -122,42 +124,89 @@ def localize_amplitude_vectors(
# fixed constants in regularizers like the log barrier
max_amplitudes = torch.max(amplitude_vectors, dim=1).values
normalized_amp_vecs = amplitude_vectors / max_amplitudes[:, None]

# -- torch optimize
# initialize with center of mass
locs = torch.column_stack((xcom, torch.full_like(xcom, y0), zcom))
if levenberg_marquardt_kwargs is None:
levenberg_marquardt_kwargs = {}
locs, i = batched_levenberg_marquardt(
locs,
vmap_point_source_grad_and_mse,
vmap_point_source_hessian,
extra_args=(normalized_amp_vecs, in_probe_mask, local_geoms),
**levenberg_marquardt_kwargs,
)

# finish: get alpha closed form
x, y0, z_rel = locs.T
y = F.softplus(y0)
alpha = vmap_point_source_find_alpha(
amplitude_vectors, in_probe_mask, x, y, z_rel, local_geoms
)
z_abs = z_rel + geom[main_channels, 1]

if model == "pointsource":

if is_numpy:
x = x.numpy(force=True)
y = y.numpy(force=True)
z_rel = z_rel.numpy(force=True)
z_abs = z_abs.numpy(force=True)
alpha = alpha.numpy(force=True)
if levenberg_marquardt_kwargs is None:
levenberg_marquardt_kwargs = {}
locs, i = batched_levenberg_marquardt(
locs,
vmap_point_source_grad_and_mse,
vmap_point_source_hessian,
extra_args=(normalized_amp_vecs, in_probe_mask, local_geoms),
**levenberg_marquardt_kwargs,
)

return dict(x=x, y=y, z_rel=z_rel, z_abs=z_abs, alpha=alpha)
# finish: get alpha closed form
x, y0, z_rel = locs.T
y = F.softplus(y0)
alpha = vmap_point_source_find_alpha(
amplitude_vectors, in_probe_mask, x, y, z_rel, local_geoms
)
z_abs = z_rel + geom[main_channels, 1]

if is_numpy:
x = x.numpy(force=True)
y = y.numpy(force=True)
z_rel = z_rel.numpy(force=True)
z_abs = z_abs.numpy(force=True)
alpha = alpha.numpy(force=True)
return dict(x=x, y=y, z_rel=z_rel, z_abs=z_abs, alpha=alpha)

if model == "dipole":
if levenberg_marquardt_kwargs is None:
levenberg_marquardt_kwargs = {}
locs, i = batched_levenberg_marquardt(
locs,
vmap_dipole_grad_and_mse,
vmap_dipole_hessian,
extra_args=(normalized_amp_vecs, local_geoms),
**levenberg_marquardt_kwargs,
)

x, y0, z_rel = locs.T
y = F.softplus(y0)
projected_dist = vmap_dipole_find_projection_distance(
normalized_amp_vecs, x, y, z_rel, local_geoms
)

# if projected_dist>th_dipole_proj_dist: return the loc values from pointsource

pointsource_spikes = torch.nonzero(projected_dist>th_dipole_proj_dist, as_tuple=True)

locs_pointsource_spikes, i = batched_levenberg_marquardt(
locs[pointsource_spikes],
vmap_point_source_grad_and_mse,
vmap_point_source_hessian,
extra_args=(normalized_amp_vecs[pointsource_spikes], in_probe_mask, local_geoms[pointsource_spikes]),
**levenberg_marquardt_kwargs,
)
x_pointsource_spikes, y0_pointsource_spikes, z_rel_pointsource_spikes = locs.T
y_pointsource_spikes = F.softplus(y0_pointsource_spikes)

x[pointsource_spikes] = x_pointsource_spikes
y[pointsource_spikes] = y_pointsource_spikes
z_rel[pointsource_spikes] = z_rel_pointsource_spikes

z_abs = z_rel + geom[main_channels, 1]

if is_numpy:
x = x.numpy(force=True)
y = y.numpy(force=True)
z_rel = z_rel.numpy(force=True)
z_abs = z_abs.numpy(force=True)
alpha = alpha.numpy(force=True)

# -- point source model library functions
return dict(x=x, y=y, z_rel=z_rel, z_abs=z_abs, alpha=projected_dist)


def point_source_amplitude_at(x, y, z, alpha, local_geom):
# -- point source / dipole model library functions
def point_source_amplitude_at(x, y, z, local_geom):
"""Point source model predicted amplitude at local_geom given location"""
dxs = torch.square(x - local_geom[:, 0])
dzs = torch.square(z - local_geom[:, 1])
Expand All @@ -174,6 +223,21 @@ def point_source_find_alpha(amp_vec, channel_mask, x, y, z, local_geoms):
)
return alpha

def dipole_find_projection_distance(normalized_amp_vec, x, y, z, local_geom):
"""We can solve for the brightness (alpha) of the source in closed form given x,y,z"""

dxs = x - local_geom[:, 0]
dzs = z - local_geom[:, 1]
dys = y
duv = torch.tensor([dxs, dys, dzs])
X = duv / torch.pow(torch.sum(torch.square(duv)), 3/2)
beta = torch.linalg.solve(torch.matmul(X.T, X), torch.matmul(X.T, normalized_amp_vec))
beta /= torch.sqrt(torch.square(beta).sum())
dipole_planar_direction = torch.sqrt(np.torch(beta[[0, 2]]).sum())
closest_chan = torch.square(duv).sum(1).argmin()
min_duv = duv[closest_chan]
val_th = torch.sqrt(torch.square(min_duv).sum())/dipole_planar_direction
return val_th

def point_source_mse(
loc, amplitude_vector, channel_mask, local_geom, logbarrier=True
Expand Down Expand Up @@ -213,8 +277,35 @@ def point_source_mse(
# obj -= torch.log(1000.0 - torch.sqrt(torch.square(x) + torch.square(z))).sum() / 10000.0
return obj

def dipole_mse(loc, amplitude_vector, local_geom, logbarrier=True):
"""Dipole model predicted amplitude at local_geom given location"""

x, y0, z = loc
y = F.softplus(y0)

dxs = x - local_geom[:, 0]
dzs = z - local_geom[:, 1]
dys = y

duv = torch.tensor([dxs, dys, dzs])

X = duv / torch.pow(torch.sum(torch.square(duv)), 3/2)

beta = torch.linalg.solve(torch.matmul(X.T, X), torch.matmul(X.T, (ptp / maxptp)))
qtq = torch.matmul(X, beta)

obj = torch.square(ptp / maxptp - qtq).mean()
if logbarrier:
obj -= torch.log(10.0 * y) / 10000.0

return obj


# vmapped functions for use in the optimizer, and might be handy for users too
vmap_point_source_grad_and_mse = vmap(grad_and_value(point_source_mse))
vmap_point_source_hessian = vmap(hessian(point_source_mse))
vmap_point_source_find_alpha = vmap(point_source_find_alpha)

vmap_dipole_grad_and_mse = vmap(grad_and_value(dipole_mse))
vmap_dipole_hessian = vmap(hessian(dipole_mse))
vmap_dipole_find_projection_distance = vmap(dipole_find_projection_distance)
62 changes: 62 additions & 0 deletions src/dartsort/transform/localize.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,65 @@ def transform(self, waveforms, max_channels=None):
]
)
return localizations

class DipoleLocalization(BaseWaveformFeaturizer):
"""Order of output columns: x, y, z_abs, alpha"""

default_name = "dipole_localizations"
shape = (4,)
dtype = torch.double

def __init__(
self,
channel_index,
geom,
radius=None,
n_channels_subset=None,
logbarrier=True,
amplitude_kind="peak",
model="dipole",
name=None,
name_prefix="",
):
assert amplitude_kind in ("peak", "ptp")
super().__init__(
geom=geom,
channel_index=channel_index,
name=name,
name_prefix=name_prefix,
)
self.amplitude_kind = amplitude_kind
self.radius = radius
self.n_channels_subset = n_channels_subset
self.logbarrier = logbarrier
self.model = model

def transform(self, waveforms, max_channels=None):
# get amplitude vectors
if self.amplitude_kind == "peak":
ampvecs = waveforms.abs().max(dim=1).values
elif self.amplitude_kind == "ptp":
ampvecs = ptp(waveforms, dim=1)

with torch.enable_grad():
loc_result = localize_amplitude_vectors(
ampvecs,
self.geom,
max_channels,
channel_index=self.channel_index,
radius=self.radius,
n_channels_subset=self.n_channels_subset,
logbarrier=self.logbarrier,
model=self.model,
dtype=self.dtype,
)

localizations = torch.column_stack(
[
loc_result["x"],
loc_result["y"],
loc_result["z_abs"],
loc_result["alpha"],
]
)
return localizations
3 changes: 2 additions & 1 deletion src/spike_psvae/chunk_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def transform(
else:
if self.ptp_precision_decimals is not None:
ptps = np.round(ptps, decimals=self.ptp_precision_decimals)

(
xs,
ys,
Expand Down Expand Up @@ -487,7 +488,7 @@ def raw_fit(self, wfs, max_channels):

self.needs_fit = False
self.dtype = self.tpca.components_.dtype
self.n_components = self.tpca.n_components
self.n_components = self.n_components
self.components_ = self.tpca.components_
self.mean_ = self.tpca.mean_
if self.centered: # otherwise SVD
Expand Down
16 changes: 11 additions & 5 deletions src/spike_psvae/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def phase_shift_and_hallucination_idx_preshift(waveforms_roll_denoise, waveforms

which = slice(offset-10, offset+10)

d_s_corr = wfs_corr(waveforms_roll_denoise[:, which], waveforms_roll[:, which])#torch.sum(wfs_denoised[which]*chan_wfs[which], 1)/torch.sqrt(torch.sum(chan_wfs[which]*chan_wfs[which],1) * torch.sum(wfs_denoised[which]*wfs_denoised[which],1)) ## didn't use which at the beginning! check whether this changes the results
d_s_corr = wfs_corr(waveforms_roll_denoise[:, which], waveforms_roll[:, which])
# torch.sum(wfs_denoised[which]*chan_wfs[which], 1)/torch.sqrt(torch.sum(chan_wfs[which]*chan_wfs[which],1) * torch.sum(wfs_denoised[which]*wfs_denoised[which],1))
# didn't use which at the beginning! check whether this changes the results

halu_idx = (ptp(waveforms_roll_denoise, 1)<small_threshold) & (d_s_corr<corr_th)
halu_idx = halu_idx.long()
Expand Down Expand Up @@ -201,6 +203,7 @@ def multichan_phase_shift_denoise_preshift(waveforms, ci_graph_all_maxCH_uniq, m

CH_checked = F.pad(CH_checked, (0, 1), 'constant', 1)

#DO YOU NEED AS MUCH ROLL?
waveforms_roll_all = torch.cat((waveforms,
torch.roll(waveforms, -15, 1),
torch.roll(waveforms, -12, 1),
Expand Down Expand Up @@ -343,9 +346,6 @@ def multichan_phase_shift_denoise_preshift(waveforms, ci_graph_all_maxCH_uniq, m

Q.insert(0, torch.cat((unfold_idx[seek_idx][:, None], Q_neighbors[seek_idx][:, None]), 1))







Expand Down Expand Up @@ -764,7 +764,11 @@ def make_radial_order_parents(
def enforce_decrease_shells(
waveforms, maxchans, radial_parents, in_place=False
):
"""Radial enforce decrease"""
"""
Radial enforce decrease
What if we localize with peak?
"""

N, T, C = waveforms.shape
assert maxchans.shape == (N,)

Expand All @@ -783,9 +787,11 @@ def enforce_decrease_shells(
for c, parents_rel in radial_parents[maxchans[i]]:
if decr_ptp[c] > decr_ptp[parents_rel].max():
decr_ptp[c] *= decr_ptp[parents_rel].max() / decr_ptp[c]
# decreasing_ptps[i] = decr_ptp

# apply decreasing ptps to the original waveforms
rescale = (decreasing_ptps / orig_ptps)[:, None, :]

if is_torch:
rescale = torch.as_tensor(rescale, device=waveforms.device)
if in_place:
Expand Down
Loading

0 comments on commit de39125

Please sign in to comment.