Skip to content

Commit

Permalink
Merge pull request #11 from MannLabs/speed_improvements
Browse files Browse the repository at this point in the history
Speed improvements when writing XML containing multiple wells
  • Loading branch information
GeorgWa authored Sep 10, 2024
2 parents 0f455ad + ec5a40a commit 0005882
Show file tree
Hide file tree
Showing 8 changed files with 1,566 additions and 133 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:
push:
branches: [ main, release ]
pull_request:
branches: [ main, release ]
workflow_dispatch:

jobs:
build:
Expand Down Expand Up @@ -35,4 +35,4 @@ jobs:
python -m pip install -e .
- name: Test with pytest
run: |
pytest
pytest
1,365 changes: 1,357 additions & 8 deletions docs_source/pages/notebooks/Image_Segmentation/Image_Segmentation_2.ipynb

Large diffs are not rendered by default.

23 changes: 7 additions & 16 deletions docs_source/pages/notebooks/generate_cutting_mask_svg.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
numpy
numpy<2.0.0
matplotlib
lxml
networkx >= 3.1
Expand All @@ -7,4 +7,4 @@ svgelements
setuptools
numba
tqdm
hilbertcurve
hilbertcurve
5 changes: 5 additions & 0 deletions src/lmd/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import matplotlib
import matplotlib.pyplot

matplotlib.use("Agg")
matplotlib.pyplot.ioff()
190 changes: 141 additions & 49 deletions src/lmd/lib.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from __future__ import annotations
from typing import Optional
from functools import partial, reduce
import multiprocessing
import multiprocessing as mp
import numpy as np
import matplotlib.pyplot as plt
from lxml import etree as ET
from matplotlib import image
from skimage import data, color
import matplotlib.ticker as ticker
from svgelements import SVG
from lmd.segmentation import get_coordinate_form, tsp_greedy_solve, tsp_hilbert_solve, calc_len
from lmd.segmentation import get_coordinate_form, tsp_greedy_solve, tsp_hilbert_solve, calc_len, _create_coord_index, _filter_coord_index
from tqdm import tqdm
# import warnings
import warnings

from skimage.morphology import dilation as binary_dilation
from skimage.morphology import binary_erosion, disk
Expand All @@ -28,6 +30,40 @@
from scipy.signal import convolve2d

import gc
import sys
import platform

from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Callable
from tqdm.auto import tqdm

def execute_indexed_parallel(func: Callable, *, args: list, tqdm_args: dict = None, n_threads: int = 10) -> list:
"""parallelization of function call with indexed arguments using ThreadPoolExecutor. Returns a list of results in the order of the input arguments.
Args:
func (Callable): _description_
args (list): _description_
tqdm_args (dict, optional): _description_. Defaults to None.
n_threads (int, optional): _description_. Defaults to 10.
Returns:
list: containing the results of the function calls in the same order as the input arguments
"""
if tqdm_args is None:
tqdm_args = {"total":len(args)}
elif "total" not in tqdm_args:
tqdm_args["total"] = len(args)

results = [None for _ in range(len(args))]
with ProcessPoolExecutor(n_threads) as executor:
with tqdm(**tqdm_args) as pbar:
futures = {executor.submit(func, *arg): i for i, arg in enumerate(args)}
for future in as_completed(futures):
index = futures[future]
results[index] = future.result()
pbar.update(1)

return results

class Collection:
"""Class which is used for creating shape collections for the Leica LMD6 & 7. Contains a coordinate system defined by calibration points and a collection of various shapes.
Expand All @@ -46,7 +82,7 @@ class Collection:
def __init__(self, calibration_points: Optional[np.ndarray] = None):


self.shapes: List[Shape] = []
self.shapes: list[Shape] = []

self.calibration_points: Optional[np.ndarray] = calibration_points

Expand Down Expand Up @@ -465,7 +501,6 @@ def to_xml(self,
return shape



class SegmentationLoader():
"""Select single cells from a segmentation and generate cutting data
Expand Down Expand Up @@ -555,9 +590,10 @@ class SegmentationLoader():
VALID_PATH_OPTIMIZERS = ["none", "hilbert", "greedy"]


def __init__(self, config = {}, verbose = False):
def __init__(self, config = {}, verbose = False, processes = 1):
self.config = config
self.verbose = verbose
self._get_context() #setup context for multiprocessing function calls to work with different operating systems

self.register_parameter('shape_dilation', 0)
self.register_parameter('shape_erosion', 0)
Expand All @@ -572,12 +608,22 @@ def __init__(self, config = {}, verbose = False):
self.register_parameter('processes', 10)
self.register_parameter('join_intersecting', True)
self.register_parameter('orientation_transform', np.eye(2))

def __call__(self, input_segmentation, cell_sets, calibration_points):

self.input_segmentation = input_segmentation
self.register_parameter('threads', 1)

self.coords_lookup = None
self.processes = processes

def _get_context(self):
if platform.system() == 'Windows':
self.context = "spawn"
elif platform.system() == 'Darwin':
self.context = "spawn"
elif platform.system() == 'Linux':
self.context = "fork"

def __call__(self, input_segmentation, cell_sets, calibration_points, coords_lookup = None):

self.calibration_points = calibration_points

sets = []

# iterate over all defined sets, perform sanity checks and load external data
Expand All @@ -588,19 +634,51 @@ def __call__(self, input_segmentation, cell_sets, calibration_points):
sets.append(cell_set)
self.log(f"cell set {i} passed sanity check")

collections = []
for i, cell_set in enumerate(cell_sets):
collections.append(self.generate_cutting_data(cell_set))
self.input_segmentation = input_segmentation

if coords_lookup is None:
self.log("Calculating coordinate locations of all cells.")
self.coords_lookup = _create_coord_index(self.input_segmentation)
else:
self.log("Loading coordinates from external source")
self.coords_lookup = coords_lookup

#try multithreading
if self.processes > 1:
self.log("Processing cell sets in parallel")
args = []
for i, cell_set in enumerate(cell_sets):
args.append((i, cell_set))

collections = execute_indexed_parallel(
self.generate_cutting_data,
args=args,
tqdm_args=dict(
file=sys.stdout,
disable=not self.verbose,
desc=" collecting cell sets",
),
n_threads = self.processes
)
else:
print("Processing cell sets in serial")
print(cell_set)
collections = []
for i, cell_set in enumerate(cell_sets):
collections.append(self.generate_cutting_data(i, cell_set))

return reduce(lambda a, b: a.join(b), collections)

def generate_cutting_data(self, cell_set):

def generate_cutting_data(self, i, cell_set):

if 0 in cell_set["classes_loaded"]:
cell_set["classes_loaded"] = cell_set["classes_loaded"][cell_set["classes_loaded"] != 0]
warnings.warn("Class 0 is not a valid class and was removed from the cell set")

self.log("Convert label format into coordinate format")

center, length, coords = get_coordinate_form(self.input_segmentation, cell_set["classes_loaded"])

self.log("Conversion finished, sanity check")
center, length, coords = get_coordinate_form(self.input_segmentation, cell_set["classes_loaded"], self.coords_lookup)

self.log("Conversion finished, performing sanity check.")

# Sanity check 1
if len(center) == len(cell_set["classes_loaded"]):
Expand All @@ -627,45 +705,56 @@ def generate_cutting_data(self, cell_set):
else:
self.log("Check failed, returned coordinates contain empty elements. Please check if all classes specified are present in your segmentation")


if self.config['join_intersecting']:
print("Merging intersecting shapes")
center, length, coords = self.merge_dilated_shapes(center, length, coords,
dilation = self.config['shape_dilation'],
erosion = self.config['shape_erosion'])
dilation = self.config['shape_dilation'],
erosion = self.config['shape_erosion'])

# Calculate dilation and erosion based on if merging was activated
dilation = self.config['binary_smoothing'] if self.config['join_intersecting'] else self.config['binary_smoothing'] + self.config['shape_dilation']
erosion = self.config['binary_smoothing'] if self.config['join_intersecting'] else self.config['binary_smoothing'] + self.config['shape_erosion']

self.log("Create shapes for merged cells")
with multiprocessing.Pool(processes=self.config['processes']) as pool:
shapes = list(tqdm(pool.imap(partial(tranform_to_map,
erosion = erosion,
dilation = dilation,
coord_format = False),
coords), total=len(center), disable = not self.verbose))


if self.config["threads"] == 1:
shapes = []
for coord in tqdm(coords, desc = "creating shapes"):
shapes.append(tranform_to_map(coord, dilation = dilation, erosion = erosion, coord_format = False))
else:
with mp.get_context(self.context).Pool(processes=self.config['threads']) as pool:
shapes = list(tqdm(pool.imap(partial(tranform_to_map,
erosion = erosion,
dilation = dilation,
coord_format = False),
coords), total=len(center), disable = not self.verbose, desc = "creating shapes"))


self.log("Calculating polygons")
with multiprocessing.Pool(processes=self.config['processes']) as pool:
shapes = list(tqdm(pool.imap(partial(create_poly,
smoothing_filter_size = self.config['convolution_smoothing'],
poly_compression_factor = self.config['poly_compression_factor']
),
shapes), total=len(center), disable = not self.verbose))

if self.config["threads"] == 1:
polygons = []
for shape in tqdm(shapes, desc = "calculating polygons"):
polygons.append(create_poly(shape,
smoothing_filter_size = self.config['convolution_smoothing'],
poly_compression_factor = self.config['poly_compression_factor']))
else:
with mp.get_context(self.context).Pool(processes=self.config['threads']) as pool:
polygons = list(tqdm(pool.imap(partial(create_poly,
smoothing_filter_size = self.config['convolution_smoothing'],
poly_compression_factor = self.config['poly_compression_factor']
),
shapes), total=len(center), disable = not self.verbose, desc = "calculating polygons" ))


self.log("Polygon calculation finished")


center = np.array(center)
unoptimized_length = calc_len(center)
self.log(f"Current path length: {unoptimized_length:,.2f} units")

# check if optimizer key has been set
if 'path_optimization' in self.config:



optimization_method = self.config['path_optimization']
self.log(f"Path optimizer defined in config: {optimization_method}")

Expand Down Expand Up @@ -700,9 +789,7 @@ def generate_cutting_data(self, cell_set):
self.log(f"Optimization factor: {optimization_factor:,.1f}x")

# order list of shapes by the optimized index array
shapes = [x for _, x in sorted(zip(optimized_idx, shapes))]


polygons = [x for _, x in sorted(zip(optimized_idx, polygons))]

# Plot coordinates if in debug mode
if self.verbose:
Expand All @@ -717,7 +804,7 @@ def generate_cutting_data(self, cell_set):

ax.scatter(center[:,1],center[:,0], s=1)

for shape in shapes:
for shape in polygons:
ax.plot(shape[:,1],shape[:,0], color="red",linewidth=1)


Expand All @@ -732,7 +819,7 @@ def generate_cutting_data(self, cell_set):
ds = Collection(calibration_points = self.calibration_points)
ds.orientation_transform = self.config['orientation_transform']

for shape in shapes:
for shape in polygons:
# Check if well key is set in cell set definition
if "well" in cell_set:
ds.new_shape(shape, well=cell_set["well"])
Expand All @@ -751,10 +838,15 @@ def merge_dilated_shapes(self,
# coordinates are created as complex numbers to facilitate comparison with np.isin
dilated_coords = []

with multiprocessing.Pool(processes=self.config['processes']) as pool:
dilated_coords = list(tqdm(pool.imap(partial(tranform_to_map,
dilation = dilation),
input_coords), total=len(input_center)))
if self.config["threads"] == 1:
for coord in tqdm(input_coords, desc = "dilating shapes"):
dilated_coords.append(tranform_to_map(coord, dilation = dilation))

else:
with mp.get_context(self.context).Pool(processes=self.config['processes']) as pool:
dilated_coords = list(tqdm(pool.imap(partial(tranform_to_map,
dilation = dilation),
input_coords), total=len(input_center)))

dilated_coords = [np.apply_along_axis(lambda args: [complex(*args)], 1, d).flatten() for d in dilated_coords]

Expand Down Expand Up @@ -894,7 +986,7 @@ def register_parameter(self, key, value):
else:
raise TypeError('Key musst be of string or a list of strings')

if not key in config_handle:
if key not in config_handle:
self.log(f'No configuration for {key} found, parameter will be set to {value}')
config_handle[key] = value

Expand Down
10 changes: 5 additions & 5 deletions src/lmd/lmd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ def test_text():
my_first_collection.join(identifier_3)

def test_segmentation_loader():
_dir = pathlib.Path(__file__).parent.resolve().absolute()
_dir = str(_dir).replace("src/lmd/", "docs_source/pages/notebooks")
im = Image.open(os.path.join(_dir, 'Image_Segmentation', 'segmentation_cytosol.tiff'))

package_base_path = pathlib.Path(__file__).parent.parent.parent.resolve().absolute()
test_segmentation_path = os.path.join(package_base_path, 'docs_source/pages/notebooks/Image_Segmentation/segmentation_cytosol.tiff')

im = Image.open(test_segmentation_path)
segmentation = np.array(im).astype(np.uint32)

all_classes = np.unique(segmentation)
Expand Down
Loading

0 comments on commit 0005882

Please sign in to comment.