Skip to content

Commit 6fd3fb5

Browse files
committed
Optimize find_HIoIs()
The optimization result in a 10% improvement in single-threaded performance
1 parent 8e96618 commit 6fd3fb5

File tree

2 files changed

+75
-35
lines changed

2 files changed

+75
-35
lines changed

src/stripepy/stripepy.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -402,35 +402,45 @@ def step_3(
402402
LT_bounded_mPs = np.concatenate(LT_bounded_mPs, dtype=int)
403403
UT_bounded_mPs = np.concatenate(UT_bounded_mPs, dtype=int)
404404

405-
# List of pairs (pair = left and right boundaries):
406-
# Choose the variable criterion between max_ascent and max_perc_descent
407-
# ---> When variable criterion is set to max_ascent, set the variable max_ascent
408-
# ---> When variable criterion is set to max_perc_descent, set the variable max_perc_descent
405+
# DataFrame with the left and right boundaries for each seed site
409406
LT_HIoIs = finders.find_HIoIs(
410-
LT_pseudo_distrib, LT_MPs, LT_bounded_mPs, int(max_width / (2 * resolution)) + 1, map=map
407+
pseudodistribution=LT_pseudo_distrib,
408+
seed_sites=LT_MPs,
409+
seed_site_bounds=LT_bounded_mPs,
410+
max_width=int(max_width / (2 * resolution)) + 1,
411+
map_=map,
412+
logger=logger,
411413
)
412414
UT_HIoIs = finders.find_HIoIs(
413-
UT_pseudo_distrib, UT_MPs, UT_bounded_mPs, int(max_width / (2 * resolution)) + 1, map=map
415+
pseudodistribution=UT_pseudo_distrib,
416+
seed_sites=UT_MPs,
417+
seed_site_bounds=UT_bounded_mPs,
418+
max_width=int(max_width / (2 * resolution)) + 1,
419+
map_=map,
420+
logger=logger,
414421
)
415422

416-
# List of left or right boundaries:
417-
LT_L_bounds, LT_R_bounds = map(list, zip(*LT_HIoIs))
418-
UT_L_bounds, UT_R_bounds = map(list, zip(*UT_HIoIs))
419-
420423
logger.bind(step=(3, 1, 2)).info("updating candidate stripes with width information")
421424
stripes = result.get("stripes", "LT")
422-
for num_cand_stripe, (LT_L_bound, LT_R_bound) in enumerate(zip(LT_L_bounds, LT_R_bounds)):
423-
stripes[num_cand_stripe].set_horizontal_bounds(LT_L_bound, LT_R_bound)
425+
LT_HIoIs.apply(
426+
lambda seed: stripes[seed.name].set_horizontal_bounds(seed["left_bound"], seed["right_bound"]),
427+
axis="columns",
428+
)
424429

425430
stripes = result.get("stripes", "UT")
426-
for num_cand_stripe, (UT_L_bound, UT_R_bound) in enumerate(zip(UT_L_bounds, UT_R_bounds)):
427-
stripes[num_cand_stripe].set_horizontal_bounds(UT_L_bound, UT_R_bound)
431+
UT_HIoIs.apply(
432+
lambda seed: stripes[seed.name].set_horizontal_bounds(seed["left_bound"], seed["right_bound"]),
433+
axis="columns",
434+
)
428435

429436
logger.bind(step=(3, 1)).info("width estimation took %s", common.pretty_format_elapsed_time(start_time))
430437

431438
logger.bind(step=(3, 2)).info("height estimation")
432439
start_time = time.time()
433440

441+
LT_HIoIs = LT_HIoIs.to_numpy() # TODO remove
442+
UT_HIoIs = UT_HIoIs.to_numpy() # TODO remove
443+
434444
logger.bind(step=(3, 2, 1)).info("estimating candidate stripe heights")
435445
LT_VIoIs, LT_peaks_ids = finders.find_VIoIs(
436446
L,

src/stripepy/utils/finders.py

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,41 +2,58 @@
22
#
33
# SPDX-License-Identifier: MIT
44

5+
import itertools
6+
import time
57
from functools import partial
68
from typing import List, Optional, Tuple
79

810
import numpy as np
11+
import numpy.typing as npt
12+
import pandas as pd
13+
import scipy.sparse as ss
14+
import structlog
915

1016
from . import TDA
17+
from .common import pretty_format_elapsed_time
1118
from .regressions import _compute_wQISA_predictions
1219

1320

14-
def find_horizontal_domain(pd, coarse_h_domain, max_width=1e9):
21+
def find_horizontal_domain(
22+
profile: npt.NDArray[float],
23+
coarse_h_domain: Tuple[int, int, int],
24+
max_width: int = 1e9,
25+
) -> Tuple[int, int]:
26+
"""
27+
Returns
28+
-------
29+
Tuple[int, int]
30+
the left and right coordinates of the horizontal domain
31+
"""
1532

1633
# Unpacking:
1734
MP, L_mP, R_mP = coarse_h_domain
1835

1936
# Left and sides of candidate:
20-
L_interval = np.flip(pd[L_mP : MP + 1])
21-
R_interval = pd[MP : R_mP + 1]
37+
L_interval = np.flip(profile[L_mP : MP + 1])
38+
R_interval = profile[MP : R_mP + 1]
2239

2340
# LEFT INTERVAL
24-
L_interval_shifted = np.append(L_interval[1:], [max(pd) + 1], axis=0)
41+
L_interval_shifted = np.append(L_interval[1:], [max(profile) + 1], axis=0)
2542
L_bound = np.where(L_interval - L_interval_shifted < 0)[0][0] + 1
2643
# L_interval_restr = L_interval[:L_bound]
2744
# L_interval_shifted_restr = L_interval_shifted[:L_bound]
2845
# L_bound = np.argmax(L_interval_restr - L_interval_shifted_restr) + 1
29-
L_bound = min(L_bound, max_width)
46+
L_bound = np.minimum(L_bound, max_width)
3047

3148
# RIGHT INTERVAL
32-
R_interval_shifted = np.append(R_interval[1:], [max(pd) + 1], axis=0)
49+
R_interval_shifted = np.append(R_interval[1:], [max(profile) + 1], axis=0)
3350
R_bound = np.where(R_interval - R_interval_shifted < 0)[0][0] + 1
3451
# R_interval_restr = R_interval[:R_bound]
3552
# R_interval_shifted_restr = R_interval_shifted[:R_bound]
3653
# R_bound = np.argmax(R_interval_restr - R_interval_shifted_restr) + 1
37-
R_bound = min(R_bound, max_width)
54+
R_bound = np.minimum(R_bound, max_width)
3855

39-
return [max(MP - L_bound, 0), min(MP + R_bound, len(pd))]
56+
return max(MP - L_bound, 0), min(MP + R_bound, len(profile))
4057

4158

4259
def find_lower_v_domain(I, threshold_cut, max_height, min_persistence, it) -> Tuple[List, Optional[List]]:
@@ -114,38 +131,51 @@ def find_upper_v_domain(I, threshold_cut, max_height, min_persistence, it) -> Tu
114131
return [seed_site - candida_bound[0], seed_site], list(seed_site - np.array(loc_Maxima[:-1]))
115132

116133

117-
def find_HIoIs(pd, seed_sites, seed_site_bounds, max_width, map=map):
134+
def find_HIoIs(
135+
pseudodistribution: npt.NDArray[float],
136+
seed_sites: npt.NDArray[int],
137+
seed_site_bounds: npt.NDArray[int],
138+
max_width: int,
139+
map_=map,
140+
logger=None,
141+
) -> pd.DataFrame:
118142
"""
119-
:param pd: acronym for pseudo-distribution, but can be any 1D array representing a uniformly-sample
120-
scalar function works
143+
:param pseudodistribution: 1D array representing a uniformly-sample scalar function works
121144
:param seed_sites: maximum values in the pseudo-distribution (i.e., genomic coordinates hosting linear
122145
patterns)
123146
:param seed_site_bounds: for the i-th entry of seed_sites:
124147
(*) seed_site_bounds[i] is the left boundary
125148
(*) seed_site_bounds[i+1] is the right boundary
126149
:param max_width: maximum width allowed
127-
:param map: alternative implementation of the built-in map function. Can be used to e.g. run this step in parallel by passing multiprocessing.Pool().map.
150+
:param map_: alternative implementation of the built-in map function. Can be used to e.g. run this step in parallel by passing multiprocessing.Pool().map.
128151
:return:
129-
HIoIs list of lists, where each sublist is a pair consisting of the left and right boundaries
152+
HIoIs a pd.DataFrame the list of left and right boundary for each seed site
130153
"""
131154
assert len(seed_site_bounds) == len(seed_sites) + 1
132155

156+
t0 = time.time()
157+
if logger is None:
158+
logger = structlog.get_logger()
159+
133160
iterable_input = [
134161
(seed_site, seed_site_bounds[num_MP], seed_site_bounds[num_MP + 1])
135162
for num_MP, seed_site in enumerate(seed_sites)
136163
]
137164

138-
HIoIs = list(map(partial(find_horizontal_domain, pd, max_width=max_width), iterable_input))
165+
tasks = map_(partial(find_horizontal_domain, pseudodistribution, max_width=max_width), iterable_input)
166+
# This efficiently constructs a 2D numpy with shape (N, 2) from a list of 2-element tuples, where N is the number of seed sites.
167+
# The first and second columns contains the left and right boundaries of the horizontal domains, respectively.
168+
HIoIs = np.fromiter(itertools.chain.from_iterable(tasks), count=2 * len(seed_sites), dtype=int).reshape(-1, 2)
169+
170+
# Handle possible overlapping intervals by ensuring that the
171+
# left bound of interval i + 1 is always greater or equal than the right bound of interval i
172+
HIoIs[1:, 0] = np.maximum(HIoIs[1:, 0], HIoIs[:-1, 1])
139173

140-
# Handle possible overlapping intervals:
141-
for i in range(len(HIoIs) - 1):
142-
current_pair = HIoIs[i]
143-
next_pair = HIoIs[i + 1]
174+
df = pd.DataFrame(data=HIoIs, columns=["left_bound", "right_bound"])
144175

145-
if current_pair[1] > next_pair[0]: # Check for intersection
146-
next_pair[0] = current_pair[1] # Modify the second pair
176+
logger.debug("find_HIoIs took %s", pretty_format_elapsed_time(t0))
147177

148-
return HIoIs
178+
return df
149179

150180

151181
def find_VIoIs(

0 commit comments

Comments
 (0)