Skip to content

Commit d79d7c8

Browse files
author
julien
committed
dipole
2 parents 949a0b0 + eb23cd3 commit d79d7c8

File tree

8 files changed

+1315
-1059
lines changed

8 files changed

+1315
-1059
lines changed

src/dartsort/localize/localize_torch.py

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -113,26 +113,22 @@ def localize_amplitude_vectors(
113113

114114
if model == "com":
115115
z_abs_com = zcom + geom[main_channels, 1]
116-
nancom = torch.full_like(xcom, torch.nan)
117-
return dict(
118-
x=xcom, y=nancom, z_rel=zcom, z_abs=z_abs_com, alpha=nancom
119-
)
116+
return dict(x=xcom, z_rel=zcom, z_abs=z_abs_com)
120117

121118
# normalized PTP vectors
122119
# this helps to keep the objective in a similar range, so we can use
123120
# fixed constants in regularizers like the log barrier
124121
max_amplitudes = torch.max(amplitude_vectors, dim=1).values
125122
normalized_amp_vecs = amplitude_vectors / max_amplitudes[:, None]
126-
123+
127124
# -- torch optimize
125+
if levenberg_marquardt_kwargs is None:
126+
levenberg_marquardt_kwargs = {}
127+
128128
# initialize with center of mass
129129
locs = torch.column_stack((xcom, torch.full_like(xcom, y0), zcom))
130130

131-
132131
if model == "pointsource":
133-
134-
if levenberg_marquardt_kwargs is None:
135-
levenberg_marquardt_kwargs = {}
136132
locs, i = batched_levenberg_marquardt(
137133
locs,
138134
vmap_point_source_grad_and_mse,
@@ -148,34 +144,29 @@ def localize_amplitude_vectors(
148144
amplitude_vectors, in_probe_mask, x, y, z_rel, local_geoms
149145
)
150146
z_abs = z_rel + geom[main_channels, 1]
151-
147+
148+
results = dict(x=x, y=y, z_rel=z_rel, z_abs=z_abs, alpha=alpha)
152149
if is_numpy:
153-
x = x.numpy(force=True)
154-
y = y.numpy(force=True)
155-
z_rel = z_rel.numpy(force=True)
156-
z_abs = z_abs.numpy(force=True)
157-
alpha = alpha.numpy(force=True)
158-
return dict(x=x, y=y, z_rel=z_rel, z_abs=z_abs, alpha=alpha)
159-
160-
if model == "dipole":
161-
if levenberg_marquardt_kwargs is None:
162-
levenberg_marquardt_kwargs = {}
150+
results = {k: v.numpy(force=True) for k, v in results.items()}
151+
return results
152+
153+
elif model == "dipole":
163154
locs, i = batched_levenberg_marquardt(
164155
locs,
165156
vmap_dipole_grad_and_mse,
166157
vmap_dipole_hessian,
167158
extra_args=(normalized_amp_vecs, local_geoms),
168159
**levenberg_marquardt_kwargs,
169160
)
170-
161+
171162
x, y0, z_rel = locs.T
172163
y = F.softplus(y0)
173164
projected_dist = vmap_dipole_find_projection_distance(
174165
normalized_amp_vecs, x, y, z_rel, local_geoms
175166
)
176167

177168
z_abs = z_rel + geom[main_channels, 1]
178-
169+
179170
if is_numpy:
180171
x = x.numpy(force=True)
181172
y = y.numpy(force=True)
@@ -185,6 +176,9 @@ def localize_amplitude_vectors(
185176

186177
return dict(x=x, y=y, z_rel=z_rel, z_abs=z_abs, alpha=projected_dist)
187178

179+
else:
180+
assert False
181+
188182

189183
# -- point source / dipole model library functions
190184
def point_source_amplitude_at(x, y, z, alpha, local_geom):
@@ -221,9 +215,7 @@ def dipole_find_projection_distance(normalized_amp_vec, x, y, z, local_geom):
221215
val_th = torch.sqrt(torch.square(min_duv).sum())/dipole_planar_direction
222216
return val_th
223217

224-
def point_source_mse(
225-
loc, amplitude_vector, channel_mask, local_geom, logbarrier=True
226-
):
218+
def point_source_mse(loc, amplitude_vector, channel_mask, local_geom, logbarrier=True):
227219
"""Objective in point source model
228220
229221
Arguments
@@ -246,22 +238,39 @@ def point_source_mse(
246238
x, y0, z = loc
247239
y = F.softplus(y0)
248240

249-
alpha = point_source_find_alpha(
250-
amplitude_vector, channel_mask, x, y, z, local_geom
251-
)
241+
alpha = point_source_find_alpha(amplitude_vector, channel_mask, x, y, z, local_geom)
252242
obj = torch.square(
253-
amplitude_vector
254-
- point_source_amplitude_at(x, y, z, alpha, local_geom)
243+
amplitude_vector - point_source_amplitude_at(x, y, z, alpha, local_geom)
255244
).mean()
256245
if logbarrier:
257246
obj -= torch.log(10.0 * y) / 10000.0
258247
# idea for logbarrier on points which run away
259248
# obj -= torch.log(1000.0 - torch.sqrt(torch.square(x) + torch.square(z))).sum() / 10000.0
260249
return obj
261250

251+
252+
def dipole_find_projection_distance(normalized_amp_vec, x, y, z, local_geom):
253+
"""We can solve for the brightness (alpha) of the source in closed form given x,y,z"""
254+
255+
dxs = x - local_geom[:, 0]
256+
dzs = z - local_geom[:, 1]
257+
dys = y
258+
duv = torch.tensor([dxs, dys, dzs])
259+
X = duv / torch.pow(torch.sum(torch.square(duv)), 3 / 2)
260+
beta = torch.linalg.solve(
261+
torch.matmul(X.T, X), torch.matmul(X.T, normalized_amp_vec)
262+
)
263+
beta /= torch.sqrt(torch.square(beta).sum())
264+
dipole_planar_direction = torch.sqrt(np.torch(beta[[0, 2]]).sum())
265+
closest_chan = torch.square(duv).sum(1).argmin()
266+
min_duv = duv[closest_chan]
267+
val_th = torch.sqrt(torch.square(min_duv).sum()) / dipole_planar_direction
268+
return val_th
269+
270+
262271
def dipole_mse(loc, amplitude_vector, local_geom, logbarrier=True):
263272
"""Dipole model predicted amplitude at local_geom given location"""
264-
273+
265274
x, y0, z = loc
266275
y = F.softplus(y0)
267276

src/dartsort/peel/matching.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,6 @@ def __init__(
125125
("temporal_components", temporal_components),
126126
("singular_values", singular_values),
127127
("spatial_components", spatial_components),
128-
(
129-
"upsampled_temporal_components",
130-
self.upsampled_temporal_components.numpy(force=True).copy(),
131-
),
132128
]
133129
if self.is_drifting:
134130
self.fixed_output_data.append(

0 commit comments

Comments
 (0)