diff --git a/src/dartsort/peel/__init__.py b/src/dartsort/peel/__init__.py index 9f90d541..7de8eeab 100644 --- a/src/dartsort/peel/__init__.py +++ b/src/dartsort/peel/__init__.py @@ -1,3 +1,4 @@ from .grab import GrabAndFeaturize from .matching import ObjectiveUpdateTemplateMatchingPeeler from .subtract import SubtractionPeeler, subtract_chunk +from .threshold import ThresholdAndFeaturize diff --git a/src/dartsort/peel/peel_base.py b/src/dartsort/peel/peel_base.py index 1a76d7d1..f18409b8 100644 --- a/src/dartsort/peel/peel_base.py +++ b/src/dartsort/peel/peel_base.py @@ -28,7 +28,7 @@ def __init__( self, recording, channel_index, - featurization_pipeline, + featurization_pipeline=None, chunk_length_samples=30_000, chunk_margin_samples=0, n_chunks_fit=40, @@ -46,7 +46,10 @@ def __init__( fit_subsampling_random_state ) self.register_buffer("channel_index", channel_index) - self.add_module("featurization_pipeline", featurization_pipeline) + if featurization_pipeline is not None: + self.add_module("featurization_pipeline", featurization_pipeline) + else: + self.featurization_pipeline = None # subclasses can append to this if they want to store more fixed # arrays in the output h5 file @@ -237,9 +240,10 @@ def out_datasets(self): SpikeDataset(name="times_seconds", shape_per_spike=(), dtype=float), SpikeDataset(name="channels", shape_per_spike=(), dtype=int), ] - for transformer in self.featurization_pipeline.transformers: - if transformer.is_featurizer: - datasets.append(transformer.spike_dataset) + if self.featurization_pipeline is not None: + for transformer in self.featurization_pipeline.transformers: + if transformer.is_featurizer: + datasets.append(transformer.spike_dataset) return datasets # -- utility methods which users likely won't touch @@ -247,6 +251,9 @@ def out_datasets(self): def featurize_collisioncleaned_waveforms( self, collisioncleaned_waveforms, max_channels ): + if self.featurization_pipeline is None: + return {} + waveforms, features = self.featurization_pipeline( collisioncleaned_waveforms, max_channels ) @@ -329,7 +336,10 @@ def gather_chunk_result( return n_new_spikes def needs_fit(self): - return self.peeling_needs_fit() or self.featurization_pipeline.needs_fit() + it_does = self.peeling_needs_fit() + if self.featurization_pipeline is not None: + it_does = it_does or self.featurization_pipeline.needs_fit() + return it_does def fit_models(self, save_folder, overwrite=False, n_jobs=0, device=None): with torch.no_grad(): @@ -349,6 +359,9 @@ def fit_models(self, save_folder, overwrite=False, n_jobs=0, device=None): assert not self.needs_fit() def fit_featurization_pipeline(self, save_folder, n_jobs=0, device=None): + if self.featurization_pipeline is None: + return + if not self.featurization_pipeline.needs_fit(): return diff --git a/src/dartsort/peel/subtract.py b/src/dartsort/peel/subtract.py index 3671802f..b4640e46 100644 --- a/src/dartsort/peel/subtract.py +++ b/src/dartsort/peel/subtract.py @@ -26,7 +26,7 @@ def __init__( spike_length_samples=121, detection_thresholds=[12, 10, 8, 6, 5, 4], chunk_length_samples=30_000, - peak_sign="neg", + peak_sign="both", spatial_dedup_channel_index=None, n_chunks_fit=40, fit_subsampling_random_state=0, @@ -274,7 +274,7 @@ def subtract_chunk( left_margin=0, right_margin=0, detection_thresholds=[12, 10, 8, 6, 5, 4], - peak_sign="neg", + peak_sign="both", spatial_dedup_channel_index=None, residnorm_decrease_threshold=3.162, # sqrt(10) ): diff --git a/src/dartsort/peel/threshold.py b/src/dartsort/peel/threshold.py new file mode 100644 index 00000000..7379f72d --- /dev/null +++ b/src/dartsort/peel/threshold.py @@ -0,0 +1,95 @@ +import torch +from dartsort.detect import detect_and_deduplicate +from dartsort.util import spiketorch + +from .peel_base import BasePeeler + + +class ThresholdAndFeaturize(BasePeeler): + def __init__( + self, + recording, + channel_index, + featurization_pipeline=None, + trough_offset_samples=42, + spike_length_samples=121, + detection_threshold=5.0, + chunk_length_samples=30_000, + peak_sign="both", + spatial_dedup_channel_index=None, + n_chunks_fit=40, + fit_subsampling_random_state=0, + ): + super().__init__( + recording=recording, + channel_index=channel_index, + featurization_pipeline=featurization_pipeline, + chunk_length_samples=chunk_length_samples, + chunk_margin_samples=spike_length_samples, + n_chunks_fit=n_chunks_fit, + fit_subsampling_random_state=fit_subsampling_random_state, + ) + + self.trough_offset_samples = trough_offset_samples + self.spike_length_samples = spike_length_samples + self.peak_sign = peak_sign + if spatial_dedup_channel_index is not None: + self.register_buffer( + "spatial_dedup_channel_index", + spatial_dedup_channel_index, + ) + else: + self.spatial_dedup_channel_index = None + self.detection_threshold = detection_threshold + self.peel_kind = f"Threshold {detection_threshold}" + + def peel_chunk( + self, + traces, + chunk_start_samples=0, + left_margin=0, + right_margin=0, + return_residual=False, + ): + times_rel, channels = detect_and_deduplicate( + traces, + self.detection_threshold, + dedup_channel_index=self.spatial_dedup_channel_index, + peak_sign=self.peak_sign, + ) + if not times_rel.numel(): + return dict(n_spikes=0) + + # want only peaks in the chunk + min_time = max(left_margin, self.spike_length_samples) + max_time = traces.shape[0] - max( + right_margin, self.spike_length_samples - self.trough_offset_samples + ) + valid = (times_rel >= min_time) & (times_rel < max_time) + times_rel = times_rel[valid] + if not times_rel.numel(): + return dict(n_spikes=0) + channels = channels[valid] + + # load up the waveforms for this chunk + waveforms = spiketorch.grab_spikes( + traces, + times_rel, + channels, + self.channel_index, + trough_offset=self.trough_offset_samples, + spike_length_samples=self.spike_length_samples, + already_padded=False, + pad_value=torch.nan, + ) + + # get absolute times + times_samples = times_rel + chunk_start_samples - left_margin + + peel_result = dict( + n_spikes=times_rel.numel(), + times_samples=times_samples, + channels=channels, + collisioncleaned_waveforms=waveforms, + ) + return peel_result diff --git a/src/dartsort/templates/template_util.py b/src/dartsort/templates/template_util.py index 97f8f631..4054f394 100644 --- a/src/dartsort/templates/template_util.py +++ b/src/dartsort/templates/template_util.py @@ -109,11 +109,15 @@ def get_registered_templates( def get_realigned_sorting( recording, sorting, + realign_peaks=True, + low_rank_denoising=False, **kwargs, ): results = get_templates( recording, sorting, + realign_peaks=realign_peaks, + low_rank_denoising=low_rank_denoising, **kwargs, ) return results["sorting"] diff --git a/src/dartsort/templates/templates.py b/src/dartsort/templates/templates.py index 6dc5b08c..f742a8e2 100644 --- a/src/dartsort/templates/templates.py +++ b/src/dartsort/templates/templates.py @@ -105,6 +105,7 @@ def from_config( save_npz_name="template_data.npz", localizations_dataset_name="point_source_localizations", n_jobs=0, + units_per_job=8, device=None, trough_offset_samples=42, spike_length_samples=121, @@ -155,6 +156,7 @@ def from_config( denoising_fit_radius=template_config.denoising_fit_radius, denoising_snr_threshold=template_config.denoising_snr_threshold, device=device, + units_per_job=units_per_job, ) if template_config.registered_templates and motion_est is not None: kwargs["registered_geom"] = drift_util.registered_geometry( diff --git a/src/dartsort/transform/decollider.py b/src/dartsort/transform/decollider.py new file mode 100644 index 00000000..5e3d60f3 --- /dev/null +++ b/src/dartsort/transform/decollider.py @@ -0,0 +1,260 @@ +import torch +from torch import nn + +# TODO implement WaveformDenoiser versions + + +class Decollider(nn.Module): + """Implements save/load logic for subclasses.""" + + subclasses = {} + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + cls.subclasses[cls.__name__] = cls + + def save(self, pt_path): + data = self.state_dict() + data["decollider_subclass"] = self.__class__.__name__ + data["decollider_kwargs"] = self._kwargs + torch.save(data, pt_path) + + @classmethod + def load(cls, pt_path): + data = torch.load(pt_path, map_location="cpu") + cls_name = data.pop("decollider_subclass") + kwargs = data.pop("decollider_kwargs") + subcls = cls.subclasses[cls_name] + self = subcls(**kwargs) + self.load_state_dict(data) + return self + + def predict(self, noisy_waveforms, channel_masks=None): + # multi-chan prediction + # multi-chan nets below naturally implement this in their + # forward(), but single-chan nets need a little logic + return self.forward(noisy_waveforms, channel_masks=channel_masks) + + def n2n_predict(self, noisier_waveforms, channel_masks=None, alpha=1.0): + """See Noisier2Noise paper. This is their Eq. 6. + + If you plan to use this at inference time, then multiply your noise2 + during training by alpha. + """ + expected_noisy_waveforms = self.predict( + noisier_waveforms, channel_masks=channel_masks + ) + if alpha == 1.0: + return 2.0 * expected_noisy_waveforms - noisier_waveforms + a2inv = 1.0 / (alpha * alpha) + a2p1 = 1.0 + alpha * alpha + return a2inv * (a2p1 * expected_noisy_waveforms - noisier_waveforms) + + +# -- single channel decolliders + + +class SingleChannelPredictor(Decollider): + def predict(self, waveforms, channel_masks=None): + """NCT -> NCT""" + n, c, t = waveforms.shape + waveforms = waveforms.reshape(n * c, 1, t) + preds = self.forward(waveforms) + return preds.reshape(n, c, t) + + +class SingleChannelDecollider(SingleChannelPredictor): + def forward(self, waveforms, channel_masks=None): + """N1T -> N1T""" + return self.net(waveforms) + + +class ConvToLinearSingleChannelDecollider(SingleChannelDecollider): + def __init__( + self, + out_channels=(16, 32, 64), + kernel_lengths=(5, 5, 11), + hidden_linear_dims=(), + spike_length_samples=121, + final_activation="relu", + ): + super().__init__() + in_channels = (1,) + out_channels[:-1] + is_hidden = [True] * (len(out_channels) - 1) + [False] + self.net = nn.Sequential() + for ic, oc, k, hid in zip( + in_channels, out_channels, kernel_lengths, is_hidden + ): + self.net.append(nn.Conv1d(ic, oc, k)) + if hid: + self.net.append(nn.ReLU()) + self.net.append(nn.Flatten()) + flat_dim = out_channels[-1] * ( + spike_length_samples - sum(kernel_lengths) + len(kernel_lengths) + ) + + lin_in_dims = (flat_dim,) + hidden_linear_dims + lin_out_dims = hidden_linear_dims + (spike_length_samples,) + is_final = [False] * len(hidden_linear_dims) + [True] + for inf, outf, fin in zip(lin_in_dims, lin_out_dims, is_final): + if fin and final_activation == "sigmoid": + self.net.append(nn.Sigmoid()) + if fin and final_activation == "tanh": + self.net.append(nn.Tanh()) + else: + self.net.append(nn.ReLU()) + self.net.append(nn.Linear(inf, outf)) + # add the empty channel dim back in + self.net.append(nn.Unflatten(1, (1, spike_length_samples))) + self._kwargs = dict( + out_channels=out_channels, + kernel_lengths=kernel_lengths, + hidden_linear_dims=hidden_linear_dims, + spike_length_samples=spike_length_samples, + final_activation=final_activation, + ) + + +class MLPSingleChannelDecollider(SingleChannelDecollider): + def __init__( + self, + hidden_sizes=(512, 256, 256), + spike_length_samples=121, + final_activation="relu", + ): + super().__init__() + self.net = nn.Sequential() + self.net.append(nn.Flatten()) + input_sizes = (spike_length_samples,) + hidden_sizes[:-1] + output_sizes = hidden_sizes + is_final = [False] * max(0, len(hidden_sizes) - 1) + [True] + for inf, outf, fin in zip(input_sizes, output_sizes, is_final): + self.net.append(nn.Linear(inf, outf)) + if fin and final_activation == "sigmoid": + self.net.append(nn.Sigmoid()) + if fin and final_activation == "tanh": + self.net.append(nn.Tanh()) + else: + self.net.append(nn.ReLU()) + self.net.append(nn.Linear(hidden_sizes[-1], spike_length_samples)) + # add the empty channel dim back in + self.net.append(nn.Unflatten(1, (1, spike_length_samples))) + self._kwargs = dict( + hidden_sizes=hidden_sizes, + spike_length_samples=spike_length_samples, + final_activation=final_activation, + ) + + +# -- multi channel decolliders + + +class MultiChannelDecollider(Decollider): + """NCT -> NCT + + self.net must map NC2T -> NCT + + Mask is added like so: + waveforms NCT -> N1CT \ + masks NC -> N1C1 -> N2CT (broadcast and concat) + """ + + def forward(self, waveforms, channel_masks=None): + # add the masks as an input channel + # I somehow feel that receiving a "badness indicator" is more useful, + # and the masks indicate good channels, so hence the flip below + if channel_masks is None: + masks = torch.ones_like(waveforms[:, :, 0]) + else: + masks = torch.logical_not(channel_masks).to(waveforms) + # NCT -> N1CT (channels are height in Conv2D NCHW convention) + waveforms = waveforms[:, None, :, :] + # NC -> N1CT + masks = torch.broadcast_to(masks[:, None, :, None], waveforms.shape) + # -> N2CT, concatenate on channel dimension (NCHW) + combined = torch.concatenate((waveforms, masks), dim=1) + return self.net(combined) + + +class ConvToLinearMultiChannelDecollider(MultiChannelDecollider): + def __init__( + self, + out_channels=(16, 32), + kernel_heights=(4, 4), + kernel_lengths=(5, 5), + hidden_linear_dims=(1024,), + n_channels=1, + spike_length_samples=121, + final_activation="relu", + ): + super().__init__() + in_channels = (2,) + out_channels[:-1] + is_hidden = [True] * (len(out_channels) - 1) + [False] + self.net = nn.Sequential() + for ic, oc, kl, kh, hid in zip( + in_channels, out_channels, kernel_lengths, kernel_heights, is_hidden + ): + self.net.append(nn.Conv2d(ic, oc, (kh, kl))) + if hid: + self.net.append(nn.ReLU()) + self.net.append(nn.Flatten()) + out_w = spike_length_samples - sum(kernel_lengths) + len(kernel_lengths) + out_h = n_channels - sum(kernel_heights) + len(kernel_heights) + flat_dim = out_channels[-1] * out_w * out_h + lin_in_dims = (flat_dim,) + hidden_linear_dims + lin_out_dims = hidden_linear_dims + (n_channels * spike_length_samples,) + is_final = [False] * len(hidden_linear_dims) + [True] + for inf, outf, fin in zip(lin_in_dims, lin_out_dims, is_final): + if fin and final_activation == "sigmoid": + self.net.append(nn.Sigmoid()) + if fin and final_activation == "tanh": + self.net.append(nn.Tanh()) + else: + self.net.append(nn.ReLU()) + self.net.append(nn.Linear(inf, outf)) + self.net.append(nn.Unflatten(1, (n_channels, spike_length_samples))) + self._kwargs = dict( + out_channels=out_channels, + kernel_heights=kernel_heights, + kernel_lengths=kernel_lengths, + hidden_linear_dims=hidden_linear_dims, + n_channels=n_channels, + spike_length_samples=spike_length_samples, + final_activation=final_activation, + ) + + +class MLPMultiChannelDecollider(MultiChannelDecollider): + def __init__( + self, + hidden_sizes=(1024, 512, 512), + n_channels=1, + spike_length_samples=121, + final_activation="relu", + ): + super().__init__() + self.net = nn.Sequential() + self.net.append(nn.Flatten()) + input_sizes = (2 * n_channels * spike_length_samples,) + hidden_sizes[ + :-1 + ] + output_sizes = hidden_sizes + is_final = [False] * max(0, len(hidden_sizes) - 1) + [True] + for inf, outf, fin in zip(input_sizes, output_sizes, is_final): + self.net.append(nn.Linear(inf, outf)) + if fin and final_activation == "sigmoid": + self.net.append(nn.Sigmoid()) + if fin and final_activation == "tanh": + self.net.append(nn.Tanh()) + else: + self.net.append(nn.ReLU()) + self.net.append( + nn.Linear(hidden_sizes[-1], n_channels * spike_length_samples) + ) + self.net.append(nn.Unflatten(1, (n_channels, spike_length_samples))) + self._kwargs = dict( + hidden_sizes=hidden_sizes, + n_channels=n_channels, + spike_length_samples=spike_length_samples, + final_activation=final_activation, + ) diff --git a/src/dartsort/util/data_util.py b/src/dartsort/util/data_util.py index 4305e244..ab1565d0 100644 --- a/src/dartsort/util/data_util.py +++ b/src/dartsort/util/data_util.py @@ -1,5 +1,5 @@ from collections import namedtuple -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from pathlib import Path from typing import Optional from warnings import warn @@ -163,7 +163,11 @@ def from_peeling_hdf5( extra_features = None if load_simple_features: extra_features = {} - loaded = (times_samples_dataset, channels_dataset, labels_dataset) + loaded = ( + times_samples_dataset, + channels_dataset, + labels_dataset, + ) for k in h5: if ( k not in loaded @@ -240,3 +244,40 @@ def check_recording( failed = True return failed, avg_detections_per_second, max_abs + + +def subset_sorting_by_spike_count(sorting, min_spikes=0): + if not min_spikes: + return sorting + + units, counts = np.unique(sorting.labels, return_counts=True) + small_units = units[counts < min_spikes] + + new_labels = np.where( + np.isin(sorting.labels, small_units), -1, sorting.labels + ) + + return replace(sorting, labels=new_labels) + + +def subset_sorting_by_time_samples( + sorting, start_sample=0, end_sample=np.inf, reference_to_start_sample=True +): + new_times = sorting.times_samples.copy() + new_labels = sorting.labels.copy() + + in_range = (new_times >= start_sample) & (new_times < end_sample) + print(in_range.sum()) + new_labels[~in_range] = -1 + + if reference_to_start_sample: + new_times -= start_sample + + return replace(sorting, labels=new_labels, times_samples=new_times) + + +def reindex_sorting_labels(sorting): + new_labels = sorting.labels.copy() + kept = np.flatnonzero(new_labels >= 0) + _, new_labels[kept] = np.unique(new_labels[kept], return_inverse=True) + return replace(sorting, labels=new_labels) diff --git a/src/dartsort/util/decollider_util.py b/src/dartsort/util/decollider_util.py new file mode 100644 index 00000000..100856df --- /dev/null +++ b/src/dartsort/util/decollider_util.py @@ -0,0 +1,987 @@ +""" +Note: a lot of the waveforms in this file are NCT rather than NTC, because this +is the expected input format for conv1ds. +""" +import time +from dataclasses import dataclass +from typing import Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +from tqdm.auto import tqdm, trange + +from ..transform.decollider import SingleChannelPredictor +from ..transform.single_channel_denoiser import SingleChannelDenoiser +from . import spikeio + + +def train_decollider( + net, + recordings, + templates_train=None, + templates_val=None, + detection_times_train=None, + detection_channels_train=None, + detection_times_val=None, + detection_channels_val=None, + channel_index=None, + channel_min_amplitude=0.0, + channel_jitter_index=None, + examples_per_epoch=10_000, + noise_same_chans=False, + noise2_alpha=1.0, + trough_offset_samples=42, + spike_length_samples=121, + data_random_seed=0, + noise_max_amplitude=np.inf, + validation_oversamples=3, + n_unsupervised_val_examples=2000, + max_n_epochs=500, + early_stop_decrease_epochs=5, + batch_size=64, + loss_class=torch.nn.MSELoss, + val_every=1, + device=None, + show_progress=True, +): + """ + Arguments + --------- + net : torch.nn.Module + Should expect input shape NCL, where C is the number of input + channels (i.e. channel_index.shape[1]) and L is spike_length_samples. + If C>1, will also receive a N1L array "mask" containing 1s on valid + inputs and 0s elsewhere. The loss will be ignored on the zeros. + recordings : List[BaseRecording] + templates_{train,val} : List[np.ndarray], length == len(recordings) + channel_index : Optional[np.ndarray] + For training multi-chan nets + recording_channel_indices : Optional[List[np.ndarray]] + Per-recording channel indices (subsets of channel_index) to allow + for holes in individual recordings' channel sets + + Returns + ------- + net : torch.nn.Module + train_losses : np.ndarray + validation_dataframe : pd.DataFrame + """ + rg = np.random.default_rng(data_random_seed) + + device = torch.device(device) + + # initial validation + train_records = [] + val_records = [] + + # allow training on recordings with different channels missing + ( + n_channels_full, + recording_channel_indices, + channel_subsets, + channel_index, + ) = reconcile_channels(recordings, channel_index) + + # combine templates on different channels using NaN padding + # the NaNs inform masking below + # these are also padded with an extra channel of NaNs, to help + # with indexing below + templates_train_recording_origin = original_train_template_index = None + if templates_train is not None: + ( + templates_train, + templates_train_recording_origin, + original_train_template_index, + ) = combine_templates(templates_train, channel_subsets) + assert spike_length_samples == templates_train.shape[2] + + templates_val_recording_origin = original_val_template_index = None + if templates_val is not None: + ( + templates_val, + templates_val_recording_origin, + original_val_template_index, + ) = combine_templates(templates_val, channel_subsets) + assert spike_length_samples == templates_val.shape[2] + + opt = torch.optim.Adam(net.parameters()) + criterion = loss_class() + examples_seen = 0 + val_losses = [] + xrange = trange if show_progress else range + for epoch in xrange(max_n_epochs): + epoch_dt = 0.0 + + # get data + tic = time.perf_counter() + epoch_data = load_epoch( + recordings, + templates=templates_train, + detection_times=detection_times_train, + detection_channels=detection_channels_train, + channel_index=channel_index, + template_recording_origin=templates_train_recording_origin, + recording_channel_indices=recording_channel_indices, + channel_min_amplitude=channel_min_amplitude, + channel_jitter_index=channel_jitter_index, + examples_per_epoch=examples_per_epoch, + noise_same_chans=noise_same_chans, + noise2_alpha=noise2_alpha, + trough_offset_samples=trough_offset_samples, + spike_length_samples=spike_length_samples, + data_random_seed=rg, + noise_max_amplitude=noise_max_amplitude, + ) + toc = time.perf_counter() + epoch_data_load_wall_dt_s = toc - tic + + # train + epoch_losses = [] + for i0 in range( + 0, len(epoch_data.noisy_waveforms) - batch_size, batch_size + ): + tic = time.perf_counter() + + i1 = i0 + batch_size + noised_batch = epoch_data.noisier_waveforms[i0:i1].to(device) + target_batch = epoch_data.noisy_waveforms[i0:i1].to(device) + masks = None + if epoch_data.channel_masks is not None: + masks = epoch_data.channel_masks[i0:i1].to(device) + + opt.zero_grad() + pred = net(noised_batch, channel_masks=masks) + if masks is not None: + loss = criterion( + pred * masks[:, :, None], target_batch * masks[:, :, None] + ) + else: + loss = criterion(pred, target_batch) + loss.backward() + opt.step() + + toc = time.perf_counter() + batch_dt = toc - tic + + loss = float(loss.numpy(force=True)) + + epoch_losses.append(loss) + train_records.append( + dict( + loss=loss, + batch_train_wall_dt_s=batch_dt, + epoch=epoch, + samples=examples_seen, + ) + ) + + # learning trackers + examples_seen += noised_batch.shape[0] + epoch_dt += batch_dt + + if epoch % val_every: + continue + + # evaluate + val_record = evaluate_decollider( + net, + recordings, + templates=templates_val, + detection_times=detection_times_val, + detection_channels=detection_channels_val, + recording_channel_indices=recording_channel_indices, + template_recording_origin=templates_val_recording_origin, + original_template_index=original_val_template_index, + n_oversamples=validation_oversamples, + n_unsupervised_val_examples=n_unsupervised_val_examples, + channel_index=channel_index, + channel_min_amplitude=channel_min_amplitude, + channel_jitter_index=channel_jitter_index, + noise_same_chans=noise_same_chans, + noise2_alpha=noise2_alpha, + trough_offset_samples=trough_offset_samples, + spike_length_samples=spike_length_samples, + data_random_seed=rg, + noise_max_amplitude=noise_max_amplitude, + device=device, + summarize=True, + ) + val_record["epoch"] = epoch + val_record["epoch_train_wall_dt_s"] = epoch_dt + val_record["epoch_data_load_wall_dt_s"] = epoch_data_load_wall_dt_s + val_records.append(val_record) + val_losses.append(val_record["val_loss"]) + if show_progress: + summary = f"epoch {epoch}. " + summary += f"mean train loss: {np.mean(epoch_losses):0.3f}, " + summary += f"init train loss: {epoch_losses[0]:0.3f}, " + summary += ", ".join( + f"{k}: {v:0.3f}" for k, v in val_record.items() if k != "epoch" + ) + tqdm.write(summary) + + # stop early + if not early_stop_decrease_epochs or epoch < early_stop_decrease_epochs: + continue + + # See: Early Stopping -- But When? + best_epoch = np.argmin(val_losses) + if epoch - best_epoch > early_stop_decrease_epochs: + if show_progress: + tqdm.write(f"Early stopping at {epoch=}, since {best_epoch=}.") + break + + validation_dataframe = pd.DataFrame.from_records(val_records) + training_dataframe = pd.DataFrame.from_records(train_records) + return net, training_dataframe, validation_dataframe + + +# -- data helpers + + +@dataclass +class EpochData: + noisy_waveforms: torch.Tensor + channels: torch.LongTensor + noisier_waveforms: torch.Tensor + channel_masks: Optional[torch.Tensor] = None + + # for hybrid experiments + gt_waveforms: Optional[torch.Tensor] = None + template_indices: Optional[torch.LongTensor] = None + + +def load_epoch( + recordings, + templates=None, + detection_times=None, + detection_channels=None, + template_recording_origin=None, + channel_index=None, + recording_channel_indices=None, + channel_min_amplitude=0.0, + channel_jitter_index=None, + examples_per_epoch=10_000, + n_oversamples=1, + noise_same_chans=False, + noise2_alpha=1.0, + trough_offset_samples=42, + spike_length_samples=121, + data_random_seed=0, + noise_max_amplitude=np.inf, +): + rg = np.random.default_rng(data_random_seed) + + if templates is not None: + ( + channels, + gt_waveforms, + noisy_waveforms, + noise_chans, + which_rec, + which_templates, + ) = get_noised_hybrid_waveforms( + templates, + noise_same_chans=noise_same_chans, + recording_channel_indices=recording_channel_indices, + trough_offset_samples=trough_offset_samples, + template_recording_origin=template_recording_origin, + noise_max_amplitude=noise_max_amplitude, + channel_index=channel_index, + channel_jitter_index=channel_jitter_index, + n_oversamples=n_oversamples, + channel_min_amplitude=channel_min_amplitude, + examples_per_epoch=examples_per_epoch, + data_random_seed=rg, + ) + else: + assert detection_times is not None + assert detection_channels is not None + which_templates = gt_waveforms = None + + noisy_waveforms, channels, which_rec = load_spikes( + recordings, + times=detection_times, + channels=detection_channels, + recording_channel_indices=recording_channel_indices, + trough_offset_samples=trough_offset_samples, + spike_length_samples=spike_length_samples, + n=examples_per_epoch, + rg=rg, + to_torch=True, + ) + + # double noise + noisier_waveforms = load_noise( + recordings, + channels=channels if noise_same_chans else None, + which_rec=which_rec, + recording_channel_indices=recording_channel_indices, + trough_offset_samples=trough_offset_samples, + spike_length_samples=spike_length_samples, + n=noisy_waveforms.shape[0], + max_abs_amp=noise_max_amplitude, + dtype=recordings[0].dtype, + rg=rg, + to_torch=True, + alpha=noise2_alpha, + ) + noisier_waveforms += noisy_waveforms + + channel_masks = np.isfinite(noisier_waveforms[:, :, 0]) + channel_masks = torch.as_tensor(channel_masks, dtype=torch.bool) + if gt_waveforms is not None: + gt_waveforms[~channel_masks] = 0.0 + noisy_waveforms[~channel_masks] = 0.0 + noisier_waveforms[~channel_masks] = 0.0 + + return EpochData( + noisy_waveforms=noisy_waveforms, + noisier_waveforms=noisier_waveforms, + channels=channels, + channel_masks=channel_masks, + gt_waveforms=gt_waveforms, + template_indices=which_templates, + ) + + +@torch.no_grad() +def evaluate_decollider( + net, + recordings, + templates=None, + detection_times=None, + detection_channels=None, + recording_channel_indices=None, + template_recording_origin=None, + original_template_index=None, + n_oversamples=1, + n_unsupervised_val_examples=2000, + channel_index=None, + channel_min_amplitude=0.0, + channel_jitter_index=None, + noise_same_chans=False, + trough_offset_samples=42, + spike_length_samples=121, + data_random_seed=0, + noise_max_amplitude=np.inf, + noise2_alpha=1.0, + device=None, + summarize=True, +): + tic = time.perf_counter() + val_data = load_epoch( + recordings, + templates=templates, + detection_times=detection_times, + detection_channels=detection_channels, + template_recording_origin=template_recording_origin, + channel_index=channel_index, + recording_channel_indices=recording_channel_indices, + channel_min_amplitude=channel_min_amplitude, + channel_jitter_index=channel_jitter_index, + examples_per_epoch=None + if templates is not None + else n_unsupervised_val_examples, + n_oversamples=n_oversamples, + noise_same_chans=noise_same_chans, + noise2_alpha=noise2_alpha, + trough_offset_samples=trough_offset_samples, + data_random_seed=data_random_seed, + noise_max_amplitude=noise_max_amplitude, + spike_length_samples=spike_length_samples, + ) + toc = time.perf_counter() + val_data_load_wall_dt_s = toc - tic + + # metrics timer + # unsupervised task for learning: predict wf from noised_wf + # preds_noised = net(val_data.noisier_waveforms) + preds_noisy = batched_infer( + net, + val_data.noisier_waveforms, + channel_masks=val_data.channel_masks, + device=device, + ) + if summarize: + val_loss = F.mse_loss(preds_noisy, val_data.noisy_waveforms) + + if templates is None: + assert summarize + return dict( + val_loss=float(val_loss), + val_data_load_wall_dt_s=val_data_load_wall_dt_s, + val_metrics_wall_dt_s=time.perf_counter() - tic, + ) + + # below here, summarize is True, and we are working with templates + # so that gt_waveforms exists + + # supervised task: predict gt_wf (template) from wf (template + noise) + # template_preds_naive = net(val_data.noisy_waveforms) + template_preds_naive = batched_infer( + net, + val_data.noisy_waveforms, + channel_masks=val_data.channel_masks, + device=device, + ) + naive_sup_max_err = ( + torch.abs(template_preds_naive - val_data.gt_waveforms) + .max(dim=(1, 2)) + .values + ) + + # noisier2noise prediction of templates + # template_preds_n2n = net.n2n_forward( + # val_data.noisier_waveforms, + # channel_masks=val_data.channel_masks, + # alpha=noise2_alpha, + # ) + template_preds_n2n = batched_n2n_infer( + net, + val_data.noisier_waveforms, + channel_masks=val_data.channel_masks, + device=device, + alpha=noise2_alpha, + ) + n2n_sup_max_err = ( + torch.abs(template_preds_naive - val_data.gt_waveforms) + .max(dim=(1, 2)) + .values + ) + + if summarize: + naive_template_mean_max_err = naive_sup_max_err.mean() + naive_template_mse = F.mse_loss( + template_preds_naive, val_data.gt_waveforms + ) + + n2n_template_mean_max_err = n2n_sup_max_err.mean() + n2n_template_mse = F.mse_loss(template_preds_n2n, val_data.gt_waveforms) + + return dict( + naive_template_mean_max_err=naive_template_mean_max_err, + naive_template_mse=naive_template_mse, + n2n_template_mean_max_err=n2n_template_mean_max_err, + n2n_template_mse=n2n_template_mse, + val_loss=val_loss, + val_data_load_wall_dt_s=val_data_load_wall_dt_s, + val_metrics_wall_dt_s=time.perf_counter() - tic, + ) + + # more detailed per-example comparisons + # break down some covariates + gt_amplitude = templates.ptp(1)[val_data.which, val_data.channels] + noise1_norm = torch.linalg.norm( + val_data.noisy_waveforms - val_data.gt_waveforms, dim=(1, 2) + ) + noise2_norm = torch.linalg.norm( + val_data.noisier_waveforms - val_data.noisy_waveforms, dim=(1, 2) + ) + + # errors for naive template prediction + naive_diff = template_preds_naive - val_data.gt_waveforms + naive_template_mse = torch.square(naive_diff).mean(dim=(1, 2)) + naive_template_max_err = torch.abs(naive_diff).max(dim=(1, 2)).values + + # errors for noisier2noise template prediction + n2n_diff = template_preds_n2n - val_data.gt_waveforms + n2n_template_mse = torch.square(n2n_diff).mean(dim=(1, 2)) + n2n_template_max_err = torch.abs(n2n_diff).max(dim=(1, 2)).values + + return pd.DataFrame( + dict( + combined_template_index=val_data.template_indices, + recording_index=template_recording_origin[ + val_data.template_indices + ], + original_template_index=original_template_index[ + val_data.template_indices + ], + gt_amplitude=gt_amplitude.numpy(force=True), + noise1_norm=noise1_norm.numpy(force=True), + noise2_norm=noise2_norm.numpy(force=True), + naive_template_mse=naive_template_mse.numpy(force=True), + naive_template_max_err=naive_template_max_err.numpy(force=True), + n2n_template_mse=n2n_template_mse.numpy(force=True), + n2n_template_max_err=n2n_template_max_err.numpy(force=True), + ) + ) + + +# -- hybrid data helper + + +def get_noised_hybrid_waveforms( + templates, + noise_same_chans=False, + recording_channel_indices=None, + trough_offset_samples=42, + template_recording_origin=None, + noise_max_amplitude=np.inf, + channel_index=None, + channel_jitter_index=None, + n_oversamples=1, + channel_min_amplitude=0.0, + examples_per_epoch=10_000, + data_random_seed=0, +): + """ + templates are NCT here. + + If single_channel: + Random single channels of templates_train landing inside their channel-indexed + chans are returned, along with which channel each one was. + Else: + Random templates are selected + """ + rg = np.random.default_rng(data_random_seed) + + # channel logic + amplitude_vectors = templates.ptp(2) + max_channels = amplitude_vectors.argmax(1) + kept = np.arange(len(templates)) + if channel_min_amplitude > 0: + kept = kept[amplitude_vectors[kept].max(1) > channel_min_amplitude] + + # this is where the random sampling happens! + which = kept + if examples_per_epoch is not None: + which = rg.choice(kept, size=examples_per_epoch) + if n_oversamples != 1: + which = np.repeat(which, n_oversamples) + + # randomly jitter channel neighborhoods + # use a large channel_jitter_index when training single channel nets + channels = max_channels[which] + if channel_jitter_index is not None: + for i, c in enumerate(channels): + choices = channel_jitter_index[c] + choices = choices[choices < len(channel_jitter_index)] + channels[i] = rg.choice(choices) + + # load noisy_waveforms + waveform_channels = channel_index[channels][:, :, None] + # multi channel, or i guess you could just have one channel + # in your channel index if you're into that kind of thing + gt_waveforms = templates[ + which[:, None, None], + waveform_channels, + np.arange(templates.shape[1])[None, None, :], + ] + + # keep this fellow on CPU + channel_masks = np.isfinite(gt_waveforms[:, :, 0]) + gt_waveforms[~channel_masks] = 0.0 + channel_masks = torch.from_numpy(channel_masks) + + # to torch + channels = torch.from_numpy(channels) + gt_waveforms = torch.from_numpy(gt_waveforms) + + # apply noise + noise_chans = channels if noise_same_chans else None + which_rec = template_recording_origin[which] if noise_same_chans else None + noisy_waveforms = load_noise( + channels=noise_chans, + which_rec=which_rec, + recording_channel_indices=recording_channel_indices, + trough_offset_samples=trough_offset_samples, + spike_length_samples=templates.shape[1], + n=gt_waveforms.shape[0], + max_abs_amp=noise_max_amplitude, + dtype=gt_waveforms.dtype, + rg=rg, + to_torch=True, + ) + noisy_waveforms += gt_waveforms + + return ( + channels, + channel_masks, + gt_waveforms, + noisy_waveforms, + which_rec, + which, + ) + + +# -- noise helpers + + +def load_noise( + recordings, + channels=None, + which_rec=None, + recording_channel_indices=None, + trough_offset_samples=42, + spike_length_samples=121, + n=100, + max_abs_amp=np.inf, + dtype=None, + alpha=1.0, + rg=0, + to_torch=True, +): + """Get NCT noise arrays.""" + rg = np.random.default_rng(rg) + + if which_rec is None: + which_rec = rg.integers(len(recordings), size=n) + + if dtype is None: + dtype = recordings[0].dtype + + c = ( + recording_channel_indices[0].shape[1] + if recording_channel_indices is not None + else 1 + ) + + noise = np.full((n, c, spike_length_samples), np.nan, dtype=dtype) + for i, rec in enumerate(recordings): + mask = np.flatnonzero(which_rec == i) + rec_channels = None + rec_ci = recording_channel_indices[i] + if channels is not None: + rec_channels = channels[mask] + noise[mask] = load_noise_singlerec( + rec, + channels=rec_channels, + channel_index=rec_ci, + trough_offset_samples=trough_offset_samples, + spike_length_samples=spike_length_samples, + n=mask.size, + max_abs_amp=max_abs_amp, + dtype=dtype, + rg=rg, + alpha=alpha, + to_torch=False, + ) + + if to_torch: + noise = torch.from_numpy(noise) + + return noise + + +def load_noise_singlerec( + recording, + channels=None, + trough_offset_samples=42, + spike_length_samples=121, + channel_index=None, + n=100, + max_abs_amp=np.inf, + alpha=1.0, + dtype=None, + rg=0, + to_torch=True, +): + rg = np.random.default_rng(rg) + + if dtype is None: + dtype = recording.dtype + + c = channel_index.shape[1] if channel_index is not None else 1 + + noise = np.full((n, c, spike_length_samples), np.nan, dtype=dtype) + needs_load_ix = np.full((n,), True, dtype=bool) + + nc = recording.get_num_channels() + nt = recording.get_num_samples() + mint = trough_offset_samples + maxt = nt - (spike_length_samples - trough_offset_samples) + + while np.any(needs_load_ix): + n_load = needs_load_ix.sum() + times = rg.integers(mint, maxt, size=n_load) + order = np.argsort(times) + + if channels is None: + channels = rg.integers(0, nc, size=n_load) + + wfs = spikeio.read_waveforms_channel_index( + recording, + times[order], + channel_index, + channels[order], + trough_offset_samples=trough_offset_samples, + spike_length_samples=spike_length_samples, + ) + wfs = wfs.transpose(0, 2, 1) + + noise[needs_load_ix] = wfs[np.argsort(order)] + needs_load_ix = np.nanmax(np.abs(noise)) > max_abs_amp + + if alpha != 1.0: + noise *= alpha + + if to_torch: + noise = torch.from_numpy(noise) + + return noise + + +# -- signal helpers + + +def load_spikes( + recordings, + times, + channels, + which_rec=None, + recording_channel_indices=None, + trough_offset_samples=42, + spike_length_samples=121, + n=100, + dtype=None, + rg=0, + to_torch=True, +): + """Get NCT noise arrays.""" + rg = np.random.default_rng(rg) + + if which_rec is None: + which_rec = rg.integers(len(recordings), size=n) + + if dtype is None: + dtype = recordings[0].dtype + + c = ( + recording_channel_indices[0].shape[1] + if recording_channel_indices is not None + else 1 + ) + + spikes = np.full((n, c, spike_length_samples), np.nan, dtype=dtype) + channels_chosen = np.zeros(n, dtype=int) + for i, rec in enumerate(recordings): + mask = np.flatnonzero(which_rec == i) + rec_ci = recording_channel_indices[i] + + ( + spikes[mask], + channels_chosen[mask], + ) = load_spikes_singlerec( + rec, + times=times[i], + channels=channels[i], + channel_index=rec_ci, + trough_offset_samples=trough_offset_samples, + spike_length_samples=spike_length_samples, + n=mask.size, + dtype=dtype, + rg=rg, + to_torch=False, + ) + + if to_torch: + spikes = torch.from_numpy(spikes) + + return spikes, channels_chosen, which_rec + + +def load_spikes_singlerec( + recording, + times, + channels, + trough_offset_samples=42, + spike_length_samples=121, + channel_index=None, + n=100, + dtype=None, + rg=0, + to_torch=True, +): + rg = np.random.default_rng(rg) + + if dtype is None: + dtype = recording.dtype + + which = rg.choice(times.size, size=n, replace=False) + times = times[which] + channels = channels[which] + order = np.argsort(times) + + wfs = spikeio.read_waveforms_channel_index( + recording, + times[order], + channel_index, + channels[order], + trough_offset_samples=trough_offset_samples, + spike_length_samples=spike_length_samples, + ) + wfs = wfs.transpose(0, 2, 1) + + spikes = wfs[np.argsort(order)] + + if to_torch: + spikes = torch.from_numpy(spikes) + + return spikes, channels + + +# -- multi-recording channel logic + + +def reconcile_channels(recordings, channel_index): + """Validate that the multi-chan setup is workable, reconcile chans across recordings""" + full_channel_set = recordings[0].channel_ids + for rec in recordings: + ids = [int(i.lstrip("AP")) for i in rec.channel_ids] + assert np.array_equal(ids, np.sort(ids)) + full_channel_set = np.union1d(full_channel_set, rec.channel_ids) + n_channels_full = full_channel_set.size + + if channel_index is not None: + assert n_channels_full == channel_index.shape[0] + else: + channel_index = np.arange(n_channels_full)[:, None] + + channel_subsets = [] + recording_channel_indices = None + recording_channel_indices = [] + for rec in recordings: + subset = np.flatnonzero(np.isin(full_channel_set, rec.channel_ids)) + channel_subsets.append(subset) + subset_ci = subset_recording_channel_index(channel_index, subset) + recording_channel_indices.append(subset_ci) + + return ( + n_channels_full, + recording_channel_indices, + channel_subsets, + channel_index, + ) + + +def subset_recording_channel_index(full_channel_index, channel_subset): + """The output of this function is a channel index containing indices into the + recording, but matching full_channel_index, just with holes. + """ + subset_channel_index = np.full_like(full_channel_index, len(channel_subset)) + for recchan, fullchan in enumerate(channel_subset): + full_ci = full_channel_index[fullchan] + subset_channel_index[recchan] = np.searchsorted(channel_subset, full_ci) + return subset_channel_index + + +def combine_templates(templates, channel_subsets): + t = templates[0].shape[1] + c = channel_subsets[0].size + n = sum(map(len, templates)) + + combined = np.full( + (n, c + 1, t), fill_value=np.nan, dtype=templates[0].dtype + ) + template_recording_origin = np.zeros(n, dtype=int) + original_template_index = np.zeros(n, dtype=int) + i = 0 + for r, (temps, subset) in enumerate(zip(templates, channel_subsets)): + j = i + temps.shape[0] + template_recording_origin[i:j] = r + original_template_index[i:j] = np.arange(j - i) + combined[i:j, subset] = temps + + return combined, template_recording_origin, original_template_index + + +# -- inference utils + + +def batched_infer( + net, + noisy_waveforms, + channel_masks=None, + batch_size=16, + device=None, + show_progress=False, +): + is_tensor = torch.is_tensor(noisy_waveforms) + if is_tensor: + out = torch.empty_like(noisy_waveforms) + out = out.pin_memory() + else: + out = np.empty_like(noisy_waveforms) + + xrange = trange if show_progress else range + for batch_start in xrange(len(noisy_waveforms)): + wfs = noisy_waveforms[batch_start : batch_start + batch_size] + if not is_tensor: + wfs = torch.from_numpy(wfs) + wfs = wfs.to(device) + + cms = None + if channel_masks is not None: + cms = channel_masks[batch_start : batch_start + batch_size] + if not is_tensor: + cms = torch.from_numpy(cms).to(torch.bool) + cms = cms.to(device) + if cms is None and wfs.shape[1] > 1 and torch.isnan(wfs).any(): + cms = torch.isfinite(wfs[:, :, 0]) + + wfs = net.predict(wfs, channel_masks=cms) + + if is_tensor: + out[batch_start : batch_start + batch_size].copy_( + wfs, non_blocking=True + ) + else: + out[batch_start : batch_start + batch_size] = wfs.numpy(force=True) + + return out + + +def batched_n2n_infer( + net, + noisier_waveforms, + channel_masks=None, + batch_size=16, + device=None, + alpha=1.0, + show_progress=False +): + is_tensor = torch.is_tensor(noisier_waveforms) + if is_tensor: + out = torch.empty_like(noisier_waveforms) + else: + out = np.empty_like(noisier_waveforms) + + xrange = trange if show_progress else range + for batch_start in xrange(len(noisier_waveforms)): + wfs = noisier_waveforms[batch_start : batch_start + batch_size] + wfs = torch.as_tensor(wfs).to(device) + + cms = None + if channel_masks is not None: + cms = channel_masks[batch_start : batch_start + batch_size] + if not is_tensor: + cms = torch.from_numpy(cms).to(torch.bool) + if cms is None and wfs.shape[1] > 1 and torch.isnan(wfs).any(): + cms = torch.isfinite(wfs[:, :, 0]) + + wfs = wfs.to(device) + wfs = net.n2n_predict(wfs, channel_masks=cms, alpha=alpha) + + if is_tensor: + out[batch_start : batch_start + batch_size] = wfs.to(out.device) + else: + out[batch_start : batch_start + batch_size] = wfs.numpy(force=True) + + return out + + +# -- for testing + + +class SCDAsDecollider(SingleChannelDenoiser, SingleChannelPredictor): + def forward(self, x): + """N1T -> N1T""" + x = self.conv1(x) + x = self.conv2(x) + x = x.view(x.shape[0], -1) + x = self.out(x) + return x[:, None, :] diff --git a/src/dartsort/util/multiprocessing_util.py b/src/dartsort/util/multiprocessing_util.py index f88e4804..9f765975 100644 --- a/src/dartsort/util/multiprocessing_util.py +++ b/src/dartsort/util/multiprocessing_util.py @@ -1,11 +1,13 @@ import multiprocessing from concurrent.futures import ProcessPoolExecutor from multiprocessing import get_context +import math # TODO: torch.multiprocessing? try: import cloudpickle + have_cloudpickle = True except ImportError: pass @@ -58,16 +60,30 @@ def __init__(self): self.get = lambda: self.q.pop(0) -def apply_cloudpickle(fn, /, *args, **kwargs): +def cloudpickle_run(fn, args): fn = cloudpickle.loads(fn) + args, kwargs = cloudpickle.loads(args) + # return cloudpickle.dumps(fn(*args, **kwargs)) return fn(*args, **kwargs) +# def uncloudpickle(future): +# res = future.result() +# future._result = cloudpickle.loads(res) + + class CloudpicklePoolExecutor(ProcessPoolExecutor): def submit(self, fn, /, *args, **kwargs): - return super().submit( - apply_cloudpickle, cloudpickle.dumps(fn), *args, **kwargs - ) + args = cloudpickle.dumps((args, kwargs)) + future = super().submit(cloudpickle_run, cloudpickle.dumps(fn), args) + # future.add_done_callback(uncloudpickle_callback) + return future + + +def rank_init(queue): + print(f"rank init waiting") + rank_init.rank = queue.get() + print(f"rank init got {rank_init.rank=}") def get_pool( @@ -75,22 +91,34 @@ def get_pool( context="spawn", cls=ProcessPoolExecutor, with_rank_queue=False, + n_tasks=None, + max_tasks_per_child=None, ): if n_jobs == -1: n_jobs = multiprocessing.cpu_count() do_parallel = n_jobs >= 1 n_jobs = max(1, n_jobs) + if cls == CloudpicklePoolExecutor and not have_cloudpickle: cls = ProcessPoolExecutor + Executor = cls if do_parallel else MockPoolExecutor context = get_context(context) + if with_rank_queue: if do_parallel: manager = context.Manager() rank_queue = manager.Queue() else: rank_queue = MockQueue() - for rank in range(n_jobs): - rank_queue.put(rank) + + n_repeats = 1 + if max_tasks_per_child is not None: + n_repeats = n_tasks // max_tasks_per_child + 1 + for _ in range(n_repeats): + for rank in range(n_jobs): + rank_queue.put(rank) + return n_jobs, Executor, context, rank_queue + return n_jobs, Executor, context