Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
julienboussard committed Feb 1, 2024
2 parents e1b82b2 + b70e2cb commit f5e75f9
Show file tree
Hide file tree
Showing 13 changed files with 396 additions and 73 deletions.
59 changes: 48 additions & 11 deletions src/dartsort/cluster/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def merge_templates(
motion_est=None,
max_shift_samples=20,
superres_linkage=np.max,
sym_function=np.minimum,
merge_distance_threshold=0.25,
temporal_upsampling_factor=8,
amplitude_scaling_variance=0.0,
Expand Down Expand Up @@ -68,6 +69,50 @@ def merge_templates(
save_npz_name=template_npz_filename,
)

units, dists, shifts, template_snrs = calculate_merge_distances(
template_data,
superres_linkage=superres_linkage,
sym_function=sym_function,
max_shift_samples=max_shift_samples,
temporal_upsampling_factor=temporal_upsampling_factor,
amplitude_scaling_variance=amplitude_scaling_variance,
amplitude_scaling_boundary=amplitude_scaling_boundary,
svd_compression_rank=svd_compression_rank,
min_channel_amplitude=min_channel_amplitude,
conv_batch_size=conv_batch_size,
units_batch_size=units_batch_size,
device=device,
n_jobs=n_jobs,
show_progress=show_progress,
)

# now run hierarchical clustering
return recluster(
sorting,
units,
dists,
shifts,
template_snrs,
merge_distance_threshold=merge_distance_threshold,
)


def calculate_merge_distances(
template_data,
superres_linkage=np.max,
sym_function=np.minimum,
max_shift_samples=20,
temporal_upsampling_factor=8,
amplitude_scaling_variance=0.0,
amplitude_scaling_boundary=0.5,
svd_compression_rank=10,
min_channel_amplitude=0.0,
conv_batch_size=128,
units_batch_size=8,
device=None,
n_jobs=0,
show_progress=True,
):
# allocate distance + shift matrices. shifts[i,j] is trough[j]-trough[i].
n_templates = template_data.templates.shape[0]
sup_dists = np.full((n_templates, n_templates), np.inf)
Expand Down Expand Up @@ -116,15 +161,9 @@ def merge_templates(
template_data.templates.ptp(1).max(1) / template_data.spike_counts
)

# now run hierarchical clustering
return recluster(
sorting,
units,
dists,
shifts,
template_snrs,
merge_distance_threshold=merge_distance_threshold,
)
dists = sym_function(dists, dists.T)

return units, dists, shifts, template_snrs


def recluster(
Expand All @@ -134,9 +173,7 @@ def recluster(
shifts,
template_snrs,
merge_distance_threshold=0.25,
sym_function=np.minimum,
):
dists = sym_function(dists, dists.T)

# upper triangle not including diagonal, aka condensed distance matrix in scipy
pdist = dists[np.triu_indices(dists.shape[0], k=1)]
Expand Down
14 changes: 6 additions & 8 deletions src/dartsort/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
default_pretrained_path = default_pretrained_path.joinpath("single_chan_denoiser.pt")


# TODO: integrate this in the other configs
@dataclass(frozen=True)
class WaveformConfig:
"""Defaults yield 42 sample trough offset and 121 total at 30kHz."""
Expand All @@ -44,7 +43,10 @@ def trough_offset_samples(self, sampling_frequency=30_000):

def spike_length_samples(self, sampling_frequency=30_000):
spike_len_ms = self.ms_before + self.ms_after
return int(spike_len_ms * (sampling_frequency / 1000))
length = int(spike_len_ms * (sampling_frequency / 1000))
# odd is better for convolution arithmetic elsewhere
length = 2 * (length // 2) + 1
return length


@dataclass(frozen=True)
Expand Down Expand Up @@ -103,8 +105,6 @@ class FeaturizationConfig:

@dataclass(frozen=True)
class SubtractionConfig:
trough_offset_samples: int = 42
spike_length_samples: int = 121
detection_thresholds: List[int] = (10, 8, 6, 5, 4)
chunk_length_samples: int = 30_000
peak_sign: str = "both"
Expand Down Expand Up @@ -147,8 +147,6 @@ class MotionEstimationConfig:

@dataclass(frozen=True)
class TemplateConfig:
trough_offset_samples: int = 42
spike_length_samples: int = 121
spikes_per_unit = 500

# -- template construction parameters
Expand Down Expand Up @@ -177,8 +175,6 @@ class TemplateConfig:

@dataclass(frozen=True)
class MatchingConfig:
trough_offset_samples: int = 42
spike_length_samples: int = 121
chunk_length_samples: int = 30_000
extract_radius: float = 200.0
n_chunks_fit: int = 40
Expand Down Expand Up @@ -235,10 +231,12 @@ class SplitMergeConfig:
merge_distance_threshold: float = 0.25


default_waveform_config = WaveformConfig()
default_featurization_config = FeaturizationConfig()
default_subtraction_config = SubtractionConfig()
default_template_config = TemplateConfig()
default_clustering_config = ClusteringConfig()
default_split_merge_config = SplitMergeConfig()
coarse_template_config = TemplateConfig(superres_templates=False)
default_matching_config = MatchingConfig()
default_motion_estimation_config = MotionEstimationConfig()
33 changes: 29 additions & 4 deletions src/dartsort/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import asdict
from pathlib import Path

from dartsort.cluster.initial import ensemble_chunks
Expand All @@ -6,14 +7,16 @@
from dartsort.config import (default_clustering_config,
default_featurization_config,
default_matching_config,
default_motion_estimation_config,
default_split_merge_config,
default_subtraction_config,
default_template_config)
default_template_config, default_waveform_config)
from dartsort.peel import (ObjectiveUpdateTemplateMatchingPeeler,
SubtractionPeeler)
from dartsort.templates import TemplateData
from dartsort.util.data_util import check_recording
from dartsort.util.peel_util import run_peeler
from dartsort.util.registration_util import estimate_motion


def dartsort_from_config(
Expand All @@ -26,7 +29,9 @@ def dartsort_from_config(
def dartsort(
recording,
output_directory,
waveform_config=default_waveform_config,
featurization_config=default_featurization_config,
motion_estimation_config=default_motion_estimation_config,
subtraction_config=default_subtraction_config,
matching_config=default_subtraction_config,
template_config=default_template_config,
Expand All @@ -43,15 +48,22 @@ def dartsort(
sorting, sub_h5 = subtract(
recording,
output_directory,
waveform_config=waveform_config,
featurization_config=featurization_config,
subtraction_config=subtraction_config,
n_jobs=n_jobs,
overwrite=overwrite,
device=device,
)
if motion_est is None:
# TODO
motion_est = estimate_motion()
motion_est = estimate_motion(
recording,
sorting,
output_directory,
overwrite=overwrite,
device=device,
**asdict(motion_estimation_config),
)
sorting = cluster(
sub_h5,
recording,
Expand All @@ -78,6 +90,7 @@ def dartsort(
output_directory,
motion_est=motion_est,
template_config=template_config,
waveform_config=waveform_config,
featurization_config=featurization_config,
matching_config=matching_config,
n_jobs_templates=n_jobs,
Expand All @@ -95,6 +108,7 @@ def dartsort(
def subtract(
recording,
output_directory,
waveform_config=default_waveform_config,
featurization_config=default_featurization_config,
subtraction_config=default_subtraction_config,
chunk_starts_samples=None,
Expand All @@ -109,6 +123,7 @@ def subtract(
check_recording(recording)
subtraction_peeler = SubtractionPeeler.from_config(
recording,
waveform_config=waveform_config,
subtraction_config=subtraction_config,
featurization_config=featurization_config,
)
Expand Down Expand Up @@ -174,6 +189,7 @@ def match(
sorting=None,
output_directory=None,
motion_est=None,
waveform_config=default_waveform_config,
template_config=default_template_config,
featurization_config=default_featurization_config,
matching_config=default_matching_config,
Expand All @@ -192,21 +208,30 @@ def match(
model_dir = Path(output_directory) / model_subdir

# compute templates
trough_offset_samples = waveform_config.trough_offset_samples(
recording.sampling_frequency
)
spike_length_samples = waveform_config.spike_length_samples(
recording.sampling_frequency
)
template_data = TemplateData.from_config(
recording,
sorting,
template_config,
template_config=template_config,
motion_est=motion_est,
n_jobs=n_jobs_templates,
save_folder=model_dir,
overwrite=overwrite,
device=device,
save_npz_name=template_npz_filename,
trough_offset_samples=trough_offset_samples,
spike_length_samples=spike_length_samples,
)

# instantiate peeler
matching_peeler = ObjectiveUpdateTemplateMatchingPeeler.from_config(
recording,
waveform_config,
matching_config,
featurization_config,
template_data,
Expand Down
6 changes: 5 additions & 1 deletion src/dartsort/peel/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def handle_upsampling(
def from_config(
cls,
recording,
waveform_config,
matching_config,
featurization_config,
template_data,
Expand All @@ -377,6 +378,9 @@ def from_config(
featurization_pipeline = WaveformPipeline.from_config(
geom, channel_index, featurization_config
)
trough_offset_samples = waveform_config.trough_offset_samples(
recording.sampling_frequency
)
return cls(
recording,
template_data,
Expand All @@ -391,7 +395,7 @@ def from_config(
amplitude_scaling_boundary=matching_config.amplitude_scaling_boundary,
conv_ignore_threshold=matching_config.conv_ignore_threshold,
coarse_approx_error_threshold=matching_config.coarse_approx_error_threshold,
trough_offset_samples=matching_config.trough_offset_samples,
trough_offset_samples=trough_offset_samples,
threshold=matching_config.threshold,
chunk_length_samples=matching_config.chunk_length_samples,
n_chunks_fit=matching_config.n_chunks_fit,
Expand Down
Loading

0 comments on commit f5e75f9

Please sign in to comment.