diff --git a/src/nplinker/strain_collection.py b/src/nplinker/strain_collection.py index 2ae1d0d3..5169f02a 100644 --- a/src/nplinker/strain_collection.py +++ b/src/nplinker/strain_collection.py @@ -1,5 +1,9 @@ import csv import os +from os import PathLike +from pathlib import Path +from typing import Iterator +from deprecated import deprecated from .logconfig import LogConfig from .strains import Strain from .utils import list_dirs @@ -12,81 +16,83 @@ class StrainCollection(): def __init__(self): + """A collection of Strain objects.""" self._strains: list[Strain] = [] - self._lookup: dict[str, Strain] = {} - self._lookup_indices: dict[int, Strain] = {} + self._strain_dict_id: dict[str, Strain] = {} + self._strain_dict_index: dict[int, Strain] = {} - def add(self, strain: Strain): - """Add the strain to the aliases. - This also adds those strain's aliases to this' strain's aliases. + def __repr__(self) -> str: + return str(self) + + def __str__(self) -> str: + if len(self) > 20: + return f'StrainCollection(n={len(self)})' + + return f'StrainCollection(n={len(self)}) [' + ','.join( + s.id for s in self._strains) + ']' + + def __len__(self) -> int: + return len(self._strains) + + def __eq__(self, other) -> bool: + return (self._strains == other._strains + and self._strain_dict_id == other._strain_dict_id + and self._strain_dict_index == other._strain_dict_index) + + def __contains__(self, strain: str | Strain) -> bool: + if isinstance(strain, str): + value = strain in self._strain_dict_id + elif isinstance(strain, Strain): + value = strain.id in self._strain_dict_id + else: + raise TypeError(f"Expected Strain or str, got {type(strain)}") + return value + + def __iter__(self) -> Iterator[Strain]: + return iter(self._strains) + + def add(self, strain: Strain) -> None: + """Add strain to the collection. + + If the strain already exists, merge the aliases. Args: - strain(Strain): Strain to add to self. - - Examples: - >>> - """ - if strain.id in self._lookup: - # if it already exists, just merge the set of aliases and update - # lookup entries + strain(Strain): The strain to add. + """ + # if the strain exists, merge the aliases + if strain.id in self._strain_dict_id: existing: Strain = self.lookup(strain.id) for alias in strain.aliases: existing.add_alias(alias) - self._lookup[alias] = existing - return - - self._lookup_indices[len(self)] = strain - self._strains.append(strain) - # insert a mapping from strain=>strain, plus all its aliases - self._lookup[strain.id] = strain - for alias in strain.aliases: - self._lookup[alias] = strain + self._strain_dict_id[alias] = existing + else: + self._strain_dict_index[len(self)] = strain + self._strains.append(strain) + self._strain_dict_id[strain.id] = strain + for alias in strain.aliases: + self._strain_dict_id[alias] = strain def remove(self, strain: Strain): - """Remove the specified strain from the aliases. - TODO: #90 Implement removing the strain also from self._lookup indices. + """Remove a strain from the collection. Args: - strain(Strain): Strain to remove. + strain(Strain): The strain to remove. """ - if strain.id not in self._lookup: - return - - self._strains.remove(strain) - del self._lookup[strain.id] - for alias in strain.aliases: - del self._lookup[alias] + if strain.id in self._strain_dict_id: + self._strains.remove(strain) + # remove from dict id + del self._strain_dict_id[strain.id] + for alias in strain.aliases: + del self._strain_dict_id[alias] def filter(self, strain_set: set[Strain]): """ Remove all strains that are not in strain_set from the strain collection """ - to_remove = [x for x in self._strains if x not in strain_set] - for strain in to_remove: - self.remove(strain) - - def __contains__(self, strain_id: str|Strain) -> bool: - """Check if the strain or strain id are contained in the lookup table. - - Args: - strain_id(str|Strain): Strain or strain id to look up. - - Returns: - bool: Whether the strain is contained in the collection. - """ - - if isinstance(strain_id, str): - return strain_id in self._lookup - # assume it's a Strain object - if isinstance(strain_id, Strain): - return strain_id.id in self._lookup - return False - - def __iter__(self): - return iter(self._strains) - - def __next__(self): - return next(self._strains) + # note that we need to copy the list of strains, as we are modifying it + for strain in self._strains.copy(): + if strain not in strain_set: + self.remove(strain) def lookup_index(self, index: int) -> Strain: """Return the strain from lookup by index. @@ -97,71 +103,62 @@ def lookup_index(self, index: int) -> Strain: Returns: Strain: Strain identified by the given index. """ - return self._lookup_indices[index] + return self._strain_dict_index[index] - def lookup(self, strain_id: str) -> Strain: - """Check whether the strain id is contained in the lookup table. If so, return the strain, otherwise return `default`. + def lookup(self, name: str) -> Strain: + """Lookup a strain by name (id or alias). - Raises: - Exception if strain_id is not found. + If the name is found, return the strain object; Otherwise, raise a + KeyError. Args: - strain_id(str): Strain id to lookup. + name(str): Strain name (id or alias) to lookup. Returns: - Strain: Strain retrieved during lookup or object passed as default. + Strain: Strain identified by the given name. + + Raises: + KeyError: If the strain name is not found. """ - if strain_id not in self._lookup: - # logger.error('Strain lookup failed for "{}"'.format(strain_id)) - raise KeyError(f'Strain lookup failed for "{strain_id}"') + if name not in self._strain_dict_id: + raise KeyError(f"Strain {name} not found in strain collection.") + return self._strain_dict_id[name] - return self._lookup[strain_id] + def add_from_file(self, file: str | PathLike) -> None: + """Add strains from a strain mapping file. - def add_from_file(self, file: str | os.PathLike): - """Read strains and aliases from file and store in self. + A strain mapping file is a csv file with the first column being the + id of the strain, and the remaining columns being aliases for the + strain. Args: - file(str): Path to strain mapping file to load. + file(str | PathLike): Path to strain mapping file (.csv). """ - - if not os.path.exists(file): - logger.warning(f'strain mappings file not found: {file}') - return - - line = 1 with open(file) as f: reader = csv.reader(f) - for ids in reader: - if len(ids) == 0: + for names in reader: + if len(names) == 0: continue - strain = Strain(ids[0]) - for id in ids[1:]: - if len(id) == 0: - logger.warning( - 'Found zero-length strain label in {} on line {}'. - format(file, line)) - else: - strain.add_alias(id) + strain = Strain(names[0]) + for alias in names[1:]: + strain.add_alias(alias) self.add(strain) - line += 1 - - def save_to_file(self, file: str | os.PathLike): - """Save this strain collection to file. + def save_to_file(self, file: str | PathLike) -> None: + """Save strains to a strain mapping file (.csv). Args: - file(str): Output file. - - Examples: - >>> - """ + file(str | PathLike): Path to strain mapping file (.csv). + """ with open(file, 'w') as f: - for strain in self._strains: + for strain in self: ids = [strain.id] + list(strain.aliases) f.write(','.join(ids) + '\n') - def generate_strain_mappings(self, strain_mappings_file: str, - antismash_dir: str) -> None: + # TODO to move this method to a separate class + @deprecated(version="1.3.3", reason="This method will be removed") + def generate_strain_mappings(self, strain_mappings_file: str | PathLike, + antismash_dir: str | PathLike) -> None: """Add AntiSMASH BGC file names as strain alias to strain mappings file. Note that if AntiSMASH ID (e.g. GCF_000016425.1) is not found in @@ -169,10 +166,10 @@ def generate_strain_mappings(self, strain_mappings_file: str, added. Args: - strain_mappings_file(str): Path to strain mappings file - antismash_dir(str): Path to AntiSMASH directory + strain_mappings_file(str | PathLike): Path to strain mappings file. + antismash_dir(str | PathLike): Path to AntiSMASH output directory. """ - if os.path.exists(strain_mappings_file): + if Path(strain_mappings_file).exists(): logger.info('Strain mappings file exist') return @@ -180,7 +177,7 @@ def generate_strain_mappings(self, strain_mappings_file: str, logger.info('Generating strain mappings file') subdirs = list_dirs(antismash_dir) for d in subdirs: - antismash_id = os.path.basename(d) + antismash_id = Path(d).name # use antismash_id (e.g. GCF_000016425.1) as strain name to query # TODO: self is empty at the moment, why lookup here? @@ -194,27 +191,8 @@ def generate_strain_mappings(self, strain_mappings_file: str, # if strain `antismash_id` exist, add all gbk file names as strain alias gbk_files = list_files(d, suffix=".gbk", keep_parent=False) for f in gbk_files: - gbk_filename = os.path.splitext(f)[0] + gbk_filename = Path(f).stem strain.add_alias(gbk_filename) logger.info(f'Saving strains to {strain_mappings_file}') self.save_to_file(strain_mappings_file) - - def __len__(self) -> int: - return len(self._strains) - - def __repr__(self) -> str: - return str(self) - - def __str__(self): - if len(self) > 20: - return f'StrainCollection(n={len(self)})' - - return f'StrainCollection(n={len(self)}) [' + ','.join( - s.id for s in self._strains) + ']' - - def __eq__(self, other): - result = self._strains == other._strains - result &= self._lookup == other._lookup - result &= self._lookup_indices == other._lookup_indices - return result diff --git a/tests/test_strain_collection.py b/tests/test_strain_collection.py index 3316670e..26ab2767 100644 --- a/tests/test_strain_collection.py +++ b/tests/test_strain_collection.py @@ -11,68 +11,94 @@ def collection(strain: Strain) -> StrainCollection: return sut -def test_default(): - sut = StrainCollection() - assert sut is not None - - -def test_add_from_file(collection_from_file: StrainCollection): - assert len(collection_from_file) == 27 - assert len(collection_from_file.lookup_index(1).aliases) == 29 - +def test_repr(collection: StrainCollection): + assert repr(collection) == str(collection) -def test_add(): - sut = StrainCollection() - item = Strain("test_id") - item.add_alias("blub") - sut.add(item) +def test_str(collection: StrainCollection): + assert str(collection) == 'StrainCollection(n=1) [strain_1]' - assert sut.lookup(item.id) == item - assert sut.lookup(next(iter(item.aliases))) == item - assert sut.lookup_index(0) == item +def test_len(collection: StrainCollection): + assert len(collection) == 1 -def test_lookup(strain: Strain): - sut = StrainCollection() - sut.add(strain) - assert sut.lookup(strain.id) == strain +def test_eq(collection: StrainCollection, strain: Strain): + other = StrainCollection() + other.add(strain) + assert collection == other def test_contains(collection: StrainCollection, strain: Strain): assert strain in collection assert strain.id in collection - assert "strain_1" in collection - assert "strain_1_a" in collection - assert "test" not in collection + for alias in strain.aliases: + assert alias in collection + assert "strain_not_exist" not in collection -def test_lookup_index(collection: StrainCollection, strain: Strain): - actual = collection.lookup_index(0) - assert actual == strain - +def test_iter(collection: StrainCollection, strain: Strain): + for actual in collection: + assert actual == strain -def test_lookup_index_exception(collection: StrainCollection): - with pytest.raises(KeyError) as exc: - collection.lookup_index(5) - assert isinstance(exc.value, KeyError) +def test_add(strain: Strain): + sut = StrainCollection() + sut.add(strain) + assert strain in sut + for alias in strain.aliases: + assert alias in sut + assert sut._strain_dict_index[0] == strain def test_remove(collection: StrainCollection, strain: Strain): + assert strain in collection collection.remove(strain) - with pytest.raises(KeyError): - collection.lookup(strain.id) - + _ = collection._strain_dict_id[strain.id] assert strain not in collection + # TODO: issue #90 + # with pytest.raises(KeyError): + # collection.lookup_index(0) - # needs fixing, see #90 - assert collection.lookup_index(0) == strain +def test_filter(collection: StrainCollection, strain: Strain): + assert strain in collection + collection.add(Strain("strain_2")) + collection.filter({strain}) + assert strain in collection + assert "strain_2" not in collection + assert len(collection) == 1 -def test_equal(collection_from_file: StrainCollection): - other = StrainCollection() - other.add_from_file(DATA_DIR / "strain_mappings.csv") - assert collection_from_file == other +def test_lookup_index(collection: StrainCollection, strain: Strain): + actual = collection.lookup_index(0) + assert actual == strain + with pytest.raises(KeyError): + collection.lookup_index(1) + + +def test_lookup(collection: StrainCollection, strain: Strain): + assert collection.lookup(strain.id) == strain + for alias in strain.aliases: + assert collection.lookup(alias) == strain + with pytest.raises(KeyError): + collection.lookup("strain_not_exist") + + +def test_add_from_file(): + sut = StrainCollection() + sut.add_from_file(DATA_DIR / "strain_mappings.csv") + assert len(sut) == 27 + assert len(sut.lookup_index(1).aliases) == 29 + + +def test_save_to_file(collection: StrainCollection, tmp_path): + collection.add(Strain("strain_2")) + path = tmp_path / "test.csv" + collection.save_to_file(path) + assert path.exists() + with open(path) as f: + lines = f.readlines() + assert len(lines) == 2 + assert lines[0].strip() == "strain_1,strain_1_a" + assert lines[1].strip() == "strain_2"