Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add abstract base class for scoring methods #247

Merged
merged 3 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/nplinker/nplinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from .metabolomics import MolecularFamily
from .metabolomics import Spectrum
from .pickler import save_pickled_data
from .scoring.abc import ScoringBase
from .scoring.link_collection import LinkCollection
from .scoring.metcalf_scoring import MetcalfScoring
from .scoring.methods import ScoringMethod
from .scoring.np_class_scoring import NPClassScoring
from .scoring.rosetta_scoring import RosettaScoring

Expand All @@ -37,9 +37,9 @@ class NPLinker:
# default set of enabled scoring methods
# TODO: ideally these shouldn't be hardcoded like this
SCORING_METHODS = {
MetcalfScoring.NAME: MetcalfScoring,
RosettaScoring.NAME: RosettaScoring,
NPClassScoring.NAME: NPClassScoring,
MetcalfScoring.name: MetcalfScoring,
RosettaScoring.name: RosettaScoring,
NPClassScoring.name: NPClassScoring,
}

def __init__(self, config_file: str | PathLike):
Expand Down Expand Up @@ -266,7 +266,7 @@ def get_links(

if not self._datalinks:
logger.debug("Creating internal datalinks object")
self._datalinks = self.scoring_method(MetcalfScoring.NAME).datalinks
self._datalinks = self.scoring_method(MetcalfScoring.name).datalinks
logger.debug("Created internal datalinks object")

if len(link_collection) == 0:
Expand Down Expand Up @@ -318,7 +318,7 @@ def get_common_strains(
and values are a list of shared Strain objects.
"""
if not self._datalinks:
self._datalinks = self.scoring_method(MetcalfScoring.NAME).datalinks
self._datalinks = self.scoring_method(MetcalfScoring.name).datalinks
common_strains = self._datalinks.get_common_strains(met, gcfs, filter_no_shared)
return common_strains

Expand Down Expand Up @@ -401,7 +401,7 @@ def class_matches(self):
"""ClassMatches with the matched classes and scoring tables from MIBiG."""
return self._class_matches

def scoring_method(self, name: str) -> ScoringMethod | None:
def scoring_method(self, name: str) -> ScoringBase | None:
"""Return an instance of a scoring method.

Args:
Expand Down
4 changes: 2 additions & 2 deletions src/nplinker/scoring/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .abc import ScoringBase
from .link_collection import LinkCollection
from .metcalf_scoring import MetcalfScoring
from .methods import ScoringMethod
from .object_link import ObjectLink


__all__ = ["LinkCollection", "MetcalfScoring", "ScoringMethod", "ObjectLink"]
__all__ = ["LinkCollection", "MetcalfScoring", "ScoringBase", "ObjectLink"]
56 changes: 56 additions & 0 deletions src/nplinker/scoring/abc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import annotations
import logging
from abc import ABC
from abc import abstractmethod
from typing import TYPE_CHECKING


if TYPE_CHECKING:
from nplinker import NPLinker
from . import LinkCollection

logger = logging.getLogger(__name__)


class ScoringBase(ABC):
"""Abstract base class of scoring methods.

Attributes:
name: The name of the scoring method.
npl: The NPLinker object.
"""

name: str = "ScoringBase"

def __init__(self, npl: NPLinker):
"""Initialize the scoring method.

Args:
npl: The NPLinker object.
"""
self.npl = npl

@classmethod
@abstractmethod
def setup(cls, npl: NPLinker):
"""Setup class level attributes."""

@abstractmethod
def get_links(self, *objects, link_collection: LinkCollection) -> LinkCollection:
"""Get links information for the given objects.

Args:
objects: A set of objects.
link_collection: The LinkCollection object.

Returns:
The LinkCollection object.
"""

@abstractmethod
def format_data(self, data) -> str:
"""Format the scoring data to a string."""

@abstractmethod
def sort(self, objects, reverse=True) -> list:
"""Sort the given objects based on the scoring data."""
39 changes: 18 additions & 21 deletions src/nplinker/scoring/metcalf_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from nplinker.metabolomics import Spectrum
from nplinker.pickler import load_pickled_data
from nplinker.pickler import save_pickled_data
from .abc import ScoringBase
from .linking import LINK_TYPES
from .linking import DataLinks
from .linking import LinkFinder
from .linking import isinstance_all
from .methods import ScoringMethod
from .object_link import ObjectLink


Expand All @@ -24,19 +24,19 @@
logger = logging.getLogger(__name__)


class MetcalfScoring(ScoringMethod):
class MetcalfScoring(ScoringBase):
"""Metcalf scoring method.

Attributes:
name: The name of this scoring method, set to a fixed value `metcalf`.
DATALINKS: The DataLinks object to use for scoring.
LINKFINDER: The LinkFinder object to use for scoring.
NAME: The name of the scoring method. This is set to 'metcalf'.
CACHE: The name of the cache file to use for storing the MetcalfScoring.
"""

name = "metcalf"
DATALINKS = None
LINKFINDER = None
NAME = "metcalf"
CACHE = "cache_metcalf_scoring.pckl"

def __init__(self, npl: NPLinker) -> None:
Expand All @@ -57,21 +57,20 @@ def __init__(self, npl: NPLinker) -> None:
self.cutoff = 1.0
self.standardised = True

# TODO CG: not sure why using staticmethod here. Check later and refactor if possible
# TODO CG: refactor this method and extract code for cache file to a separate method
@staticmethod
def setup(npl: NPLinker):
"""Setup the MetcalfScoring object.
@classmethod
def setup(cls, npl: NPLinker):
"""Setup the DataLinks and LinkFinder objects.

DataLinks and LinkFinder objects are created and cached for later use.
This method is only called once to setup the DataLinks and LinkFinder objects.
"""
logger.info(
"MetcalfScoring.setup (bgcs={}, gcfs={}, spectra={}, molfams={}, strains={})".format(
len(npl.bgcs), len(npl.gcfs), len(npl.spectra), len(npl.molfams), len(npl.strains)
)
)

cache_file = npl.output_dir / MetcalfScoring.CACHE
cache_file = npl.output_dir / cls.CACHE

# the metcalf preprocessing can take a long time for large datasets, so it's
# better to cache as the data won't change unless the number of objects does
Expand All @@ -97,27 +96,25 @@ def setup(npl: NPLinker):
break

if cache_ok:
MetcalfScoring.DATALINKS = datalinks
MetcalfScoring.LINKFINDER = linkfinder
cls.DATALINKS = datalinks
cls.LINKFINDER = linkfinder

if MetcalfScoring.DATALINKS is None:
if cls.DATALINKS is None:
logger.info("MetcalfScoring.setup preprocessing dataset (this may take some time)")
MetcalfScoring.DATALINKS = DataLinks(npl.gcfs, npl.spectra, npl.molfams, npl.strains)
MetcalfScoring.LINKFINDER = LinkFinder()
MetcalfScoring.LINKFINDER.calc_score(MetcalfScoring.DATALINKS, link_type=LINK_TYPES[0])
MetcalfScoring.LINKFINDER.calc_score(MetcalfScoring.DATALINKS, link_type=LINK_TYPES[1])
cls.DATALINKS = DataLinks(npl.gcfs, npl.spectra, npl.molfams, npl.strains)
cls.LINKFINDER = LinkFinder()
cls.LINKFINDER.calc_score(MetcalfScoring.DATALINKS, link_type=LINK_TYPES[0])
cls.LINKFINDER.calc_score(MetcalfScoring.DATALINKS, link_type=LINK_TYPES[1])
logger.info("MetcalfScoring.setup caching results")
save_pickled_data(
(dataset_counts, MetcalfScoring.DATALINKS, MetcalfScoring.LINKFINDER), cache_file
)
save_pickled_data((dataset_counts, cls.DATALINKS, cls.LINKFINDER), cache_file)

logger.info("MetcalfScoring.setup completed")

# TODO CG: is it needed? remove it if not
@property
def datalinks(self) -> DataLinks | None:
"""Get the DataLinks object used for scoring."""
return MetcalfScoring.DATALINKS
return self.DATALINKS

def get_links(
self, *objects: GCF | Spectrum | MolecularFamily, link_collection: LinkCollection
Expand Down
50 changes: 0 additions & 50 deletions src/nplinker/scoring/methods.py

This file was deleted.

12 changes: 6 additions & 6 deletions src/nplinker/scoring/np_class_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
from nplinker.genomics import BGC
from nplinker.genomics import GCF
from nplinker.metabolomics import Spectrum
from nplinker.scoring.abc import ScoringBase
from nplinker.scoring.metcalf_scoring import MetcalfScoring
from nplinker.scoring.methods import ScoringMethod
from nplinker.scoring.object_link import ObjectLink


logger = logging.getLogger(__name__)


class NPClassScoring(ScoringMethod):
NAME = "npclassscore"
class NPClassScoring(ScoringBase):
name = "npclassscore"

def __init__(self, npl):
super().__init__(npl)
Expand Down Expand Up @@ -313,8 +313,8 @@ def _get_met_classes(self, spec_like, method="mix"):
)
return spec_like_classes, spec_like_classes_names_inds

@staticmethod
def setup(npl):
@classmethod
def setup(cls, npl):
"""Perform any one-off initialisation required (will only be called once)."""
logger.info("Set up NPClassScore scoring")
met_options = npl.chem_classes.class_predict_options
Expand Down Expand Up @@ -347,7 +347,7 @@ def get_links(self, objects, link_collection):
logger.info("Using Metcalf scoring to get shared strains")
# get mapping of shared strains
if not self.npl._datalinks:
self.npl._datalinks = self.npl.scoring_method(MetcalfScoring.NAME).datalinks
self.npl._datalinks = self.npl.scoring_method(MetcalfScoring.name).datalinks
if obj_is_gen:
common_strains = self.npl.get_common_strains(targets, objects)
else:
Expand Down
20 changes: 11 additions & 9 deletions src/nplinker/scoring/rosetta_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
from nplinker.genomics.bgc import BGC
from nplinker.genomics.gcf import GCF
from nplinker.metabolomics import MolecularFamily
from nplinker.scoring.methods import ScoringMethod
from nplinker.scoring.abc import ScoringBase
from nplinker.scoring.object_link import ObjectLink
from nplinker.scoring.rosetta.rosetta import Rosetta


logger = logging.getLogger(__name__)


class RosettaScoring(ScoringMethod):
NAME = "rosetta"
class RosettaScoring(ScoringBase):
name = "rosetta"
ROSETTA_OBJ = None

def __init__(self, npl):
Expand All @@ -22,10 +22,14 @@ def __init__(self, npl):
self.spec_score_cutoff = 0.0
self.bgc_score_cutoff = 0.0

@staticmethod
def setup(npl):
@classmethod
def setup(cls, npl):
"""Setup the Rosetta object and run the scoring algorithm.

This method is only called once to setup the Rosetta object.
"""
logger.info("RosettaScoring setup")
RosettaScoring.ROSETTA_OBJ = Rosetta(npl, ignore_genomic_cache=False)
cls.ROSETTA_OBJ = Rosetta(npl, ignore_genomic_cache=False)
ms1_tol = Rosetta.DEF_MS1_TOL
ms2_tol = Rosetta.DEF_MS2_TOL
score_thresh = Rosetta.DEF_SCORE_THRESH
Expand All @@ -35,9 +39,7 @@ def setup(npl):
npl.config
)

RosettaScoring.ROSETTA_OBJ.run(
npl.spectra, npl.bgcs, ms1_tol, ms2_tol, score_thresh, min_match_peaks
)
cls.ROSETTA_OBJ.run(npl.spectra, npl.bgcs, ms1_tol, ms2_tol, score_thresh, min_match_peaks)
logger.info("RosettaScoring setup completed")

@staticmethod
Expand Down
Loading