Skip to content

Commit

Permalink
refactor Spectrum class
Browse files Browse the repository at this point in the history
Refactor the whole `Spectrum` class. There are quite some commits, but the major idea of these changes are simple:

- remove attributes, properties and methods that are not used anywhere
- change attribute to property if necessary
- update class arguments and their order
- update the order of attributes and methods
- add typings
- add docstrings
- add unit tests for all methods

You could find more details about the changes in each commit.
  • Loading branch information
CunliangGeng authored Dec 19, 2023
1 parent 30f157a commit 49d7e22
Show file tree
Hide file tree
Showing 14 changed files with 193 additions and 241 deletions.
3 changes: 1 addition & 2 deletions src/nplinker/metabolomics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import logging
from .molecular_family import MolecularFamily
from .spectrum import GNPS_KEY
from .spectrum import Spectrum


logging.getLogger(__name__).addHandler(logging.NullHandler())

__all__ = ["MolecularFamily", "GNPS_KEY", "Spectrum"]
__all__ = ["MolecularFamily", "Spectrum"]
13 changes: 7 additions & 6 deletions src/nplinker/metabolomics/gnps/gnps_spectrum_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def _validate(self):

def _load(self):
"""Load the MGF file into Spectrum objects."""
i = 0
for spec in mgf.MGF(self._file):
# Skip if m/z array is empty, as this is an invalid spectrum.
# The invalid spectrum does not exist in other GNPS files, e.g.
Expand All @@ -77,20 +76,22 @@ def _load(self):
continue

# Load the spectrum
peaks: list[tuple[float, float]] = list(zip(spec["m/z array"], spec["intensity array"]))
spectrum_id: str = spec["params"]["scans"]
# calculate precursor m/z from precursor mass and charge
precursor_mass = spec["params"]["pepmass"][0]
precursor_charge = self._get_precursor_charge(spec["params"]["charge"])
precursor_mz: float = precursor_mass / abs(precursor_charge)
rt: float | None = spec["params"].get("rtinseconds", None)
rt = spec["params"].get("rtinseconds", 0)

spectrum = Spectrum(
id=i, peaks=peaks, spectrum_id=spectrum_id, precursor_mz=precursor_mz, rt=rt
spectrum_id=spectrum_id,
mz=list(spec["m/z array"]),
intensity=list(spec["intensity array"]),
precursor_mz=precursor_mz,
rt=rt,
metadata=spec["params"],
)
spectrum.metadata = spec["params"]
self._spectra.append(spectrum)
i += 1

def _get_precursor_charge(self, charges: list) -> int:
"""Get the precursor charge from the charge list.
Expand Down
250 changes: 62 additions & 188 deletions src/nplinker/metabolomics/spectrum.py
Original file line number Diff line number Diff line change
@@ -1,214 +1,88 @@
from __future__ import annotations
from functools import cached_property
from typing import TYPE_CHECKING
import numpy as np
from nplinker.strain import Strain
from nplinker.strain_collection import StrainCollection
from nplinker.utils import sqrt_normalise


if TYPE_CHECKING:
from .molecular_family import MolecularFamily

GNPS_KEY = "gnps"

JCAMP = (
"##TITLE={}\\n"
+ "##JCAMP-DX=nplinker vTODO\\n"
+ "##DATA TYPE=Spectrum\\n"
+ "##DATA CLASS=PEAKTABLE\\n"
+ "##ORIGIN=TODO_DATASET_ID\\n"
+ "##OWNER=nobody\\n"
+ "##XUNITS=M/Z\\n"
+ "##YUNITS=RELATIVE ABUNDANCE\\n"
+ "##NPOINTS={}\\n"
+ "##PEAK TABLE=(XY..XY)\\n"
+ "{}\\n"
+ "##END=\\n"
)


class Spectrum:
def __init__(self, id, peaks, spectrum_id: str, precursor_mz, parent_mz=None, rt=None):
self.id = id
self.peaks = sorted(peaks, key=lambda x: x[0]) # ensure sorted by mz
self.normalised_peaks = sqrt_normalise(self.peaks) # useful later
self.n_peaks = len(self.peaks)
self.max_ms2_intensity = max(intensity for mz, intensity in self.peaks)
self.total_ms2_intensity = sum(intensity for mz, intensity in self.peaks)
self.spectrum_id = spectrum_id # MS1.name
self.rt = rt
# TODO CG: should include precursor mass and charge to calculate precursor_mz
# parent_mz can be calculate from precursor_mass and charge mass
def __init__(
self,
spectrum_id: str,
mz: list[float],
intensity: list[float],
precursor_mz: float,
rt: float = 0,
metadata: dict | None = None,
) -> None:
"""Class to model MS/MS Spectrum.
Args:
spectrum_id (str): the spectrum ID.
mz (list[float]): the list of m/z values.
intensity (list[float]): the list of intensity values.
precursor_mz (float): the precursor m/z.
rt (float): the retention time in seconds. Defaults to 0.
metadata (dict, optional): the metadata of the spectrum, i.e. the header infomation
in the MGF file.
Attributes:
spectrum_id (str): the spectrum ID.
mz (list[float]): the list of m/z values.
intensity (list[float]): the list of intensity values.
precursor_mz (float): the m/z value of the precursor.
rt (float): the retention time in seconds.
metadata (dict): the metadata of the spectrum, i.e. the header infomation in the MGF
file.
gnps_annotations (dict): the GNPS annotations of the spectrum.
gnps_id (str | None): the GNPS ID of the spectrum.
strains (StrainCollection): the strains that this spectrum belongs to.
family (MolecularFamily): the molecular family that this spectrum belongs to.
peaks (np.ndarray): 2D array of peaks, each row is a peak of (m/z, intensity) values.
"""
self.spectrum_id = spectrum_id
self.mz = mz
self.intensity = intensity
self.precursor_mz = precursor_mz
self.parent_mz = parent_mz
self.gnps_id = None # CCMSLIB...
# TODO should add intensity here too
self.metadata = {}
self.edges = []
self.strains = StrainCollection()
# this is a dict indexed by Strain objects (the strains found in this Spectrum), with
# the values being dicts of the form {growth_medium: peak intensity} for the parent strain
self.growth_media = {}
self.family: MolecularFamily | None = None
# a dict indexed by filename, or "gnps"
self.annotations = {}
self._losses = None
self._jcamp = None

def add_strain(self, strain, growth_medium, peak_intensity):
# adds the strain to the StrainCollection if not already there
self.strains.add(strain)

if strain not in self.growth_media:
self.growth_media[strain] = {}

if growth_medium is None:
self.growth_media[strain].update(
{f"unknown_medium_{len(self.growth_media[strain])}": peak_intensity}
)
return

if strain in self.growth_media and growth_medium in self.growth_media[strain]:
raise Exception("Growth medium clash: {} / {} {}".format(self, strain, growth_medium))

self.growth_media[strain].update({growth_medium: peak_intensity})

@property
def is_library(self):
return GNPS_KEY in self.annotations

def set_annotations(self, key, data):
self.annotations[key] = data

@property
def gnps_annotations(self):
if GNPS_KEY not in self.annotations:
return None

return self.annotations[GNPS_KEY][0]

def has_annotations(self):
return len(self.annotations) > 0

def get_metadata_value(self, key):
val = self.metadata.get(key, None)
return val

def has_strain(self, strain: Strain):
return strain in self.strains

def get_growth_medium(self, strain):
if strain not in self.strains:
return None

gms = self.growth_media[strain]
return list(gms.keys())[0]

def to_jcamp_str(self, force_refresh=False):
if self._jcamp is not None and not force_refresh:
return self._jcamp
self.rt = rt
self.metadata = metadata or {}

peakdata = "\\n".join("{}, {}".format(*p) for p in self.peaks)
self._jcamp = JCAMP.format(str(self), self.n_peaks, peakdata)
return self._jcamp
self.gnps_annotations: dict = {}
self.gnps_id: str | None = None
self.strains: StrainCollection = StrainCollection()
self.family: MolecularFamily | None = None

def __str__(self):
return "Spectrum(id={}, spectrum_id={}, strains={})".format(
self.id, self.spectrum_id, len(self.strains)
)
def __str__(self) -> str:
return f"Spectrum(spectrum_id={self.spectrum_id}, #strains={len(self.strains)})"

def __repr__(self):
def __repr__(self) -> str:
return str(self)

def __eq__(self, other) -> bool:
if isinstance(other, Spectrum):
return (
self.id == other.id
and self.spectrum_id == other.spectrum_id
and self.precursor_mz == other.precursor_mz
and self.parent_mz == other.parent_mz
)
return self.spectrum_id == other.spectrum_id and self.precursor_mz == other.precursor_mz
return NotImplemented

def __hash__(self) -> int:
return hash((self.id, self.spectrum_id, self.precursor_mz, self.parent_mz))

def __cmp__(self, other):
if self.parent_mz >= other.parent_mz:
return 1
else:
return -1

def __lt__(self, other):
if self.parent_mz <= other.parent_mz:
return 1
else:
return 0

# from molnet repo
def keep_top_k(self, k=6, mz_range=50):
# only keep peaks that are in the top k in += mz_range
start_pos = 0
new_peaks = []
for mz, intensity in self.peaks:
while self.peaks[start_pos][0] < mz - mz_range:
start_pos += 1
end_pos = start_pos
return hash((self.spectrum_id, self.precursor_mz))

n_bigger = 0
while end_pos < len(self.peaks) and self.peaks[end_pos][0] <= mz + mz_range:
if self.peaks[end_pos][1] > intensity:
n_bigger += 1
end_pos += 1
@cached_property
def peaks(self) -> np.ndarray:
"""Get the peaks, a 2D array with each row containing the values of (m/z, intensity)."""
return np.array(list(zip(self.mz, self.intensity)))

if n_bigger < k:
new_peaks.append((mz, intensity))

self.peaks = new_peaks
self.n_peaks = len(self.peaks)
if self.n_peaks > 0:
self.normalised_peaks = sqrt_normalise(self.peaks)
self.max_ms2_intensity = max(intensity for mz, intensity in self.peaks)
self.total_ms2_intensity = sum(intensity for mz, intensity in self.peaks)
else:
self.normalised_peaks = []
self.max_ms2_intensity = 0.0
self.total_ms2_intensity = 0.0

@property
def losses(self):
"""All mass shifts in the spectrum, and the indices of the peaks."""
if self._losses is None:
# populate loss table
losses = []
for i in range(len(self.peaks)):
loss = self.precursor_mz - self.peaks[i][0]
losses.append((loss, self.id, i))

# THIS SEEMED TO ME LIKE IT WOULD TAKE THE WRONG DIFFERENCES AS LOSSES:
# TODO: please check!
# for j in range(i):
# loss = self.peaks[i][0] - self.peaks[j][0]
# losses.append((loss, i, j))

# Sort by loss
losses.sort(key=lambda x: x[0])
self._losses = losses
return self._losses

def has_loss(self, mass, tol):
"""Check if the scan has the specified loss (within tolerance)."""
matched_losses = []

idx = 0
# Check losses in range [0, mass]
while idx < len(self.losses) and self.losses[idx][0] <= mass:
if mass - self.losses[idx][0] < tol:
matched_losses.append(self.losses[idx])
idx += 1
def has_strain(self, strain: Strain):
"""Check if the given strain exists in the spectrum.
# Add all losses in range [mass, mass+tol(
while idx < len(self.losses) and self.losses[idx][0] < mass + tol:
matched_losses.append(self.losses[idx])
idx += 1
Args:
strain(Strain): `Strain` object.
return matched_losses
Returns:
bool: True when the given strain exist in the spectrum.
"""
return strain in self.strains
4 changes: 2 additions & 2 deletions src/nplinker/pickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def persistent_id(self, obj):
elif isinstance(obj, GCF):
return ("GCF", obj.gcf_id)
elif isinstance(obj, Spectrum):
return ("Spectrum", obj.id)
return ("Spectrum", obj.spectrum_id)
elif isinstance(obj, MolecularFamily):
return ("MolecularFamily", obj.id)
return ("MolecularFamily", obj.family_id)
else:
# TODO: ideally should use isinstance(obj, ScoringMethod) here
# but it's currently a problem because it creates a circular
Expand Down
2 changes: 1 addition & 1 deletion src/nplinker/scoring/iokr/nplinker_iokr.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def score_smiles(self, ms_list, candidate_smiles):
candidates = iokr_opt.preprocess_candidates(candidate_fps, latent, latent_basis, gamma)

for ms_index, ms in enumerate(ms_list):
logger.debug("Rank spectrum {} ({}/{})".format(ms.id, ms_index, len(ms_list)))
logger.debug("Rank spectrum {} ({}/{})".format(ms.spectrum_id, ms_index, len(ms_list)))
ms.filter = spectrum_filters.filter_by_frozen_dag
logger.debug("kernel vector")
t0 = time.time()
Expand Down
4 changes: 2 additions & 2 deletions src/nplinker/scoring/iokr/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def __init__(self, mgf_dict=None, spec=None):

def init_from_spec(self, spec):
self.id = spec.id
self.raw_parentmass = spec.parent_mz
self.raw_spectrum = numpy.array(spec.peaks)
self.raw_parentmass = spec.precursor_mz
self.raw_spectrum = spec.peaks
# TODO this is a temporary default for the Crusemann data
# should check for it in the mgf in metabolomics.py and store
# in the Spectrum object if found
Expand Down
1 change: 0 additions & 1 deletion src/nplinker/scoring/rosetta/rosetta.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,6 @@ def export_to_csv(self, filename):
for hit in self._rosetta_hits:
csvwriter.writerow(
[
hit.spec.id,
hit.spec.spectrum_id,
hit.gnps_id,
hit.spec_match_score,
Expand Down
Loading

0 comments on commit 49d7e22

Please sign in to comment.