Skip to content

Commit

Permalink
Make sure pseudodistribution profiles are reproducible across platforms
Browse files Browse the repository at this point in the history
Basically we truncate pseudodistribution values at the 10th decimal place
to avoid producing different results on different platforms/architectures
due to minor difference in FP values.
  • Loading branch information
robomics committed Jan 13, 2025
1 parent cb079e0 commit 62e3f68
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 10 deletions.
23 changes: 17 additions & 6 deletions src/stripepy/stripepy.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def _extract_RoIs(I: ss.csr_matrix, RoI: Dict[str, List[int]]) -> Optional[NDArr
return I_RoI


def _compute_global_pseudodistribution(T: ss.csr_matrix, smooth: bool = True) -> NDArray[float]:
def _compute_global_pseudodistribution(
T: ss.csr_matrix, smooth: bool = True, decimal_places: int = 10
) -> NDArray[float]:
"""
Given a sparse matrix T, marginalize it, scale the marginal so that maximum is 1, and then smooth it.
Expand All @@ -130,17 +132,26 @@ def _compute_global_pseudodistribution(T: ss.csr_matrix, smooth: bool = True) ->
the sparse matrix to be processed
smooth: bool
if set to True, smoothing is applied to the pseudo-distribution (default value is True)
decimal_places: int
the number of decimal places to truncate the pseudo-distribution to.
Pass -1 to not truncate the pseudo-distribution values
Returns
-------
NDArray[np.float64]
NDArray[float]
a vector with the re-scaled and smoothed marginals.
"""

pseudo_dist = np.squeeze(np.asarray(np.sum(T, axis=0))) # marginalization
pseudo_dist /= np.max(pseudo_dist) # scaling
if smooth:
pseudo_dist = np.maximum(regressions._compute_wQISA_predictions(pseudo_dist, 11), pseudo_dist) # smoothing

if decimal_places >= 0:
# We need to truncate FP numbers to ensure that later steps generate consistent results
# even in the presence to very minor numeric differences on different platforms.
return common.truncate_np(pseudo_dist, decimal_places)

return pseudo_dist


Expand Down Expand Up @@ -290,10 +301,10 @@ def step_2(
# so that each maximum is still paired to its minimum.

# Maximum and minimum points sorted w.r.t. coordinates (NOTATION: cs = coordinate-sorted):
LT_mPs, LT_pers_of_mPs = common.sort_based_on_arg0(LT_ps_mPs, pers_of_LT_ps_mPs)
LT_MPs, LT_pers_of_MPs = common.sort_based_on_arg0(LT_ps_MPs, pers_of_LT_ps_MPs)
UT_mPs, UT_pers_of_mPs = common.sort_based_on_arg0(UT_ps_mPs, pers_of_UT_ps_mPs)
UT_MPs, UT_pers_of_MPs = common.sort_based_on_arg0(UT_ps_MPs, pers_of_UT_ps_MPs)
LT_mPs, LT_pers_of_mPs = common.sort_values(LT_ps_mPs, pers_of_LT_ps_mPs)
LT_MPs, LT_pers_of_MPs = common.sort_values(LT_ps_MPs, pers_of_LT_ps_MPs)
UT_mPs, UT_pers_of_mPs = common.sort_values(UT_ps_mPs, pers_of_UT_ps_mPs)
UT_MPs, UT_pers_of_MPs = common.sort_values(UT_ps_MPs, pers_of_UT_ps_MPs)

logger.bind(step=(2, 2, 3)).info("removing seeds overlapping sparse regions")
LT_mask = _check_neighborhood(_compute_global_pseudodistribution(L, smooth=False))
Expand Down
40 changes: 36 additions & 4 deletions src/stripepy/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@
#
# SPDX-License-Identifier: MIT

import decimal
import time
from typing import Optional, Sequence, Tuple

import numpy as np
import pandas as pd
from numpy.typing import NDArray


def sort_based_on_arg0(*vectors: Sequence) -> Tuple[NDArray]:
def sort_values(*vectors: Sequence) -> Tuple[NDArray]:
"""
Sort two or more sequences of objects based on the first sequence of objects.
Sort two or more sequences of objects as if each sequence was a column in a
table and the table was being sorted row-by-row based on values from all columns.
Parameters
----------
Expand All @@ -31,9 +34,10 @@ def sort_based_on_arg0(*vectors: Sequence) -> Tuple[NDArray]:
if len(vectors[0]) == 0:
return tuple((np.array(v) for v in vectors)) # noqa

permutation = np.argsort(vectors[0], stable=True)
df = pd.DataFrame({i: v for i, v in enumerate(vectors)})
df.sort_values(by=df.columns.tolist(), inplace=True, kind="stable")

return tuple((np.array(v)[permutation] for v in vectors)) # noqa
return tuple(df[col].to_numpy() for col in df.columns) # noqa


def pretty_format_elapsed_time(t0: float, t1: Optional[float] = None) -> str:
Expand Down Expand Up @@ -71,6 +75,34 @@ def pretty_format_elapsed_time(t0: float, t1: Optional[float] = None) -> str:
return f"{hours:.0f}h:{minutes:.0f}m:{seconds:.3f}s"


def truncate_np(v: NDArray[float], places: int) -> NDArray[float]:
"""
Truncate a numpy array to the given number of decimal places.
Implementation based on https://stackoverflow.com/a/28323804
Parameters
----------
v: NDArray[float]
the numpy array to be truncated
places: int
the number of decimal places to truncate to
Returns
-------
NDArray[float]
numpy array with truncated values
"""
assert places >= 0

if places == 0:
return v.round(0)

with decimal.localcontext() as context:
context.rounding = decimal.ROUND_DOWN
exponent = decimal.Decimal(str(10**-places))
return np.array([float(decimal.Decimal(str(n)).quantize(exponent)) for n in v], dtype=float)


def _import_matplotlib():
"""
Helper function to import matplotlib.
Expand Down

0 comments on commit 62e3f68

Please sign in to comment.