Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
julienboussard committed Feb 1, 2024
2 parents 4d6d6fa + 0a1ee31 commit 16959b8
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 13 deletions.
9 changes: 6 additions & 3 deletions src/dartsort/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from . import vis
from .config import *
from .localize.localize_util import (localize_amplitude_vectors, localize_hdf5,
localize_waveforms)
from .main import (DARTsortSorting, ObjectiveUpdateTemplateMatchingPeeler,
SubtractionPeeler, check_recording, cluster, dartsort,
match, run_peeler, split_merge, subtract)
from .main import (ObjectiveUpdateTemplateMatchingPeeler, SubtractionPeeler,
check_recording, cluster, dartsort, match, run_peeler,
split_merge, subtract)
from .peel.grab import GrabAndFeaturize
from .templates import TemplateData
from .transform import WaveformPipeline
from .util.data_util import DARTsortSorting
from .util.waveform_util import make_channel_index
24 changes: 23 additions & 1 deletion src/dartsort/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,28 @@ class SubtractionConfig:
)


@dataclass(frozen=True)
class MotionEstimationConfig:
"""Configure motion estimation.
You can also make your own and pass it to dartsort() to bypass this
"""
do_motion_estimation: bool = True

# sometimes spikes can be localized far away from the probe, causing
# issues with motion estimation, we will ignore such spikes
probe_boundary_padding_um: float = 100.0

# DREDge parameters
spatial_bin_length_um: float = 1.0
temporal_bin_length_s: float = 1.0
window_step_um: float = 400.0
window_scale_um: float = 450.0
window_margin_um: Optional[float] = None
max_dt_s: float = 0.1
max_disp_um: Optional[float] = None


@dataclass(frozen=True)
class TemplateConfig:
trough_offset_samples: int = 42
Expand Down Expand Up @@ -189,7 +211,7 @@ class ClusteringConfig:
log_c: float = 5.0
recursive: bool = False
remove_duplicates: bool = True
#remove large clusters in hdbscan?
# remove large clusters in hdbscan?
remove_big_units: bool = True
zstd_big_units: float = 50.0

Expand Down
2 changes: 1 addition & 1 deletion src/dartsort/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from dartsort.peel import (ObjectiveUpdateTemplateMatchingPeeler,
SubtractionPeeler)
from dartsort.templates import TemplateData
from dartsort.util.data_util import DARTsortSorting, check_recording
from dartsort.util.data_util import check_recording
from dartsort.util.peel_util import run_peeler


Expand Down
14 changes: 14 additions & 0 deletions src/dartsort/util/registration_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
try:
from dredge import dredge_ap
have_dredge = True
except ImportError:
have_dredge = False
pass


def estimate_motion(recording, sorting, motion_estimation_config=None, localizations_dataset_name="point_source_localizations"):
if not motion_estimation_config.do_motion_estimation:
return None

if not have_dredge:
raise ValueError("Please install DREDge to use motion estimation.")
2 changes: 1 addition & 1 deletion tests/test_grab_and_featurize.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def test_grab_and_featurize():
localize_hdf5(
Path(tempdir) / "grab.h5",
radius=50.0,
amplitude_vectors_dataset_name="amplitude_vectors",
amplitude_vectors_dataset_name="peak_amplitude_vectors",
)

with h5py.File(Path(tempdir) / "grab.h5") as h5:
Expand Down
15 changes: 8 additions & 7 deletions tests/test_subtract.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,12 @@ def test_fakedata_nonn():
subconf.subtraction_denoising_config, do_tpca_denoise=False
),
)
nolocfeatconf = dataclasses.replace(featconf, do_localization=False)
with tempfile.TemporaryDirectory() as tempdir:
st0, out_h5 = subtract(
rec,
tempdir,
featurization_config=dataclasses.replace(featconf, do_localization=False),
featurization_config=nolocfeatconf,
subtraction_config=subconf,
)
ns0 = len(st0)
Expand All @@ -263,13 +264,13 @@ def test_fakedata_nonn():
sta, out_h5 = subtract(
rec.frame_slice(start_frame=0, end_frame=int(20 * fs)),
tempdir,
featurization_config=dataclasses.replace(featconf, do_localization=False),
featurization_config=nolocfeatconf,
subtraction_config=subconf,
)
stb, out_h5 = subtract(
rec,
tempdir,
featurization_config=featconf,
featurization_config=nolocfeatconf,
subtraction_config=subconf,
)
assert len(sta) < ns0
Expand All @@ -279,13 +280,13 @@ def test_fakedata_nonn():
sta, out_h5 = subtract(
rec.frame_slice(start_frame=0, end_frame=int(25 * fs)),
tempdir,
featurization_config=dataclasses.replace(featconf, do_localization=False),
featurization_config=nolocfeatconf,
subtraction_config=subconf,
)
stb, out_h5 = subtract(
rec,
tempdir,
featurization_config=featconf,
featurization_config=nolocfeatconf,
subtraction_config=subconf,
)
assert len(sta) < ns0
Expand All @@ -295,13 +296,13 @@ def test_fakedata_nonn():
sta, out_h5 = subtract(
rec.frame_slice(start_frame=0, end_frame=int(30 * fs)),
tempdir,
featurization_config=dataclasses.replace(featconf, do_localization=False),
featurization_config=nolocfeatconf,
subtraction_config=subconf,
)
stb, out_h5 = subtract(
rec,
tempdir,
featurization_config=featconf,
featurization_config=nolocfeatconf,
subtraction_config=subconf,
)
assert len(sta) < ns0
Expand Down

0 comments on commit 16959b8

Please sign in to comment.