Skip to content

Commit

Permalink
Bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Oct 22, 2024
1 parent aa942c4 commit 9632b7a
Show file tree
Hide file tree
Showing 11 changed files with 76 additions and 37 deletions.
29 changes: 17 additions & 12 deletions src/dartsort/cluster/initial.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def cluster_chunk(
if sorting is None:
sorting = DARTsortSorting.from_peeling_hdf5(peeling_hdf5_filename)
xyza = getattr(sorting, localizations_dataset_name)
z_reg = motion_est.correct_s(sorting.times_seconds, xyza[:, 2])
amps = getattr(sorting, amplitudes_dataset_name)

if recording is None:
Expand All @@ -68,20 +69,24 @@ def cluster_chunk(
to_cluster = np.setdiff1d(to_cluster, np.flatnonzero(sorting.labels < -1))
labels = np.full_like(sorting.labels, -1)
extra_features = sorting.extra_features
z_reg = z_reg[to_cluster]
xyza = xyza[to_cluster]
amps = amps[to_cluster]
t_s = sorting.times_seconds[to_cluster]

if clustering_config.cluster_strategy == "closest_registered_channels":
labels[to_cluster] = cluster_util.closest_registered_channels(
sorting.times_seconds[to_cluster],
xyza[to_cluster, 0],
xyza[to_cluster, 2],
t_s,
xyza[:, 0],
xyza[:, 2],
geom,
motion_est,
)
elif clustering_config.cluster_strategy == "grid_snap":
labels[to_cluster] = cluster_util.grid_snap(
sorting.times_seconds[to_cluster],
xyza[to_cluster, 0],
xyza[to_cluster, 2],
t_s,
xyza[:, 0],
xyza[:, 2],
geom,
grid_dx=clustering_config.grid_dx,
grid_dz=clustering_config.grid_dz,
Expand All @@ -90,11 +95,11 @@ def cluster_chunk(
elif clustering_config.cluster_strategy == "hdbscan":
labels[to_cluster] = cluster_util.hdbscan_clustering(
recording,
sorting.times_seconds[to_cluster],
t_s,
sorting.times_samples[to_cluster],
xyza[to_cluster, 0],
xyza[to_cluster, 2],
amps[to_cluster],
xyza[:, 0],
xyza[:, 2],
amps,
geom,
motion_est,
min_cluster_size=clustering_config.min_cluster_size,
Expand All @@ -109,9 +114,9 @@ def cluster_chunk(
zstd_big_units=clustering_config.zstd_big_units,
)
elif clustering_config.cluster_strategy == "dpc":
features = (xyza[:, [0, 2]][to_cluster],)
features = (xyza[:, 0], z_reg)
if clustering_config.use_amplitude:
ampfeat = np.log(clustering_config.amp_log_c + amps[to_cluster])
ampfeat = np.log(clustering_config.amp_log_c + amps)
ampfeat *= clustering_config.amp_scale
features = (*features, ampfeat)
if clustering_config.n_main_channel_pcs:
Expand Down
1 change: 1 addition & 0 deletions src/dartsort/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class FeaturizationConfig:
# in the future we may add multi-channel or other nns
nn_denoiser_class_name: str = "SingleChannelWaveformDenoiser"
nn_denoiser_pretrained_path: Optional[str] = default_pretrained_path
nn_denoiser_train_epochs: int = 50

# optionally restrict how many channels TPCA are fit on
tpca_fit_radius: Optional[float] = None
Expand Down
14 changes: 8 additions & 6 deletions src/dartsort/peel/subtract.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,21 +291,23 @@ def _fit_subtraction_transformers(
ifeats = [init_voltage_feature, init_waveform_feature]
if which == "denoisers":
# add all the already fitted denoisers until we hit the next unfitted one
ffeats = []
already_fitted = []
fit_feats = []
for t in orig_denoise:
if t.is_denoiser:
if t.needs_fit():
fit_feats = [t]
break
ffeats.append(t)
already_fitted.append(t)

# this is the sequence of transforms to actually use in fitting
fit_feats = ffeats + [t]
fit_feats = already_fitted + fit_feats

# if we have no denoisers yet, then definitely don't do subtraction!
self._turn_off_subtraction = not ffeats
self._turn_off_subtraction = not already_fitted
else:
ffeats = [t for t in orig_denoise if t.is_denoiser]
self.subtraction_denoising_pipeline = WaveformPipeline(ifeats + ffeats)
already_fitted = [t for t in orig_denoise if t.is_denoiser]
self.subtraction_denoising_pipeline = WaveformPipeline(ifeats + already_fitted)

# and we don't need any features for this
orig_featurization_pipeline = self.featurization_pipeline
Expand Down
16 changes: 16 additions & 0 deletions src/dartsort/templates/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,22 @@ def to_npz(self, npz_path):
npz_path.parent.mkdir()
np.savez(npz_path, **to_save)

def __getitem__(self, subset):
# need to implement other cases
assert np.array_equal(self.unit_ids, np.arange(len(self.unit_ids)))
return self.__class__(
templates=self.templates[subset],
unit_ids=self.unit_ids[subset],
spike_counts=self.spike_counts[subset],
spike_counts_by_channel=self.spike_counts_by_channel[subset] if self.spike_counts_by_channel is not None else None,
raw_std_dev=self.raw_std_dev[subset] if self.raw_std_dev is not None else None,
registered_geom=self.registered_geom,
registered_template_depths_um=self.registered_template_depths_um[subset] if self.registered_template_depths_um is not None else None,
localization_radius_um=self.localization_radius_um,
trough_offset_samples=self.trough_offset_samples,
spike_length_samples=self.spike_length_samples,
)

def coarsen(self, with_locs=True):
"""Weighted average all templates that share a unit id and re-localize."""
# update templates
Expand Down
12 changes: 7 additions & 5 deletions src/dartsort/transform/amortized_localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,12 @@ def loss_function(self, recon_x, x, mask, mu, var):
recon_x_masked = recon_x * mask
x_masked = x * mask
if self.scale_loss_by_mean:
# 1/mean amplitude
rescale = mask.sum(1, keepdim=True) / x_masked.sum(1, keepdim=True)
x_masked *= rescale
recon_x_masked *= rescale
# 1/(n_chans_retained*mean amplitude)
rescale = 1.0 / x_masked.sum(1, keepdim=True)
else:
rescale = 1.0 / mask.sum(1, keepdim=True)
x_masked *= rescale
recon_x_masked *= rescale
mse = F.mse_loss(recon_x_masked, x_masked, reduction="sum") / self.batch_size
kld = 0.0
if self.variational:
Expand Down Expand Up @@ -308,7 +310,7 @@ def _fit(self, waveforms, channels):
pbar.set_description(f"Localizer converged at epoch={epoch} {desc}")
break

def fit(self, waveforms, max_channels):
def fit(self, waveforms, max_channels, recording=None):
with torch.enable_grad():
self._fit(waveforms, max_channels)
self._needs_fit = False
Expand Down
6 changes: 3 additions & 3 deletions src/dartsort/transform/decollider.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
inference_kind="amortized",
batch_size=32,
learning_rate=1e-3,
epochs=50,
n_epochs=50,
channelwise_dropout_p=0.2,
inference_z_samples=10,
n_data_workers=4,
Expand All @@ -54,7 +54,7 @@ def __init__(
self.n_channels = len(geom)
self.batch_size = batch_size
self.learning_rate = learning_rate
self.epochs = epochs
self.n_epochs = n_epochs
self.channelwise_dropout_p = channelwise_dropout_p
self.n_data_workers = n_data_workers
self.with_conv_fullheight = with_conv_fullheight
Expand Down Expand Up @@ -268,7 +268,7 @@ def _fit(self, waveforms, channels, recording):
)
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)

with trange(self.epochs, desc="Epochs", unit="epoch") as pbar:
with trange(self.n_epochs, desc="Epochs", unit="epoch") as pbar:
for epoch in pbar:
epoch_losses = {}
for (waveform_batch, channels_batch), noise_batch in dataloader:
Expand Down
10 changes: 7 additions & 3 deletions src/dartsort/transform/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ def fit(self, waveforms, max_channels, recording):
return

for transformer in self.transformers:
transformer.train()
transformer.fit(waveforms, max_channels=max_channels, recording=recording)
if transformer.needs_fit():
transformer.train()
transformer.fit(waveforms, max_channels=max_channels, recording=recording)
transformer.eval()

# if we're done already, stop before denoising
Expand Down Expand Up @@ -160,7 +161,10 @@ def featurization_config_to_class_names_and_kwargs(fconf):
class_names_and_kwargs.append(
(
fconf.nn_denoiser_class_name,
{"pretrained_path": fconf.nn_denoiser_pretrained_path},
{
"pretrained_path": fconf.nn_denoiser_pretrained_path,
"n_epochs": fconf.nn_denoiser_train_epochs,
},
)
)
if fconf.do_tpca_denoise:
Expand Down
1 change: 1 addition & 0 deletions src/dartsort/transform/single_channel_denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
name=None,
name_prefix="",
clsname="SingleChannelDenoiser",
n_epochs=None,
):
super().__init__(
channel_index=channel_index, name=name, name_prefix=name_prefix
Expand Down
3 changes: 3 additions & 0 deletions src/dartsort/transform/transform_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def needs_precompute(self):
def precompute(self):
pass

def extra_repr(self):
return f"name={self.name},needs_fit={self.needs_fit()}"


class BaseWaveformDenoiser(BaseWaveformModule):
is_denoiser = True
Expand Down
14 changes: 9 additions & 5 deletions src/dartsort/util/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,8 @@ def unit_info_dataframe(self):
df = df.astype(float) # not sure what the problem was...
df['gt_ptp_amplitude'] = amplitudes
df['gt_firing_rate'] = firing_rates
dist = self.template_distances[
np.arange(len(self.template_distances)), self.comparison.best_match_12.astype(int)
]
print(f"{self.template_distances.shape=}")
dist = np.diagonal(self.template_distances)
df['temp_dist'] = dist
rec = []
for uid in df.index:
Expand Down Expand Up @@ -115,7 +114,11 @@ def _calculate_template_distances(self):
return

gt_td = self.gt_analysis.coarse_template_data
tested_td = self.tested_analysis.coarse_template_data
nugt = gt_td.templates.shape[0]
matches = self.comparison.best_match_12.astype(int).values
matched = np.flatnonzero(matches >= 0)
matches = matches[matched]
tested_td = self.tested_analysis.coarse_template_data[matches]

dists, shifts, snrs_a, snrs_b = merge.cross_match_distance_matrix(
gt_td,
Expand All @@ -124,7 +127,8 @@ def _calculate_template_distances(self):
n_jobs=self.n_jobs,
device=self.device,
)
self._template_distances = dists
self._template_distances = np.full((nugt, nugt), np.inf)
self._template_distances[np.arange(nugt)[:, None], matched[None, :]] = dists

def _calculate_unsorted_detection(self):
if self._unsorted_detection is not None:
Expand Down
7 changes: 4 additions & 3 deletions src/dartsort/vis/gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,10 @@ def draw(self, panel, comparison):

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
df_show = df[np.isfinite(df[self.y].values)]
df_show = df_show[np.isfinite(df[self.x].values)]
sns.regplot(
data=df,
data=df_show,
x=self.x,
y=self.y,
logistic=True,
Expand Down Expand Up @@ -521,8 +523,7 @@ def __init__(self, cmap=plt.cm.magma_r):
def draw(self, panel, comparison):
agreement = comparison.comparison.get_ordered_agreement_scores()
row_order = agreement.index
col_order = np.array(agreement.columns)[:agreement.shape[0]]
dist = comparison.template_distances[row_order, :][:, col_order]
dist = comparison.template_distances[row_order, :]

ax = panel.subplots()
log1p_norm = FuncNorm((np.log1p, np.expm1), vmin=0)
Expand Down

0 comments on commit 9632b7a

Please sign in to comment.