Skip to content

Commit 5960352

Browse files
committed
Threading-based parallelism in the peelers
1 parent 6664428 commit 5960352

File tree

7 files changed

+78
-40
lines changed

7 files changed

+78
-40
lines changed

src/dartsort/config.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ class FeaturizationConfig:
142142
tpca_fit_radius: float = 75.0
143143
tpca_rank: int = 8
144144
tpca_centered: bool = False
145-
# todo: use a WaveformConfig...
145+
learn_cleaned_tpca_basis: bool = False
146146
input_tpca_waveform_config: WaveformConfig | None = WaveformConfig(
147147
ms_before=0.75, ms_after=1.25
148148
)
@@ -343,9 +343,9 @@ class ClusteringConfig:
343343
class ComputationConfig:
344344
n_jobs_cpu: int = 0
345345
n_jobs_gpu: int = 0
346-
device: torch.device | None = argfield(default=None, arg_type=torch.device)
346+
executor: str = "ThreadPoolExecutor"
347+
device: str | None = argfield(default=None, arg_type=str)
347348

348-
@property
349349
def actual_device(self):
350350
if self.device is None:
351351
have_cuda = torch.cuda.is_available()
@@ -354,9 +354,8 @@ def actual_device(self):
354354
return torch.device("cpu")
355355
return torch.device(self.device)
356356

357-
@property
358-
def actual_n_jobs_gpu(self):
359-
if self.actual_device.type == "cuda":
357+
def actual_n_jobs(self):
358+
if self.actual_device().type == "cuda":
360359
return self.n_jobs_gpu
361360
return self.n_jobs_cpu
362361

src/dartsort/main.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
default_subtraction_config,
1717
default_template_config,
1818
default_waveform_config,
19+
default_computation_config,
1920
)
2021
from dartsort.peel import ObjectiveUpdateTemplateMatchingPeeler, SubtractionPeeler
2122
from dartsort.templates import TemplateData
@@ -143,12 +144,11 @@ def subtract(
143144
waveform_config=default_waveform_config,
144145
featurization_config=default_featurization_config,
145146
subtraction_config=default_subtraction_config,
147+
computation_config=default_computation_config,
146148
chunk_starts_samples=None,
147-
n_jobs=0,
148149
overwrite=False,
149150
residual_filename=None,
150151
show_progress=True,
151-
device=None,
152152
hdf5_filename="subtraction.h5",
153153
model_subdir="subtraction_models",
154154
):
@@ -163,14 +163,13 @@ def subtract(
163163
subtraction_peeler,
164164
output_directory,
165165
hdf5_filename,
166-
model_subdir,
167-
featurization_config,
166+
model_subdir=model_subdir,
167+
featurization_config=featurization_config,
168168
chunk_starts_samples=chunk_starts_samples,
169169
overwrite=overwrite,
170-
n_jobs=n_jobs,
170+
computation_config=computation_config,
171171
residual_filename=residual_filename,
172172
show_progress=show_progress,
173-
device=device,
174173
)
175174
return detections, output_hdf5_filename
176175

src/dartsort/peel/subtract.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -289,15 +289,15 @@ def precompute_peeler_models(self):
289289
self.subtraction_denoising_pipeline.precompute()
290290

291291
def fit_featurization_pipeline(
292-
self, save_folder, tmp_dir=None, n_jobs=0, device=None
292+
self, save_folder, tmp_dir=None, computation_config=None
293293
):
294294
super().fit_featurization_pipeline(
295-
save_folder, tmp_dir=tmp_dir, n_jobs=n_jobs, device=device
295+
save_folder, tmp_dir=tmp_dir, computation_config=computation_config
296296
)
297297
if self.use_singlechan_templates:
298298
self.have_singlechan_templates = True
299299

300-
def fit_peeler_models(self, save_folder, tmp_dir=None, n_jobs=0, device=None):
300+
def fit_peeler_models(self, save_folder, tmp_dir=None, computation_config=None):
301301
# when fitting peelers for subtraction, there are basically
302302
# two cases. fitting featurizers is easy -- they don't modify
303303
# the waveforms. fitting denoisers is hard -- they do. each
@@ -309,24 +309,23 @@ def fit_peeler_models(self, save_folder, tmp_dir=None, n_jobs=0, device=None):
309309
# so we will cheat for now:
310310
# just remove all the denoisers that need fitting, run peeling,
311311
# and fit everything
312+
312313
while self._fit_subtraction_transformers(
313314
save_folder,
314315
tmp_dir=tmp_dir,
315-
n_jobs=n_jobs,
316-
device=device,
316+
computation_config=computation_config,
317317
which="denoisers",
318318
):
319319
pass
320320
self._fit_subtraction_transformers(
321321
save_folder,
322322
tmp_dir=tmp_dir,
323-
n_jobs=n_jobs,
324-
device=device,
323+
computation_config=computation_config,
325324
which="featurizers",
326325
)
327326

328327
def _fit_subtraction_transformers(
329-
self, save_folder, tmp_dir=None, n_jobs=0, device=None, which="denoisers"
328+
self, save_folder, tmp_dir=None, computation_config=None, which="denoisers"
330329
):
331330
"""Fit models which are run during the subtraction step
332331
@@ -352,9 +351,11 @@ def _fit_subtraction_transformers(
352351
if not needs_fit:
353352
return False
354353

355-
if device is None:
354+
if computation_config is None:
356355
device = "cuda" if torch.cuda.is_available() else "cpu"
357-
device = torch.device(device)
356+
device = torch.device(device)
357+
else:
358+
device = computation_config.actual_device()
358359

359360
orig_denoise = self.subtraction_denoising_pipeline
360361
init_voltage_feature = Voltage(
@@ -396,8 +397,7 @@ def _fit_subtraction_transformers(
396397
try:
397398
self.run_subsampled_peeling(
398399
temp_hdf5_filename,
399-
n_jobs=n_jobs,
400-
device=device,
400+
computation_config=computation_config,
401401
task_name=f"Load examples for {which[:-1]} fitting",
402402
)
403403

@@ -563,6 +563,7 @@ def subtract_chunk(
563563
times_samples >= trough_offset_samples, times_samples < max_trough_time
564564
)
565565
(keep,) = keep.nonzero(as_tuple=True)
566+
566567
if not keep.numel():
567568
break
568569
times_samples = times_samples[keep]

src/dartsort/transform/enforce_decrease.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(
2121
self,
2222
channel_index,
2323
geom,
24-
batch_size=32,
24+
batch_size=1024,
2525
name=None,
2626
name_prefix="",
2727
):

src/dartsort/transform/single_channel_denoiser.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def __init__(
2525
channel_index,
2626
geom=None,
2727
denoiser=None,
28-
batch_size=32,
29-
in_place=True,
28+
batch_size=1024,
29+
in_place=False,
3030
pretrained_path=default_pretrained_path,
3131
name=None,
3232
name_prefix="",
@@ -41,8 +41,6 @@ def __init__(
4141

4242
if denoiser is None:
4343
denoiser = dnclss[clsname]().load(pretrained_path)
44-
denoiser.eval()
45-
denoiser.requires_grad_(False)
4644
self.denoiser = denoiser
4745

4846
def forward(self, waveforms, max_channels=None):
@@ -97,6 +95,8 @@ def forward(self, x):
9795
def load(self, pretrained_path=default_pretrained_path):
9896
checkpoint = torch.load(pretrained_path, map_location="cpu", weights_only=True)
9997
self.load_state_dict(checkpoint)
98+
self.eval()
99+
self.requires_grad_(False)
100100
return self
101101

102102

src/dartsort/util/multiprocessing_util.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from concurrent.futures import ThreadPoolExecutor as _ThreadPoolExecutor
44
from multiprocessing import get_context
55

6+
import torch
67
import torch.multiprocessing as torchmp
78

89
# TODO: torch.multiprocessing?
@@ -125,6 +126,25 @@ def rank_init(queue):
125126
print(f"rank init got {rank_init.rank=}")
126127

127128

129+
def pool_from_cfg(computation_config=None, with_rank_queue=False, check_local=False):
130+
if computation_config is None:
131+
device = "cuda" if torch.cuda.is_available() else "cpu"
132+
device = torch.device(device)
133+
n_jobs = 0
134+
cls = MockPoolExecutor
135+
else:
136+
device = computation_config.actual_device()
137+
n_jobs = computation_config.actual_n_jobs()
138+
cls = computation_config.executor
139+
140+
return get_pool(
141+
n_jobs,
142+
cls=cls,
143+
with_rank_queue=with_rank_queue,
144+
check_local=check_local,
145+
)
146+
147+
128148
def get_pool(
129149
n_jobs,
130150
context="spawn",
@@ -133,16 +153,30 @@ def get_pool(
133153
rank_queue_empty=False,
134154
n_tasks=None,
135155
max_tasks_per_child=None,
156+
check_local=False,
136157
):
137158
if n_jobs == -1:
138159
n_jobs = multiprocessing.cpu_count()
139160
do_parallel = n_jobs >= 1
140161
n_jobs = max(1, n_jobs)
141162

163+
if isinstance(cls, str):
164+
if cls == "CloudpicklePoolExecutor":
165+
cls = CloudpicklePoolExecutor
166+
elif cls == "ThreadPoolExecutor":
167+
cls = ThreadPoolExecutor
168+
elif cls == "ProcessPoolExecutor":
169+
cls = ProcessPoolExecutor
170+
elif cls == "MockPoolExecutor":
171+
cls = MockPoolExecutor
172+
else:
173+
assert False
174+
142175
if cls == CloudpicklePoolExecutor and not have_cloudpickle:
143176
cls = ProcessPoolExecutor
144177

145178
Executor = cls if do_parallel else MockPoolExecutor
179+
is_local = cls in (MockPoolExecutor, ThreadPoolExecutor)
146180
if context == "torchspawn":
147181
context = torchmp.get_context("spawn")
148182
else:
@@ -163,6 +197,10 @@ def get_pool(
163197
for rank in range(n_jobs):
164198
rank_queue.put(rank)
165199

200+
if check_local:
201+
return n_jobs, Executor, context, rank_queue, is_local
166202
return n_jobs, Executor, context, rank_queue
167203

204+
if check_local:
205+
return n_jobs, Executor, context, is_local
168206
return n_jobs, Executor, context

src/dartsort/util/peel_util.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,11 @@ def run_peeler(
1414
hdf5_filename,
1515
model_subdir,
1616
featurization_config,
17+
computation_config,
1718
chunk_starts_samples=None,
1819
overwrite=False,
19-
n_jobs=0,
2020
residual_filename=None,
2121
show_progress=True,
22-
device=None,
2322
localization_dataset_name="point_source_localizations",
2423
):
2524
output_directory = Path(output_directory)
@@ -49,20 +48,21 @@ def run_peeler(
4948

5049
# fit models if needed
5150
peeler.load_or_fit_and_save_models(
52-
model_dir, overwrite=overwrite, n_jobs=n_jobs, device=device
51+
model_dir, overwrite=overwrite, computation_config=computation_config
5352
)
5453

5554
# run main
5655
peeler.peel(
5756
output_hdf5_filename,
5857
chunk_starts_samples=chunk_starts_samples,
59-
n_jobs=n_jobs,
6058
overwrite=overwrite,
6159
residual_filename=residual_filename,
6260
show_progress=show_progress,
63-
device=device,
61+
computation_config=computation_config,
62+
)
63+
_gc(
64+
computation_config.actual_n_jobs(), computation_config.actual_device()
6465
)
65-
_gc(n_jobs, device)
6666

6767
# do localization
6868
if do_localization:
@@ -74,21 +74,22 @@ def run_peeler(
7474
amplitude_vectors_dataset_name=f"{wf_name}_{loc_amp_type}_amplitude_vectors",
7575
output_dataset_name=localization_dataset_name,
7676
show_progress=show_progress,
77-
n_jobs=n_jobs,
78-
device=device,
77+
n_jobs=computation_config.actual_n_jobs(),
78+
device=computation_config.actual_device(),
7979
localization_model=featurization_config.localization_model,
8080
)
81-
_gc(n_jobs, device)
81+
_gc(
82+
computation_config.actual_n_jobs(), computation_config.actual_device()
83+
)
8284

8385
if featurization_config.n_residual_snips:
8486
peeler.run_subsampled_peeling(
8587
output_hdf5_filename,
86-
n_jobs=n_jobs,
8788
chunk_length_samples=peeler.spike_length_samples,
8889
residual_to_h5=True,
8990
skip_features=True,
9091
ignore_resuming=True,
91-
device=device,
92+
computation_config=computation_config,
9293
n_chunks=featurization_config.n_residual_snips,
9394
task_name="Residual snips",
9495
overwrite=False,

0 commit comments

Comments
 (0)