Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add pickling #258

Merged
merged 6 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/nplinker/genomics/bgc.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def __eq__(self, other) -> bool:
def __hash__(self) -> int:
return hash((self.id, self.product_prediction))

def __reduce__(self) -> tuple:
"""Reduce function for pickling."""
return (self.__class__, (self.id, *self.product_prediction), self.__dict__)

def add_parent(self, gcf: GCF) -> None:
"""Add a parent GCF to the BGC.

Expand Down
4 changes: 4 additions & 0 deletions src/nplinker/genomics/gcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def __hash__(self) -> int:
"""
return hash(self.id)

def __reduce__(self) -> tuple:
"""Reduce function for pickling."""
return (self.__class__, (self.id,), self.__dict__)

@property
def bgcs(self) -> set[BGC]:
"""Get the BGC objects."""
Expand Down
4 changes: 4 additions & 0 deletions src/nplinker/metabolomics/molecular_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def __eq__(self, other) -> bool:
def __hash__(self) -> int:
return hash(self.id)

def __reduce__(self) -> tuple:
"""Reduce function for pickling."""
return (self.__class__, (self.id,), self.__dict__)

@property
def spectra(self) -> set[Spectrum]:
"""Get Spectrum objects in the molecular family."""
Expand Down
8 changes: 8 additions & 0 deletions src/nplinker/metabolomics/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ def __eq__(self, other) -> bool:
def __hash__(self) -> int:
return hash((self.id, self.precursor_mz))

def __reduce__(self) -> tuple:
"""Reduce function for pickling."""
return (
self.__class__,
(self.id, self.mz, self.intensity, self.precursor_mz, self.rt, self.metadata),
self.__dict__,
)

@cached_property
def peaks(self) -> np.ndarray:
"""Get the peaks, a 2D array with each row containing the values of (m/z, intensity)."""
Expand Down
20 changes: 20 additions & 0 deletions src/nplinker/nplinker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
import logging
import pickle
from os import PathLike
from pprint import pformat
from typing import Sequence
Expand Down Expand Up @@ -295,3 +296,22 @@ def lookup_mf(self, id: str) -> MolecularFamily | None:
The MolecularFamily object with the given ID, or None if no such object exists.
"""
return self._mf_dict.get(id, None)

def save_data(
self,
file: str | PathLike,
links: LinkGraph | None = None,
) -> None:
"""Pickle data to a file.

The data to be pickled is a tuple containing the BGCs, GCFs, Spectra, MolecularFamilies,
StrainCollection and links, i.e. `(bgcs, gcfs, spectra, mfs, strains, links)`. If the links
are not provided, `None` will be used.

Args:
file: The path to the pickle file to save the data to.
links: The LinkGraph object to save.
"""
data = (self.bgcs, self.gcfs, self.spectra, self.mfs, self.strains, links)
with open(file, "wb") as f:
pickle.dump(data, f)
116 changes: 0 additions & 116 deletions src/nplinker/pickler.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/nplinker/scoring/metcalf_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def get_links(self, *objects, **parameters):
"MetcalfScoring.metcalf_mean and metcalf_std are not set. Run MetcalfScoring.setup first."
)
# use negative infinity as the score cutoff to ensure we get all links
scores_list = self._get_links(*objects, obj_type=obj_type, score_cutoff=np.NINF)
scores_list = self._get_links(*objects, obj_type=obj_type, score_cutoff=-np.inf)
scores_list = self._calc_standardised_score(scores_list)

links = LinkGraph()
Expand Down
36 changes: 36 additions & 0 deletions tests/integration/test_nplinker_local.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import os
import pickle
import pytest
from nplinker.genomics import GCF
from nplinker.metabolomics import MolecularFamily
from nplinker.metabolomics import Spectrum
from nplinker.nplinker import NPLinker
from . import DATA_DIR

Expand Down Expand Up @@ -70,3 +74,35 @@ def test_get_links(npl):
for _, _, scores in lg.links:
score = scores[scoring_method]
assert score.value >= 0


def test_save_data(npl):
scoring_method = "metcalf"
links = npl.get_links(npl.gcfs[:3], scoring_method)

pickle_file = os.path.join(npl.output_dir, "npl.pkl")
npl.save_data(pickle_file, links)

with open(pickle_file, "rb") as f:
bgcs, gcfs, spectra, mfs, strains, lg = pickle.load(f)

# tests from `test_load_data`
assert len(bgcs) == 390
assert len(gcfs) == 64
assert len(spectra) == 24652
assert len(mfs) == 29
assert len(strains) == 46

# tests from `test_get_links`
for obj1, obj2, scores in lg.links:
score = scores[scoring_method]
assert score.value >= 0

if isinstance(obj1, GCF):
assert obj1 in gcfs
elif isinstance(obj1, Spectrum):
assert obj1 in spectra
elif isinstance(obj1, MolecularFamily):
assert obj1 in mfs
else:
assert False
12 changes: 6 additions & 6 deletions tests/unit/scoring/test_metcalf_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_get_links_invalid_mixed_types(mc, spectra, mfs):
def test_get_links_gcf_standardised_false(mc, gcfs, spectra, mfs):
"""Test `get_links` method when input is GCF objects and `standardised` is False."""
# when cutoff is negative infinity, i.e. taking all scores
lg = mc.get_links(*gcfs, cutoff=np.NINF, standardised=False)
lg = mc.get_links(*gcfs, cutoff=-np.inf, standardised=False)
assert lg[gcfs[0]][spectra[0]][mc.name].value == 12
assert lg[gcfs[1]][spectra[0]][mc.name].value == -9
assert lg[gcfs[2]][spectra[0]][mc.name].value == 11
Expand All @@ -121,7 +121,7 @@ def test_get_links_gcf_standardised_false(mc, gcfs, spectra, mfs):

def test_get_links_gcf_standardised_true(mc, gcfs):
"""Test `get_links` method when input is GCF objects and `standardised` is True."""
lg = mc.get_links(*gcfs, cutoff=np.NINF, standardised=True)
lg = mc.get_links(*gcfs, cutoff=-np.inf, standardised=True)
assert len(lg.links) == 18

lg = mc.get_links(*gcfs, cutoff=0, standardised=True)
Expand All @@ -130,7 +130,7 @@ def test_get_links_gcf_standardised_true(mc, gcfs):

def test_get_links_spec_standardised_false(mc, gcfs, spectra):
"""Test `get_links` method when input is Spectrum objects and `standardised` is False."""
lg = mc.get_links(*spectra, cutoff=np.NINF, standardised=False)
lg = mc.get_links(*spectra, cutoff=-np.inf, standardised=False)
assert lg[spectra[0]][gcfs[0]][mc.name].value == 12
assert lg[spectra[0]][gcfs[1]][mc.name].value == -9
assert lg[spectra[0]][gcfs[2]][mc.name].value == 11
Expand All @@ -143,7 +143,7 @@ def test_get_links_spec_standardised_false(mc, gcfs, spectra):

def test_get_links_spec_standardised_true(mc, gcfs, spectra):
"""Test `get_links` method when input is Spectrum objects and `standardised` is True."""
lg = mc.get_links(*spectra, cutoff=np.NINF, standardised=True)
lg = mc.get_links(*spectra, cutoff=-np.inf, standardised=True)
assert len(lg.links) == 9

lg = mc.get_links(*spectra, cutoff=0, standardised=True)
Expand All @@ -152,7 +152,7 @@ def test_get_links_spec_standardised_true(mc, gcfs, spectra):

def test_get_links_mf_standardised_false(mc, gcfs, mfs):
"""Test `get_links` method when input is MolecularFamily objects and `standardised` is False."""
lg = mc.get_links(*mfs, cutoff=np.NINF, standardised=False)
lg = mc.get_links(*mfs, cutoff=-np.inf, standardised=False)
assert lg[mfs[0]][gcfs[0]][mc.name].value == 12
assert lg[mfs[0]][gcfs[1]][mc.name].value == -9
assert lg[mfs[0]][gcfs[2]][mc.name].value == 11
Expand All @@ -165,7 +165,7 @@ def test_get_links_mf_standardised_false(mc, gcfs, mfs):

def test_get_links_mf_standardised_true(mc, gcfs, mfs):
"""Test `get_links` method when input is MolecularFamily objects and `standardised` is True."""
lg = mc.get_links(*mfs, cutoff=np.NINF, standardised=True)
lg = mc.get_links(*mfs, cutoff=-np.inf, standardised=True)
assert len(lg.links) == 9

lg = mc.get_links(*mfs, cutoff=0, standardised=True)
Expand Down
Loading