Skip to content

Commit

Permalink
adaptive bin size fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
Julien Boussard committed Feb 19, 2024
2 parents faf0704 + 835f472 commit ac95c98
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 40 deletions.
15 changes: 13 additions & 2 deletions src/dartsort/templates/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def __post_init__(self):
assert self.shifts_b.shape == (
self.upsampled_shifted_template_index_b.shape[1],
)

self.a_shift_offset, self.offset_shift_a_to_ix = _get_shift_indexer(
self.shifts_a
)
Expand All @@ -69,6 +68,15 @@ def __post_init__(self):
)

def get_shift_ix_a(self, shifts_a):
"""Map shift (an integer, signed) to a shift index
A shift index can be used to index into axis=1 of shifted_template_index_a,
or self.shifts_a for that matter.
It's an int in [0, n_shifts_a).
It's equal to np.searchsorted(self.shifts_a, shifts_a).
The thing is, searchsorted is slow, and we can pre-bake a lookup table.
_get_shift_indexer does the baking for us above.
"""
shifts_a = torch.atleast_1d(torch.as_tensor(shifts_a))
return self.offset_shift_a_to_ix[shifts_a.to(int) + self.a_shift_offset]

Expand Down Expand Up @@ -328,6 +336,7 @@ def query(
# device=self.device,
# )


def batched_h5_read(dataset, indices, batch_size=1000):
if indices.size < batch_size:
return dataset[indices]
Expand All @@ -341,14 +350,16 @@ def batched_h5_read(dataset, indices, batch_size=1000):

def _get_shift_indexer(shifts):
assert torch.equal(shifts, torch.sort(shifts).values)
# smallest shift (say, -5) becomes 5
shift_offset = -int(shifts[0])
offset_shift_to_ix = []

for j, shift in enumerate(shifts):
ix = shift + shift_offset
assert len(offset_shift_to_ix) <= ix
# assert 0 <= ix < len(shifts)
while len(offset_shift_to_ix) < ix:
offset_shift_to_ix.append(len(shifts))
offset_shift_to_ix.append(j)

offset_shift_to_ix = torch.tensor(offset_shift_to_ix, device=shifts.device)
return shift_offset, offset_shift_to_ix
45 changes: 32 additions & 13 deletions src/dartsort/util/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@
from ..templates import TemplateData
from ..transform import WaveformPipeline
from .data_util import DARTsortSorting
from .drift_util import (get_spike_pitch_shifts,
get_waveforms_on_static_channels, registered_average)
from .drift_util import (
get_spike_pitch_shifts,
get_waveforms_on_static_channels,
registered_average,
)
from .spikeio import read_waveforms_channel_index
from .waveform_util import make_channel_index

Expand Down Expand Up @@ -307,8 +310,10 @@ def show_geom(self):
show_geom = self.recording.get_channel_locations()
return show_geom

def show_channel_index(self, channel_channel_show_radius_um=50):
return make_channel_index(self.show_geom, channel_channel_show_radius_um)
def show_channel_index(self, channel_show_radius_um=50, channel_dist_p=np.inf):
return make_channel_index(
self.show_geom, channel_show_radius_um, p=channel_dist_p
)

# spike feature loading methods

Expand Down Expand Up @@ -369,6 +374,7 @@ def unit_raw_waveforms(
channel_show_radius_um=75,
trough_offset_samples=42,
spike_length_samples=121,
channel_dist_p=np.inf,
relocated=False,
):
if which is None:
Expand All @@ -389,8 +395,9 @@ def unit_raw_waveforms(
if self.shifting:
load_ci = self.channel_index
else:
load_ci = make_channel_index(
self.recording.get_channel_locations(), channel_show_radius_um
load_ci = self.show_channel_index(
channel_show_radius_um=channel_show_radius_um,
channel_dist_p=channel_dist_p,
)
waveforms = read_waveforms_channel_index(
self.recording,
Expand All @@ -415,6 +422,7 @@ def unit_raw_waveforms(
waveforms,
load_ci,
channel_show_radius_um=channel_show_radius_um,
channel_dist_p=channel_dist_p,
relocated=relocated,
)
return which, waveforms, max_chan, show_geom, show_channel_index
Expand All @@ -441,8 +449,14 @@ def unit_tpca_waveforms(

tpca_embeds = self.tpca_features(which=which)
n, rank, c = tpca_embeds.shape
waveforms = tpca_embeds.transpose(0, 2, 1).reshape(n * c, rank)
waveforms = self.sklearn_tpca.inverse_transform(waveforms)
tpca_embeds = tpca_embeds.transpose(0, 2, 1).reshape(n * c, rank)
waveforms = np.full(
(n * c, self.sklearn_tpca.components_.shape[1]),
np.nan,
dtype=tpca_embeds.dtype,
)
valid = np.flatnonzero(np.isfinite(tpca_embeds[:, 0]))
waveforms[valid] = self.sklearn_tpca.inverse_transform(tpca_embeds[valid])
t = waveforms.shape[1]
waveforms = waveforms.reshape(n, c, t).transpose(0, 2, 1)

Expand Down Expand Up @@ -501,12 +515,15 @@ def unit_shift_or_relocate_channels(
waveforms,
load_channel_index,
channel_show_radius_um=75,
channel_dist_p=np.inf,
relocated=False,
):
geom = self.recording.get_channel_locations()
show_geom = self.template_data.registered_geom
if show_geom is None:
show_geom = geom
show_geom = self.show_geom
show_channel_index = self.show_channel_index(
channel_show_radius_um=channel_show_radius_um, channel_dist_p=channel_dist_p
)

temp = self.coarse_template_data.unit_templates(unit_id)
n_pitches_shift = None
if temp.shape[0]:
Expand All @@ -532,7 +549,7 @@ def unit_shift_or_relocate_channels(
else:
amp_template = np.nanmean(amps, axis=0)
max_chan = np.nanargmax(amp_template)
show_channel_index = make_channel_index(show_geom, channel_show_radius_um)

show_chans = show_channel_index[max_chan]
show_chans = show_chans[show_chans < len(show_geom)]
show_channel_index = np.broadcast_to(
Expand Down Expand Up @@ -580,7 +597,9 @@ def nearby_coarse_templates(self, unit_id, n_neighbors=5):
unit_ix = np.searchsorted(self.unit_ids, unit_id)
unit_dists = self.merge_dist[unit_ix]
distance_order = np.argsort(unit_dists)
distance_order = np.concatenate(([unit_ix], distance_order[distance_order != unit_ix]))
distance_order = np.concatenate(
([unit_ix], distance_order[distance_order != unit_ix])
)
# assert distance_order[0] == unit_ix
neighbor_ixs = distance_order[:n_neighbors]
neighbor_ids = self.unit_ids[neighbor_ixs]
Expand Down
Loading

0 comments on commit ac95c98

Please sign in to comment.