Skip to content

Commit

Permalink
merge DataLinks get_common_strains to ObjectLink (#250)
Browse files Browse the repository at this point in the history
* add method intersection to StrainCollection

* rename `shared_strains` to `common_strains`

* change `common_strains` to a property in ObjectLink

* remove useless code for getting common strains

now it's very easy to get common strains using the new `StrainCollection.intersection` method

* remove useless unit tests

* remove unused functions that use DataLink
  • Loading branch information
CunliangGeng authored Jun 10, 2024
1 parent 5ba9e31 commit 3bed728
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 155 deletions.
59 changes: 0 additions & 59 deletions src/nplinker/nplinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import sys
from os import PathLike
from pprint import pformat
from typing import TYPE_CHECKING
from . import setup_logging
from .arranger import DatasetArranger
from .config import load_config
Expand All @@ -22,10 +21,6 @@
from .scoring.rosetta_scoring import RosettaScoring


if TYPE_CHECKING:
from collections.abc import Sequence
from .strain import Strain

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -88,8 +83,6 @@ def __init__(self, config_file: str | PathLike):
name: False for name in self._scoring_methods.keys()
}

self._datalinks = None

self._repro_data = {}
repro_file = self.config.get("repro_file")
if repro_file:
Expand Down Expand Up @@ -264,64 +257,12 @@ def get_links(
)
link_collection = method.get_links(*objects_for_method, link_collection=link_collection)

if not self._datalinks:
logger.debug("Creating internal datalinks object")
self._datalinks = self.scoring_method(MetcalfScoring.name).datalinks
logger.debug("Created internal datalinks object")

if len(link_collection) == 0:
logger.debug("No links found or remaining after merging all method results!")

# populate shared strain info
logger.debug("Calculating shared strain information...")
# TODO more efficient version?
for source, link_data in link_collection.links.items():
if isinstance(source, BGC):
logger.debug("Cannot determine shared strains for BGC input!")
break

targets = list(filter(lambda x: not isinstance(x, BGC), link_data.keys()))
if len(targets) > 0:
if isinstance(source, GCF):
shared_strains = self._datalinks.get_common_strains(targets, [source], True)
for target, link in link_data.items():
if (target, source) in shared_strains:
link.shared_strains = shared_strains[(target, source)]
else:
shared_strains = self._datalinks.get_common_strains([source], targets, True)
for target, link in link_data.items():
if (source, target) in shared_strains:
link.shared_strains = shared_strains[(source, target)]

logger.info("Finished calculating shared strain information")

logger.info("Final size of link collection is {}".format(len(link_collection)))
return link_collection

def get_common_strains(
self,
met: Sequence[Spectrum] | Sequence[MolecularFamily],
gcfs: Sequence[GCF],
filter_no_shared: bool = True,
) -> dict[tuple[Spectrum | MolecularFamily, GCF], list[Strain]]:
"""Get common strains between given spectra/molecular families and GCFs.
Args:
met:
A list of Spectrum or MolecularFamily objects.
gcfs: A list of GCF objects.
filter_no_shared: If True, the pairs of spectrum/mf and GCF
without common strains will be removed from the returned dict;
Returns:
A dict where the keys are tuples of (Spectrum/MolecularFamily, GCF)
and values are a list of shared Strain objects.
"""
if not self._datalinks:
self._datalinks = self.scoring_method(MetcalfScoring.name).datalinks
common_strains = self._datalinks.get_common_strains(met, gcfs, filter_no_shared)
return common_strains

def has_bgc(self, bgc_id):
"""Returns True if BGC ``bgc_id`` exists in the dataset."""
return bgc_id in self._bgc_lookup
Expand Down
61 changes: 0 additions & 61 deletions src/nplinker/process_output.py

This file was deleted.

26 changes: 24 additions & 2 deletions src/nplinker/scoring/object_link.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
from __future__ import annotations
from typing import TYPE_CHECKING


if TYPE_CHECKING:
from nplinker.genomics import GCF
from nplinker.metabolomics import MolecularFamily
from nplinker.metabolomics import Spectrum
from nplinker.scoring import ScoringBase
from nplinker.strain import StrainCollection


class ObjectLink:
"""Class which stores information about a single link between two objects.
Expand All @@ -15,10 +27,15 @@ class ObjectLink:
- the output of the scoring method(s) used for this link (e.g. a metcalf score)
"""

def __init__(self, source, target, method, data=None, shared_strains=[]):
def __init__(
self,
source: GCF | Spectrum | MolecularFamily,
target: GCF | Spectrum | MolecularFamily,
method: ScoringBase,
data=None,
):
self.source = source
self.target = target
self.shared_strains = shared_strains
self._method_data = {method: data}

def _merge(self, other_link):
Expand All @@ -28,6 +45,11 @@ def _merge(self, other_link):
def set_data(self, method, newdata):
self._method_data[method] = newdata

@property
def common_strains(self) -> StrainCollection:
"""Get the strains common to both source and target."""
return self.source.strains(self.target.strains)

@property
def method_count(self):
return len(self._method_data)
Expand Down
15 changes: 15 additions & 0 deletions src/nplinker/strain/strain_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,21 @@ def filter(self, strain_set: set[Strain]):
if strain not in strain_set:
self.remove(strain)

def intersection(self, other: StrainCollection) -> StrainCollection:
"""Get the intersection of two strain collections.
Args:
other: The other strain collection to compare.
Returns:
StrainCollection object containing the strains that are in both collections.
"""
intersection = StrainCollection()
for strain in self:
if strain in other:
intersection.add(strain)
return intersection

def has_name(self, name: str) -> bool:
"""Check if the strain collection contains the given strain name (id or alias).
Expand Down
7 changes: 1 addition & 6 deletions tests/unit/scoring/test_metcalf_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,7 @@ def test_setup_load_cache(mc, npl):


def test_calc_score_raw_score(mc):
"""Test `calc_score` method for `raw_score_spec_gcf` and `raw_score_mf_gcf`.
The expected values are calculated manually by using values from `test_init`
of `test_data_links.py` and the default scoring weights.
"""
"""Test `calc_score` method for `raw_score_spec_gcf` and `raw_score_mf_gcf`."""
# link type = 'spec-gcf'
mc.calc_score(link_type="spec-gcf")
assert_frame_equal(
Expand Down Expand Up @@ -185,7 +181,6 @@ def test_get_links_gcf_standardised_false(mc, gcfs, spectra, mfs):
assert len(links) == 3
assert {i.gcf_id for i in links.keys()} == {"gcf1", "gcf2", "gcf3"}
assert isinstance(links[gcfs[0]][spectra[0]], ObjectLink)
# expected values are from `test_get_links_gcf` of test_link_finder.py
assert links[gcfs[0]][spectra[0]].data(mc) == 12
assert links[gcfs[1]][spectra[0]].data(mc) == -9
assert links[gcfs[2]][spectra[0]].data(mc) == 11
Expand Down
27 changes: 0 additions & 27 deletions tests/unit/scoring/test_nplinker_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
from nplinker.scoring import ObjectLink


pytestmark = pytest.mark.skip(reason="Skipping all tests in this file temporarily for dev")


def test_get_links_gcf_standardised_false(npl, mc, gcfs, spectra, mfs, strains_list):
"""Test `get_links` method when input is GCF objects and `standardised` is False."""
# test raw scores (no standardisation)
Expand All @@ -20,20 +17,12 @@ def test_get_links_gcf_standardised_false(npl, mc, gcfs, spectra, mfs, strains_l
assert len(links) == 3
assert {i.gcf_id for i in links.keys()} == {"gcf1", "gcf2", "gcf3"}
assert isinstance(links[gcfs[0]][spectra[0]], ObjectLink)
# expected values are from `test_get_links_gcf` of test_link_finder.py
assert links[gcfs[0]][spectra[0]].data(mc) == 12
assert links[gcfs[1]][spectra[0]].data(mc) == -9
assert links[gcfs[2]][spectra[0]].data(mc) == 11
assert links[gcfs[0]][mfs[0]].data(mc) == 12
assert links[gcfs[1]][mfs[1]].data(mc) == 12
assert links[gcfs[2]][mfs[2]].data(mc) == 21
# expected values are from `test_get_common_strains_spec` of test_data_links.py
assert links[gcfs[0]][spectra[0]].shared_strains == [strains_list[0]]
assert links[gcfs[1]][spectra[0]].shared_strains == []
assert links[gcfs[2]][spectra[0]].shared_strains == [strains_list[0]]
assert links[gcfs[0]][mfs[0]].shared_strains == [strains_list[0]]
assert links[gcfs[1]][mfs[1]].shared_strains == [strains_list[1]]
assert set(links[gcfs[2]][mfs[2]].shared_strains) == set(strains_list[0:2])

# when test cutoff is 0, i.e. taking scores >= 0
mc.cutoff = 0
Expand All @@ -49,12 +38,6 @@ def test_get_links_gcf_standardised_false(npl, mc, gcfs, spectra, mfs, strains_l
assert links[gcfs[0]][mfs[0]].data(mc) == 12
assert links[gcfs[1]][mfs[1]].data(mc) == 12
assert links[gcfs[2]][mfs[2]].data(mc) == 21
# test shared strains
assert links[gcfs[0]][spectra[0]].shared_strains == [strains_list[0]]
assert links[gcfs[2]][spectra[0]].shared_strains == [strains_list[0]]
assert links[gcfs[0]][mfs[0]].shared_strains == [strains_list[0]]
assert links[gcfs[1]][mfs[1]].shared_strains == [strains_list[1]]
assert set(links[gcfs[2]][mfs[2]].shared_strains) == set(strains_list[0:2])


@pytest.mark.skip(reason="To add after refactoring relevant code.")
Expand All @@ -78,9 +61,6 @@ def test_get_links_spec_standardised_false(npl, mc, gcfs, spectra, strains_list)
assert links[spectra[0]][gcfs[0]].data(mc) == 12
assert links[spectra[0]][gcfs[1]].data(mc) == -9
assert links[spectra[0]][gcfs[2]].data(mc) == 11
assert links[spectra[0]][gcfs[0]].shared_strains == [strains_list[0]]
assert links[spectra[0]][gcfs[1]].shared_strains == []
assert links[spectra[0]][gcfs[2]].shared_strains == [strains_list[0]]

mc.cutoff = 0
links = npl.get_links(list(spectra), mc, and_mode=True)
Expand All @@ -92,8 +72,6 @@ def test_get_links_spec_standardised_false(npl, mc, gcfs, spectra, strains_list)
assert links[spectra[0]][gcfs[0]].data(mc) == 12
assert links[spectra[0]].get(gcfs[1]) is None
assert links[spectra[0]][gcfs[2]].data(mc) == 11
assert links[spectra[0]][gcfs[0]].shared_strains == [strains_list[0]]
assert links[spectra[0]][gcfs[2]].shared_strains == [strains_list[0]]


@pytest.mark.skip(reason="To add after refactoring relevant code.")
Expand All @@ -117,9 +95,6 @@ def test_get_links_mf_standardised_false(npl, mc, gcfs, mfs, strains_list):
assert links[mfs[0]][gcfs[0]].data(mc) == 12
assert links[mfs[0]][gcfs[1]].data(mc) == -9
assert links[mfs[0]][gcfs[2]].data(mc) == 11
assert links[mfs[0]][gcfs[0]].shared_strains == [strains_list[0]]
assert links[mfs[0]][gcfs[1]].shared_strains == []
assert links[mfs[0]][gcfs[2]].shared_strains == [strains_list[0]]

mc.cutoff = 0
links = npl.get_links(list(mfs), mc, and_mode=True)
Expand All @@ -131,8 +106,6 @@ def test_get_links_mf_standardised_false(npl, mc, gcfs, mfs, strains_list):
assert links[mfs[0]][gcfs[0]].data(mc) == 12
assert links[mfs[0]].get(gcfs[1]) is None
assert links[mfs[0]][gcfs[2]].data(mc) == 11
assert links[mfs[0]][gcfs[0]].shared_strains == [strains_list[0]]
assert links[mfs[0]][gcfs[2]].shared_strains == [strains_list[0]]


@pytest.mark.skip(reason="To add after refactoring relevant code.")
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/strain/test_strain_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,27 @@ def test_filter(collection: StrainCollection, strain: Strain):
assert len(collection) == 1


def test_intersection(collection: StrainCollection, strain: Strain):
# test empty collection
other = StrainCollection()
actual = collection.intersection(other)
assert len(actual) == 0

# test no intersection
other = StrainCollection()
other.add(Strain("strain_2"))
actual = collection.intersection(other)
assert len(actual) == 0

# test intersection
other = StrainCollection()
other.add(strain)
other.add(Strain("strain_2"))
actual = collection.intersection(other)
assert len(actual) == 1
assert strain in actual


def test_has_name(collection: StrainCollection):
assert collection.has_name("strain_1")
assert collection.has_name("strain_1_a")
Expand Down

0 comments on commit 3bed728

Please sign in to comment.