Skip to content

Commit

Permalink
wip on subtract and localize numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
oliche committed Oct 22, 2024
1 parent 2d15db8 commit cdf6178
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 6 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
--extra-index-url https://download.pytorch.org/whl/cpu
click
dredge
h5py
hdbscan
matplotlib
Expand Down
2 changes: 1 addition & 1 deletion src/spike_psvae/chunk_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def handle_which_wfs(self, subtracted_wfs, cleaned_wfs, denoised_wfs):
f"which_waveforms={self.which_waveforms} not in ('subtracted', 'cleaned', denoised)"
)
if not self.tensor_ok and torch.is_tensor(wfs):
wfs = wfs.cpu().numpy()
wfs = wfs.cpu().detach().numpy()
return wfs


Expand Down
7 changes: 3 additions & 4 deletions src/spike_psvae/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sklearn.decomposition import PCA
from torch import nn
from tqdm.auto import trange
import dartsort

try:
from .denoise_temporal_decrease import (
Expand All @@ -20,9 +21,7 @@
pass


pretrained_path = (
Path(__file__).parent.parent.parent / "pretrained/single_chan_denoiser.pt"
)
pretrained_path = Path(dartsort.__file__).parent.joinpath('pretrained', 'single_chan_denoiser.pt')


class SingleChanDenoiser(nn.Module):
Expand Down Expand Up @@ -775,7 +774,7 @@ def enforce_decrease_shells(
# compute original ptps and allocate storage for decreasing ones
is_torch = False
if torch.is_tensor(waveforms):
orig_ptps = (waveforms.max(dim=1).values - waveforms.min(dim=1).values).cpu().numpy()
orig_ptps = (waveforms.max(dim=1).values - waveforms.min(dim=1).values).cpu().detach().numpy()
is_torch = True
else:
orig_ptps = waveforms.ptp(axis=1)
Expand Down
5 changes: 4 additions & 1 deletion src/spike_psvae/subtract.py
Original file line number Diff line number Diff line change
Expand Up @@ -1730,6 +1730,7 @@ def subtract_and_localize_numpy(
radial_parents=None,
tpca=None,
device=None,
enfdec=None,
probe="np1",
trough_offset=42,
spike_length_samples=121,
Expand Down Expand Up @@ -1772,6 +1773,7 @@ def subtract_and_localize_numpy(
residual,
threshold,
radial_parents,
enfdec, # enfdec
tpca,
dedup_channel_index,
extract_channel_index,
Expand Down Expand Up @@ -1852,4 +1854,5 @@ def subtract_and_localize_numpy(
],
columns=["sample", "trace", "x", "y", "z", "alpha"],
)
return df_localisation, cleaned_wfs.to("cpu")
np_waveforms = cleaned_wfs.to("cpu").detach().numpy()
return df_localisation, np_waveforms

0 comments on commit cdf6178

Please sign in to comment.