diff --git a/src/stripepy/stripepy.py b/src/stripepy/stripepy.py index a17d5c9..9bf8687 100644 --- a/src/stripepy/stripepy.py +++ b/src/stripepy/stripepy.py @@ -120,7 +120,9 @@ def _extract_RoIs(I: ss.csr_matrix, RoI: Dict[str, List[int]]) -> Optional[NDArr return I_RoI -def _compute_global_pseudodistribution(T: ss.csr_matrix, smooth: bool = True) -> NDArray[float]: +def _compute_global_pseudodistribution( + T: ss.csr_matrix, smooth: bool = True, decimal_places: int = 10 +) -> NDArray[float]: """ Given a sparse matrix T, marginalize it, scale the marginal so that maximum is 1, and then smooth it. @@ -130,10 +132,13 @@ def _compute_global_pseudodistribution(T: ss.csr_matrix, smooth: bool = True) -> the sparse matrix to be processed smooth: bool if set to True, smoothing is applied to the pseudo-distribution (default value is True) + decimal_places: int + the number of decimal places to truncate the pseudo-distribution to. + Pass -1 to not truncate the pseudo-distribution values Returns ------- - NDArray[np.float64] + NDArray[float] a vector with the re-scaled and smoothed marginals. """ @@ -141,6 +146,12 @@ def _compute_global_pseudodistribution(T: ss.csr_matrix, smooth: bool = True) -> pseudo_dist /= np.max(pseudo_dist) # scaling if smooth: pseudo_dist = np.maximum(regressions._compute_wQISA_predictions(pseudo_dist, 11), pseudo_dist) # smoothing + + if decimal_places >= 0: + # We need to truncate FP numbers to ensure that later steps generate consistent results + # even in the presence to very minor numeric differences on different platforms. + return common.truncate_np(pseudo_dist, decimal_places) + return pseudo_dist @@ -290,10 +301,10 @@ def step_2( # so that each maximum is still paired to its minimum. # Maximum and minimum points sorted w.r.t. coordinates (NOTATION: cs = coordinate-sorted): - LT_mPs, LT_pers_of_mPs = common.sort_based_on_arg0(LT_ps_mPs, pers_of_LT_ps_mPs) - LT_MPs, LT_pers_of_MPs = common.sort_based_on_arg0(LT_ps_MPs, pers_of_LT_ps_MPs) - UT_mPs, UT_pers_of_mPs = common.sort_based_on_arg0(UT_ps_mPs, pers_of_UT_ps_mPs) - UT_MPs, UT_pers_of_MPs = common.sort_based_on_arg0(UT_ps_MPs, pers_of_UT_ps_MPs) + LT_mPs, LT_pers_of_mPs = common.sort_values(LT_ps_mPs, pers_of_LT_ps_mPs) + LT_MPs, LT_pers_of_MPs = common.sort_values(LT_ps_MPs, pers_of_LT_ps_MPs) + UT_mPs, UT_pers_of_mPs = common.sort_values(UT_ps_mPs, pers_of_UT_ps_mPs) + UT_MPs, UT_pers_of_MPs = common.sort_values(UT_ps_MPs, pers_of_UT_ps_MPs) logger.bind(step=(2, 2, 3)).info("removing seeds overlapping sparse regions") LT_mask = _check_neighborhood(_compute_global_pseudodistribution(L, smooth=False)) diff --git a/src/stripepy/utils/common.py b/src/stripepy/utils/common.py index 8293de8..70acef7 100644 --- a/src/stripepy/utils/common.py +++ b/src/stripepy/utils/common.py @@ -2,16 +2,19 @@ # # SPDX-License-Identifier: MIT +import decimal import time from typing import Optional, Sequence, Tuple import numpy as np +import pandas as pd from numpy.typing import NDArray -def sort_based_on_arg0(*vectors: Sequence) -> Tuple[NDArray]: +def sort_values(*vectors: Sequence) -> Tuple[NDArray]: """ - Sort two or more sequences of objects based on the first sequence of objects. + Sort two or more sequences of objects as if each sequence was a column in a + table and the table was being sorted row-by-row based on values from all columns. Parameters ---------- @@ -31,9 +34,10 @@ def sort_based_on_arg0(*vectors: Sequence) -> Tuple[NDArray]: if len(vectors[0]) == 0: return tuple((np.array(v) for v in vectors)) # noqa - permutation = np.argsort(vectors[0], stable=True) + df = pd.DataFrame({i: v for i, v in enumerate(vectors)}) + df.sort_values(by=df.columns.tolist(), inplace=True, kind="stable") - return tuple((np.array(v)[permutation] for v in vectors)) # noqa + return tuple(df[col].to_numpy() for col in df.columns) # noqa def pretty_format_elapsed_time(t0: float, t1: Optional[float] = None) -> str: @@ -71,6 +75,34 @@ def pretty_format_elapsed_time(t0: float, t1: Optional[float] = None) -> str: return f"{hours:.0f}h:{minutes:.0f}m:{seconds:.3f}s" +def truncate_np(v: NDArray[float], places: int) -> NDArray[float]: + """ + Truncate a numpy array to the given number of decimal places. + Implementation based on https://stackoverflow.com/a/28323804 + + Parameters + ---------- + v: NDArray[float] + the numpy array to be truncated + places: int + the number of decimal places to truncate to + + Returns + ------- + NDArray[float] + numpy array with truncated values + """ + assert places >= 0 + + if places == 0: + return v.round(0) + + with decimal.localcontext() as context: + context.rounding = decimal.ROUND_DOWN + exponent = decimal.Decimal(str(10**-places)) + return np.array([float(decimal.Decimal(str(n)).quantize(exponent)) for n in v], dtype=float) + + def _import_matplotlib(): """ Helper function to import matplotlib.