Skip to content

Commit

Permalink
make naming conventions more clear
Browse files Browse the repository at this point in the history
use multi processing for cellsets
Each of these processes can then have multiple threads underneath it
  • Loading branch information
sophiamaedler committed Jul 10, 2024
1 parent 94c4116 commit a0176bc
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/lmd/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import sys
import platform

from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Callable
from tqdm.auto import tqdm

Expand All @@ -53,7 +53,7 @@ def execute_indexed_parallel(func: Callable, *, args: list, tqdm_args: dict = No
tqdm_args["total"] = len(args)

results = [None for _ in range(len(args))]
with ThreadPoolExecutor(n_threads) as executor:
with ProcessPoolExecutor(n_threads) as executor:
with tqdm(**tqdm_args) as pbar:
futures = {executor.submit(func, *arg): i for i, arg in enumerate(args)}
for future in as_completed(futures):
Expand Down Expand Up @@ -588,7 +588,7 @@ class SegmentationLoader():
VALID_PATH_OPTIMIZERS = ["none", "hilbert", "greedy"]


def __init__(self, config = {}, verbose = False, threads = None):
def __init__(self, config = {}, verbose = False, processes = 1):
self.config = config
self.verbose = verbose
self._get_context() #setup context for multiprocessing function calls to work with different operating systems
Expand All @@ -608,7 +608,7 @@ def __init__(self, config = {}, verbose = False, threads = None):
self.register_parameter('orientation_transform', np.eye(2))

self.coords_lookup = None
self.threads = threads
self.processes = processes

def _get_context(self):
if platform.system() == 'Windows':
Expand Down Expand Up @@ -642,8 +642,8 @@ def __call__(self, input_segmentation, cell_sets, calibration_points, coords_loo

#try multithreading

if self.threads is not None:
self.log("Multithreading the processing of cell sets")
if self.processes > 1:
self.log("Processing cell sets in parallel")
args = []
for i, cell_set in enumerate(cell_sets):
args.append((i, cell_set))
Expand Down Expand Up @@ -708,12 +708,12 @@ def generate_cutting_data(self, i, cell_set):

self.log("Create shapes for merged cells")

if self.config["processes"] == 1:
if self.config["threads"] == 1:
shapes = []
for coord in tqdm(coords, desc = "creating shapes"):
shapes.append(tranform_to_map(coord, dilation = dilation, erosion = erosion, coord_format = False))
else:
with mp.get_context(self.context).Pool(processes=self.config['processes']) as pool:
with mp.get_context(self.context).Pool(processes=self.config['threads']) as pool:
shapes = list(tqdm(pool.imap(partial(tranform_to_map,
erosion = erosion,
dilation = dilation,
Expand All @@ -722,14 +722,14 @@ def generate_cutting_data(self, i, cell_set):


self.log("Calculating polygons")
if self.config["processes"] == 1:
if self.config["threads"] == 1:
shapes = []
for shape in tqdm(shapes, desc = "calculating polygons"):
shapes.append(create_poly(shape,
smoothing_filter_size = self.config['convolution_smoothing'],
poly_compression_factor = self.config['poly_compression_factor']))
else:
with mp.get_context(self.context).Pool(processes=self.config['processes']) as pool:
with mp.get_context(self.context).Pool(processes=self.config['threads']) as pool:
shapes = list(tqdm(pool.imap(partial(create_poly,
smoothing_filter_size = self.config['convolution_smoothing'],
poly_compression_factor = self.config['poly_compression_factor']
Expand Down Expand Up @@ -829,7 +829,7 @@ def merge_dilated_shapes(self,
# coordinates are created as complex numbers to facilitate comparison with np.isin
dilated_coords = []

if self.config["processes"] == 1:
if self.config["threads"] == 1:
for coord in tqdm(input_coords, desc = "dilating shapes"):
dilated_coords.append(tranform_to_map(coord, dilation = dilation))

Expand Down

0 comments on commit a0176bc

Please sign in to comment.