Skip to content

Commit 0bccb1a

Browse files
committed
Make universal scale...
1 parent 695279a commit 0bccb1a

File tree

5 files changed

+74
-31
lines changed

5 files changed

+74
-31
lines changed

src/dartsort/peel/matching.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
featurization_pipeline,
3838
motion_est=None,
3939
pairwise_conv_db=None,
40+
low_rank_templates=None,
4041
svd_compression_rank=10,
4142
coarse_objective=True,
4243
temporal_upsampling_factor=8,
@@ -78,6 +79,7 @@ def __init__(
7879
# main properties
7980
self.template_data = template_data
8081
self.pairwise_conv_db = pairwise_conv_db
82+
self.low_rank_templates = low_rank_templates
8183
self.coarse_objective = coarse_objective
8284
self.temporal_upsampling_factor = temporal_upsampling_factor
8385
self.upsampling_peak_window_radius = upsampling_peak_window_radius
@@ -244,14 +246,15 @@ def build_template_data(
244246
if not have_groups:
245247
# the logic of coarsening is not needed.
246248
self.coarse_objective = False
247-
low_rank_templates = template_util.svd_compress_templates(
248-
template_data,
249-
min_channel_amplitude=min_channel_amplitude,
250-
rank=svd_compression_rank,
251-
)
252-
temporal_components = low_rank_templates.temporal_components.astype(dtype)
253-
singular_values = low_rank_templates.singular_values.astype(dtype)
254-
spatial_components = low_rank_templates.spatial_components.astype(dtype)
249+
if self.low_rank_templates is None:
250+
self.low_rank_templates = template_util.svd_compress_templates(
251+
template_data,
252+
min_channel_amplitude=min_channel_amplitude,
253+
rank=svd_compression_rank,
254+
)
255+
temporal_components = self.low_rank_templates.temporal_components.astype(dtype)
256+
singular_values = self.low_rank_templates.singular_values.astype(dtype)
257+
spatial_components = self.low_rank_templates.spatial_components.astype(dtype)
255258
self.register_buffer("temporal_components", torch.tensor(temporal_components))
256259
self.register_buffer("singular_values", torch.tensor(singular_values))
257260
self.register_buffer("spatial_components", torch.tensor(spatial_components))
@@ -292,7 +295,7 @@ def build_template_data(
292295
self.obj_n_templates = spatial_components.shape[0]
293296
else:
294297
objective_temp_data = template_data
295-
objective_low_rank_temps = low_rank_templates
298+
objective_low_rank_temps = self.low_rank_templates
296299
self.objective_template_depths_um = self.registered_template_depths_um
297300
self.register_buffer(
298301
"objective_temporal_components", self.temporal_components
@@ -302,6 +305,7 @@ def build_template_data(
302305
"objective_spatial_components", self.spatial_components
303306
)
304307
self.obj_n_templates = self.n_templates
308+
305309
self.handle_template_groups(
306310
objective_temp_data.unit_ids, self.template_data.unit_ids
307311
)
@@ -322,13 +326,14 @@ def build_template_data(
322326
chunk_centers_s = self.recording._recording_segments[0].sample_index_to_time(
323327
chunk_centers_samples
324328
)
329+
325330
if self.pairwise_conv_db is None:
326331
self.pairwise_conv_db = CompressedPairwiseConv.from_template_data(
327332
save_folder / "pconv.h5",
328333
template_data=objective_temp_data,
329334
low_rank_templates=objective_low_rank_temps,
330335
template_data_b=template_data,
331-
low_rank_templates_b=low_rank_templates,
336+
low_rank_templates_b=self.low_rank_templates,
332337
compressed_upsampled_temporal=compressed_upsampled_temporal,
333338
chunk_time_centers_s=chunk_centers_s,
334339
motion_est=self.motion_est,

src/dartsort/peel/universal.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import torch
2+
import numpy as np
23

34
from ..util import universal_util, waveform_util
45
from ..transform import WaveformPipeline
56
from .matching import ObjectiveUpdateTemplateMatchingPeeler
67
from ..templates.pairwise import SeparablePairwiseConv
8+
from ..templates.template_util import LowRankTemplates
79

810

911
class UniversalTemplatesMatchingPeeler(ObjectiveUpdateTemplateMatchingPeeler):
@@ -30,7 +32,6 @@ def __init__(
3032
threshold=50.0,
3133
trough_offset_samples=42,
3234
spike_length_samples=121,
33-
svd_compression_rank=10,
3435
amplitude_scaling_variance=100.0,
3536
amplitude_scaling_boundary=500.0,
3637
detection_threshold=6.0,
@@ -73,17 +74,34 @@ def __init__(
7374
kmeanspp_initial="random",
7475
)
7576
)
77+
78+
Nf = len(footprints)
79+
Ns = len(shapes)
80+
shapes_ixd = torch.asarray(shapes)[None]
81+
shapes_ixd = shapes_ixd.broadcast_to((Nf, Ns, *shapes.shape[1:]))
82+
shapes_ixd = shapes_ixd.reshape(Nf * Ns, *shapes.shape[1:], 1)
83+
footprints_ixd = torch.asarray(footprints)[:, None]
84+
footprints_ixd = footprints_ixd.broadcast_to((Nf, Ns, *footprints.shape[1:]))
85+
footprints_ixd = footprints_ixd.reshape(Nf * Ns, 1, *footprints.shape[1:])
86+
low_rank_templates = LowRankTemplates(
87+
temporal_components=shapes_ixd.numpy(),
88+
singular_values=shapes_ixd.new_ones(Nf * Ns, 1).numpy(),
89+
spatial_components=footprints_ixd.numpy(),
90+
spike_counts_by_channel=np.broadcast_to(
91+
np.atleast_2d([100]), (Nf * Ns, footprints.shape[1])
92+
),
93+
)
7694
pairwise_conv_db = SeparablePairwiseConv(footprints, shapes)
7795
super().__init__(
7896
recording,
7997
template_data,
8098
channel_index,
8199
featurization_pipeline,
82100
pairwise_conv_db=pairwise_conv_db,
101+
low_rank_templates=low_rank_templates,
83102
threshold=threshold,
84103
amplitude_scaling_variance=amplitude_scaling_variance,
85104
amplitude_scaling_boundary=amplitude_scaling_boundary,
86-
svd_compression_rank=svd_compression_rank,
87105
# usual gizmos
88106
trough_offset_samples=trough_offset_samples,
89107
chunk_length_samples=chunk_length_samples,
@@ -94,10 +112,11 @@ def __init__(
94112
fit_sampling=fit_sampling,
95113
dtype=dtype,
96114
# matching params which don't need setting
97-
min_channel_amplitude=1.0,
115+
svd_compression_rank=1,
116+
min_channel_amplitude=0.0,
98117
motion_est=None,
99118
coarse_approx_error_threshold=0.0,
100-
conv_ignore_threshold=5.0,
119+
conv_ignore_threshold=0.0,
101120
coarse_objective=True,
102121
temporal_upsampling_factor=1,
103122
refractory_radius_frames=10,
@@ -127,6 +146,7 @@ def from_config(
127146
)
128147
return cls(
129148
recording,
149+
threshold=subtraction_config.universal_threshold,
130150
channel_index=channel_index,
131151
featurization_pipeline=featurization_pipeline,
132152
trough_offset_samples=trough_offset_samples,

src/dartsort/templates/pairwise.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import h5py
55
import numpy as np
66
import torch
7+
from torch._functorch.vmap import _as_tuple
78
import torch.nn.functional as F
89

910
from .pairwise_util import compressed_convolve_to_h5
@@ -282,9 +283,9 @@ def __init__(self, spatial_footprints, temporal_shapes):
282283
Let Nf = len(spatial_footprints), Ns = len(temporal_shapes). Then
283284
indexing is footprint-major, so that
284285
285-
template[i] = spatial_footprints[i // Nf] * temporal_shapes[i - Nf * (i // Nf)]
286+
template[i] = spatial_footprints[i // Ns] * temporal_shapes[i - Ns * (i // Ns)]
286287
287-
Let f(i) = i // Nf and s(i) = i - Nf * (i // Nf). Then the channel-summed
288+
Let f(i) = i // Ns and s(i) = i - Ns * (i // Ns). Then the channel-summed
288289
convolution of templates i and j is given by
289290
290291
conv(t; i, j) = (
@@ -299,6 +300,7 @@ def __init__(self, spatial_footprints, temporal_shapes):
299300
self.register_buffer("spatial_footprints", torch.asarray(spatial_footprints))
300301
self.register_buffer("temporal_shapes", torch.asarray(temporal_shapes))
301302
self.Nf = len(spatial_footprints)
303+
self.Ns, self.nt = temporal_shapes.shape
302304

303305
# convolve all pairs of temporal shapes
304306
# i is data, j is filter
@@ -311,7 +313,8 @@ def __init__(self, spatial_footprints, temporal_shapes):
311313
# spatial component
312314
sdot = self.spatial_footprints @ self.spatial_footprints.T
313315
self.register_buffer("sdot", sdot)
314-
self.tia = torch.arange(len(temporal_shapes))
316+
self.overlap = (self.sdot > 0).cpu()
317+
self.tia = torch.arange(self.Ns * self.Nf)
315318

316319
def query(
317320
self,
@@ -329,24 +332,35 @@ def query(
329332
if device is not None and device != self.spatial_footprints.device:
330333
self.to(device)
331334
assert shifts_a is shifts_b is None
332-
assert upsampling_indices_b is None
333-
del return_zero_convs # choose not to implement this
335+
assert upsampling_indices_b is None or (upsampling_indices_b == 0).all()
336+
assert not return_zero_convs # choose not to implement this
334337
assert grid # only this case here. can probably do the same above.
335338
if template_indices_a is None:
336339
template_indices_a = self.tia.to(template_indices_b)
337340

338-
f_i = template_indices_b // self.Nf
339-
f_j = template_indices_a // self.Nf
340-
s_i = template_indices_b - self.Nf * f_i
341-
s_j = template_indices_a - self.Nf * f_j
341+
f_i = template_indices_a // self.Ns
342+
f_j = template_indices_b // self.Ns
343+
keep_i, keep_j = self.overlap[f_i[:, None], f_j[None, :]].nonzero(as_tuple=True)
342344

343-
pconvs = (
344-
self.sdot[f_i[:, None], f_j[None, :]]
345-
* self.tconv[s_i[:, None], s_j[None, :]]
346-
)
345+
template_indices_a = template_indices_a[keep_i]
346+
template_indices_b = template_indices_b[keep_j]
347+
if scalings_b is not None:
348+
scalings_b = scalings_b[keep_j]
349+
if times_b is not None:
350+
times_b = times_b[keep_j]
347351

352+
f_i = template_indices_a // self.Ns
353+
f_j = template_indices_b // self.Ns
354+
s_i = template_indices_a - self.Ns * f_i
355+
s_j = template_indices_b - self.Ns * f_j
356+
357+
sdot = self.sdot[f_i, f_j]
358+
assert sdot.ndim == 1
348359
if scalings_b is not None:
349-
pconvs.mul_(scalings_b[:, None])
360+
sdot = sdot * scalings_b
361+
tconv = self.tconv[s_i, s_j]
362+
assert tconv.shape == (len(sdot), 2 * self.nt - 1)
363+
pconvs = sdot.unsqueeze(1) * tconv
350364

351365
if times_b is not None:
352366
return template_indices_a, template_indices_b, times_b, pconvs

src/dartsort/util/internal_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ class SubtractionConfig:
161161
n_singlechan_templates: int = 10
162162
singlechan_alignment_padding_ms: float = 0.7
163163
use_universal_templates: bool = False
164+
universal_threshold: float = 50.0
164165

165166
# how will waveforms be denoised before subtraction?
166167
# users can also save waveforms/features during subtraction

src/dartsort/util/universal_util.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def get_singlechan_waveforms(
133133

134134

135135
def spatial_footprint_bank(
136-
geom, n_sigmas=5, min_template_size=10.0, max_distance=32.0, dx=32.0
136+
geom, n_sigmas=5, min_template_size=10.0, max_distance=32.0, dx=32.0, eps=0.025
137137
):
138138
# this is just a single shank version, since I don't plan to use
139139
# this in production. but it is copied from KS' code and can be
@@ -168,6 +168,8 @@ def spatial_footprint_bank(
168168
sigmas = min_template_size * (1 + np.arange(n_sigmas))
169169
spatial_profiles = np.exp(-distsq[:, None, :] / sigmas[:, None] ** 2)
170170
spatial_profiles = spatial_profiles.reshape(n_centers * n_sigmas, len(geom))
171+
spatial_profiles /= np.linalg.norm(spatial_profiles, axis=1, keepdims=True)
172+
spatial_profiles[spatial_profiles < eps] = 0.0
171173

172174
return spatial_profiles
173175

@@ -189,6 +191,7 @@ def singlechan_to_library(
189191
max_distance=max_distance,
190192
dx=dx,
191193
)
194+
footprints = footprints.astype(str(singlechan_templates.dtype).split(".")[1])
192195
if torch.is_tensor(singlechan_templates):
193196
singlechan_templates = singlechan_templates.numpy(force=True)
194197
nf, nc = footprints.shape
@@ -202,8 +205,8 @@ def singlechan_to_library(
202205

203206
template_data = TemplateData(
204207
templates,
205-
unit_ids=np.arange(nsct * nt),
206-
spike_counts=np.ones(nsct * nt, dtype=int),
208+
unit_ids=np.arange(nsct * nf),
209+
spike_counts=np.ones(nsct * nf, dtype=int),
207210
)
208211
return footprints, template_data
209212

0 commit comments

Comments
 (0)