Skip to content

Commit e4cb9aa

Browse files
author
julienboussard
committed
split merge batch updated
1 parent 5a8c9c6 commit e4cb9aa

File tree

4 files changed

+96
-67
lines changed

4 files changed

+96
-67
lines changed

src/dartsort/cluster/merge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def merge_templates(
2424
amplitude_scaling_boundary=0.5,
2525
svd_compression_rank=10,
2626
min_channel_amplitude=0.0,
27-
conv_batch_size=1024,
27+
conv_batch_size=128,
2828
units_batch_size=8,
2929
device=None,
3030
n_jobs=0,
@@ -189,7 +189,7 @@ def get_deconv_resid_norm_iter(
189189
amplitude_scaling_boundary=0.5,
190190
svd_compression_rank=10,
191191
min_channel_amplitude=0.0,
192-
conv_batch_size=1024,
192+
conv_batch_size=128,
193193
units_batch_size=8,
194194
device=None,
195195
n_jobs=0,

src/dartsort/cluster/split.py

Lines changed: 54 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -296,58 +296,75 @@ def get_registered_channels(self, in_unit):
296296

297297
return max_registered_channel, n_pitches_shift, reloc_amplitudes, kept
298298

299-
def pca_features(self, in_unit, max_registered_channel, n_pitches_shift):
299+
def pca_features(self, in_unit, max_registered_channel, n_pitches_shift, batch_size=20_000, max_pca_batch=50_000):
300300
"""Compute relocated PCA features on a drift-invariant channel set"""
301301
# figure out which set of channels to use
302302
# we use the stored amplitudes to do this rather than computing a
303303
# template, which can be expensive
304+
# max_batch_size set to avoid memory errors
304305
pca_channels = self.registered_channel_index[max_registered_channel]
305306
pca_channels = pca_channels[pca_channels < len(self.registered_geom)]
306-
307+
307308
# load waveform embeddings and invert TPCA if we are relocating
308-
waveforms = batched_h5_read(self.tpca_features, in_unit)
309-
n, rank, c = waveforms.shape
310-
if self.relocated:
311-
waveforms = waveforms.transpose(0, 2, 1).reshape(n * c, rank)
312-
waveforms = self.tpca.inverse_transform(waveforms)
313-
t = waveforms.shape[1]
314-
waveforms = waveforms.reshape(n, c, t).transpose(0, 2, 1)
315309

316-
# relocate or just restrict to channel subset
317-
if self.relocated:
318-
waveforms = relocate.relocated_waveforms_on_static_channels(
319-
waveforms,
320-
main_channels=self.channels[in_unit],
321-
channel_index=self.channel_index,
322-
target_channels=pca_channels,
323-
xyza_from=self.xyza[in_unit],
324-
z_to=self.z_reg[in_unit],
325-
geom=self.geom,
326-
registered_geom=self.registered_geom,
327-
)
328-
else:
329-
waveforms = drift_util.get_waveforms_on_static_channels(
330-
waveforms,
331-
self.geom,
332-
main_channels=self.channels[in_unit],
333-
channel_index=self.channel_index,
334-
target_channels=pca_channels,
335-
n_pitches_shift=n_pitches_shift,
336-
registered_geom=self.registered_geom,
337-
)
338-
# ravel t,c dims -- everything below is spatiotemporal
339-
waveforms = waveforms.reshape(n, t * waveforms.shape[2])
310+
waveforms_all = batched_h5_read(self.tpca_features, in_unit) #in tpca manifold
311+
n, rank, c = waveforms_all.shape
312+
313+
for bs in range(0, n, batch_size):
314+
be = min(n, bs + batch_size)
315+
waveforms = waveforms_all[bs:be]
316+
n_batch = int(be-bs)
317+
if self.relocated:
318+
waveforms = waveforms.transpose(0, 2, 1).reshape(n_batch * c, rank)
319+
waveforms = self.tpca.inverse_transform(waveforms) #breaks here if too many waveforms
320+
t = waveforms.shape[1]
321+
waveforms = waveforms.reshape(n_batch, c, t).transpose(0, 2, 1)
322+
323+
# relocate or just restrict to channel subset
324+
if self.relocated:
325+
waveforms = relocate.relocated_waveforms_on_static_channels(
326+
waveforms,
327+
main_channels=self.channels[in_unit][bs:be],
328+
channel_index=self.channel_index,
329+
target_channels=pca_channels,
330+
xyza_from=self.xyza[in_unit][bs:be],
331+
z_to=self.z_reg[in_unit][bs:be],
332+
geom=self.geom,
333+
registered_geom=self.registered_geom,
334+
)
335+
else:
336+
waveforms = drift_util.get_waveforms_on_static_channels(
337+
waveforms,
338+
self.geom,
339+
main_channels=self.channels[in_unit][bs:be],
340+
channel_index=self.channel_index,
341+
target_channels=pca_channels,
342+
n_pitches_shift=n_pitches_shift,
343+
registered_geom=self.registered_geom,
344+
)
345+
# ravel t,c dims -- everything below is spatiotemporal
346+
if bs==0:
347+
wfs_out = np.empty((n, t*len(pca_channels)), dtype=waveforms.dtype)
348+
wfs_out[bs:be] = waveforms.reshape(n_batch, t * waveforms.shape[2])
340349

341350
# figure out which waveforms actually overlap with the requested channels
342-
no_nan = np.flatnonzero(~np.isnan(waveforms).any(axis=1))
351+
no_nan = np.flatnonzero(~np.isnan(wfs_out).any(axis=1))
343352
if no_nan.size < max(self.min_cluster_size, self.n_pca_features):
344353
return False, no_nan, None
345354

346355
# fit pca and embed
347356
pca = PCA(self.n_pca_features, random_state=self.random_state, whiten=True)
348-
pca_projs = np.full((n, self.n_pca_features), np.nan, dtype=waveforms.dtype)
349-
pca_projs[no_nan] = pca.fit_transform(waveforms[no_nan])
350-
357+
pca_projs = np.full((n, self.n_pca_features), np.nan, dtype=wfs_out.dtype)
358+
359+
if len(no_nan)>max_pca_batch:
360+
idx_pca = np.random.choice(no_nan, max_pca_batch, replace=False)
361+
pca.fit(wfs_out[idx_pca])
362+
for bs in range(0, len(no_nan), max_pca_batch):
363+
be = min(len(no_nan), bs + max_pca_batch)
364+
idx_pca = no_nan[bs:be]
365+
pca_projs[idx_pca] = pca.transform(wfs_out[idx_pca])
366+
else:
367+
pca_projs[no_nan] = pca.fit_transform(wfs_out[no_nan])
351368
return True, no_nan, pca_projs
352369

353370
def initialize_from_h5(

src/dartsort/templates/pairwise_util.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def compressed_convolve_to_h5(
3030
geom: Optional[np.ndarray] = None,
3131
conv_ignore_threshold=0.0,
3232
coarse_approx_error_threshold=0.0,
33-
conv_batch_size=1024,
33+
conv_batch_size=128,
3434
units_batch_size=8,
3535
overwrite=False,
3636
device=None,
@@ -174,7 +174,7 @@ def iterate_compressed_pairwise_convolutions(
174174
amplitude_scaling_variance=0.0,
175175
amplitude_scaling_boundary=0.5,
176176
reduce_deconv_resid_norm=False,
177-
conv_batch_size=1024,
177+
conv_batch_size=128,
178178
units_batch_size=8,
179179
device=None,
180180
n_jobs=0,
@@ -401,7 +401,7 @@ def compressed_convolve_pairs(
401401
amplitude_scaling_boundary=0.5,
402402
reduce_deconv_resid_norm=False,
403403
max_shift="full",
404-
batch_size=1024,
404+
batch_size=128,
405405
device=None,
406406
) -> Optional[CompressedConvResult]:
407407
"""Compute compressed pairwise convolutions between template pairs
@@ -469,12 +469,14 @@ def compressed_convolve_pairs(
469469

470470
# handle upsampling
471471
# each pair will be duplicated by the b unit's number of upsampled copies
472+
472473
(
473474
ix_b,
474475
compression_index,
475476
conv_ix,
476477
conv_upsampling_indices_b,
477-
conv_temporal_components_up_b,
478+
conv_temporal_components_up_b, #Need to change this conv_temporal_components_up_b[conv_compressed_upsampled_ix_b]
479+
conv_compressed_upsampled_ix_b,
478480
compression_dup_ix,
479481
) = compressed_upsampled_pairs(
480482
ix_b,
@@ -491,10 +493,14 @@ def compressed_convolve_pairs(
491493
# run convolutions
492494
temporal_a = low_rank_templates_a.temporal_components[temp_ix_a]
493495
pconv, kept = correlate_pairs_lowrank(
494-
torch.as_tensor(spatial_singular_a[ix_a[conv_ix]], device=device),
495-
torch.as_tensor(spatial_singular_b[ix_b[conv_ix]], device=device),
496-
torch.as_tensor(temporal_a[ix_a[conv_ix]], device=device),
496+
torch.as_tensor(spatial_singular_a, device=device),
497+
torch.as_tensor(spatial_singular_b, device=device),
498+
torch.as_tensor(temporal_a, device=device),
497499
torch.as_tensor(conv_temporal_components_up_b, device=device),
500+
ix_a=ix_a,
501+
ix_b=ix_b,
502+
conv_ix=conv_ix,
503+
conv_compressed_upsampled_ix_b=conv_compressed_upsampled_ix_b,
498504
max_shift=max_shift,
499505
conv_ignore_threshold=conv_ignore_threshold,
500506
batch_size=batch_size,
@@ -558,9 +564,13 @@ def correlate_pairs_lowrank(
558564
spatial_b,
559565
temporal_a,
560566
temporal_b,
567+
ix_a,
568+
ix_b,
569+
conv_ix,
570+
conv_compressed_upsampled_ix_b,
561571
max_shift="full",
562572
conv_ignore_threshold=0.0,
563-
batch_size=1024,
573+
batch_size=128,
564574
):
565575
"""Convolve pairs of low rank templates
566576
@@ -580,15 +590,19 @@ def correlate_pairs_lowrank(
580590
-------
581591
pconv, kept
582592
"""
583-
n_pairs, rank, nchan = spatial_a.shape
584-
n_pairs_, rank_, nchan_ = spatial_b.shape
593+
594+
# Now need to take ix_a/b[conv_ix] of spatial_a, spatial_b, temporal_a
595+
_, rank, nchan = spatial_a.shape
596+
_, rank_, nchan_ = spatial_b.shape
597+
n_pairs = conv_ix.shape[0]
585598
assert rank == rank_
586599
assert nchan == nchan_
587-
assert n_pairs == n_pairs_
588-
n_pairs_, t, rank_ = temporal_a.shape
589-
assert n_pairs == n_pairs_
600+
# assert n_pairs == n_pairs_
601+
_, t, rank_ = temporal_a.shape
602+
# assert n_pairs == n_pairs_
590603
assert rank_ == rank
591-
n_pairs_, t_, rank_ = temporal_b.shape
604+
_, t_, rank_ = temporal_b.shape
605+
n_pairs_ = conv_compressed_upsampled_ix_b.shape[0]
592606
assert n_pairs == n_pairs_
593607
assert t == t_
594608
assert rank == rank_
@@ -609,12 +623,12 @@ def correlate_pairs_lowrank(
609623
ix = slice(istart, iend)
610624

611625
# want conv filter: nco, 1, rank, t
612-
template_a = torch.bmm(temporal_a[ix], spatial_a[ix])
613-
conv_filt = torch.bmm(spatial_b[ix], template_a.mT)
626+
template_a = torch.bmm(temporal_a[ix_a[conv_ix][ix]], spatial_a[ix_a[conv_ix][ix]])
627+
conv_filt = torch.bmm(spatial_b[ix_b[conv_ix][ix]], template_a.mT)
614628
conv_filt = conv_filt[:, None] # (nco, 1, rank, t)
615629

616630
# 1, nco, rank, t
617-
conv_in = temporal_b[ix].mT[None]
631+
conv_in = temporal_b[conv_compressed_upsampled_ix_b[ix]].mT[None]
618632

619633
# conv2d:
620634
# depthwise, chans=nco. batch=1. h=rank. w=t. out: nup=1, nco, 1, 2p+1.
@@ -951,10 +965,10 @@ def compressed_upsampled_pairs(
951965
compression_dup_ix = slice(None)
952966
if up_factor == 1:
953967
upinds = np.zeros(len(conv_ix), dtype=int)
954-
temp_comps = compressed_upsampled_temporal.compressed_upsampled_templates[
955-
np.atleast_1d(temp_ix_b[ix_b[conv_ix]])
956-
]
957-
return ix_b, compression_index, conv_ix, upinds, temp_comps, compression_dup_ix
968+
# temp_comps = compressed_upsampled_temporal.compressed_upsampled_templates[
969+
# np.atleast_1d(temp_ix_b[ix_b[conv_ix]])
970+
# ]
971+
return ix_b, compression_index, conv_ix, upinds, compressed_upsampled_temporal.compressed_upsampled_templates, np.atleast_1d(temp_ix_b[ix_b[conv_ix]]), compression_dup_ix
958972

959973
# each conv_ix needs to be duplicated as many times as its b template has
960974
# upsampled copies
@@ -991,18 +1005,16 @@ def compressed_upsampled_pairs(
9911005
conv_compressed_upsampled_ix
9921006
]
9931007
)
994-
conv_temporal_components_up_b = (
995-
compressed_upsampled_temporal.compressed_upsampled_templates[
996-
conv_compressed_upsampled_ix
997-
]
998-
)
1008+
1009+
# conv_temporal_components_up_b = compressed_upsampled_temporal.compressed_upsampled_templates
9991010

10001011
return (
10011012
ix_b_up,
10021013
compression_index_up,
10031014
conv_ix_up,
10041015
conv_upsampling_indices_b,
1005-
conv_temporal_components_up_b,
1016+
compressed_upsampled_temporal.compressed_upsampled_templates,
1017+
conv_compressed_upsampled_ix,
10061018
compression_dup_ix,
10071019
)
10081020

src/spike_psvae/cluster_viz.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import matplotlib.pyplot as plt
66
import matplotlib.transforms as transforms
77
import numpy as np
8-
import seaborn as sns
8+
# import seaborn as sns
99
from matplotlib.patches import Ellipse
1010
# %%
1111
# matplotlib.use('Agg')

0 commit comments

Comments
 (0)