Skip to content

Commit

Permalink
Refactor to pass recording during fit(), and improve subtraction deno…
Browse files Browse the repository at this point in the history
…iser fitting logic
  • Loading branch information
cwindolf committed Oct 21, 2024
1 parent e2015b9 commit c3a356c
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 85 deletions.
11 changes: 6 additions & 5 deletions src/dartsort/peel/grab.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ def out_datasets(self):
)
return datasets

def process_chunk(self, chunk_start_samples, return_residual=False):
def process_chunk(self, chunk_start_samples, chunk_end_samples=None, return_residual=False):
"""Override process_chunk to skip empties."""
chunk_end_samples = min(
self.recording.get_num_samples(),
chunk_start_samples + self.chunk_length_samples,
)
if chunk_end_samples is None:
chunk_end_samples = min(
self.recording.get_num_samples(),
chunk_start_samples + self.chunk_length_samples,
)
in_chunk = self.times_samples == self.times_samples.clip(chunk_start_samples, chunk_end_samples - 1)
if not in_chunk.any():
return dict(n_spikes=0)
Expand Down
2 changes: 1 addition & 1 deletion src/dartsort/peel/peel_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def fit_featurization_pipeline(
channels = torch.as_tensor(channels, device=device)
waveforms = torch.as_tensor(waveforms, device=device)
featurization_pipeline = featurization_pipeline.to(device)
featurization_pipeline.fit(waveforms, max_channels=channels)
featurization_pipeline.fit(waveforms, max_channels=channels, recording=self.recording)
featurization_pipeline = featurization_pipeline.to("cpu")
self.featurization_pipeline = featurization_pipeline
finally:
Expand Down
86 changes: 70 additions & 16 deletions src/dartsort/peel/subtract.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
relative_channel_subset_index)

from .peel_base import BasePeeler
from .threshold import threshold_chunk


class SubtractionPeeler(BasePeeler):
Expand Down Expand Up @@ -91,6 +92,10 @@ def __init__(
"subtraction_denoising_pipeline", subtraction_denoising_pipeline
)

# internal api for switching to thresholding during denoiser fitting
# when there are no pre-fit denoisers
self._turn_off_subtraction = False

def out_datasets(self):
datasets = super().out_datasets()

Expand Down Expand Up @@ -211,6 +216,7 @@ def peel_chunk(
spatial_dedup_channel_index=self.spatial_dedup_channel_index,
residnorm_decrease_threshold=self.residnorm_decrease_threshold,
persist_deduplication=self.persist_deduplication,
no_subtraction=self._turn_off_subtraction,
)

# add in chunk_start_samples
Expand Down Expand Up @@ -243,25 +249,31 @@ def fit_peeler_models(self, save_folder, tmp_dir=None, n_jobs=0, device=None):
# so we will cheat for now:
# just remove all the denoisers that need fitting, run peeling,
# and fit everything
self._fit_subtraction_transformers(
while self._fit_subtraction_transformers(
save_folder, tmp_dir=tmp_dir, n_jobs=n_jobs, device=device, which="denoisers"
)
):
pass
self._fit_subtraction_transformers(
save_folder, tmp_dir=tmp_dir, n_jobs=n_jobs, device=device, which="featurizers"
)

def _fit_subtraction_transformers(
self, save_folder, tmp_dir=None, n_jobs=0, device=None, which="denoisers"
):
"""Handle fitting either denoisers or featurizers since the logic is similar"""
if not any(
(
(t.is_denoiser if which == "denoisers" else t.is_featurizer)
and t.needs_fit()
)
for t in self.subtraction_denoising_pipeline
):
return
"""Fit models which are run during the subtraction step
These include denoisers and featurizers. Featurizers are easy, we just fit them
to the extracted waveforms output from a mini-subtraction. Denoisers are a bit
harder, since they influence the actual waveforms that are extracted. In that sense,
you need to fit them serially with a new mini subtraction each time.
"""
if which == "denoisers":
needs_fit = any(t.is_denoiser and t.needs_fit() for t in self.subtraction_denoising_pipeline)
elif which == "featurizers":
assert not any(t.is_denoiser and t.needs_fit() for t in self.subtraction_denoising_pipeline)
needs_fit = any(t.is_featurizer and t.needs_fit() for t in self.subtraction_denoising_pipeline)
if not needs_fit:
return False

if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
Expand All @@ -278,7 +290,19 @@ def _fit_subtraction_transformers(
)
ifeats = [init_voltage_feature, init_waveform_feature]
if which == "denoisers":
ffeats = [t for t in orig_denoise if (t.is_denoiser and not t.needs_fit())]
# add all the already fitted denoisers until we hit the next unfitted one
ffeats = []
for t in orig_denoise:
if t.is_denoiser:
if t.needs_fit():
break
ffeats.append(t)

# this is the sequence of transforms to actually use in fitting
fit_feats = ffeats + [t]

# if we have no denoisers yet, then definitely don't do subtraction!
self._turn_off_subtraction = not ffeats
else:
ffeats = [t for t in orig_denoise if t.is_denoiser]
self.subtraction_denoising_pipeline = WaveformPipeline(ifeats + ffeats)
Expand Down Expand Up @@ -313,15 +337,18 @@ def _fit_subtraction_transformers(

channels = torch.as_tensor(channels, device=device)
waveforms = torch.as_tensor(waveforms, device=device)
orig_denoise = orig_denoise.to(device)
orig_denoise.fit(waveforms, max_channels=channels)
orig_denoise = orig_denoise.to("cpu")
fit_denoise = WaveformPipeline(fit_feats)
fit_denoise = fit_denoise.to(device)
fit_denoise.fit(waveforms, max_channels=channels, recording=self.recording)
fit_denoise = fit_denoise.to("cpu")
self._turn_off_subtraction = False
self.subtraction_denoising_pipeline = orig_denoise
self.featurization_pipeline = orig_featurization_pipeline
finally:
self.to("cpu")
if temp_hdf5_filename.exists():
temp_hdf5_filename.unlink()
return True


ChunkSubtractionResult = namedtuple(
Expand Down Expand Up @@ -354,8 +381,34 @@ def subtract_chunk(
persist_deduplication=True,
relative_peak_radius=5,
dedup_temporal_radius=7,
no_subtraction=False,
):
"""Core peeling routine for subtraction"""
if no_subtraction:
times_rel, channels, voltages, waveforms = threshold_chunk(
traces,
channel_index,
detection_threshold=min(detection_thresholds),
peak_sign=peak_sign,
spatial_dedup_channel_index=spatial_dedup_channel_index,
trough_offset_samples=trough_offset_samples,
spike_length_samples=spike_length_samples,
left_margin=left_margin,
right_margin=right_margin,
relative_peak_radius=relative_peak_radius,
dedup_temporal_radius=dedup_temporal_radius,
max_spikes_per_chunk=None,
quiet=False,
)
waveforms, features = denoising_pipeline(waveforms, channels)
return ChunkSubtractionResult(
n_spikes=times_rel.numel(),
times_samples=times_rel,
channels=channels,
collisioncleaned_waveforms=waveforms,
residual=None,
features=features,
)
# validate arguments to avoid confusing error messages later
re_extract = extract_index is not None
if extract_index is None:
Expand Down Expand Up @@ -410,7 +463,8 @@ def subtract_chunk(
continue

# throw away spikes which cannot be subtracted
keep = (times_samples >= trough_offset_samples) & (
keep = torch.logical_and(
times_samples >= trough_offset_samples,
times_samples < max_trough_time
)
times_samples = times_samples[keep]
Expand Down
135 changes: 86 additions & 49 deletions src/dartsort/peel/threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,60 +76,24 @@ def peel_chunk(
right_margin=0,
return_residual=False,
):
times_rel, channels, energies = detect_and_deduplicate(
times_rel, channels, voltages, waveforms = threshold_chunk(
traces,
self.detection_threshold,
dedup_channel_index=self.spatial_dedup_channel_index,
peak_sign=self.peak_sign,
dedup_temporal_radius=self.dedup_temporal_radius_samples,
relative_peak_radius=self.relative_peak_radius_samples,
return_energies=True,
)
if not times_rel.numel():
return dict(n_spikes=0)

# want only peaks in the chunk
min_time = max(left_margin, self.spike_length_samples)
max_time = traces.shape[0] - max(
right_margin, self.spike_length_samples - self.trough_offset_samples
)
valid = (times_rel >= min_time) & (times_rel < max_time)
times_rel = times_rel[valid]
n_detect = times_rel.numel()
if not n_detect:
return dict(n_spikes=0)
channels = channels[valid]
voltages = traces[times_rel, channels]

if self.max_spikes_per_chunk is not None:
if n_detect > self.max_spikes_per_chunk:
warnings.warn(
f"{n_detect} spikes in chunk was larger than "
f"{self.max_spikes_per_chunk=}. Keeping the top ones."
)
energies = energies[valid]
best = torch.argsort(energies)[-self.max_spikes_per_chunk:]
best = best.sort().values
del energies

times_rel = times_rel[best]
channels = channels[best]
voltages = voltages[best]

# load up the waveforms for this chunk
waveforms = spiketorch.grab_spikes(
traces,
times_rel,
channels,
self.channel_index,
trough_offset=self.trough_offset_samples,
spike_length_samples=self.spike_length_samples,
already_padded=False,
pad_value=torch.nan,
detection_threshold=4,
peak_sign="both",
spatial_dedup_channel_index=None,
trough_offset_samples=42,
spike_length_samples=121,
left_margin=0,
right_margin=0,
relative_peak_radius=5,
dedup_temporal_radius=7,
max_spikes_per_chunk=None,
quiet=False,
)

# get absolute times
times_samples = times_rel + chunk_start_samples - left_margin
times_samples = times_rel + chunk_start_samples

peel_result = dict(
n_spikes=times_rel.numel(),
Expand All @@ -139,3 +103,76 @@ def peel_chunk(
collisioncleaned_waveforms=waveforms,
)
return peel_result


def threshold_chunk(
traces,
channel_index,
detection_threshold=4,
peak_sign="both",
spatial_dedup_channel_index=None,
trough_offset_samples=42,
spike_length_samples=121,
left_margin=0,
right_margin=0,
relative_peak_radius=5,
dedup_temporal_radius=7,
max_spikes_per_chunk=None,
quiet=False,
):
times_rel, channels, energies = detect_and_deduplicate(
traces,
detection_threshold,
dedup_channel_index=spatial_dedup_channel_index,
peak_sign=peak_sign,
dedup_temporal_radius=dedup_temporal_radius,
relative_peak_radius=relative_peak_radius,
return_energies=True,
)
if not times_rel.numel():
return dict(n_spikes=0)

# want only peaks in the chunk
min_time = max(left_margin, spike_length_samples)
max_time = traces.shape[0] - max(
right_margin, spike_length_samples - trough_offset_samples
)
valid = (times_rel >= min_time) & (times_rel < max_time)
times_rel = times_rel[valid]
n_detect = times_rel.numel()
if not n_detect:
return dict(n_spikes=0)
channels = channels[valid]
voltages = traces[times_rel, channels]

if max_spikes_per_chunk is not None:
if n_detect > max_spikes_per_chunk and not quiet:
warnings.warn(
f"{n_detect} spikes in chunk was larger than "
f"{max_spikes_per_chunk=}. Keeping the top ones."
)
energies = energies[valid]
best = torch.argsort(energies)[-max_spikes_per_chunk:]
best = best.sort().values
del energies

times_rel = times_rel[best]
channels = channels[best]
voltages = voltages[best]

# load up the waveforms for this chunk
waveforms = spiketorch.grab_spikes(
traces,
times_rel,
channels,
channel_index,
trough_offset=trough_offset_samples,
spike_length_samples=spike_length_samples,
already_padded=False,
pad_value=torch.nan,
)

# offset times for caller
times_rel -= left_margin

return times_rel, channels, voltages, waveforms
Loading

0 comments on commit c3a356c

Please sign in to comment.