Skip to content

Commit

Permalink
Batch EnforceDecrease
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Jan 22, 2024
1 parent 40561bd commit e82b42f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
24 changes: 17 additions & 7 deletions src/dartsort/transform/enforce_decrease.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(
self,
channel_index,
geom,
batch_size=2048,
name=None,
name_prefix="",
):
Expand All @@ -38,6 +39,7 @@ def __init__(
"_1",
torch.tensor(1.0),
)
self.batch_size = batch_size

def forward(self, waveforms, max_channels):
"""
Expand All @@ -54,14 +56,22 @@ def forward(self, waveforms, max_channels):

# pad with an extra channel to support indexing tricks
pad_ptps = F.pad(ptps, (0, 1), value=torch.inf)
parent_min_ptps = torch.zeros_like(ptps)
# get amplitudes of all parents for all channels -- (N, c, <=c-1)
# TODO batch this?
# TODO it may be possible to refactor using the new torch.Tensor.scatter_reduce_!
parent_ptps = pad_ptps[
torch.arange(n)[:, None, None],
self.parents_index[max_channels],
]
parent_min_ptps = parent_ptps.min(dim=2).values
# this is a gather_reduce, and it would be nice if torch had one.
# batching the following:
# parent_ptps = pad_ptps[
# torch.arange(n)[:, None, None],
# self.parents_index[max_channels],
# ]
# parent_min_ptps = parent_ptps.min(dim=2).values
for bs in range(0, n, self.batch_size):
be = min(n, bs + self.batch_size)
parent_ptps = pad_ptps[
torch.arange(bs, be)[:, None, None],
self.parents_index[max_channels[bs:be]],
]
parent_min_ptps[bs:be] = parent_ptps.min(dim=2).values

# what would we need to multiply by to ensure my amp is <= all parents?
rescaling = torch.minimum(parent_min_ptps / ptps, self._1)
Expand Down
8 changes: 8 additions & 0 deletions src/dartsort/util/drift_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,14 @@ def get_waveforms_on_static_channels(
):
"""Load a set of drifting waveforms on a static set of channels
Waveforms are by default detected waveforms on the full probe with channel
locations stored in geom. If main_channels and channel_index are supplied,
then waveforms[i] appears on the channels indexed by channel_index[main_channels[i]].
Now, the user wants to extract a subset of channels from each waveform such that
the same channel in each subsetted waveform is always at the same physical position,
accounting for drift. Here, the drift is inputted by n_pitches_shift.
Arguments
---------
waveforms : (n_spikes, t (optional), c) array
Expand Down

0 comments on commit e82b42f

Please sign in to comment.