Skip to content

Commit 2ecc274

Browse files
committed
Fix tests after merge
1 parent de39125 commit 2ecc274

File tree

1 file changed

+61
-60
lines changed

1 file changed

+61
-60
lines changed

src/dartsort/localize/localize_torch.py

Lines changed: 61 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,7 @@ def localize_amplitude_vectors(
6969
assert channel_index.shape == (n_channels_tot, c)
7070
assert main_channels.shape == (n_spikes,)
7171
# we'll return numpy if user sent numpy
72-
is_numpy = not torch.is_tensor(amplitude_vectors)
73-
72+
is_numpy = not torch.is_tensor(amplitude_vectors)
7473

7574
# handle channel subsetting
7675
if radius is not None or n_channels_subset is not None:
@@ -114,26 +113,22 @@ def localize_amplitude_vectors(
114113

115114
if model == "com":
116115
z_abs_com = zcom + geom[main_channels, 1]
117-
nancom = torch.full_like(xcom, torch.nan)
118-
return dict(
119-
x=xcom, y=nancom, z_rel=zcom, z_abs=z_abs_com, alpha=nancom
120-
)
116+
return dict(x=xcom, z_rel=zcom, z_abs=z_abs_com)
121117

122118
# normalized PTP vectors
123119
# this helps to keep the objective in a similar range, so we can use
124120
# fixed constants in regularizers like the log barrier
125121
max_amplitudes = torch.max(amplitude_vectors, dim=1).values
126122
normalized_amp_vecs = amplitude_vectors / max_amplitudes[:, None]
127-
123+
128124
# -- torch optimize
125+
if levenberg_marquardt_kwargs is None:
126+
levenberg_marquardt_kwargs = {}
127+
129128
# initialize with center of mass
130129
locs = torch.column_stack((xcom, torch.full_like(xcom, y0), zcom))
131130

132-
133131
if model == "pointsource":
134-
135-
if levenberg_marquardt_kwargs is None:
136-
levenberg_marquardt_kwargs = {}
137132
locs, i = batched_levenberg_marquardt(
138133
locs,
139134
vmap_point_source_grad_and_mse,
@@ -149,52 +144,53 @@ def localize_amplitude_vectors(
149144
amplitude_vectors, in_probe_mask, x, y, z_rel, local_geoms
150145
)
151146
z_abs = z_rel + geom[main_channels, 1]
152-
147+
148+
results = dict(x=x, y=y, z_rel=z_rel, z_abs=z_abs, alpha=alpha)
153149
if is_numpy:
154-
x = x.numpy(force=True)
155-
y = y.numpy(force=True)
156-
z_rel = z_rel.numpy(force=True)
157-
z_abs = z_abs.numpy(force=True)
158-
alpha = alpha.numpy(force=True)
159-
return dict(x=x, y=y, z_rel=z_rel, z_abs=z_abs, alpha=alpha)
160-
161-
if model == "dipole":
162-
if levenberg_marquardt_kwargs is None:
163-
levenberg_marquardt_kwargs = {}
150+
results = {k: v.numpy(force=True) for k, v in results.items()}
151+
return results
152+
153+
elif model == "dipole":
164154
locs, i = batched_levenberg_marquardt(
165155
locs,
166156
vmap_dipole_grad_and_mse,
167157
vmap_dipole_hessian,
168158
extra_args=(normalized_amp_vecs, local_geoms),
169159
**levenberg_marquardt_kwargs,
170160
)
171-
161+
172162
x, y0, z_rel = locs.T
173163
y = F.softplus(y0)
174164
projected_dist = vmap_dipole_find_projection_distance(
175165
normalized_amp_vecs, x, y, z_rel, local_geoms
176-
)
177-
166+
)
167+
178168
# if projected_dist>th_dipole_proj_dist: return the loc values from pointsource
179169

180-
pointsource_spikes = torch.nonzero(projected_dist>th_dipole_proj_dist, as_tuple=True)
181-
170+
pointsource_spikes = torch.nonzero(
171+
projected_dist > th_dipole_proj_dist, as_tuple=True
172+
)
173+
182174
locs_pointsource_spikes, i = batched_levenberg_marquardt(
183175
locs[pointsource_spikes],
184176
vmap_point_source_grad_and_mse,
185177
vmap_point_source_hessian,
186-
extra_args=(normalized_amp_vecs[pointsource_spikes], in_probe_mask, local_geoms[pointsource_spikes]),
178+
extra_args=(
179+
normalized_amp_vecs[pointsource_spikes],
180+
in_probe_mask,
181+
local_geoms[pointsource_spikes],
182+
),
187183
**levenberg_marquardt_kwargs,
188184
)
189185
x_pointsource_spikes, y0_pointsource_spikes, z_rel_pointsource_spikes = locs.T
190186
y_pointsource_spikes = F.softplus(y0_pointsource_spikes)
191-
187+
192188
x[pointsource_spikes] = x_pointsource_spikes
193189
y[pointsource_spikes] = y_pointsource_spikes
194190
z_rel[pointsource_spikes] = z_rel_pointsource_spikes
195-
191+
196192
z_abs = z_rel + geom[main_channels, 1]
197-
193+
198194
if is_numpy:
199195
x = x.numpy(force=True)
200196
y = y.numpy(force=True)
@@ -204,9 +200,14 @@ def localize_amplitude_vectors(
204200

205201
return dict(x=x, y=y, z_rel=z_rel, z_abs=z_abs, alpha=projected_dist)
206202

203+
else:
204+
assert False
205+
207206

208207
# -- point source / dipole model library functions
209-
def point_source_amplitude_at(x, y, z, local_geom):
208+
209+
210+
def point_source_amplitude_at(x, y, z, alpha, local_geom):
210211
"""Point source model predicted amplitude at local_geom given location"""
211212
dxs = torch.square(x - local_geom[:, 0])
212213
dzs = torch.square(z - local_geom[:, 1])
@@ -223,25 +224,8 @@ def point_source_find_alpha(amp_vec, channel_mask, x, y, z, local_geoms):
223224
)
224225
return alpha
225226

226-
def dipole_find_projection_distance(normalized_amp_vec, x, y, z, local_geom):
227-
"""We can solve for the brightness (alpha) of the source in closed form given x,y,z"""
228-
229-
dxs = x - local_geom[:, 0]
230-
dzs = z - local_geom[:, 1]
231-
dys = y
232-
duv = torch.tensor([dxs, dys, dzs])
233-
X = duv / torch.pow(torch.sum(torch.square(duv)), 3/2)
234-
beta = torch.linalg.solve(torch.matmul(X.T, X), torch.matmul(X.T, normalized_amp_vec))
235-
beta /= torch.sqrt(torch.square(beta).sum())
236-
dipole_planar_direction = torch.sqrt(np.torch(beta[[0, 2]]).sum())
237-
closest_chan = torch.square(duv).sum(1).argmin()
238-
min_duv = duv[closest_chan]
239-
val_th = torch.sqrt(torch.square(min_duv).sum())/dipole_planar_direction
240-
return val_th
241227

242-
def point_source_mse(
243-
loc, amplitude_vector, channel_mask, local_geom, logbarrier=True
244-
):
228+
def point_source_mse(loc, amplitude_vector, channel_mask, local_geom, logbarrier=True):
245229
"""Objective in point source model
246230
247231
Arguments
@@ -264,36 +248,53 @@ def point_source_mse(
264248
x, y0, z = loc
265249
y = F.softplus(y0)
266250

267-
alpha = point_source_find_alpha(
268-
amplitude_vector, channel_mask, x, y, z, local_geom
269-
)
251+
alpha = point_source_find_alpha(amplitude_vector, channel_mask, x, y, z, local_geom)
270252
obj = torch.square(
271-
amplitude_vector
272-
- point_source_amplitude_at(x, y, z, alpha, local_geom)
253+
amplitude_vector - point_source_amplitude_at(x, y, z, alpha, local_geom)
273254
).mean()
274255
if logbarrier:
275256
obj -= torch.log(10.0 * y) / 10000.0
276257
# idea for logbarrier on points which run away
277258
# obj -= torch.log(1000.0 - torch.sqrt(torch.square(x) + torch.square(z))).sum() / 10000.0
278259
return obj
279260

261+
262+
def dipole_find_projection_distance(normalized_amp_vec, x, y, z, local_geom):
263+
"""We can solve for the brightness (alpha) of the source in closed form given x,y,z"""
264+
265+
dxs = x - local_geom[:, 0]
266+
dzs = z - local_geom[:, 1]
267+
dys = y
268+
duv = torch.tensor([dxs, dys, dzs])
269+
X = duv / torch.pow(torch.sum(torch.square(duv)), 3 / 2)
270+
beta = torch.linalg.solve(
271+
torch.matmul(X.T, X), torch.matmul(X.T, normalized_amp_vec)
272+
)
273+
beta /= torch.sqrt(torch.square(beta).sum())
274+
dipole_planar_direction = torch.sqrt(np.torch(beta[[0, 2]]).sum())
275+
closest_chan = torch.square(duv).sum(1).argmin()
276+
min_duv = duv[closest_chan]
277+
val_th = torch.sqrt(torch.square(min_duv).sum()) / dipole_planar_direction
278+
return val_th
279+
280+
280281
def dipole_mse(loc, amplitude_vector, local_geom, logbarrier=True):
281282
"""Dipole model predicted amplitude at local_geom given location"""
282-
283+
283284
x, y0, z = loc
284285
y = F.softplus(y0)
285286

286287
dxs = x - local_geom[:, 0]
287288
dzs = z - local_geom[:, 1]
288289
dys = y
289-
290+
290291
duv = torch.tensor([dxs, dys, dzs])
291292

292-
X = duv / torch.pow(torch.sum(torch.square(duv)), 3/2)
293-
293+
X = duv / torch.pow(torch.sum(torch.square(duv)), 3 / 2)
294+
294295
beta = torch.linalg.solve(torch.matmul(X.T, X), torch.matmul(X.T, (ptp / maxptp)))
295296
qtq = torch.matmul(X, beta)
296-
297+
297298
obj = torch.square(ptp / maxptp - qtq).mean()
298299
if logbarrier:
299300
obj -= torch.log(10.0 * y) / 10000.0

0 commit comments

Comments
 (0)