Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decollider testing #9

Merged
merged 26 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
ca6bd91
Working on decollider training
cwindolf Jan 23, 2024
40e6dbf
Flesh out neural net decolliders and training
cwindolf Jan 23, 2024
cfc5883
Some spike train helpers
cwindolf Jan 23, 2024
a0faca6
Wrong no_grad context
cwindolf Jan 23, 2024
dd4fa17
Add Noisier2Noise prediction, training, evaluation
cwindolf Jan 24, 2024
8f74e2e
Merge branch 'decollider-testing' of github.com:cwindolf/dartsort int…
cwindolf Jan 24, 2024
18a6406
Add a classic thresholding, just because I want to detect spikes for …
cwindolf Jan 24, 2024
8cf36b6
Imports
cwindolf Jan 24, 2024
bfd16b4
Friendlier
cwindolf Jan 24, 2024
4e2dc3d
Debug edges
cwindolf Jan 24, 2024
8790ee9
Improve handling of varying channel neighborhoods; sketch out fully u…
cwindolf Jan 25, 2024
740f217
Improve handling of varying channel neighborhoods; sketch out fully u…
cwindolf Jan 25, 2024
ba9f960
Shim for old denoiser
cwindolf Jan 25, 2024
c41abec
Checking
cwindolf Jan 25, 2024
a050647
Sketch out unsupervised training
cwindolf Jan 25, 2024
82c4756
Sketch out unsupervised training
cwindolf Jan 25, 2024
2e3d9bf
Fix channel logic and improve batching helpers
cwindolf Jan 25, 2024
5acb501
Debug decollider and improve the training and metrics code
cwindolf Jan 26, 2024
3a70503
quality of life
cwindolf Jan 27, 2024
71a6527
Fix save/load
cwindolf Jan 27, 2024
d7aa5f6
Improve training, saving
cwindolf Jan 29, 2024
c802cec
Trying this sigmoid idea
cwindolf Jan 29, 2024
448b02e
Realignment stuff
cwindolf Feb 20, 2024
0d131dc
Cloudpickle things
cwindolf Feb 20, 2024
39693c8
Net training improvements
cwindolf Feb 20, 2024
2b8bf46
Merge upstream
cwindolf Feb 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/dartsort/peel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .grab import GrabAndFeaturize
from .matching import ObjectiveUpdateTemplateMatchingPeeler
from .subtract import SubtractionPeeler, subtract_chunk
from .threshold import ThresholdAndFeaturize
25 changes: 19 additions & 6 deletions src/dartsort/peel/peel_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
self,
recording,
channel_index,
featurization_pipeline,
featurization_pipeline=None,
chunk_length_samples=30_000,
chunk_margin_samples=0,
n_chunks_fit=40,
Expand All @@ -46,7 +46,10 @@ def __init__(
fit_subsampling_random_state
)
self.register_buffer("channel_index", channel_index)
self.add_module("featurization_pipeline", featurization_pipeline)
if featurization_pipeline is not None:
self.add_module("featurization_pipeline", featurization_pipeline)
else:
self.featurization_pipeline = None

# subclasses can append to this if they want to store more fixed
# arrays in the output h5 file
Expand Down Expand Up @@ -237,16 +240,20 @@ def out_datasets(self):
SpikeDataset(name="times_seconds", shape_per_spike=(), dtype=float),
SpikeDataset(name="channels", shape_per_spike=(), dtype=int),
]
for transformer in self.featurization_pipeline.transformers:
if transformer.is_featurizer:
datasets.append(transformer.spike_dataset)
if self.featurization_pipeline is not None:
for transformer in self.featurization_pipeline.transformers:
if transformer.is_featurizer:
datasets.append(transformer.spike_dataset)
return datasets

# -- utility methods which users likely won't touch

def featurize_collisioncleaned_waveforms(
self, collisioncleaned_waveforms, max_channels
):
if self.featurization_pipeline is None:
return {}

waveforms, features = self.featurization_pipeline(
collisioncleaned_waveforms, max_channels
)
Expand Down Expand Up @@ -329,7 +336,10 @@ def gather_chunk_result(
return n_new_spikes

def needs_fit(self):
return self.peeling_needs_fit() or self.featurization_pipeline.needs_fit()
it_does = self.peeling_needs_fit()
if self.featurization_pipeline is not None:
it_does = it_does or self.featurization_pipeline.needs_fit()
return it_does

def fit_models(self, save_folder, overwrite=False, n_jobs=0, device=None):
with torch.no_grad():
Expand All @@ -349,6 +359,9 @@ def fit_models(self, save_folder, overwrite=False, n_jobs=0, device=None):
assert not self.needs_fit()

def fit_featurization_pipeline(self, save_folder, n_jobs=0, device=None):
if self.featurization_pipeline is None:
return

if not self.featurization_pipeline.needs_fit():
return

Expand Down
4 changes: 2 additions & 2 deletions src/dartsort/peel/subtract.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
spike_length_samples=121,
detection_thresholds=[12, 10, 8, 6, 5, 4],
chunk_length_samples=30_000,
peak_sign="neg",
peak_sign="both",
spatial_dedup_channel_index=None,
n_chunks_fit=40,
fit_subsampling_random_state=0,
Expand Down Expand Up @@ -274,7 +274,7 @@ def subtract_chunk(
left_margin=0,
right_margin=0,
detection_thresholds=[12, 10, 8, 6, 5, 4],
peak_sign="neg",
peak_sign="both",
spatial_dedup_channel_index=None,
residnorm_decrease_threshold=3.162, # sqrt(10)
):
Expand Down
95 changes: 95 additions & 0 deletions src/dartsort/peel/threshold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
from dartsort.detect import detect_and_deduplicate
from dartsort.util import spiketorch

from .peel_base import BasePeeler


class ThresholdAndFeaturize(BasePeeler):
def __init__(
self,
recording,
channel_index,
featurization_pipeline=None,
trough_offset_samples=42,
spike_length_samples=121,
detection_threshold=5.0,
chunk_length_samples=30_000,
peak_sign="both",
spatial_dedup_channel_index=None,
n_chunks_fit=40,
fit_subsampling_random_state=0,
):
super().__init__(
recording=recording,
channel_index=channel_index,
featurization_pipeline=featurization_pipeline,
chunk_length_samples=chunk_length_samples,
chunk_margin_samples=spike_length_samples,
n_chunks_fit=n_chunks_fit,
fit_subsampling_random_state=fit_subsampling_random_state,
)

self.trough_offset_samples = trough_offset_samples
self.spike_length_samples = spike_length_samples
self.peak_sign = peak_sign
if spatial_dedup_channel_index is not None:
self.register_buffer(
"spatial_dedup_channel_index",
spatial_dedup_channel_index,
)
else:
self.spatial_dedup_channel_index = None
self.detection_threshold = detection_threshold
self.peel_kind = f"Threshold {detection_threshold}"

def peel_chunk(
self,
traces,
chunk_start_samples=0,
left_margin=0,
right_margin=0,
return_residual=False,
):
times_rel, channels = detect_and_deduplicate(
traces,
self.detection_threshold,
dedup_channel_index=self.spatial_dedup_channel_index,
peak_sign=self.peak_sign,
)
if not times_rel.numel():
return dict(n_spikes=0)

# want only peaks in the chunk
min_time = max(left_margin, self.spike_length_samples)
max_time = traces.shape[0] - max(
right_margin, self.spike_length_samples - self.trough_offset_samples
)
valid = (times_rel >= min_time) & (times_rel < max_time)
times_rel = times_rel[valid]
if not times_rel.numel():
return dict(n_spikes=0)
channels = channels[valid]

# load up the waveforms for this chunk
waveforms = spiketorch.grab_spikes(
traces,
times_rel,
channels,
self.channel_index,
trough_offset=self.trough_offset_samples,
spike_length_samples=self.spike_length_samples,
already_padded=False,
pad_value=torch.nan,
)

# get absolute times
times_samples = times_rel + chunk_start_samples - left_margin

peel_result = dict(
n_spikes=times_rel.numel(),
times_samples=times_samples,
channels=channels,
collisioncleaned_waveforms=waveforms,
)
return peel_result
4 changes: 4 additions & 0 deletions src/dartsort/templates/template_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,15 @@ def get_registered_templates(
def get_realigned_sorting(
recording,
sorting,
realign_peaks=True,
low_rank_denoising=False,
**kwargs,
):
results = get_templates(
recording,
sorting,
realign_peaks=realign_peaks,
low_rank_denoising=low_rank_denoising,
**kwargs,
)
return results["sorting"]
Expand Down
2 changes: 2 additions & 0 deletions src/dartsort/templates/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def from_config(
save_npz_name="template_data.npz",
localizations_dataset_name="point_source_localizations",
n_jobs=0,
units_per_job=8,
device=None,
trough_offset_samples=42,
spike_length_samples=121,
Expand Down Expand Up @@ -155,6 +156,7 @@ def from_config(
denoising_fit_radius=template_config.denoising_fit_radius,
denoising_snr_threshold=template_config.denoising_snr_threshold,
device=device,
units_per_job=units_per_job,
)
if template_config.registered_templates and motion_est is not None:
kwargs["registered_geom"] = drift_util.registered_geometry(
Expand Down
Loading