Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge DataLinks get_common_strains to ObjectLink #250

Merged
merged 6 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 0 additions & 59 deletions src/nplinker/nplinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import sys
from os import PathLike
from pprint import pformat
from typing import TYPE_CHECKING
from . import setup_logging
from .arranger import DatasetArranger
from .config import load_config
Expand All @@ -22,10 +21,6 @@
from .scoring.rosetta_scoring import RosettaScoring


if TYPE_CHECKING:
from collections.abc import Sequence
from .strain import Strain

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -88,8 +83,6 @@ def __init__(self, config_file: str | PathLike):
name: False for name in self._scoring_methods.keys()
}

self._datalinks = None

self._repro_data = {}
repro_file = self.config.get("repro_file")
if repro_file:
Expand Down Expand Up @@ -264,64 +257,12 @@ def get_links(
)
link_collection = method.get_links(*objects_for_method, link_collection=link_collection)

if not self._datalinks:
logger.debug("Creating internal datalinks object")
self._datalinks = self.scoring_method(MetcalfScoring.name).datalinks
logger.debug("Created internal datalinks object")

if len(link_collection) == 0:
logger.debug("No links found or remaining after merging all method results!")

# populate shared strain info
logger.debug("Calculating shared strain information...")
# TODO more efficient version?
for source, link_data in link_collection.links.items():
if isinstance(source, BGC):
logger.debug("Cannot determine shared strains for BGC input!")
break

targets = list(filter(lambda x: not isinstance(x, BGC), link_data.keys()))
if len(targets) > 0:
if isinstance(source, GCF):
shared_strains = self._datalinks.get_common_strains(targets, [source], True)
for target, link in link_data.items():
if (target, source) in shared_strains:
link.shared_strains = shared_strains[(target, source)]
else:
shared_strains = self._datalinks.get_common_strains([source], targets, True)
for target, link in link_data.items():
if (source, target) in shared_strains:
link.shared_strains = shared_strains[(source, target)]

logger.info("Finished calculating shared strain information")

logger.info("Final size of link collection is {}".format(len(link_collection)))
return link_collection

def get_common_strains(
self,
met: Sequence[Spectrum] | Sequence[MolecularFamily],
gcfs: Sequence[GCF],
filter_no_shared: bool = True,
) -> dict[tuple[Spectrum | MolecularFamily, GCF], list[Strain]]:
"""Get common strains between given spectra/molecular families and GCFs.

Args:
met:
A list of Spectrum or MolecularFamily objects.
gcfs: A list of GCF objects.
filter_no_shared: If True, the pairs of spectrum/mf and GCF
without common strains will be removed from the returned dict;

Returns:
A dict where the keys are tuples of (Spectrum/MolecularFamily, GCF)
and values are a list of shared Strain objects.
"""
if not self._datalinks:
self._datalinks = self.scoring_method(MetcalfScoring.name).datalinks
common_strains = self._datalinks.get_common_strains(met, gcfs, filter_no_shared)
return common_strains

def has_bgc(self, bgc_id):
"""Returns True if BGC ``bgc_id`` exists in the dataset."""
return bgc_id in self._bgc_lookup
Expand Down
61 changes: 0 additions & 61 deletions src/nplinker/process_output.py

This file was deleted.

26 changes: 24 additions & 2 deletions src/nplinker/scoring/object_link.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
from __future__ import annotations
from typing import TYPE_CHECKING


if TYPE_CHECKING:
gcroci2 marked this conversation as resolved.
Show resolved Hide resolved
from nplinker.genomics import GCF
from nplinker.metabolomics import MolecularFamily
from nplinker.metabolomics import Spectrum
from nplinker.scoring import ScoringBase
from nplinker.strain import StrainCollection


class ObjectLink:
"""Class which stores information about a single link between two objects.

Expand All @@ -15,10 +27,15 @@ class ObjectLink:
- the output of the scoring method(s) used for this link (e.g. a metcalf score)
"""

def __init__(self, source, target, method, data=None, shared_strains=[]):
def __init__(
self,
source: GCF | Spectrum | MolecularFamily,
target: GCF | Spectrum | MolecularFamily,
method: ScoringBase,
data=None,
):
self.source = source
self.target = target
self.shared_strains = shared_strains
self._method_data = {method: data}

def _merge(self, other_link):
Expand All @@ -28,6 +45,11 @@ def _merge(self, other_link):
def set_data(self, method, newdata):
self._method_data[method] = newdata

@property
def common_strains(self) -> StrainCollection:
"""Get the strains common to both source and target."""
return self.source.strains(self.target.strains)

@property
def method_count(self):
return len(self._method_data)
Expand Down
15 changes: 15 additions & 0 deletions src/nplinker/strain/strain_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,21 @@ def filter(self, strain_set: set[Strain]):
if strain not in strain_set:
self.remove(strain)

def intersection(self, other: StrainCollection) -> StrainCollection:
"""Get the intersection of two strain collections.

Args:
other: The other strain collection to compare.

Returns:
StrainCollection object containing the strains that are in both collections.
"""
intersection = StrainCollection()
for strain in self:
if strain in other:
intersection.add(strain)
return intersection

def has_name(self, name: str) -> bool:
"""Check if the strain collection contains the given strain name (id or alias).

Expand Down
7 changes: 1 addition & 6 deletions tests/unit/scoring/test_metcalf_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,7 @@ def test_setup_load_cache(mc, npl):


def test_calc_score_raw_score(mc):
"""Test `calc_score` method for `raw_score_spec_gcf` and `raw_score_mf_gcf`.

The expected values are calculated manually by using values from `test_init`
of `test_data_links.py` and the default scoring weights.
"""
"""Test `calc_score` method for `raw_score_spec_gcf` and `raw_score_mf_gcf`."""
# link type = 'spec-gcf'
mc.calc_score(link_type="spec-gcf")
assert_frame_equal(
Expand Down Expand Up @@ -185,7 +181,6 @@ def test_get_links_gcf_standardised_false(mc, gcfs, spectra, mfs):
assert len(links) == 3
assert {i.gcf_id for i in links.keys()} == {"gcf1", "gcf2", "gcf3"}
assert isinstance(links[gcfs[0]][spectra[0]], ObjectLink)
# expected values are from `test_get_links_gcf` of test_link_finder.py
assert links[gcfs[0]][spectra[0]].data(mc) == 12
assert links[gcfs[1]][spectra[0]].data(mc) == -9
assert links[gcfs[2]][spectra[0]].data(mc) == 11
Expand Down
27 changes: 0 additions & 27 deletions tests/unit/scoring/test_nplinker_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
from nplinker.scoring import ObjectLink


pytestmark = pytest.mark.skip(reason="Skipping all tests in this file temporarily for dev")


def test_get_links_gcf_standardised_false(npl, mc, gcfs, spectra, mfs, strains_list):
"""Test `get_links` method when input is GCF objects and `standardised` is False."""
# test raw scores (no standardisation)
Expand All @@ -20,20 +17,12 @@ def test_get_links_gcf_standardised_false(npl, mc, gcfs, spectra, mfs, strains_l
assert len(links) == 3
assert {i.gcf_id for i in links.keys()} == {"gcf1", "gcf2", "gcf3"}
assert isinstance(links[gcfs[0]][spectra[0]], ObjectLink)
# expected values are from `test_get_links_gcf` of test_link_finder.py
assert links[gcfs[0]][spectra[0]].data(mc) == 12
assert links[gcfs[1]][spectra[0]].data(mc) == -9
assert links[gcfs[2]][spectra[0]].data(mc) == 11
assert links[gcfs[0]][mfs[0]].data(mc) == 12
assert links[gcfs[1]][mfs[1]].data(mc) == 12
assert links[gcfs[2]][mfs[2]].data(mc) == 21
# expected values are from `test_get_common_strains_spec` of test_data_links.py
assert links[gcfs[0]][spectra[0]].shared_strains == [strains_list[0]]
assert links[gcfs[1]][spectra[0]].shared_strains == []
assert links[gcfs[2]][spectra[0]].shared_strains == [strains_list[0]]
assert links[gcfs[0]][mfs[0]].shared_strains == [strains_list[0]]
assert links[gcfs[1]][mfs[1]].shared_strains == [strains_list[1]]
assert set(links[gcfs[2]][mfs[2]].shared_strains) == set(strains_list[0:2])

# when test cutoff is 0, i.e. taking scores >= 0
mc.cutoff = 0
Expand All @@ -49,12 +38,6 @@ def test_get_links_gcf_standardised_false(npl, mc, gcfs, spectra, mfs, strains_l
assert links[gcfs[0]][mfs[0]].data(mc) == 12
assert links[gcfs[1]][mfs[1]].data(mc) == 12
assert links[gcfs[2]][mfs[2]].data(mc) == 21
# test shared strains
assert links[gcfs[0]][spectra[0]].shared_strains == [strains_list[0]]
assert links[gcfs[2]][spectra[0]].shared_strains == [strains_list[0]]
assert links[gcfs[0]][mfs[0]].shared_strains == [strains_list[0]]
assert links[gcfs[1]][mfs[1]].shared_strains == [strains_list[1]]
assert set(links[gcfs[2]][mfs[2]].shared_strains) == set(strains_list[0:2])


@pytest.mark.skip(reason="To add after refactoring relevant code.")
Expand All @@ -78,9 +61,6 @@ def test_get_links_spec_standardised_false(npl, mc, gcfs, spectra, strains_list)
assert links[spectra[0]][gcfs[0]].data(mc) == 12
assert links[spectra[0]][gcfs[1]].data(mc) == -9
assert links[spectra[0]][gcfs[2]].data(mc) == 11
assert links[spectra[0]][gcfs[0]].shared_strains == [strains_list[0]]
assert links[spectra[0]][gcfs[1]].shared_strains == []
assert links[spectra[0]][gcfs[2]].shared_strains == [strains_list[0]]

mc.cutoff = 0
links = npl.get_links(list(spectra), mc, and_mode=True)
Expand All @@ -92,8 +72,6 @@ def test_get_links_spec_standardised_false(npl, mc, gcfs, spectra, strains_list)
assert links[spectra[0]][gcfs[0]].data(mc) == 12
assert links[spectra[0]].get(gcfs[1]) is None
assert links[spectra[0]][gcfs[2]].data(mc) == 11
assert links[spectra[0]][gcfs[0]].shared_strains == [strains_list[0]]
assert links[spectra[0]][gcfs[2]].shared_strains == [strains_list[0]]


@pytest.mark.skip(reason="To add after refactoring relevant code.")
Expand All @@ -117,9 +95,6 @@ def test_get_links_mf_standardised_false(npl, mc, gcfs, mfs, strains_list):
assert links[mfs[0]][gcfs[0]].data(mc) == 12
assert links[mfs[0]][gcfs[1]].data(mc) == -9
assert links[mfs[0]][gcfs[2]].data(mc) == 11
assert links[mfs[0]][gcfs[0]].shared_strains == [strains_list[0]]
assert links[mfs[0]][gcfs[1]].shared_strains == []
assert links[mfs[0]][gcfs[2]].shared_strains == [strains_list[0]]

mc.cutoff = 0
links = npl.get_links(list(mfs), mc, and_mode=True)
Expand All @@ -131,8 +106,6 @@ def test_get_links_mf_standardised_false(npl, mc, gcfs, mfs, strains_list):
assert links[mfs[0]][gcfs[0]].data(mc) == 12
assert links[mfs[0]].get(gcfs[1]) is None
assert links[mfs[0]][gcfs[2]].data(mc) == 11
assert links[mfs[0]][gcfs[0]].shared_strains == [strains_list[0]]
assert links[mfs[0]][gcfs[2]].shared_strains == [strains_list[0]]


@pytest.mark.skip(reason="To add after refactoring relevant code.")
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/strain/test_strain_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,27 @@ def test_filter(collection: StrainCollection, strain: Strain):
assert len(collection) == 1


def test_intersection(collection: StrainCollection, strain: Strain):
# test empty collection
other = StrainCollection()
actual = collection.intersection(other)
assert len(actual) == 0

# test no intersection
other = StrainCollection()
other.add(Strain("strain_2"))
actual = collection.intersection(other)
assert len(actual) == 0

# test intersection
other = StrainCollection()
other.add(strain)
other.add(Strain("strain_2"))
actual = collection.intersection(other)
assert len(actual) == 1
assert strain in actual


def test_has_name(collection: StrainCollection):
assert collection.has_name("strain_1")
assert collection.has_name("strain_1_a")
Expand Down
Loading