From 5b24aa50d98b5ee3fececcedabdb2f9273f9675e Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Fri, 31 Mar 2023 14:50:07 +0200 Subject: [PATCH 01/21] Update .prospector.yml for pylint --- .prospector.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.prospector.yml b/.prospector.yml index 5a10896b..81b249fb 100644 --- a/.prospector.yml +++ b/.prospector.yml @@ -28,3 +28,10 @@ pydocstyle: D213, # Multi-line docstring summary should start at the second line D404, # First word of the docstring should not be This ] + +pylint: + disable: [ + W0212, # Access to a protected member %s of a client class + W1514, # Using open without explicitly specifying an encoding + W1203, # Use %s formatting in logging functions +] From f5b1548df009ac3bf9c80e06e67329b8b9085d4c Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Fri, 31 Mar 2023 15:33:42 +0200 Subject: [PATCH 02/21] update type hints from Path to PathLike for utils --- src/nplinker/utils.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/nplinker/utils.py b/src/nplinker/utils.py index 186215b6..696ef532 100644 --- a/src/nplinker/utils.py +++ b/src/nplinker/utils.py @@ -66,7 +66,7 @@ def find_delimiter(file: str | PathLike) -> str: with open(file, mode='rt', encoding='utf-8') as fp: delimiter = sniffer.sniff(fp.read(5000)).delimiter return delimiter - + def get_headers(file: str | PathLike) -> list[str]: """Read headers from the given tabular file. @@ -94,7 +94,7 @@ def get_headers(file: str | PathLike) -> list[str]: def _save_response_content(content: Iterator[bytes], - destination: str | Path, + destination: str | PathLike, length: int | None = None) -> None: with open(destination, "wb") as fh, tqdm(total=length) as pbar: for chunk in content: @@ -107,7 +107,7 @@ def _save_response_content(content: Iterator[bytes], def _urlretrieve(url: str, - filename: str | Path, + filename: str | PathLike, chunk_size: int = 1024 * 32) -> None: with urllib.request.urlopen( urllib.request.Request(url, headers={"User-Agent": @@ -118,7 +118,7 @@ def _urlretrieve(url: str, length=response.length) -def calculate_md5(fpath: str | Path, chunk_size: int = 1024 * 1024) -> str: +def calculate_md5(fpath: str | PathLike, chunk_size: int = 1024 * 1024) -> str: # Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are # not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without # it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere. @@ -132,11 +132,11 @@ def calculate_md5(fpath: str | Path, chunk_size: int = 1024 * 1024) -> str: return md5.hexdigest() -def check_md5(fpath: str | Path, md5: str) -> bool: +def check_md5(fpath: str | PathLike, md5: str) -> bool: return md5 == calculate_md5(fpath) -def check_integrity(fpath: str | Path, md5: str | None = None) -> bool: +def check_integrity(fpath: str | PathLike, md5: str | None = None) -> bool: if not os.path.isfile(fpath): return False if md5 is None: @@ -161,7 +161,7 @@ def _get_redirect_url(url: str, max_hops: int = 3) -> str: def download_url(url: str, - root: str | Path, + root: str | PathLike, filename: str | None = None, md5: str | None = None, max_redirect_hops: int = 3) -> None: @@ -208,7 +208,7 @@ def download_url(url: str, raise RuntimeError("File not found or corrupted, or md5 validation failed.") -def list_dirs(root: str | Path, +def list_dirs(root: str | PathLike, keep_parent: bool = True) -> list[str]: """List all directories at a given root @@ -224,7 +224,7 @@ def list_dirs(root: str | Path, return directories -def list_files(root: str | Path, +def list_files(root: str | PathLike, prefix: str | tuple[str, ...] = "", suffix: str | tuple[str, ...] = "", keep_parent: bool = True) -> list[str]: @@ -253,7 +253,7 @@ def list_files(root: str | Path, return files -def _extract_tar(from_path: str | Path, to_path: str | Path, +def _extract_tar(from_path: str | PathLike, to_path: str | PathLike, compression: str | None) -> None: with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar: @@ -266,7 +266,7 @@ def _extract_tar(from_path: str | Path, to_path: str | Path, } -def _extract_zip(from_path: str | Path, to_path: str | Path, +def _extract_zip(from_path: str | PathLike, to_path: str | PathLike, compression: str | None) -> None: with zipfile.ZipFile(from_path, "r", @@ -375,7 +375,7 @@ def _decompress(from_path: Path | str, return str(to_path) -def extract_archive(from_path: str | Path, +def extract_archive(from_path: str | PathLike, to_path: str | Path | None = None, remove_finished: bool = False) -> str: """Extract an archive. @@ -438,7 +438,7 @@ def extract_file_matching_pattern(archive: zipfile.ZipFile, prefix: str, suffix: def download_and_extract_archive( url: str, - download_root: str | Path, + download_root: str | PathLike, extract_root: str | Path | None = None, filename: str | None = None, md5: str | None = None, From 6b1284280b64c46afe61aa132479371d79b9b9f9 Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Thu, 9 Mar 2023 17:00:24 +0100 Subject: [PATCH 03/21] remove unused method `Strain.has_alias` --- src/nplinker/strains.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/nplinker/strains.py b/src/nplinker/strains.py index 8887fa6f..be8c88eb 100644 --- a/src/nplinker/strains.py +++ b/src/nplinker/strains.py @@ -9,17 +9,6 @@ def __init__(self, primary_strain_id: str): self.id: str = primary_strain_id self.aliases: set[str] = set() - def has_alias(self, alt_id: str) -> bool: - """Check if strain has an alias. - - Args: - alt_id(str): Alias to check. - - Returns: - bool: Whether the alias is registered in the set of aliases or not. - """ - return alt_id in self.aliases - def add_alias(self, alt_id: str): """Add an alias to the list of known aliases. From f657e2daf843dfec2046682b9d6c5bf843c077d1 Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Thu, 9 Mar 2023 17:06:26 +0100 Subject: [PATCH 04/21] change `Strain.aliases` from attribute to property Using property will force user to use method `add_alias` but not directly add alias to the set of aliases. --- src/nplinker/strain_collection.py | 2 +- src/nplinker/strains.py | 41 ++++++++++++++++++++----------- tests/conftest.py | 2 +- tests/test_strain.py | 13 ++-------- 4 files changed, 30 insertions(+), 28 deletions(-) diff --git a/src/nplinker/strain_collection.py b/src/nplinker/strain_collection.py index ca160764..2ae1d0d3 100644 --- a/src/nplinker/strain_collection.py +++ b/src/nplinker/strain_collection.py @@ -30,8 +30,8 @@ def add(self, strain: Strain): # if it already exists, just merge the set of aliases and update # lookup entries existing: Strain = self.lookup(strain.id) - existing.aliases.update(strain.aliases) for alias in strain.aliases: + existing.add_alias(alias) self._lookup[alias] = existing return diff --git a/src/nplinker/strains.py b/src/nplinker/strains.py index be8c88eb..b2820b5d 100644 --- a/src/nplinker/strains.py +++ b/src/nplinker/strains.py @@ -1,39 +1,50 @@ from .logconfig import LogConfig + logger = LogConfig.getLogger(__name__) class Strain(): + def __init__(self, primary_id: str) -> None: + """To model the mapping between strain id and its aliases. + + It's recommended to use NCBI taxonomy strain id or name as the primary + id. + + Args: + primary_id(str): the representative id of the strain. + """ + self.id: str = primary_id + self._aliases: set[str] = set() - def __init__(self, primary_strain_id: str): - self.id: str = primary_strain_id - self.aliases: set[str] = set() + @property + def aliases(self) -> set[str]: + return self._aliases - def add_alias(self, alt_id: str): + def add_alias(self, alias: str) -> None: """Add an alias to the list of known aliases. Args: - alt_id(str): Alternative id to add to the list of known aliases. + alias(str): The alias to add to the list of known aliases. """ - if len(alt_id) == 0: + if len(alias) == 0: logger.warning( - f'Refusing to add zero-length alias to strain {self}') - return - - self.aliases.add(alt_id) + 'Refusing to add an empty-string alias to strain {%s}', self) + else: + self._aliases.add(alias) def __repr__(self) -> str: return str(self) def __str__(self) -> str: - return f'Strain({self.id}) [{len(self.aliases)} aliases]' - - def __eq__(self, other): + return f'Strain({self.id}) [{len(self._aliases)} aliases]' + + def __eq__(self, other) -> bool: return ( isinstance(other, Strain) and self.id == other.id - and self.aliases == other.aliases + and self._aliases == other._aliases ) - + def __hash__(self) -> int: return hash(self.id) diff --git a/tests/conftest.py b/tests/conftest.py index 378cc94a..63200203 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -53,5 +53,5 @@ def collection_from_file() -> StrainCollection: @pytest.fixture def strain() -> Strain: item = Strain("peter") - item.aliases = set(["dieter"]) + item.add_alias("dieter") return item diff --git a/tests/test_strain.py b/tests/test_strain.py index b73fbd39..990ba450 100644 --- a/tests/test_strain.py +++ b/tests/test_strain.py @@ -1,21 +1,13 @@ -import pytest from nplinker.strains import Strain def test_default(): sut = Strain("peter") assert sut.id == "peter" + assert isinstance(sut.aliases, set) assert len(sut.aliases) == 0 -@pytest.mark.parametrize("alias, expected", [ - ["dieter", True], - ["ulrich", False] -]) -def test_has_alias(strain: Strain, alias: str, expected: bool): - assert strain.has_alias(alias) == expected - - def test_add_alias(strain: Strain): strain.add_alias("test") assert len(strain.aliases) == 2 @@ -24,5 +16,4 @@ def test_add_alias(strain: Strain): def test_equal(strain: Strain): other = Strain("peter") other.add_alias("dieter") - - assert strain == other \ No newline at end of file + assert strain == other From e2032f63622e2b167d18b326d4c2a032c09b44d0 Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Thu, 30 Mar 2023 14:33:17 +0200 Subject: [PATCH 05/21] add docstring to Strain class --- src/nplinker/strains.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/nplinker/strains.py b/src/nplinker/strains.py index b2820b5d..291d0476 100644 --- a/src/nplinker/strains.py +++ b/src/nplinker/strains.py @@ -1,3 +1,4 @@ +from __future__ import annotations from .logconfig import LogConfig @@ -19,6 +20,11 @@ def __init__(self, primary_id: str) -> None: @property def aliases(self) -> set[str]: + """Get the set of known aliases. + + Returns: + set[str]: A set of aliases associated with the strain. + """ return self._aliases def add_alias(self, alias: str) -> None: From 001967d2e84ba8983233547a805e6257e07dfc75 Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Thu, 30 Mar 2023 14:35:44 +0200 Subject: [PATCH 06/21] adjust the orders of methods in Strain class --- src/nplinker/strains.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/nplinker/strains.py b/src/nplinker/strains.py index 291d0476..de41a8b0 100644 --- a/src/nplinker/strains.py +++ b/src/nplinker/strains.py @@ -6,6 +6,7 @@ class Strain(): + def __init__(self, primary_id: str) -> None: """To model the mapping between strain id and its aliases. @@ -18,6 +19,19 @@ def __init__(self, primary_id: str) -> None: self.id: str = primary_id self._aliases: set[str] = set() + def __repr__(self) -> str: + return str(self) + + def __str__(self) -> str: + return f'Strain({self.id}) [{len(self._aliases)} aliases]' + + def __eq__(self, other) -> bool: + return (isinstance(other, Strain) and self.id == other.id + and self._aliases == other._aliases) + + def __hash__(self) -> int: + return hash(self.id) + @property def aliases(self) -> set[str]: """Get the set of known aliases. @@ -38,19 +52,3 @@ def add_alias(self, alias: str) -> None: 'Refusing to add an empty-string alias to strain {%s}', self) else: self._aliases.add(alias) - - def __repr__(self) -> str: - return str(self) - - def __str__(self) -> str: - return f'Strain({self.id}) [{len(self._aliases)} aliases]' - - def __eq__(self, other) -> bool: - return ( - isinstance(other, Strain) - and self.id == other.id - and self._aliases == other._aliases - ) - - def __hash__(self) -> int: - return hash(self.id) From 24fe7d90e3e87e89b5d0a1cced0102ef70c67f82 Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Thu, 30 Mar 2023 14:47:11 +0200 Subject: [PATCH 07/21] update strain unit tests --- tests/conftest.py | 4 ++-- tests/test_strain.py | 32 ++++++++++++++++++++++++-------- tests/test_strain_collection.py | 4 ++-- 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 63200203..cac0255c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -52,6 +52,6 @@ def collection_from_file() -> StrainCollection: @pytest.fixture def strain() -> Strain: - item = Strain("peter") - item.add_alias("dieter") + item = Strain("strain_1") + item.add_alias("strain_1_a") return item diff --git a/tests/test_strain.py b/tests/test_strain.py index 990ba450..c8ed0a56 100644 --- a/tests/test_strain.py +++ b/tests/test_strain.py @@ -2,18 +2,34 @@ def test_default(): - sut = Strain("peter") - assert sut.id == "peter" + sut = Strain("strain_1") + assert sut.id == "strain_1" assert isinstance(sut.aliases, set) assert len(sut.aliases) == 0 -def test_add_alias(strain: Strain): - strain.add_alias("test") - assert len(strain.aliases) == 2 +def test_repr(strain: Strain): + assert repr(strain) == "Strain(strain_1) [1 aliases]" + +def test_str(strain: Strain): + assert str(strain) == "Strain(strain_1) [1 aliases]" -def test_equal(strain: Strain): - other = Strain("peter") - other.add_alias("dieter") + +def test_eq(strain: Strain): + other = Strain("strain_1") + other.add_alias("strain_1_a") assert strain == other + + +def test_hash(strain: Strain): + assert hash(strain) == hash("strain_1") + + +def test_alias(strain: Strain): + assert len(strain.aliases) == 1 + + +def test_add_alias(strain: Strain): + strain.add_alias("strain_1_b") + assert len(strain.aliases) == 2 diff --git a/tests/test_strain_collection.py b/tests/test_strain_collection.py index ed9bf457..3316670e 100644 --- a/tests/test_strain_collection.py +++ b/tests/test_strain_collection.py @@ -43,8 +43,8 @@ def test_lookup(strain: Strain): def test_contains(collection: StrainCollection, strain: Strain): assert strain in collection assert strain.id in collection - assert "peter" in collection - assert "dieter" in collection + assert "strain_1" in collection + assert "strain_1_a" in collection assert "test" not in collection From 0c8c9626601cca135fa23b2556592da5e42868f3 Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Thu, 30 Mar 2023 15:21:54 +0200 Subject: [PATCH 08/21] adjust method orders of StrainCollection class --- src/nplinker/strain_collection.py | 84 +++++++++++++++---------------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/src/nplinker/strain_collection.py b/src/nplinker/strain_collection.py index 2ae1d0d3..ee6a4a88 100644 --- a/src/nplinker/strain_collection.py +++ b/src/nplinker/strain_collection.py @@ -16,6 +16,48 @@ def __init__(self): self._lookup: dict[str, Strain] = {} self._lookup_indices: dict[int, Strain] = {} + def __repr__(self) -> str: + return str(self) + + def __str__(self): + if len(self) > 20: + return f'StrainCollection(n={len(self)})' + + return f'StrainCollection(n={len(self)}) [' + ','.join( + s.id for s in self._strains) + ']' + + def __len__(self) -> int: + return len(self._strains) + + def __eq__(self, other): + result = self._strains == other._strains + result &= self._lookup == other._lookup + result &= self._lookup_indices == other._lookup_indices + return result + + def __contains__(self, strain_id: str | Strain) -> bool: + """Check if the strain or strain id are contained in the lookup table. + + Args: + strain_id(str|Strain): Strain or strain id to look up. + + Returns: + bool: Whether the strain is contained in the collection. + """ + + if isinstance(strain_id, str): + return strain_id in self._lookup + # assume it's a Strain object + if isinstance(strain_id, Strain): + return strain_id.id in self._lookup + return False + + def __iter__(self): + return iter(self._strains) + + def __next__(self): + return next(self._strains) + def add(self, strain: Strain): """Add the strain to the aliases. This also adds those strain's aliases to this' strain's aliases. @@ -65,29 +107,6 @@ def filter(self, strain_set: set[Strain]): for strain in to_remove: self.remove(strain) - def __contains__(self, strain_id: str|Strain) -> bool: - """Check if the strain or strain id are contained in the lookup table. - - Args: - strain_id(str|Strain): Strain or strain id to look up. - - Returns: - bool: Whether the strain is contained in the collection. - """ - - if isinstance(strain_id, str): - return strain_id in self._lookup - # assume it's a Strain object - if isinstance(strain_id, Strain): - return strain_id.id in self._lookup - return False - - def __iter__(self): - return iter(self._strains) - - def __next__(self): - return next(self._strains) - def lookup_index(self, index: int) -> Strain: """Return the strain from lookup by index. @@ -199,22 +218,3 @@ def generate_strain_mappings(self, strain_mappings_file: str, logger.info(f'Saving strains to {strain_mappings_file}') self.save_to_file(strain_mappings_file) - - def __len__(self) -> int: - return len(self._strains) - - def __repr__(self) -> str: - return str(self) - - def __str__(self): - if len(self) > 20: - return f'StrainCollection(n={len(self)})' - - return f'StrainCollection(n={len(self)}) [' + ','.join( - s.id for s in self._strains) + ']' - - def __eq__(self, other): - result = self._strains == other._strains - result &= self._lookup == other._lookup - result &= self._lookup_indices == other._lookup_indices - return result From a581d0112f50c318d75a6ba2030452a66bb82e63 Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Thu, 30 Mar 2023 15:52:08 +0200 Subject: [PATCH 09/21] update docstrings and type hints in StrainCollection add type hints --- src/nplinker/strain_collection.py | 32 +++++++++++++------------------ 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/src/nplinker/strain_collection.py b/src/nplinker/strain_collection.py index ee6a4a88..ac42c314 100644 --- a/src/nplinker/strain_collection.py +++ b/src/nplinker/strain_collection.py @@ -1,5 +1,6 @@ import csv import os +from typing import Iterator from .logconfig import LogConfig from .strains import Strain from .utils import list_dirs @@ -12,6 +13,7 @@ class StrainCollection(): def __init__(self): + """A collection of Strain objects.""" self._strains: list[Strain] = [] self._lookup: dict[str, Strain] = {} self._lookup_indices: dict[int, Strain] = {} @@ -19,7 +21,7 @@ def __init__(self): def __repr__(self) -> str: return str(self) - def __str__(self): + def __str__(self) -> str: if len(self) > 20: return f'StrainCollection(n={len(self)})' @@ -29,7 +31,7 @@ def __str__(self): def __len__(self) -> int: return len(self._strains) - def __eq__(self, other): + def __eq__(self, other) -> bool: result = self._strains == other._strains result &= self._lookup == other._lookup result &= self._lookup_indices == other._lookup_indices @@ -52,25 +54,19 @@ def __contains__(self, strain_id: str | Strain) -> bool: return strain_id.id in self._lookup return False - def __iter__(self): + def __iter__(self) -> Iterator[Strain]: return iter(self._strains) - def __next__(self): - return next(self._strains) + def add(self, strain: Strain) -> None: + """Add strain to the collection. - def add(self, strain: Strain): - """Add the strain to the aliases. - This also adds those strain's aliases to this' strain's aliases. + If the strain already exists, merge the aliases. Args: - strain(Strain): Strain to add to self. - - Examples: - >>> - """ + strain(Strain): The strain to add. + """ + # if the strain exists, merge the aliases if strain.id in self._lookup: - # if it already exists, just merge the set of aliases and update - # lookup entries existing: Strain = self.lookup(strain.id) for alias in strain.aliases: existing.add_alias(alias) @@ -79,17 +75,15 @@ def add(self, strain: Strain): self._lookup_indices[len(self)] = strain self._strains.append(strain) - # insert a mapping from strain=>strain, plus all its aliases self._lookup[strain.id] = strain for alias in strain.aliases: self._lookup[alias] = strain def remove(self, strain: Strain): - """Remove the specified strain from the aliases. - TODO: #90 Implement removing the strain also from self._lookup indices. + """Remove a strain from the collection. Args: - strain(Strain): Strain to remove. + strain(Strain): The strain to remove. """ if strain.id not in self._lookup: return From 4dea67eee4ca0563fa92038206e3758cbd9f7542 Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Tue, 4 Apr 2023 14:24:57 +0200 Subject: [PATCH 10/21] update attribute names self._lookup -> self._strain_dict_id self._lookup_indices -> self._strain_dict_index --- src/nplinker/strain_collection.py | 37 +++++++++++++++---------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/src/nplinker/strain_collection.py b/src/nplinker/strain_collection.py index ac42c314..48eadc87 100644 --- a/src/nplinker/strain_collection.py +++ b/src/nplinker/strain_collection.py @@ -15,8 +15,8 @@ class StrainCollection(): def __init__(self): """A collection of Strain objects.""" self._strains: list[Strain] = [] - self._lookup: dict[str, Strain] = {} - self._lookup_indices: dict[int, Strain] = {} + self._strain_dict_id: dict[str, Strain] = {} + self._strain_dict_index: dict[int, Strain] = {} def __repr__(self) -> str: return str(self) @@ -33,8 +33,8 @@ def __len__(self) -> int: def __eq__(self, other) -> bool: result = self._strains == other._strains - result &= self._lookup == other._lookup - result &= self._lookup_indices == other._lookup_indices + result &= self._strain_dict_id == other._strain_dict_id + result &= self._strain_dict_index == other._strain_dict_index return result def __contains__(self, strain_id: str | Strain) -> bool: @@ -66,18 +66,18 @@ def add(self, strain: Strain) -> None: strain(Strain): The strain to add. """ # if the strain exists, merge the aliases - if strain.id in self._lookup: + if strain.id in self._strain_dict_id: existing: Strain = self.lookup(strain.id) for alias in strain.aliases: existing.add_alias(alias) - self._lookup[alias] = existing + self._strain_dict_id[alias] = existing return - self._lookup_indices[len(self)] = strain + self._strain_dict_index[len(self)] = strain self._strains.append(strain) - self._lookup[strain.id] = strain + self._strain_dict_id[strain.id] = strain for alias in strain.aliases: - self._lookup[alias] = strain + self._strain_dict_id[alias] = strain def remove(self, strain: Strain): """Remove a strain from the collection. @@ -85,13 +85,12 @@ def remove(self, strain: Strain): Args: strain(Strain): The strain to remove. """ - if strain.id not in self._lookup: - return - - self._strains.remove(strain) - del self._lookup[strain.id] - for alias in strain.aliases: - del self._lookup[alias] + if strain.id in self._strain_dict_id: + self._strains.remove(strain) + # remove from dict id + del self._strain_dict_id[strain.id] + for alias in strain.aliases: + del self._strain_dict_id[alias] def filter(self, strain_set: set[Strain]): """ @@ -110,7 +109,7 @@ def lookup_index(self, index: int) -> Strain: Returns: Strain: Strain identified by the given index. """ - return self._lookup_indices[index] + return self._strain_dict_index[index] def lookup(self, strain_id: str) -> Strain: """Check whether the strain id is contained in the lookup table. If so, return the strain, otherwise return `default`. @@ -124,11 +123,11 @@ def lookup(self, strain_id: str) -> Strain: Returns: Strain: Strain retrieved during lookup or object passed as default. """ - if strain_id not in self._lookup: + if strain_id not in self._strain_dict_id: # logger.error('Strain lookup failed for "{}"'.format(strain_id)) raise KeyError(f'Strain lookup failed for "{strain_id}"') - return self._lookup[strain_id] + return self._strain_dict_id[strain_id] def add_from_file(self, file: str | os.PathLike): """Read strains and aliases from file and store in self. From 0ec47deac61703b1d79b6a6f8923767584e00afa Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Thu, 30 Mar 2023 15:33:34 +0200 Subject: [PATCH 11/21] update `__contains__` in StrainCollection update `__contains__` method --- src/nplinker/strain_collection.py | 24 ++++++++---------------- tests/test_strain_collection.py | 1 - 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/src/nplinker/strain_collection.py b/src/nplinker/strain_collection.py index 48eadc87..8efff1fc 100644 --- a/src/nplinker/strain_collection.py +++ b/src/nplinker/strain_collection.py @@ -37,22 +37,14 @@ def __eq__(self, other) -> bool: result &= self._strain_dict_index == other._strain_dict_index return result - def __contains__(self, strain_id: str | Strain) -> bool: - """Check if the strain or strain id are contained in the lookup table. - - Args: - strain_id(str|Strain): Strain or strain id to look up. - - Returns: - bool: Whether the strain is contained in the collection. - """ - - if isinstance(strain_id, str): - return strain_id in self._lookup - # assume it's a Strain object - if isinstance(strain_id, Strain): - return strain_id.id in self._lookup - return False + def __contains__(self, strain: str | Strain) -> bool: + if isinstance(strain, str): + value = strain in self._strain_dict_id + elif isinstance(strain, Strain): + value = strain.id in self._strain_dict_id + else: + raise TypeError(f"Expected Strain or str, got {type(strain)}") + return value def __iter__(self) -> Iterator[Strain]: return iter(self._strains) diff --git a/tests/test_strain_collection.py b/tests/test_strain_collection.py index 3316670e..9ffc9562 100644 --- a/tests/test_strain_collection.py +++ b/tests/test_strain_collection.py @@ -41,7 +41,6 @@ def test_lookup(strain: Strain): def test_contains(collection: StrainCollection, strain: Strain): - assert strain in collection assert strain.id in collection assert "strain_1" in collection assert "strain_1_a" in collection From 652d18bb1e670a60041df6294c4fa27c1bab91e6 Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Fri, 31 Mar 2023 11:04:43 +0200 Subject: [PATCH 12/21] refactor `__eq__` method --- src/nplinker/strain_collection.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/nplinker/strain_collection.py b/src/nplinker/strain_collection.py index 8efff1fc..4fda2754 100644 --- a/src/nplinker/strain_collection.py +++ b/src/nplinker/strain_collection.py @@ -32,10 +32,9 @@ def __len__(self) -> int: return len(self._strains) def __eq__(self, other) -> bool: - result = self._strains == other._strains - result &= self._strain_dict_id == other._strain_dict_id - result &= self._strain_dict_index == other._strain_dict_index - return result + return (self._strains == other._strains + and self._strain_dict_id == other._strain_dict_id + and self._strain_dict_index == other._strain_dict_index) def __contains__(self, strain: str | Strain) -> bool: if isinstance(strain, str): From 40256d4635f2daf929bed74d379f0059090e350b Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Fri, 31 Mar 2023 11:00:48 +0200 Subject: [PATCH 13/21] refactor `add` method --- src/nplinker/strain_collection.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/nplinker/strain_collection.py b/src/nplinker/strain_collection.py index 4fda2754..a7a544b9 100644 --- a/src/nplinker/strain_collection.py +++ b/src/nplinker/strain_collection.py @@ -62,13 +62,12 @@ def add(self, strain: Strain) -> None: for alias in strain.aliases: existing.add_alias(alias) self._strain_dict_id[alias] = existing - return - - self._strain_dict_index[len(self)] = strain - self._strains.append(strain) - self._strain_dict_id[strain.id] = strain - for alias in strain.aliases: - self._strain_dict_id[alias] = strain + else: + self._strain_dict_index[len(self)] = strain + self._strains.append(strain) + self._strain_dict_id[strain.id] = strain + for alias in strain.aliases: + self._strain_dict_id[alias] = strain def remove(self, strain: Strain): """Remove a strain from the collection. From 461db792cdbf1d8b420cc1377400188e2b133330 Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Wed, 5 Apr 2023 10:22:12 +0200 Subject: [PATCH 14/21] refactor `filter` method --- src/nplinker/strain_collection.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/nplinker/strain_collection.py b/src/nplinker/strain_collection.py index a7a544b9..372babd5 100644 --- a/src/nplinker/strain_collection.py +++ b/src/nplinker/strain_collection.py @@ -86,9 +86,10 @@ def filter(self, strain_set: set[Strain]): """ Remove all strains that are not in strain_set from the strain collection """ - to_remove = [x for x in self._strains if x not in strain_set] - for strain in to_remove: - self.remove(strain) + # note that we need to copy the list of strains, as we are modifying it + for strain in self._strains.copy(): + if strain not in strain_set: + self.remove(strain) def lookup_index(self, index: int) -> Strain: """Return the strain from lookup by index. From 5884f57074e2c43c2a118cbdbcc89f9622892a1d Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Fri, 31 Mar 2023 12:08:40 +0200 Subject: [PATCH 15/21] update tests for remove and filter --- tests/test_strain_collection.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/test_strain_collection.py b/tests/test_strain_collection.py index 9ffc9562..5f6d3dee 100644 --- a/tests/test_strain_collection.py +++ b/tests/test_strain_collection.py @@ -59,15 +59,23 @@ def test_lookup_index_exception(collection: StrainCollection): def test_remove(collection: StrainCollection, strain: Strain): + assert strain in collection collection.remove(strain) - with pytest.raises(KeyError): collection.lookup(strain.id) - assert strain not in collection - - # needs fixing, see #90 - assert collection.lookup_index(0) == strain + # TODO: issue #90 + # with pytest.raises(KeyError): + # collection.lookup_index(0) + + +def test_filter(collection: StrainCollection, strain: Strain): + assert strain in collection + collection.add(Strain("strain_2")) + collection.filter({strain}) + assert strain in collection + assert "strain_2" not in collection + assert len(collection) == 1 def test_equal(collection_from_file: StrainCollection): From de08d95f79f0a1261241fad250dfe950a8a58aa9 Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Fri, 31 Mar 2023 13:52:30 +0200 Subject: [PATCH 16/21] update `lookup` method and its tests --- src/nplinker/strain_collection.py | 23 ++++++++++++----------- tests/test_strain_collection.py | 11 ++++++----- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/nplinker/strain_collection.py b/src/nplinker/strain_collection.py index 372babd5..54d937cc 100644 --- a/src/nplinker/strain_collection.py +++ b/src/nplinker/strain_collection.py @@ -102,23 +102,24 @@ def lookup_index(self, index: int) -> Strain: """ return self._strain_dict_index[index] - def lookup(self, strain_id: str) -> Strain: - """Check whether the strain id is contained in the lookup table. If so, return the strain, otherwise return `default`. + def lookup(self, name: str) -> Strain: + """Lookup a strain by name (id or alias). - Raises: - Exception if strain_id is not found. + If the name is found, return the strain object; Otherwise, raise a + KeyError. Args: - strain_id(str): Strain id to lookup. + name(str): Strain name (id or alias) to lookup. Returns: - Strain: Strain retrieved during lookup or object passed as default. - """ - if strain_id not in self._strain_dict_id: - # logger.error('Strain lookup failed for "{}"'.format(strain_id)) - raise KeyError(f'Strain lookup failed for "{strain_id}"') + Strain: Strain identified by the given name. - return self._strain_dict_id[strain_id] + Raises: + KeyError: If the strain name is not found. + """ + if name not in self._strain_dict_id: + raise KeyError(f"Strain {name} not found in strain collection.") + return self._strain_dict_id[name] def add_from_file(self, file: str | os.PathLike): """Read strains and aliases from file and store in self. diff --git a/tests/test_strain_collection.py b/tests/test_strain_collection.py index 5f6d3dee..2fcf4a50 100644 --- a/tests/test_strain_collection.py +++ b/tests/test_strain_collection.py @@ -33,11 +33,12 @@ def test_add(): assert sut.lookup_index(0) == item -def test_lookup(strain: Strain): - sut = StrainCollection() - sut.add(strain) - - assert sut.lookup(strain.id) == strain +def test_lookup(collection: StrainCollection, strain: Strain): + assert collection.lookup(strain.id) == strain + for alias in strain.aliases: + assert collection.lookup(alias) == strain + with pytest.raises(KeyError): + collection.lookup("strain_not_exist") def test_contains(collection: StrainCollection, strain: Strain): From db4626296c38fe9c2a11581eee23bc4e82acc4d5 Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Fri, 31 Mar 2023 14:19:52 +0200 Subject: [PATCH 17/21] add tests for magic methods --- tests/test_strain_collection.py | 38 +++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/tests/test_strain_collection.py b/tests/test_strain_collection.py index 2fcf4a50..66b0507e 100644 --- a/tests/test_strain_collection.py +++ b/tests/test_strain_collection.py @@ -16,6 +16,37 @@ def test_default(): assert sut is not None +def test_repr(collection: StrainCollection): + assert repr(collection) == str(collection) + + +def test_str(collection: StrainCollection): + assert str(collection) == 'StrainCollection(n=1) [strain_1]' + + +def test_len(collection: StrainCollection): + assert len(collection) == 1 + + +def test_eq(collection: StrainCollection, strain: Strain): + other = StrainCollection() + other.add(strain) + assert collection == other + + +def test_contains(collection: StrainCollection, strain: Strain): + assert strain in collection + assert strain.id in collection + for alias in strain.aliases: + assert alias in collection + assert "strain_not_exist" not in collection + + +def test_iter(collection: StrainCollection, strain: Strain): + for actual in collection: + assert actual == strain + + def test_add_from_file(collection_from_file: StrainCollection): assert len(collection_from_file) == 27 assert len(collection_from_file.lookup_index(1).aliases) == 29 @@ -41,13 +72,6 @@ def test_lookup(collection: StrainCollection, strain: Strain): collection.lookup("strain_not_exist") -def test_contains(collection: StrainCollection, strain: Strain): - assert strain.id in collection - assert "strain_1" in collection - assert "strain_1_a" in collection - assert "test" not in collection - - def test_lookup_index(collection: StrainCollection, strain: Strain): actual = collection.lookup_index(0) assert actual == strain From 46660698be0a4e6fae6b8488e706eec93baa9cba Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Fri, 31 Mar 2023 14:50:23 +0200 Subject: [PATCH 18/21] update tests --- tests/test_strain_collection.py | 55 +++++++++++---------------------- 1 file changed, 18 insertions(+), 37 deletions(-) diff --git a/tests/test_strain_collection.py b/tests/test_strain_collection.py index 66b0507e..de63dc5b 100644 --- a/tests/test_strain_collection.py +++ b/tests/test_strain_collection.py @@ -11,11 +11,6 @@ def collection(strain: Strain) -> StrainCollection: return sut -def test_default(): - sut = StrainCollection() - assert sut is not None - - def test_repr(collection: StrainCollection): assert repr(collection) == str(collection) @@ -52,42 +47,20 @@ def test_add_from_file(collection_from_file: StrainCollection): assert len(collection_from_file.lookup_index(1).aliases) == 29 -def test_add(): +def test_add(strain: Strain): sut = StrainCollection() - item = Strain("test_id") - item.add_alias("blub") - - sut.add(item) - - assert sut.lookup(item.id) == item - assert sut.lookup(next(iter(item.aliases))) == item - assert sut.lookup_index(0) == item - - -def test_lookup(collection: StrainCollection, strain: Strain): - assert collection.lookup(strain.id) == strain + sut.add(strain) + assert strain in sut for alias in strain.aliases: - assert collection.lookup(alias) == strain - with pytest.raises(KeyError): - collection.lookup("strain_not_exist") - - -def test_lookup_index(collection: StrainCollection, strain: Strain): - actual = collection.lookup_index(0) - assert actual == strain - - -def test_lookup_index_exception(collection: StrainCollection): - with pytest.raises(KeyError) as exc: - collection.lookup_index(5) - assert isinstance(exc.value, KeyError) + assert alias in sut + assert sut._strain_dict_index[0] == strain def test_remove(collection: StrainCollection, strain: Strain): assert strain in collection collection.remove(strain) with pytest.raises(KeyError): - collection.lookup(strain.id) + _ = collection._strain_dict_id[strain.id] assert strain not in collection # TODO: issue #90 # with pytest.raises(KeyError): @@ -103,8 +76,16 @@ def test_filter(collection: StrainCollection, strain: Strain): assert len(collection) == 1 -def test_equal(collection_from_file: StrainCollection): - other = StrainCollection() - other.add_from_file(DATA_DIR / "strain_mappings.csv") +def test_lookup_index(collection: StrainCollection, strain: Strain): + actual = collection.lookup_index(0) + assert actual == strain + with pytest.raises(KeyError): + collection.lookup_index(1) + - assert collection_from_file == other +def test_lookup(collection: StrainCollection, strain: Strain): + assert collection.lookup(strain.id) == strain + for alias in strain.aliases: + assert collection.lookup(alias) == strain + with pytest.raises(KeyError): + collection.lookup("strain_not_exist") From ae579aacaa4b40b23a75435024bfb99dae4885e3 Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Fri, 31 Mar 2023 15:11:33 +0200 Subject: [PATCH 19/21] refactor `add_from_file` method --- src/nplinker/strain_collection.py | 37 ++++++++++++------------------- tests/test_strain_collection.py | 13 ++++++----- 2 files changed, 21 insertions(+), 29 deletions(-) diff --git a/src/nplinker/strain_collection.py b/src/nplinker/strain_collection.py index 54d937cc..e51cbfd6 100644 --- a/src/nplinker/strain_collection.py +++ b/src/nplinker/strain_collection.py @@ -1,5 +1,5 @@ import csv -import os +from os import PathLike from typing import Iterator from .logconfig import LogConfig from .strains import Strain @@ -121,36 +121,27 @@ def lookup(self, name: str) -> Strain: raise KeyError(f"Strain {name} not found in strain collection.") return self._strain_dict_id[name] - def add_from_file(self, file: str | os.PathLike): - """Read strains and aliases from file and store in self. + def add_from_file(self, file: str | PathLike) -> None: + """Add strains from a strain mapping file. + + A strain mapping file is a csv file with the first column being the + id of the strain, and the remaining columns being aliases for the + strain. Args: - file(str): Path to strain mapping file to load. + file(str | PathLike): Path to strain mapping file (.csv). """ - - if not os.path.exists(file): - logger.warning(f'strain mappings file not found: {file}') - return - - line = 1 with open(file) as f: reader = csv.reader(f) - for ids in reader: - if len(ids) == 0: + for names in reader: + if len(names) == 0: continue - strain = Strain(ids[0]) - for id in ids[1:]: - if len(id) == 0: - logger.warning( - 'Found zero-length strain label in {} on line {}'. - format(file, line)) - else: - strain.add_alias(id) + strain = Strain(names[0]) + for alias in names[1:]: + strain.add_alias(alias) self.add(strain) - line += 1 - - def save_to_file(self, file: str | os.PathLike): + def save_to_file(self, file: str | PathLike): """Save this strain collection to file. Args: diff --git a/tests/test_strain_collection.py b/tests/test_strain_collection.py index de63dc5b..8911f6dd 100644 --- a/tests/test_strain_collection.py +++ b/tests/test_strain_collection.py @@ -41,12 +41,6 @@ def test_iter(collection: StrainCollection, strain: Strain): for actual in collection: assert actual == strain - -def test_add_from_file(collection_from_file: StrainCollection): - assert len(collection_from_file) == 27 - assert len(collection_from_file.lookup_index(1).aliases) == 29 - - def test_add(strain: Strain): sut = StrainCollection() sut.add(strain) @@ -89,3 +83,10 @@ def test_lookup(collection: StrainCollection, strain: Strain): assert collection.lookup(alias) == strain with pytest.raises(KeyError): collection.lookup("strain_not_exist") + + +def test_add_from_file(): + sut = StrainCollection() + sut.add_from_file(DATA_DIR / "strain_mappings.csv") + assert len(sut) == 27 + assert len(sut.lookup_index(1).aliases) == 29 From 568a560ae667731f4bf91c9ede8887cb26448481 Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Fri, 31 Mar 2023 15:23:34 +0200 Subject: [PATCH 20/21] update `save_to_file` method --- src/nplinker/strain_collection.py | 15 ++++++--------- tests/test_strain_collection.py | 12 ++++++++++++ 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/nplinker/strain_collection.py b/src/nplinker/strain_collection.py index e51cbfd6..48ca7d2e 100644 --- a/src/nplinker/strain_collection.py +++ b/src/nplinker/strain_collection.py @@ -1,4 +1,5 @@ import csv +import os from os import PathLike from typing import Iterator from .logconfig import LogConfig @@ -6,7 +7,6 @@ from .utils import list_dirs from .utils import list_files - logger = LogConfig.getLogger(__name__) @@ -141,17 +141,14 @@ def add_from_file(self, file: str | PathLike) -> None: strain.add_alias(alias) self.add(strain) - def save_to_file(self, file: str | PathLike): - """Save this strain collection to file. + def save_to_file(self, file: str | PathLike) -> None: + """Save strains to a strain mapping file (.csv). Args: - file(str): Output file. - - Examples: - >>> - """ + file(str | PathLike): Path to strain mapping file (.csv). + """ with open(file, 'w') as f: - for strain in self._strains: + for strain in self: ids = [strain.id] + list(strain.aliases) f.write(','.join(ids) + '\n') diff --git a/tests/test_strain_collection.py b/tests/test_strain_collection.py index 8911f6dd..26ab2767 100644 --- a/tests/test_strain_collection.py +++ b/tests/test_strain_collection.py @@ -90,3 +90,15 @@ def test_add_from_file(): sut.add_from_file(DATA_DIR / "strain_mappings.csv") assert len(sut) == 27 assert len(sut.lookup_index(1).aliases) == 29 + + +def test_save_to_file(collection: StrainCollection, tmp_path): + collection.add(Strain("strain_2")) + path = tmp_path / "test.csv" + collection.save_to_file(path) + assert path.exists() + with open(path) as f: + lines = f.readlines() + assert len(lines) == 2 + assert lines[0].strip() == "strain_1,strain_1_a" + assert lines[1].strip() == "strain_2" From 4e86b86de77e675c910246fec5d8717826f1a4b4 Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Fri, 31 Mar 2023 16:22:57 +0200 Subject: [PATCH 21/21] update `generate_strain_mappings` method --- src/nplinker/strain_collection.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/nplinker/strain_collection.py b/src/nplinker/strain_collection.py index 48ca7d2e..29d532e2 100644 --- a/src/nplinker/strain_collection.py +++ b/src/nplinker/strain_collection.py @@ -1,7 +1,9 @@ import csv import os from os import PathLike +from pathlib import Path from typing import Iterator +from deprecated import deprecated from .logconfig import LogConfig from .strains import Strain from .utils import list_dirs @@ -152,8 +154,10 @@ def save_to_file(self, file: str | PathLike) -> None: ids = [strain.id] + list(strain.aliases) f.write(','.join(ids) + '\n') - def generate_strain_mappings(self, strain_mappings_file: str, - antismash_dir: str) -> None: + # TODO to move this method to a separate class + @deprecated(version="1.3.3", reason="This method will be removed") + def generate_strain_mappings(self, strain_mappings_file: str | PathLike, + antismash_dir: str | PathLike) -> None: """Add AntiSMASH BGC file names as strain alias to strain mappings file. Note that if AntiSMASH ID (e.g. GCF_000016425.1) is not found in @@ -161,10 +165,10 @@ def generate_strain_mappings(self, strain_mappings_file: str, added. Args: - strain_mappings_file(str): Path to strain mappings file - antismash_dir(str): Path to AntiSMASH directory + strain_mappings_file(str | PathLike): Path to strain mappings file. + antismash_dir(str | PathLike): Path to AntiSMASH output directory. """ - if os.path.exists(strain_mappings_file): + if Path(strain_mappings_file).exists(): logger.info('Strain mappings file exist') return @@ -172,7 +176,7 @@ def generate_strain_mappings(self, strain_mappings_file: str, logger.info('Generating strain mappings file') subdirs = list_dirs(antismash_dir) for d in subdirs: - antismash_id = os.path.basename(d) + antismash_id = Path(d).name # use antismash_id (e.g. GCF_000016425.1) as strain name to query # TODO: self is empty at the moment, why lookup here? @@ -186,7 +190,7 @@ def generate_strain_mappings(self, strain_mappings_file: str, # if strain `antismash_id` exist, add all gbk file names as strain alias gbk_files = list_files(d, suffix=".gbk", keep_parent=False) for f in gbk_files: - gbk_filename = os.path.splitext(f)[0] + gbk_filename = Path(f).stem strain.add_alias(gbk_filename) logger.info(f'Saving strains to {strain_mappings_file}')