Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor StrainCollection class #135

Merged
merged 22 commits into from
Apr 5, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
5b24aa5
Update .prospector.yml for pylint
CunliangGeng Mar 31, 2023
f5b1548
update type hints from Path to PathLike for utils
CunliangGeng Mar 31, 2023
6b12842
remove unused method `Strain.has_alias`
CunliangGeng Mar 9, 2023
f657e2d
change `Strain.aliases` from attribute to property
CunliangGeng Mar 9, 2023
e2032f6
add docstring to Strain class
CunliangGeng Mar 30, 2023
001967d
adjust the orders of methods in Strain class
CunliangGeng Mar 30, 2023
24fe7d9
update strain unit tests
CunliangGeng Mar 30, 2023
0c8c962
adjust method orders of StrainCollection class
CunliangGeng Mar 30, 2023
a581d01
update docstrings and type hints in StrainCollection
CunliangGeng Mar 30, 2023
4dea67e
update attribute names
CunliangGeng Apr 4, 2023
0ec47de
update `__contains__` in StrainCollection
CunliangGeng Mar 30, 2023
652d18b
refactor `__eq__` method
CunliangGeng Mar 31, 2023
40256d4
refactor `add` method
CunliangGeng Mar 31, 2023
461db79
refactor `filter` method
CunliangGeng Apr 5, 2023
5884f57
update tests for remove and filter
CunliangGeng Mar 31, 2023
de08d95
update `lookup` method and its tests
CunliangGeng Mar 31, 2023
db46262
add tests for magic methods
CunliangGeng Mar 31, 2023
4666069
update tests
CunliangGeng Mar 31, 2023
ae579aa
refactor `add_from_file` method
CunliangGeng Mar 31, 2023
568a560
update `save_to_file` method
CunliangGeng Mar 31, 2023
4e86b86
update `generate_strain_mappings` method
CunliangGeng Mar 31, 2023
49f75f8
Merge branch 'dev' into refactor_StrainCollection
CunliangGeng Apr 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .prospector.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,10 @@ pydocstyle:
D213, # Multi-line docstring summary should start at the second line
D404, # First word of the docstring should not be This
]

pylint:
disable: [
W0212, # Access to a protected member %s of a client class
W1514, # Using open without explicitly specifying an encoding
W1203, # Use %s formatting in logging functions
]
229 changes: 103 additions & 126 deletions src/nplinker/strain_collection.py
Original file line number Diff line number Diff line change
@@ -1,92 +1,97 @@
import csv
import os
from os import PathLike
from pathlib import Path
from typing import Iterator
from deprecated import deprecated
from .logconfig import LogConfig
from .strains import Strain
from .utils import list_dirs
from .utils import list_files


logger = LogConfig.getLogger(__name__)


class StrainCollection():

def __init__(self):
"""A collection of Strain objects."""
self._strains: list[Strain] = []
self._lookup: dict[str, Strain] = {}
self._lookup_indices: dict[int, Strain] = {}
self._strain_dict_id: dict[str, Strain] = {}
self._strain_dict_index: dict[int, Strain] = {}

def __repr__(self) -> str:
return str(self)

def __str__(self) -> str:
if len(self) > 20:
return f'StrainCollection(n={len(self)})'

def add(self, strain: Strain):
"""Add the strain to the aliases.
This also adds those strain's aliases to this' strain's aliases.
return f'StrainCollection(n={len(self)}) [' + ','.join(
s.id for s in self._strains) + ']'

def __len__(self) -> int:
return len(self._strains)

def __eq__(self, other) -> bool:
return (self._strains == other._strains
and self._strain_dict_id == other._strain_dict_id
and self._strain_dict_index == other._strain_dict_index)

def __contains__(self, strain: str | Strain) -> bool:
if isinstance(strain, str):
value = strain in self._strain_dict_id
elif isinstance(strain, Strain):
value = strain.id in self._strain_dict_id
else:
raise TypeError(f"Expected Strain or str, got {type(strain)}")
return value

def __iter__(self) -> Iterator[Strain]:
return iter(self._strains)

def add(self, strain: Strain) -> None:
"""Add strain to the collection.

If the strain already exists, merge the aliases.

Args:
strain(Strain): Strain to add to self.

Examples:
>>>
"""
if strain.id in self._lookup:
# if it already exists, just merge the set of aliases and update
# lookup entries
strain(Strain): The strain to add.
"""
# if the strain exists, merge the aliases
if strain.id in self._strain_dict_id:
existing: Strain = self.lookup(strain.id)
existing.aliases.update(strain.aliases)
for alias in strain.aliases:
self._lookup[alias] = existing
return

self._lookup_indices[len(self)] = strain
self._strains.append(strain)
# insert a mapping from strain=>strain, plus all its aliases
self._lookup[strain.id] = strain
for alias in strain.aliases:
self._lookup[alias] = strain
existing.add_alias(alias)
self._strain_dict_id[alias] = existing
else:
self._strain_dict_index[len(self)] = strain
self._strains.append(strain)
self._strain_dict_id[strain.id] = strain
for alias in strain.aliases:
self._strain_dict_id[alias] = strain

def remove(self, strain: Strain):
"""Remove the specified strain from the aliases.
TODO: #90 Implement removing the strain also from self._lookup indices.
"""Remove a strain from the collection.

Args:
strain(Strain): Strain to remove.
strain(Strain): The strain to remove.
"""
if strain.id not in self._lookup:
return

self._strains.remove(strain)
del self._lookup[strain.id]
for alias in strain.aliases:
del self._lookup[alias]
if strain.id in self._strain_dict_id:
self._strains.remove(strain)
# remove from dict id
del self._strain_dict_id[strain.id]
for alias in strain.aliases:
del self._strain_dict_id[alias]

def filter(self, strain_set: set[Strain]):
"""
Remove all strains that are not in strain_set from the strain collection
"""
to_remove = [x for x in self._strains if x not in strain_set]
for strain in to_remove:
self.remove(strain)

def __contains__(self, strain_id: str|Strain) -> bool:
"""Check if the strain or strain id are contained in the lookup table.

Args:
strain_id(str|Strain): Strain or strain id to look up.

Returns:
bool: Whether the strain is contained in the collection.
"""

if isinstance(strain_id, str):
return strain_id in self._lookup
# assume it's a Strain object
if isinstance(strain_id, Strain):
return strain_id.id in self._lookup
return False

def __iter__(self):
return iter(self._strains)

def __next__(self):
return next(self._strains)
# note that we need to copy the list of strains, as we are modifying it
for strain in self._strains.copy():
if strain not in strain_set:
self.remove(strain)

def lookup_index(self, index: int) -> Strain:
"""Return the strain from lookup by index.
Expand All @@ -97,90 +102,81 @@ def lookup_index(self, index: int) -> Strain:
Returns:
Strain: Strain identified by the given index.
"""
return self._lookup_indices[index]
return self._strain_dict_index[index]

def lookup(self, strain_id: str) -> Strain:
"""Check whether the strain id is contained in the lookup table. If so, return the strain, otherwise return `default`.
def lookup(self, name: str) -> Strain:
"""Lookup a strain by name (id or alias).

Raises:
Exception if strain_id is not found.
If the name is found, return the strain object; Otherwise, raise a
KeyError.

Args:
strain_id(str): Strain id to lookup.
name(str): Strain name (id or alias) to lookup.

Returns:
Strain: Strain retrieved during lookup or object passed as default.
Strain: Strain identified by the given name.

Raises:
KeyError: If the strain name is not found.
"""
if strain_id not in self._lookup:
# logger.error('Strain lookup failed for "{}"'.format(strain_id))
raise KeyError(f'Strain lookup failed for "{strain_id}"')
if name not in self._strain_dict_id:
raise KeyError(f"Strain {name} not found in strain collection.")
return self._strain_dict_id[name]

return self._lookup[strain_id]
def add_from_file(self, file: str | PathLike) -> None:
"""Add strains from a strain mapping file.

def add_from_file(self, file: str | os.PathLike):
"""Read strains and aliases from file and store in self.
A strain mapping file is a csv file with the first column being the
id of the strain, and the remaining columns being aliases for the
strain.

Args:
file(str): Path to strain mapping file to load.
file(str | PathLike): Path to strain mapping file (.csv).
"""

if not os.path.exists(file):
logger.warning(f'strain mappings file not found: {file}')
return

line = 1
with open(file) as f:
reader = csv.reader(f)
for ids in reader:
if len(ids) == 0:
for names in reader:
if len(names) == 0:
continue
strain = Strain(ids[0])
for id in ids[1:]:
if len(id) == 0:
logger.warning(
'Found zero-length strain label in {} on line {}'.
format(file, line))
else:
strain.add_alias(id)
strain = Strain(names[0])
for alias in names[1:]:
strain.add_alias(alias)
self.add(strain)

line += 1

def save_to_file(self, file: str | os.PathLike):
"""Save this strain collection to file.
def save_to_file(self, file: str | PathLike) -> None:
"""Save strains to a strain mapping file (.csv).

Args:
file(str): Output file.

Examples:
>>>
"""
file(str | PathLike): Path to strain mapping file (.csv).
"""
with open(file, 'w') as f:
for strain in self._strains:
for strain in self:
ids = [strain.id] + list(strain.aliases)
f.write(','.join(ids) + '\n')

def generate_strain_mappings(self, strain_mappings_file: str,
antismash_dir: str) -> None:
# TODO to move this method to a separate class
@deprecated(version="1.3.3", reason="This method will be removed")
def generate_strain_mappings(self, strain_mappings_file: str | PathLike,
antismash_dir: str | PathLike) -> None:
"""Add AntiSMASH BGC file names as strain alias to strain mappings file.

Note that if AntiSMASH ID (e.g. GCF_000016425.1) is not found in
existing self.strains, its corresponding BGC file names will not be
added.

Args:
strain_mappings_file(str): Path to strain mappings file
antismash_dir(str): Path to AntiSMASH directory
strain_mappings_file(str | PathLike): Path to strain mappings file.
antismash_dir(str | PathLike): Path to AntiSMASH output directory.
"""
if os.path.exists(strain_mappings_file):
if Path(strain_mappings_file).exists():
logger.info('Strain mappings file exist')
return

# if not exist, generate strain mapping file with antismash BGC names
logger.info('Generating strain mappings file')
subdirs = list_dirs(antismash_dir)
for d in subdirs:
antismash_id = os.path.basename(d)
antismash_id = Path(d).name

# use antismash_id (e.g. GCF_000016425.1) as strain name to query
# TODO: self is empty at the moment, why lookup here?
Expand All @@ -194,27 +190,8 @@ def generate_strain_mappings(self, strain_mappings_file: str,
# if strain `antismash_id` exist, add all gbk file names as strain alias
gbk_files = list_files(d, suffix=".gbk", keep_parent=False)
for f in gbk_files:
gbk_filename = os.path.splitext(f)[0]
gbk_filename = Path(f).stem
strain.add_alias(gbk_filename)

logger.info(f'Saving strains to {strain_mappings_file}')
self.save_to_file(strain_mappings_file)

def __len__(self) -> int:
return len(self._strains)

def __repr__(self) -> str:
return str(self)

def __str__(self):
if len(self) > 20:
return f'StrainCollection(n={len(self)})'

return f'StrainCollection(n={len(self)}) [' + ','.join(
s.id for s in self._strains) + ']'

def __eq__(self, other):
result = self._strains == other._strains
result &= self._lookup == other._lookup
result &= self._lookup_indices == other._lookup_indices
return result
Loading