@@ -113,26 +113,22 @@ def localize_amplitude_vectors(
113
113
114
114
if model == "com" :
115
115
z_abs_com = zcom + geom [main_channels , 1 ]
116
- nancom = torch .full_like (xcom , torch .nan )
117
- return dict (
118
- x = xcom , y = nancom , z_rel = zcom , z_abs = z_abs_com , alpha = nancom
119
- )
116
+ return dict (x = xcom , z_rel = zcom , z_abs = z_abs_com )
120
117
121
118
# normalized PTP vectors
122
119
# this helps to keep the objective in a similar range, so we can use
123
120
# fixed constants in regularizers like the log barrier
124
121
max_amplitudes = torch .max (amplitude_vectors , dim = 1 ).values
125
122
normalized_amp_vecs = amplitude_vectors / max_amplitudes [:, None ]
126
-
123
+
127
124
# -- torch optimize
125
+ if levenberg_marquardt_kwargs is None :
126
+ levenberg_marquardt_kwargs = {}
127
+
128
128
# initialize with center of mass
129
129
locs = torch .column_stack ((xcom , torch .full_like (xcom , y0 ), zcom ))
130
130
131
-
132
131
if model == "pointsource" :
133
-
134
- if levenberg_marquardt_kwargs is None :
135
- levenberg_marquardt_kwargs = {}
136
132
locs , i = batched_levenberg_marquardt (
137
133
locs ,
138
134
vmap_point_source_grad_and_mse ,
@@ -148,34 +144,29 @@ def localize_amplitude_vectors(
148
144
amplitude_vectors , in_probe_mask , x , y , z_rel , local_geoms
149
145
)
150
146
z_abs = z_rel + geom [main_channels , 1 ]
151
-
147
+
148
+ results = dict (x = x , y = y , z_rel = z_rel , z_abs = z_abs , alpha = alpha )
152
149
if is_numpy :
153
- x = x .numpy (force = True )
154
- y = y .numpy (force = True )
155
- z_rel = z_rel .numpy (force = True )
156
- z_abs = z_abs .numpy (force = True )
157
- alpha = alpha .numpy (force = True )
158
- return dict (x = x , y = y , z_rel = z_rel , z_abs = z_abs , alpha = alpha )
159
-
160
- if model == "dipole" :
161
- if levenberg_marquardt_kwargs is None :
162
- levenberg_marquardt_kwargs = {}
150
+ results = {k : v .numpy (force = True ) for k , v in results .items ()}
151
+ return results
152
+
153
+ elif model == "dipole" :
163
154
locs , i = batched_levenberg_marquardt (
164
155
locs ,
165
156
vmap_dipole_grad_and_mse ,
166
157
vmap_dipole_hessian ,
167
158
extra_args = (normalized_amp_vecs , local_geoms ),
168
159
** levenberg_marquardt_kwargs ,
169
160
)
170
-
161
+
171
162
x , y0 , z_rel = locs .T
172
163
y = F .softplus (y0 )
173
164
projected_dist = vmap_dipole_find_projection_distance (
174
165
normalized_amp_vecs , x , y , z_rel , local_geoms
175
166
)
176
167
177
168
z_abs = z_rel + geom [main_channels , 1 ]
178
-
169
+
179
170
if is_numpy :
180
171
x = x .numpy (force = True )
181
172
y = y .numpy (force = True )
@@ -185,6 +176,9 @@ def localize_amplitude_vectors(
185
176
186
177
return dict (x = x , y = y , z_rel = z_rel , z_abs = z_abs , alpha = projected_dist )
187
178
179
+ else :
180
+ assert False
181
+
188
182
189
183
# -- point source / dipole model library functions
190
184
def point_source_amplitude_at (x , y , z , alpha , local_geom ):
@@ -221,9 +215,7 @@ def dipole_find_projection_distance(normalized_amp_vec, x, y, z, local_geom):
221
215
val_th = torch .sqrt (torch .square (min_duv ).sum ())/ dipole_planar_direction
222
216
return val_th
223
217
224
- def point_source_mse (
225
- loc , amplitude_vector , channel_mask , local_geom , logbarrier = True
226
- ):
218
+ def point_source_mse (loc , amplitude_vector , channel_mask , local_geom , logbarrier = True ):
227
219
"""Objective in point source model
228
220
229
221
Arguments
@@ -246,22 +238,39 @@ def point_source_mse(
246
238
x , y0 , z = loc
247
239
y = F .softplus (y0 )
248
240
249
- alpha = point_source_find_alpha (
250
- amplitude_vector , channel_mask , x , y , z , local_geom
251
- )
241
+ alpha = point_source_find_alpha (amplitude_vector , channel_mask , x , y , z , local_geom )
252
242
obj = torch .square (
253
- amplitude_vector
254
- - point_source_amplitude_at (x , y , z , alpha , local_geom )
243
+ amplitude_vector - point_source_amplitude_at (x , y , z , alpha , local_geom )
255
244
).mean ()
256
245
if logbarrier :
257
246
obj -= torch .log (10.0 * y ) / 10000.0
258
247
# idea for logbarrier on points which run away
259
248
# obj -= torch.log(1000.0 - torch.sqrt(torch.square(x) + torch.square(z))).sum() / 10000.0
260
249
return obj
261
250
251
+
252
+ def dipole_find_projection_distance (normalized_amp_vec , x , y , z , local_geom ):
253
+ """We can solve for the brightness (alpha) of the source in closed form given x,y,z"""
254
+
255
+ dxs = x - local_geom [:, 0 ]
256
+ dzs = z - local_geom [:, 1 ]
257
+ dys = y
258
+ duv = torch .tensor ([dxs , dys , dzs ])
259
+ X = duv / torch .pow (torch .sum (torch .square (duv )), 3 / 2 )
260
+ beta = torch .linalg .solve (
261
+ torch .matmul (X .T , X ), torch .matmul (X .T , normalized_amp_vec )
262
+ )
263
+ beta /= torch .sqrt (torch .square (beta ).sum ())
264
+ dipole_planar_direction = torch .sqrt (np .torch (beta [[0 , 2 ]]).sum ())
265
+ closest_chan = torch .square (duv ).sum (1 ).argmin ()
266
+ min_duv = duv [closest_chan ]
267
+ val_th = torch .sqrt (torch .square (min_duv ).sum ()) / dipole_planar_direction
268
+ return val_th
269
+
270
+
262
271
def dipole_mse (loc , amplitude_vector , local_geom , logbarrier = True ):
263
272
"""Dipole model predicted amplitude at local_geom given location"""
264
-
273
+
265
274
x , y0 , z = loc
266
275
y = F .softplus (y0 )
267
276
0 commit comments