diff --git a/src/nplinker/strain_collection.py b/src/nplinker/strain_collection.py index a2f60d92..e471a563 100644 --- a/src/nplinker/strain_collection.py +++ b/src/nplinker/strain_collection.py @@ -15,8 +15,8 @@ class StrainCollection(): def __init__(self): """A collection of Strain objects.""" self._strains: list[Strain] = [] - self._lookup: dict[str, Strain] = {} - self._lookup_indices: dict[int, Strain] = {} + self._strain_dict_id: dict[str, Strain] = {} + self._strain_dict_index: dict[int, Strain] = {} def __repr__(self) -> str: return str(self) @@ -33,8 +33,8 @@ def __len__(self) -> int: def __eq__(self, other) -> bool: result = self._strains == other._strains - result &= self._lookup == other._lookup - result &= self._lookup_indices == other._lookup_indices + result &= self._strain_dict_id == other._strain_dict_id + result &= self._strain_dict_index == other._strain_dict_index return result def __contains__(self, strain: str | Strain) -> bool: @@ -58,18 +58,18 @@ def add(self, strain: Strain) -> None: strain(Strain): The strain to add. """ # if the strain exists, merge the aliases - if strain.id in self._lookup: + if strain.id in self._strain_dict_id: existing: Strain = self.lookup(strain.id) for alias in strain.aliases: existing.add_alias(alias) - self._lookup[alias] = existing + self._strain_dict_id[alias] = existing return - self._lookup_indices[len(self)] = strain + self._strain_dict_index[len(self)] = strain self._strains.append(strain) - self._lookup[strain.id] = strain + self._strain_dict_id[strain.id] = strain for alias in strain.aliases: - self._lookup[alias] = strain + self._strain_dict_id[alias] = strain def remove(self, strain: Strain): """Remove a strain from the collection. @@ -77,13 +77,17 @@ def remove(self, strain: Strain): Args: strain(Strain): The strain to remove. """ - if strain.id not in self._lookup: - return - - self._strains.remove(strain) - del self._lookup[strain.id] - for alias in strain.aliases: - del self._lookup[alias] + if strain.id in self._strain_dict_id: + self._strains.remove(strain) + # remove from dict id + del self._strain_dict_id[strain.id] + for alias in strain.aliases: + del self._strain_dict_id[alias] + # remove from dict index + for i in range(len(self)): + if self._strain_dict_index[i] == strain: + del self._strain_dict_index[i] + break def filter(self, strain_set: set[Strain]): """ @@ -102,7 +106,7 @@ def lookup_index(self, index: int) -> Strain: Returns: Strain: Strain identified by the given index. """ - return self._lookup_indices[index] + return self._strain_dict_index[index] def lookup(self, strain_id: str) -> Strain: """Check whether the strain id is contained in the lookup table. If so, return the strain, otherwise return `default`. @@ -116,11 +120,11 @@ def lookup(self, strain_id: str) -> Strain: Returns: Strain: Strain retrieved during lookup or object passed as default. """ - if strain_id not in self._lookup: + if strain_id not in self._strain_dict_id: # logger.error('Strain lookup failed for "{}"'.format(strain_id)) raise KeyError(f'Strain lookup failed for "{strain_id}"') - return self._lookup[strain_id] + return self._strain_dict_id[strain_id] def add_from_file(self, file: str | os.PathLike): """Read strains and aliases from file and store in self.