Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
[no ci]
  • Loading branch information
robomics committed Jan 13, 2025
1 parent cb079e0 commit c68b477
Showing 2 changed files with 53 additions and 10 deletions.
23 changes: 17 additions & 6 deletions src/stripepy/stripepy.py
Original file line number Diff line number Diff line change
@@ -120,7 +120,9 @@ def _extract_RoIs(I: ss.csr_matrix, RoI: Dict[str, List[int]]) -> Optional[NDArr
return I_RoI


def _compute_global_pseudodistribution(T: ss.csr_matrix, smooth: bool = True) -> NDArray[float]:
def _compute_global_pseudodistribution(
T: ss.csr_matrix, smooth: bool = True, decimal_places: int = 10
) -> NDArray[float]:
"""
Given a sparse matrix T, marginalize it, scale the marginal so that maximum is 1, and then smooth it.
@@ -130,17 +132,26 @@ def _compute_global_pseudodistribution(T: ss.csr_matrix, smooth: bool = True) ->
the sparse matrix to be processed
smooth: bool
if set to True, smoothing is applied to the pseudo-distribution (default value is True)
decimal_places: int
the number of decimal places to truncate the pseudo-distribution to.
Pass -1 to not truncate the pseudo-distribution values
Returns
-------
NDArray[np.float64]
NDArray[float]
a vector with the re-scaled and smoothed marginals.
"""

pseudo_dist = np.squeeze(np.asarray(np.sum(T, axis=0))) # marginalization
pseudo_dist /= np.max(pseudo_dist) # scaling
if smooth:
pseudo_dist = np.maximum(regressions._compute_wQISA_predictions(pseudo_dist, 11), pseudo_dist) # smoothing

if decimal_places >= 0:
# We need to truncate FP numbers to ensure that later steps generate consistent results
# even in the presence to very minor numeric differences on different platforms.
return common.truncate_np(pseudo_dist, decimal_places)

return pseudo_dist


@@ -290,10 +301,10 @@ def step_2(
# so that each maximum is still paired to its minimum.

# Maximum and minimum points sorted w.r.t. coordinates (NOTATION: cs = coordinate-sorted):
LT_mPs, LT_pers_of_mPs = common.sort_based_on_arg0(LT_ps_mPs, pers_of_LT_ps_mPs)
LT_MPs, LT_pers_of_MPs = common.sort_based_on_arg0(LT_ps_MPs, pers_of_LT_ps_MPs)
UT_mPs, UT_pers_of_mPs = common.sort_based_on_arg0(UT_ps_mPs, pers_of_UT_ps_mPs)
UT_MPs, UT_pers_of_MPs = common.sort_based_on_arg0(UT_ps_MPs, pers_of_UT_ps_MPs)
LT_mPs, LT_pers_of_mPs = common.sort_values(LT_ps_mPs, pers_of_LT_ps_mPs)
LT_MPs, LT_pers_of_MPs = common.sort_values(LT_ps_MPs, pers_of_LT_ps_MPs)
UT_mPs, UT_pers_of_mPs = common.sort_values(UT_ps_mPs, pers_of_UT_ps_mPs)
UT_MPs, UT_pers_of_MPs = common.sort_values(UT_ps_MPs, pers_of_UT_ps_MPs)

logger.bind(step=(2, 2, 3)).info("removing seeds overlapping sparse regions")
LT_mask = _check_neighborhood(_compute_global_pseudodistribution(L, smooth=False))
40 changes: 36 additions & 4 deletions src/stripepy/utils/common.py
Original file line number Diff line number Diff line change
@@ -2,16 +2,19 @@
#
# SPDX-License-Identifier: MIT

import decimal
import time
from typing import Optional, Sequence, Tuple

import numpy as np
import pandas as pd
from numpy.typing import NDArray


def sort_based_on_arg0(*vectors: Sequence) -> Tuple[NDArray]:
def sort_values(*vectors: Sequence) -> Tuple[NDArray]:
"""
Sort two or more sequences of objects based on the first sequence of objects.
Sort two or more sequences of objects as if each sequence was a column in a
table and the table was being sorted row-by-row based on values from all columns.
Parameters
----------
@@ -31,9 +34,10 @@ def sort_based_on_arg0(*vectors: Sequence) -> Tuple[NDArray]:
if len(vectors[0]) == 0:
return tuple((np.array(v) for v in vectors)) # noqa

permutation = np.argsort(vectors[0], stable=True)
df = pd.DataFrame({i: v for i, v in enumerate(vectors)})
df.sort_values(by=df.columns.tolist(), inplace=True, kind="stable")

return tuple((np.array(v)[permutation] for v in vectors)) # noqa
return tuple(df[col].to_numpy() for col in df.columns) # noqa


def pretty_format_elapsed_time(t0: float, t1: Optional[float] = None) -> str:
@@ -71,6 +75,34 @@ def pretty_format_elapsed_time(t0: float, t1: Optional[float] = None) -> str:
return f"{hours:.0f}h:{minutes:.0f}m:{seconds:.3f}s"


def truncate_np(v: NDArray[float], places: int) -> NDArray[float]:
"""
Truncate a numpy array to the given number of decimal places.
Implementation based on https://stackoverflow.com/a/28323804
Parameters
----------
v: NDArray[float]
the numpy array to be truncated
places: int
the number of decimal places to truncate to
Returns
-------
NDArray[float]
numpy array with truncated values
"""
assert places >= 0

if places == 0:
return v.round(0)

with decimal.localcontext() as context:
context.rounding = decimal.ROUND_DOWN
exponent = decimal.Decimal(str(10**-places))
return np.array([float(decimal.Decimal(str(n)).quantize(exponent)) for n in v], dtype=float)


def _import_matplotlib():
"""
Helper function to import matplotlib.

0 comments on commit c68b477

Please sign in to comment.