@@ -69,8 +69,7 @@ def localize_amplitude_vectors(
69
69
assert channel_index .shape == (n_channels_tot , c )
70
70
assert main_channels .shape == (n_spikes ,)
71
71
# we'll return numpy if user sent numpy
72
- is_numpy = not torch .is_tensor (amplitude_vectors )
73
-
72
+ is_numpy = not torch .is_tensor (amplitude_vectors )
74
73
75
74
# handle channel subsetting
76
75
if radius is not None or n_channels_subset is not None :
@@ -114,26 +113,22 @@ def localize_amplitude_vectors(
114
113
115
114
if model == "com" :
116
115
z_abs_com = zcom + geom [main_channels , 1 ]
117
- nancom = torch .full_like (xcom , torch .nan )
118
- return dict (
119
- x = xcom , y = nancom , z_rel = zcom , z_abs = z_abs_com , alpha = nancom
120
- )
116
+ return dict (x = xcom , z_rel = zcom , z_abs = z_abs_com )
121
117
122
118
# normalized PTP vectors
123
119
# this helps to keep the objective in a similar range, so we can use
124
120
# fixed constants in regularizers like the log barrier
125
121
max_amplitudes = torch .max (amplitude_vectors , dim = 1 ).values
126
122
normalized_amp_vecs = amplitude_vectors / max_amplitudes [:, None ]
127
-
123
+
128
124
# -- torch optimize
125
+ if levenberg_marquardt_kwargs is None :
126
+ levenberg_marquardt_kwargs = {}
127
+
129
128
# initialize with center of mass
130
129
locs = torch .column_stack ((xcom , torch .full_like (xcom , y0 ), zcom ))
131
130
132
-
133
131
if model == "pointsource" :
134
-
135
- if levenberg_marquardt_kwargs is None :
136
- levenberg_marquardt_kwargs = {}
137
132
locs , i = batched_levenberg_marquardt (
138
133
locs ,
139
134
vmap_point_source_grad_and_mse ,
@@ -149,52 +144,53 @@ def localize_amplitude_vectors(
149
144
amplitude_vectors , in_probe_mask , x , y , z_rel , local_geoms
150
145
)
151
146
z_abs = z_rel + geom [main_channels , 1 ]
152
-
147
+
148
+ results = dict (x = x , y = y , z_rel = z_rel , z_abs = z_abs , alpha = alpha )
153
149
if is_numpy :
154
- x = x .numpy (force = True )
155
- y = y .numpy (force = True )
156
- z_rel = z_rel .numpy (force = True )
157
- z_abs = z_abs .numpy (force = True )
158
- alpha = alpha .numpy (force = True )
159
- return dict (x = x , y = y , z_rel = z_rel , z_abs = z_abs , alpha = alpha )
160
-
161
- if model == "dipole" :
162
- if levenberg_marquardt_kwargs is None :
163
- levenberg_marquardt_kwargs = {}
150
+ results = {k : v .numpy (force = True ) for k , v in results .items ()}
151
+ return results
152
+
153
+ elif model == "dipole" :
164
154
locs , i = batched_levenberg_marquardt (
165
155
locs ,
166
156
vmap_dipole_grad_and_mse ,
167
157
vmap_dipole_hessian ,
168
158
extra_args = (normalized_amp_vecs , local_geoms ),
169
159
** levenberg_marquardt_kwargs ,
170
160
)
171
-
161
+
172
162
x , y0 , z_rel = locs .T
173
163
y = F .softplus (y0 )
174
164
projected_dist = vmap_dipole_find_projection_distance (
175
165
normalized_amp_vecs , x , y , z_rel , local_geoms
176
- )
177
-
166
+ )
167
+
178
168
# if projected_dist>th_dipole_proj_dist: return the loc values from pointsource
179
169
180
- pointsource_spikes = torch .nonzero (projected_dist > th_dipole_proj_dist , as_tuple = True )
181
-
170
+ pointsource_spikes = torch .nonzero (
171
+ projected_dist > th_dipole_proj_dist , as_tuple = True
172
+ )
173
+
182
174
locs_pointsource_spikes , i = batched_levenberg_marquardt (
183
175
locs [pointsource_spikes ],
184
176
vmap_point_source_grad_and_mse ,
185
177
vmap_point_source_hessian ,
186
- extra_args = (normalized_amp_vecs [pointsource_spikes ], in_probe_mask , local_geoms [pointsource_spikes ]),
178
+ extra_args = (
179
+ normalized_amp_vecs [pointsource_spikes ],
180
+ in_probe_mask ,
181
+ local_geoms [pointsource_spikes ],
182
+ ),
187
183
** levenberg_marquardt_kwargs ,
188
184
)
189
185
x_pointsource_spikes , y0_pointsource_spikes , z_rel_pointsource_spikes = locs .T
190
186
y_pointsource_spikes = F .softplus (y0_pointsource_spikes )
191
-
187
+
192
188
x [pointsource_spikes ] = x_pointsource_spikes
193
189
y [pointsource_spikes ] = y_pointsource_spikes
194
190
z_rel [pointsource_spikes ] = z_rel_pointsource_spikes
195
-
191
+
196
192
z_abs = z_rel + geom [main_channels , 1 ]
197
-
193
+
198
194
if is_numpy :
199
195
x = x .numpy (force = True )
200
196
y = y .numpy (force = True )
@@ -204,9 +200,14 @@ def localize_amplitude_vectors(
204
200
205
201
return dict (x = x , y = y , z_rel = z_rel , z_abs = z_abs , alpha = projected_dist )
206
202
203
+ else :
204
+ assert False
205
+
207
206
208
207
# -- point source / dipole model library functions
209
- def point_source_amplitude_at (x , y , z , local_geom ):
208
+
209
+
210
+ def point_source_amplitude_at (x , y , z , alpha , local_geom ):
210
211
"""Point source model predicted amplitude at local_geom given location"""
211
212
dxs = torch .square (x - local_geom [:, 0 ])
212
213
dzs = torch .square (z - local_geom [:, 1 ])
@@ -223,25 +224,8 @@ def point_source_find_alpha(amp_vec, channel_mask, x, y, z, local_geoms):
223
224
)
224
225
return alpha
225
226
226
- def dipole_find_projection_distance (normalized_amp_vec , x , y , z , local_geom ):
227
- """We can solve for the brightness (alpha) of the source in closed form given x,y,z"""
228
-
229
- dxs = x - local_geom [:, 0 ]
230
- dzs = z - local_geom [:, 1 ]
231
- dys = y
232
- duv = torch .tensor ([dxs , dys , dzs ])
233
- X = duv / torch .pow (torch .sum (torch .square (duv )), 3 / 2 )
234
- beta = torch .linalg .solve (torch .matmul (X .T , X ), torch .matmul (X .T , normalized_amp_vec ))
235
- beta /= torch .sqrt (torch .square (beta ).sum ())
236
- dipole_planar_direction = torch .sqrt (np .torch (beta [[0 , 2 ]]).sum ())
237
- closest_chan = torch .square (duv ).sum (1 ).argmin ()
238
- min_duv = duv [closest_chan ]
239
- val_th = torch .sqrt (torch .square (min_duv ).sum ())/ dipole_planar_direction
240
- return val_th
241
227
242
- def point_source_mse (
243
- loc , amplitude_vector , channel_mask , local_geom , logbarrier = True
244
- ):
228
+ def point_source_mse (loc , amplitude_vector , channel_mask , local_geom , logbarrier = True ):
245
229
"""Objective in point source model
246
230
247
231
Arguments
@@ -264,36 +248,53 @@ def point_source_mse(
264
248
x , y0 , z = loc
265
249
y = F .softplus (y0 )
266
250
267
- alpha = point_source_find_alpha (
268
- amplitude_vector , channel_mask , x , y , z , local_geom
269
- )
251
+ alpha = point_source_find_alpha (amplitude_vector , channel_mask , x , y , z , local_geom )
270
252
obj = torch .square (
271
- amplitude_vector
272
- - point_source_amplitude_at (x , y , z , alpha , local_geom )
253
+ amplitude_vector - point_source_amplitude_at (x , y , z , alpha , local_geom )
273
254
).mean ()
274
255
if logbarrier :
275
256
obj -= torch .log (10.0 * y ) / 10000.0
276
257
# idea for logbarrier on points which run away
277
258
# obj -= torch.log(1000.0 - torch.sqrt(torch.square(x) + torch.square(z))).sum() / 10000.0
278
259
return obj
279
260
261
+
262
+ def dipole_find_projection_distance (normalized_amp_vec , x , y , z , local_geom ):
263
+ """We can solve for the brightness (alpha) of the source in closed form given x,y,z"""
264
+
265
+ dxs = x - local_geom [:, 0 ]
266
+ dzs = z - local_geom [:, 1 ]
267
+ dys = y
268
+ duv = torch .tensor ([dxs , dys , dzs ])
269
+ X = duv / torch .pow (torch .sum (torch .square (duv )), 3 / 2 )
270
+ beta = torch .linalg .solve (
271
+ torch .matmul (X .T , X ), torch .matmul (X .T , normalized_amp_vec )
272
+ )
273
+ beta /= torch .sqrt (torch .square (beta ).sum ())
274
+ dipole_planar_direction = torch .sqrt (np .torch (beta [[0 , 2 ]]).sum ())
275
+ closest_chan = torch .square (duv ).sum (1 ).argmin ()
276
+ min_duv = duv [closest_chan ]
277
+ val_th = torch .sqrt (torch .square (min_duv ).sum ()) / dipole_planar_direction
278
+ return val_th
279
+
280
+
280
281
def dipole_mse (loc , amplitude_vector , local_geom , logbarrier = True ):
281
282
"""Dipole model predicted amplitude at local_geom given location"""
282
-
283
+
283
284
x , y0 , z = loc
284
285
y = F .softplus (y0 )
285
286
286
287
dxs = x - local_geom [:, 0 ]
287
288
dzs = z - local_geom [:, 1 ]
288
289
dys = y
289
-
290
+
290
291
duv = torch .tensor ([dxs , dys , dzs ])
291
292
292
- X = duv / torch .pow (torch .sum (torch .square (duv )), 3 / 2 )
293
-
293
+ X = duv / torch .pow (torch .sum (torch .square (duv )), 3 / 2 )
294
+
294
295
beta = torch .linalg .solve (torch .matmul (X .T , X ), torch .matmul (X .T , (ptp / maxptp )))
295
296
qtq = torch .matmul (X , beta )
296
-
297
+
297
298
obj = torch .square (ptp / maxptp - qtq ).mean ()
298
299
if logbarrier :
299
300
obj -= torch .log (10.0 * y ) / 10000.0
0 commit comments