From 7b532579affb2cd2b8e535472d64e54894412b66 Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Fri, 31 May 2024 16:56:27 +0200 Subject: [PATCH 1/6] add method intersection to StrainCollection --- src/nplinker/strain/strain_collection.py | 15 +++++++++++++++ tests/unit/strain/test_strain_collection.py | 21 +++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/src/nplinker/strain/strain_collection.py b/src/nplinker/strain/strain_collection.py index 74933665..26cd8852 100644 --- a/src/nplinker/strain/strain_collection.py +++ b/src/nplinker/strain/strain_collection.py @@ -121,6 +121,21 @@ def filter(self, strain_set: set[Strain]): if strain not in strain_set: self.remove(strain) + def intersection(self, other: StrainCollection) -> StrainCollection: + """Get the intersection of two strain collections. + + Args: + other: The other strain collection to compare. + + Returns: + StrainCollection object containing the strains that are in both collections. + """ + intersection = StrainCollection() + for strain in self: + if strain in other: + intersection.add(strain) + return intersection + def has_name(self, name: str) -> bool: """Check if the strain collection contains the given strain name (id or alias). diff --git a/tests/unit/strain/test_strain_collection.py b/tests/unit/strain/test_strain_collection.py index 50b844c4..145d1a99 100644 --- a/tests/unit/strain/test_strain_collection.py +++ b/tests/unit/strain/test_strain_collection.py @@ -165,6 +165,27 @@ def test_filter(collection: StrainCollection, strain: Strain): assert len(collection) == 1 +def test_intersection(collection: StrainCollection, strain: Strain): + # test empty collection + other = StrainCollection() + actual = collection.intersection(other) + assert len(actual) == 0 + + # test no intersection + other = StrainCollection() + other.add(Strain("strain_2")) + actual = collection.intersection(other) + assert len(actual) == 0 + + # test intersection + other = StrainCollection() + other.add(strain) + other.add(Strain("strain_2")) + actual = collection.intersection(other) + assert len(actual) == 1 + assert strain in actual + + def test_has_name(collection: StrainCollection): assert collection.has_name("strain_1") assert collection.has_name("strain_1_a") From 15f61eb0ff1fc81d9a55875be4903e56a447cb17 Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Fri, 31 May 2024 16:58:55 +0200 Subject: [PATCH 2/6] rename `shared_strains` to `common_strains` --- src/nplinker/nplinker.py | 12 +++--- src/nplinker/scoring/object_link.py | 4 +- tests/unit/scoring/test_nplinker_scoring.py | 42 ++++++++++----------- 3 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/nplinker/nplinker.py b/src/nplinker/nplinker.py index e4e008b6..5624222d 100644 --- a/src/nplinker/nplinker.py +++ b/src/nplinker/nplinker.py @@ -283,15 +283,15 @@ def get_links( targets = list(filter(lambda x: not isinstance(x, BGC), link_data.keys())) if len(targets) > 0: if isinstance(source, GCF): - shared_strains = self._datalinks.get_common_strains(targets, [source], True) + common_strains = self._datalinks.get_common_strains(targets, [source], True) for target, link in link_data.items(): - if (target, source) in shared_strains: - link.shared_strains = shared_strains[(target, source)] + if (target, source) in common_strains: + link.common_strains = common_strains[(target, source)] else: - shared_strains = self._datalinks.get_common_strains([source], targets, True) + common_strains = self._datalinks.get_common_strains([source], targets, True) for target, link in link_data.items(): - if (source, target) in shared_strains: - link.shared_strains = shared_strains[(source, target)] + if (source, target) in common_strains: + link.common_strains = common_strains[(source, target)] logger.info("Finished calculating shared strain information") diff --git a/src/nplinker/scoring/object_link.py b/src/nplinker/scoring/object_link.py index b0f6f5af..722f4093 100644 --- a/src/nplinker/scoring/object_link.py +++ b/src/nplinker/scoring/object_link.py @@ -15,10 +15,10 @@ class ObjectLink: - the output of the scoring method(s) used for this link (e.g. a metcalf score) """ - def __init__(self, source, target, method, data=None, shared_strains=[]): + def __init__(self, source, target, method, data=None, common_strains=[]): self.source = source self.target = target - self.shared_strains = shared_strains + self.common_strains = common_strains self._method_data = {method: data} def _merge(self, other_link): diff --git a/tests/unit/scoring/test_nplinker_scoring.py b/tests/unit/scoring/test_nplinker_scoring.py index fbfeac53..8062cd75 100644 --- a/tests/unit/scoring/test_nplinker_scoring.py +++ b/tests/unit/scoring/test_nplinker_scoring.py @@ -28,12 +28,12 @@ def test_get_links_gcf_standardised_false(npl, mc, gcfs, spectra, mfs, strains_l assert links[gcfs[1]][mfs[1]].data(mc) == 12 assert links[gcfs[2]][mfs[2]].data(mc) == 21 # expected values are from `test_get_common_strains_spec` of test_data_links.py - assert links[gcfs[0]][spectra[0]].shared_strains == [strains_list[0]] - assert links[gcfs[1]][spectra[0]].shared_strains == [] - assert links[gcfs[2]][spectra[0]].shared_strains == [strains_list[0]] - assert links[gcfs[0]][mfs[0]].shared_strains == [strains_list[0]] - assert links[gcfs[1]][mfs[1]].shared_strains == [strains_list[1]] - assert set(links[gcfs[2]][mfs[2]].shared_strains) == set(strains_list[0:2]) + assert links[gcfs[0]][spectra[0]].common_strains == [strains_list[0]] + assert links[gcfs[1]][spectra[0]].common_strains == [] + assert links[gcfs[2]][spectra[0]].common_strains == [strains_list[0]] + assert links[gcfs[0]][mfs[0]].common_strains == [strains_list[0]] + assert links[gcfs[1]][mfs[1]].common_strains == [strains_list[1]] + assert set(links[gcfs[2]][mfs[2]].common_strains) == set(strains_list[0:2]) # when test cutoff is 0, i.e. taking scores >= 0 mc.cutoff = 0 @@ -50,11 +50,11 @@ def test_get_links_gcf_standardised_false(npl, mc, gcfs, spectra, mfs, strains_l assert links[gcfs[1]][mfs[1]].data(mc) == 12 assert links[gcfs[2]][mfs[2]].data(mc) == 21 # test shared strains - assert links[gcfs[0]][spectra[0]].shared_strains == [strains_list[0]] - assert links[gcfs[2]][spectra[0]].shared_strains == [strains_list[0]] - assert links[gcfs[0]][mfs[0]].shared_strains == [strains_list[0]] - assert links[gcfs[1]][mfs[1]].shared_strains == [strains_list[1]] - assert set(links[gcfs[2]][mfs[2]].shared_strains) == set(strains_list[0:2]) + assert links[gcfs[0]][spectra[0]].common_strains == [strains_list[0]] + assert links[gcfs[2]][spectra[0]].common_strains == [strains_list[0]] + assert links[gcfs[0]][mfs[0]].common_strains == [strains_list[0]] + assert links[gcfs[1]][mfs[1]].common_strains == [strains_list[1]] + assert set(links[gcfs[2]][mfs[2]].common_strains) == set(strains_list[0:2]) @pytest.mark.skip(reason="To add after refactoring relevant code.") @@ -78,9 +78,9 @@ def test_get_links_spec_standardised_false(npl, mc, gcfs, spectra, strains_list) assert links[spectra[0]][gcfs[0]].data(mc) == 12 assert links[spectra[0]][gcfs[1]].data(mc) == -9 assert links[spectra[0]][gcfs[2]].data(mc) == 11 - assert links[spectra[0]][gcfs[0]].shared_strains == [strains_list[0]] - assert links[spectra[0]][gcfs[1]].shared_strains == [] - assert links[spectra[0]][gcfs[2]].shared_strains == [strains_list[0]] + assert links[spectra[0]][gcfs[0]].common_strains == [strains_list[0]] + assert links[spectra[0]][gcfs[1]].common_strains == [] + assert links[spectra[0]][gcfs[2]].common_strains == [strains_list[0]] mc.cutoff = 0 links = npl.get_links(list(spectra), mc, and_mode=True) @@ -92,8 +92,8 @@ def test_get_links_spec_standardised_false(npl, mc, gcfs, spectra, strains_list) assert links[spectra[0]][gcfs[0]].data(mc) == 12 assert links[spectra[0]].get(gcfs[1]) is None assert links[spectra[0]][gcfs[2]].data(mc) == 11 - assert links[spectra[0]][gcfs[0]].shared_strains == [strains_list[0]] - assert links[spectra[0]][gcfs[2]].shared_strains == [strains_list[0]] + assert links[spectra[0]][gcfs[0]].common_strains == [strains_list[0]] + assert links[spectra[0]][gcfs[2]].common_strains == [strains_list[0]] @pytest.mark.skip(reason="To add after refactoring relevant code.") @@ -117,9 +117,9 @@ def test_get_links_mf_standardised_false(npl, mc, gcfs, mfs, strains_list): assert links[mfs[0]][gcfs[0]].data(mc) == 12 assert links[mfs[0]][gcfs[1]].data(mc) == -9 assert links[mfs[0]][gcfs[2]].data(mc) == 11 - assert links[mfs[0]][gcfs[0]].shared_strains == [strains_list[0]] - assert links[mfs[0]][gcfs[1]].shared_strains == [] - assert links[mfs[0]][gcfs[2]].shared_strains == [strains_list[0]] + assert links[mfs[0]][gcfs[0]].common_strains == [strains_list[0]] + assert links[mfs[0]][gcfs[1]].common_strains == [] + assert links[mfs[0]][gcfs[2]].common_strains == [strains_list[0]] mc.cutoff = 0 links = npl.get_links(list(mfs), mc, and_mode=True) @@ -131,8 +131,8 @@ def test_get_links_mf_standardised_false(npl, mc, gcfs, mfs, strains_list): assert links[mfs[0]][gcfs[0]].data(mc) == 12 assert links[mfs[0]].get(gcfs[1]) is None assert links[mfs[0]][gcfs[2]].data(mc) == 11 - assert links[mfs[0]][gcfs[0]].shared_strains == [strains_list[0]] - assert links[mfs[0]][gcfs[2]].shared_strains == [strains_list[0]] + assert links[mfs[0]][gcfs[0]].common_strains == [strains_list[0]] + assert links[mfs[0]][gcfs[2]].common_strains == [strains_list[0]] @pytest.mark.skip(reason="To add after refactoring relevant code.") From 3afcfd71595a66ced89d13167d9854bda9c89a6f Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Fri, 31 May 2024 17:07:40 +0200 Subject: [PATCH 3/6] change `common_strains` to a property in ObjectLink --- src/nplinker/scoring/object_link.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/src/nplinker/scoring/object_link.py b/src/nplinker/scoring/object_link.py index 722f4093..ab493c77 100644 --- a/src/nplinker/scoring/object_link.py +++ b/src/nplinker/scoring/object_link.py @@ -1,3 +1,15 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from nplinker.genomics import GCF + from nplinker.metabolomics import MolecularFamily + from nplinker.metabolomics import Spectrum + from nplinker.scoring import ScoringBase + from nplinker.strain import StrainCollection + + class ObjectLink: """Class which stores information about a single link between two objects. @@ -15,10 +27,15 @@ class ObjectLink: - the output of the scoring method(s) used for this link (e.g. a metcalf score) """ - def __init__(self, source, target, method, data=None, common_strains=[]): + def __init__( + self, + source: GCF | Spectrum | MolecularFamily, + target: GCF | Spectrum | MolecularFamily, + method: ScoringBase, + data=None, + ): self.source = source self.target = target - self.common_strains = common_strains self._method_data = {method: data} def _merge(self, other_link): @@ -28,6 +45,11 @@ def _merge(self, other_link): def set_data(self, method, newdata): self._method_data[method] = newdata + @property + def common_strains(self) -> StrainCollection: + """Get the strains common to both source and target.""" + return self.source.strains(self.target.strains) + @property def method_count(self): return len(self._method_data) From 53ba6fd416bd0c801c5a3c5e711d6e8d618880ea Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Fri, 31 May 2024 17:31:37 +0200 Subject: [PATCH 4/6] remove useless code for getting common strains now it's very easy to get common strains using the new `StrainCollection.intersection` method --- src/nplinker/nplinker.py | 59 ---------------------------------------- 1 file changed, 59 deletions(-) diff --git a/src/nplinker/nplinker.py b/src/nplinker/nplinker.py index 5624222d..537f68fa 100644 --- a/src/nplinker/nplinker.py +++ b/src/nplinker/nplinker.py @@ -3,7 +3,6 @@ import sys from os import PathLike from pprint import pformat -from typing import TYPE_CHECKING from . import setup_logging from .arranger import DatasetArranger from .config import load_config @@ -22,10 +21,6 @@ from .scoring.rosetta_scoring import RosettaScoring -if TYPE_CHECKING: - from collections.abc import Sequence - from .strain import Strain - logger = logging.getLogger(__name__) @@ -88,8 +83,6 @@ def __init__(self, config_file: str | PathLike): name: False for name in self._scoring_methods.keys() } - self._datalinks = None - self._repro_data = {} repro_file = self.config.get("repro_file") if repro_file: @@ -264,64 +257,12 @@ def get_links( ) link_collection = method.get_links(*objects_for_method, link_collection=link_collection) - if not self._datalinks: - logger.debug("Creating internal datalinks object") - self._datalinks = self.scoring_method(MetcalfScoring.name).datalinks - logger.debug("Created internal datalinks object") - if len(link_collection) == 0: logger.debug("No links found or remaining after merging all method results!") - # populate shared strain info - logger.debug("Calculating shared strain information...") - # TODO more efficient version? - for source, link_data in link_collection.links.items(): - if isinstance(source, BGC): - logger.debug("Cannot determine shared strains for BGC input!") - break - - targets = list(filter(lambda x: not isinstance(x, BGC), link_data.keys())) - if len(targets) > 0: - if isinstance(source, GCF): - common_strains = self._datalinks.get_common_strains(targets, [source], True) - for target, link in link_data.items(): - if (target, source) in common_strains: - link.common_strains = common_strains[(target, source)] - else: - common_strains = self._datalinks.get_common_strains([source], targets, True) - for target, link in link_data.items(): - if (source, target) in common_strains: - link.common_strains = common_strains[(source, target)] - - logger.info("Finished calculating shared strain information") - logger.info("Final size of link collection is {}".format(len(link_collection))) return link_collection - def get_common_strains( - self, - met: Sequence[Spectrum] | Sequence[MolecularFamily], - gcfs: Sequence[GCF], - filter_no_shared: bool = True, - ) -> dict[tuple[Spectrum | MolecularFamily, GCF], list[Strain]]: - """Get common strains between given spectra/molecular families and GCFs. - - Args: - met: - A list of Spectrum or MolecularFamily objects. - gcfs: A list of GCF objects. - filter_no_shared: If True, the pairs of spectrum/mf and GCF - without common strains will be removed from the returned dict; - - Returns: - A dict where the keys are tuples of (Spectrum/MolecularFamily, GCF) - and values are a list of shared Strain objects. - """ - if not self._datalinks: - self._datalinks = self.scoring_method(MetcalfScoring.name).datalinks - common_strains = self._datalinks.get_common_strains(met, gcfs, filter_no_shared) - return common_strains - def has_bgc(self, bgc_id): """Returns True if BGC ``bgc_id`` exists in the dataset.""" return bgc_id in self._bgc_lookup From 35f5a88486dfbad97bc35ab0139e9fd4060cad94 Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Fri, 31 May 2024 17:33:12 +0200 Subject: [PATCH 5/6] remove useless unit tests --- tests/unit/scoring/test_metcalf_scoring.py | 7 +----- tests/unit/scoring/test_nplinker_scoring.py | 27 --------------------- 2 files changed, 1 insertion(+), 33 deletions(-) diff --git a/tests/unit/scoring/test_metcalf_scoring.py b/tests/unit/scoring/test_metcalf_scoring.py index 74cdd93f..b1cf4c7a 100644 --- a/tests/unit/scoring/test_metcalf_scoring.py +++ b/tests/unit/scoring/test_metcalf_scoring.py @@ -129,11 +129,7 @@ def test_setup_load_cache(mc, npl): def test_calc_score_raw_score(mc): - """Test `calc_score` method for `raw_score_spec_gcf` and `raw_score_mf_gcf`. - - The expected values are calculated manually by using values from `test_init` - of `test_data_links.py` and the default scoring weights. - """ + """Test `calc_score` method for `raw_score_spec_gcf` and `raw_score_mf_gcf`.""" # link type = 'spec-gcf' mc.calc_score(link_type="spec-gcf") assert_frame_equal( @@ -185,7 +181,6 @@ def test_get_links_gcf_standardised_false(mc, gcfs, spectra, mfs): assert len(links) == 3 assert {i.gcf_id for i in links.keys()} == {"gcf1", "gcf2", "gcf3"} assert isinstance(links[gcfs[0]][spectra[0]], ObjectLink) - # expected values are from `test_get_links_gcf` of test_link_finder.py assert links[gcfs[0]][spectra[0]].data(mc) == 12 assert links[gcfs[1]][spectra[0]].data(mc) == -9 assert links[gcfs[2]][spectra[0]].data(mc) == 11 diff --git a/tests/unit/scoring/test_nplinker_scoring.py b/tests/unit/scoring/test_nplinker_scoring.py index 8062cd75..7464eef7 100644 --- a/tests/unit/scoring/test_nplinker_scoring.py +++ b/tests/unit/scoring/test_nplinker_scoring.py @@ -4,9 +4,6 @@ from nplinker.scoring import ObjectLink -pytestmark = pytest.mark.skip(reason="Skipping all tests in this file temporarily for dev") - - def test_get_links_gcf_standardised_false(npl, mc, gcfs, spectra, mfs, strains_list): """Test `get_links` method when input is GCF objects and `standardised` is False.""" # test raw scores (no standardisation) @@ -20,20 +17,12 @@ def test_get_links_gcf_standardised_false(npl, mc, gcfs, spectra, mfs, strains_l assert len(links) == 3 assert {i.gcf_id for i in links.keys()} == {"gcf1", "gcf2", "gcf3"} assert isinstance(links[gcfs[0]][spectra[0]], ObjectLink) - # expected values are from `test_get_links_gcf` of test_link_finder.py assert links[gcfs[0]][spectra[0]].data(mc) == 12 assert links[gcfs[1]][spectra[0]].data(mc) == -9 assert links[gcfs[2]][spectra[0]].data(mc) == 11 assert links[gcfs[0]][mfs[0]].data(mc) == 12 assert links[gcfs[1]][mfs[1]].data(mc) == 12 assert links[gcfs[2]][mfs[2]].data(mc) == 21 - # expected values are from `test_get_common_strains_spec` of test_data_links.py - assert links[gcfs[0]][spectra[0]].common_strains == [strains_list[0]] - assert links[gcfs[1]][spectra[0]].common_strains == [] - assert links[gcfs[2]][spectra[0]].common_strains == [strains_list[0]] - assert links[gcfs[0]][mfs[0]].common_strains == [strains_list[0]] - assert links[gcfs[1]][mfs[1]].common_strains == [strains_list[1]] - assert set(links[gcfs[2]][mfs[2]].common_strains) == set(strains_list[0:2]) # when test cutoff is 0, i.e. taking scores >= 0 mc.cutoff = 0 @@ -49,12 +38,6 @@ def test_get_links_gcf_standardised_false(npl, mc, gcfs, spectra, mfs, strains_l assert links[gcfs[0]][mfs[0]].data(mc) == 12 assert links[gcfs[1]][mfs[1]].data(mc) == 12 assert links[gcfs[2]][mfs[2]].data(mc) == 21 - # test shared strains - assert links[gcfs[0]][spectra[0]].common_strains == [strains_list[0]] - assert links[gcfs[2]][spectra[0]].common_strains == [strains_list[0]] - assert links[gcfs[0]][mfs[0]].common_strains == [strains_list[0]] - assert links[gcfs[1]][mfs[1]].common_strains == [strains_list[1]] - assert set(links[gcfs[2]][mfs[2]].common_strains) == set(strains_list[0:2]) @pytest.mark.skip(reason="To add after refactoring relevant code.") @@ -78,9 +61,6 @@ def test_get_links_spec_standardised_false(npl, mc, gcfs, spectra, strains_list) assert links[spectra[0]][gcfs[0]].data(mc) == 12 assert links[spectra[0]][gcfs[1]].data(mc) == -9 assert links[spectra[0]][gcfs[2]].data(mc) == 11 - assert links[spectra[0]][gcfs[0]].common_strains == [strains_list[0]] - assert links[spectra[0]][gcfs[1]].common_strains == [] - assert links[spectra[0]][gcfs[2]].common_strains == [strains_list[0]] mc.cutoff = 0 links = npl.get_links(list(spectra), mc, and_mode=True) @@ -92,8 +72,6 @@ def test_get_links_spec_standardised_false(npl, mc, gcfs, spectra, strains_list) assert links[spectra[0]][gcfs[0]].data(mc) == 12 assert links[spectra[0]].get(gcfs[1]) is None assert links[spectra[0]][gcfs[2]].data(mc) == 11 - assert links[spectra[0]][gcfs[0]].common_strains == [strains_list[0]] - assert links[spectra[0]][gcfs[2]].common_strains == [strains_list[0]] @pytest.mark.skip(reason="To add after refactoring relevant code.") @@ -117,9 +95,6 @@ def test_get_links_mf_standardised_false(npl, mc, gcfs, mfs, strains_list): assert links[mfs[0]][gcfs[0]].data(mc) == 12 assert links[mfs[0]][gcfs[1]].data(mc) == -9 assert links[mfs[0]][gcfs[2]].data(mc) == 11 - assert links[mfs[0]][gcfs[0]].common_strains == [strains_list[0]] - assert links[mfs[0]][gcfs[1]].common_strains == [] - assert links[mfs[0]][gcfs[2]].common_strains == [strains_list[0]] mc.cutoff = 0 links = npl.get_links(list(mfs), mc, and_mode=True) @@ -131,8 +106,6 @@ def test_get_links_mf_standardised_false(npl, mc, gcfs, mfs, strains_list): assert links[mfs[0]][gcfs[0]].data(mc) == 12 assert links[mfs[0]].get(gcfs[1]) is None assert links[mfs[0]][gcfs[2]].data(mc) == 11 - assert links[mfs[0]][gcfs[0]].common_strains == [strains_list[0]] - assert links[mfs[0]][gcfs[2]].common_strains == [strains_list[0]] @pytest.mark.skip(reason="To add after refactoring relevant code.") From 230162e8651e879027d0726bfbcf84102c5c6ea9 Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Mon, 3 Jun 2024 10:18:32 +0200 Subject: [PATCH 6/6] remove unused functions that use DataLink --- src/nplinker/process_output.py | 61 ---------------------------------- 1 file changed, 61 deletions(-) delete mode 100644 src/nplinker/process_output.py diff --git a/src/nplinker/process_output.py b/src/nplinker/process_output.py deleted file mode 100644 index f981bd5f..00000000 --- a/src/nplinker/process_output.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2021 The NPLinker Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Some functions for processing results - should be put somewhere else -import numpy as np - - -def get_sig_links(scores, random_scores, p_threshold=0.95, direction="greater"): - # gcfs are columns - n_spec, n_gcf = scores.shape - sig_links = np.zeros((n_spec, n_gcf)) - for gpos in range(n_gcf): - max_rand = random_scores[:, gpos].max() - min_rand = random_scores[:, gpos].min() - if p_threshold > 1: - for i, s in enumerate(scores[:, gpos]): - if direction == "greater": - if s > max_rand: - sig_links[i, gpos] = 1 - else: - if s < min_rand: - sig_links[i, gpos] = 1 - else: - perc = np.percentile(random_scores[:, gpos], int(100 * p_threshold)) - for i, s in enumerate(scores[:, gpos]): - if direction == "greater": - if s >= perc: - sig_links[i, gpos] = 1 - else: - if s <= perc: - sig_links[i, gpos] = 1 - return sig_links - - -def get_sig_spec(data_link, sig_links, scores, gcf_pos, min_n_strains=2): - # Check if there are *any* strains in the GCF - # No strains = MiBIG - # Can also filter if only (e.g. 2 strains) - strain_sum = data_link.occurrence_gcf_strain[gcf_pos, :].sum() - if strain_sum < min_n_strains: - return [] - col = sig_links[:, gcf_pos] # get the column - sig_pos = np.where(col == 1)[0] - orig_ids = [] - for sp in sig_pos: - orig_ids.append( - (int(data_link.mapping_spec.iloc[sp]["original spec-id"]), scores[sp, gcf_pos]) - ) - orig_ids.sort(key=lambda x: x[1], reverse=True) - return orig_ids