Skip to content

Commit

Permalink
Implement comparison operator for StrainCollection class (#111)
Browse files Browse the repository at this point in the history
* Implement comparison operator for `StrainCollection` class
Fixes #110

* Update src/nplinker/strains.py

Co-authored-by: Cunliang Geng <c.geng@esciencecenter.nl>

---------

Co-authored-by: Cunliang Geng <c.geng@esciencecenter.nl>
  • Loading branch information
hechth and CunliangGeng authored Feb 23, 2023
1 parent 0e450c4 commit a60e9f1
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/nplinker/strain_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,9 @@ def __str__(self):

return f'StrainCollection(n={len(self)}) [' + ','.join(
s.id for s in self._strains) + ']'

def __eq__(self, other):
result = self._strains == other._strains
result &= self._lookup == other._lookup
result &= self._lookup_indices == other._lookup_indices
return result
7 changes: 7 additions & 0 deletions src/nplinker/strains.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,10 @@ def __repr__(self) -> str:

def __str__(self) -> str:
return f'Strain({self.id}) [{len(self.aliases)} aliases]'

def __eq__(self, other):
return (
isinstance(other, Strain)
and self.id == other.id
and self.aliases == other.aliases
)
7 changes: 7 additions & 0 deletions tests/test_strain.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,10 @@ def test_has_alias(strain: Strain, alias: str, expected: bool):
def test_add_alias(strain: Strain):
strain.add_alias("test")
assert len(strain.aliases) == 2


def test_equal(strain: Strain):
other = Strain("peter")
other.add_alias("dieter")

assert strain == other
8 changes: 8 additions & 0 deletions tests/test_strain_collection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from nplinker.strain_collection import StrainCollection
from nplinker.strains import Strain
from tests import DATA_DIR


@pytest.fixture
Expand Down Expand Up @@ -68,3 +69,10 @@ def test_remove(collection: StrainCollection, strain: Strain):

# needs fixing, see #90
assert collection.lookup_index(0) == strain


def test_equal(collection_from_file: StrainCollection):
other = StrainCollection()
other.add_from_file(DATA_DIR / "strain_mappings.csv")

assert collection_from_file == other

0 comments on commit a60e9f1

Please sign in to comment.