Skip to content

Commit

Permalink
Refactor StrainCollection class (#135)
Browse files Browse the repository at this point in the history
* Update .prospector.yml for pylint

* update type hints from Path to PathLike for utils

* remove unused method `Strain.has_alias`

* change `Strain.aliases` from attribute to property

Using property will force user to use method `add_alias` but not directly add alias to the set of aliases.

* add docstring to Strain class

* adjust the orders of methods in Strain class

* update strain unit tests

* adjust method orders of StrainCollection class

* update docstrings and type hints in StrainCollection

add type hints

* update attribute names

self._lookup -> self._strain_dict_id
self._lookup_indices -> self._strain_dict_index

* update `__contains__` in StrainCollection

update `__contains__` method

* refactor `__eq__` method

* refactor `add` method

* refactor `filter` method

* update tests for remove and filter

* update `lookup` method and its tests

* add tests for magic methods

* update tests

* refactor `add_from_file` method

* update `save_to_file` method

* update `generate_strain_mappings` method
  • Loading branch information
CunliangGeng authored Apr 5, 2023
1 parent 610d447 commit 5e58957
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 165 deletions.
226 changes: 102 additions & 124 deletions src/nplinker/strain_collection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import csv
import os
from os import PathLike
from pathlib import Path
from typing import Iterator
from deprecated import deprecated
from .logconfig import LogConfig
from .strains import Strain
from .utils import list_dirs
Expand All @@ -12,81 +16,83 @@
class StrainCollection():

def __init__(self):
"""A collection of Strain objects."""
self._strains: list[Strain] = []
self._lookup: dict[str, Strain] = {}
self._lookup_indices: dict[int, Strain] = {}
self._strain_dict_id: dict[str, Strain] = {}
self._strain_dict_index: dict[int, Strain] = {}

def add(self, strain: Strain):
"""Add the strain to the aliases.
This also adds those strain's aliases to this' strain's aliases.
def __repr__(self) -> str:
return str(self)

def __str__(self) -> str:
if len(self) > 20:
return f'StrainCollection(n={len(self)})'

return f'StrainCollection(n={len(self)}) [' + ','.join(
s.id for s in self._strains) + ']'

def __len__(self) -> int:
return len(self._strains)

def __eq__(self, other) -> bool:
return (self._strains == other._strains
and self._strain_dict_id == other._strain_dict_id
and self._strain_dict_index == other._strain_dict_index)

def __contains__(self, strain: str | Strain) -> bool:
if isinstance(strain, str):
value = strain in self._strain_dict_id
elif isinstance(strain, Strain):
value = strain.id in self._strain_dict_id
else:
raise TypeError(f"Expected Strain or str, got {type(strain)}")
return value

def __iter__(self) -> Iterator[Strain]:
return iter(self._strains)

def add(self, strain: Strain) -> None:
"""Add strain to the collection.
If the strain already exists, merge the aliases.
Args:
strain(Strain): Strain to add to self.
Examples:
>>>
"""
if strain.id in self._lookup:
# if it already exists, just merge the set of aliases and update
# lookup entries
strain(Strain): The strain to add.
"""
# if the strain exists, merge the aliases
if strain.id in self._strain_dict_id:
existing: Strain = self.lookup(strain.id)
for alias in strain.aliases:
existing.add_alias(alias)
self._lookup[alias] = existing
return

self._lookup_indices[len(self)] = strain
self._strains.append(strain)
# insert a mapping from strain=>strain, plus all its aliases
self._lookup[strain.id] = strain
for alias in strain.aliases:
self._lookup[alias] = strain
self._strain_dict_id[alias] = existing
else:
self._strain_dict_index[len(self)] = strain
self._strains.append(strain)
self._strain_dict_id[strain.id] = strain
for alias in strain.aliases:
self._strain_dict_id[alias] = strain

def remove(self, strain: Strain):
"""Remove the specified strain from the aliases.
TODO: #90 Implement removing the strain also from self._lookup indices.
"""Remove a strain from the collection.
Args:
strain(Strain): Strain to remove.
strain(Strain): The strain to remove.
"""
if strain.id not in self._lookup:
return

self._strains.remove(strain)
del self._lookup[strain.id]
for alias in strain.aliases:
del self._lookup[alias]
if strain.id in self._strain_dict_id:
self._strains.remove(strain)
# remove from dict id
del self._strain_dict_id[strain.id]
for alias in strain.aliases:
del self._strain_dict_id[alias]

def filter(self, strain_set: set[Strain]):
"""
Remove all strains that are not in strain_set from the strain collection
"""
to_remove = [x for x in self._strains if x not in strain_set]
for strain in to_remove:
self.remove(strain)

def __contains__(self, strain_id: str|Strain) -> bool:
"""Check if the strain or strain id are contained in the lookup table.
Args:
strain_id(str|Strain): Strain or strain id to look up.
Returns:
bool: Whether the strain is contained in the collection.
"""

if isinstance(strain_id, str):
return strain_id in self._lookup
# assume it's a Strain object
if isinstance(strain_id, Strain):
return strain_id.id in self._lookup
return False

def __iter__(self):
return iter(self._strains)

def __next__(self):
return next(self._strains)
# note that we need to copy the list of strains, as we are modifying it
for strain in self._strains.copy():
if strain not in strain_set:
self.remove(strain)

def lookup_index(self, index: int) -> Strain:
"""Return the strain from lookup by index.
Expand All @@ -97,90 +103,81 @@ def lookup_index(self, index: int) -> Strain:
Returns:
Strain: Strain identified by the given index.
"""
return self._lookup_indices[index]
return self._strain_dict_index[index]

def lookup(self, strain_id: str) -> Strain:
"""Check whether the strain id is contained in the lookup table. If so, return the strain, otherwise return `default`.
def lookup(self, name: str) -> Strain:
"""Lookup a strain by name (id or alias).
Raises:
Exception if strain_id is not found.
If the name is found, return the strain object; Otherwise, raise a
KeyError.
Args:
strain_id(str): Strain id to lookup.
name(str): Strain name (id or alias) to lookup.
Returns:
Strain: Strain retrieved during lookup or object passed as default.
Strain: Strain identified by the given name.
Raises:
KeyError: If the strain name is not found.
"""
if strain_id not in self._lookup:
# logger.error('Strain lookup failed for "{}"'.format(strain_id))
raise KeyError(f'Strain lookup failed for "{strain_id}"')
if name not in self._strain_dict_id:
raise KeyError(f"Strain {name} not found in strain collection.")
return self._strain_dict_id[name]

return self._lookup[strain_id]
def add_from_file(self, file: str | PathLike) -> None:
"""Add strains from a strain mapping file.
def add_from_file(self, file: str | os.PathLike):
"""Read strains and aliases from file and store in self.
A strain mapping file is a csv file with the first column being the
id of the strain, and the remaining columns being aliases for the
strain.
Args:
file(str): Path to strain mapping file to load.
file(str | PathLike): Path to strain mapping file (.csv).
"""

if not os.path.exists(file):
logger.warning(f'strain mappings file not found: {file}')
return

line = 1
with open(file) as f:
reader = csv.reader(f)
for ids in reader:
if len(ids) == 0:
for names in reader:
if len(names) == 0:
continue
strain = Strain(ids[0])
for id in ids[1:]:
if len(id) == 0:
logger.warning(
'Found zero-length strain label in {} on line {}'.
format(file, line))
else:
strain.add_alias(id)
strain = Strain(names[0])
for alias in names[1:]:
strain.add_alias(alias)
self.add(strain)

line += 1

def save_to_file(self, file: str | os.PathLike):
"""Save this strain collection to file.
def save_to_file(self, file: str | PathLike) -> None:
"""Save strains to a strain mapping file (.csv).
Args:
file(str): Output file.
Examples:
>>>
"""
file(str | PathLike): Path to strain mapping file (.csv).
"""
with open(file, 'w') as f:
for strain in self._strains:
for strain in self:
ids = [strain.id] + list(strain.aliases)
f.write(','.join(ids) + '\n')

def generate_strain_mappings(self, strain_mappings_file: str,
antismash_dir: str) -> None:
# TODO to move this method to a separate class
@deprecated(version="1.3.3", reason="This method will be removed")
def generate_strain_mappings(self, strain_mappings_file: str | PathLike,
antismash_dir: str | PathLike) -> None:
"""Add AntiSMASH BGC file names as strain alias to strain mappings file.
Note that if AntiSMASH ID (e.g. GCF_000016425.1) is not found in
existing self.strains, its corresponding BGC file names will not be
added.
Args:
strain_mappings_file(str): Path to strain mappings file
antismash_dir(str): Path to AntiSMASH directory
strain_mappings_file(str | PathLike): Path to strain mappings file.
antismash_dir(str | PathLike): Path to AntiSMASH output directory.
"""
if os.path.exists(strain_mappings_file):
if Path(strain_mappings_file).exists():
logger.info('Strain mappings file exist')
return

# if not exist, generate strain mapping file with antismash BGC names
logger.info('Generating strain mappings file')
subdirs = list_dirs(antismash_dir)
for d in subdirs:
antismash_id = os.path.basename(d)
antismash_id = Path(d).name

# use antismash_id (e.g. GCF_000016425.1) as strain name to query
# TODO: self is empty at the moment, why lookup here?
Expand All @@ -194,27 +191,8 @@ def generate_strain_mappings(self, strain_mappings_file: str,
# if strain `antismash_id` exist, add all gbk file names as strain alias
gbk_files = list_files(d, suffix=".gbk", keep_parent=False)
for f in gbk_files:
gbk_filename = os.path.splitext(f)[0]
gbk_filename = Path(f).stem
strain.add_alias(gbk_filename)

logger.info(f'Saving strains to {strain_mappings_file}')
self.save_to_file(strain_mappings_file)

def __len__(self) -> int:
return len(self._strains)

def __repr__(self) -> str:
return str(self)

def __str__(self):
if len(self) > 20:
return f'StrainCollection(n={len(self)})'

return f'StrainCollection(n={len(self)}) [' + ','.join(
s.id for s in self._strains) + ']'

def __eq__(self, other):
result = self._strains == other._strains
result &= self._lookup == other._lookup
result &= self._lookup_indices == other._lookup_indices
return result
Loading

0 comments on commit 5e58957

Please sign in to comment.