Skip to content

Commit b831e73

Browse files
committed
perhaps use a helper function instead?
1 parent fc62b29 commit b831e73

File tree

2 files changed

+26
-16
lines changed

2 files changed

+26
-16
lines changed

src/pymatgen/core/sites.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pymatgen.core.lattice import Lattice
1414
from pymatgen.core.periodic_table import DummySpecies, Element, Species, get_el_sp
1515
from pymatgen.util.coord import pbc_diff
16+
from pymatgen.util.misc import is_np_dict_equal
1617

1718
if TYPE_CHECKING:
1819
from typing import Any
@@ -96,18 +97,10 @@ def __eq__(self, other: object) -> bool:
9697
if not isinstance(other, type(self)):
9798
return NotImplemented
9899

99-
# Some properties could be np.array, and in these cases
100-
# using "==" for dict equality check would fail
101-
try:
102-
np.testing.assert_equal(self.properties, other.properties)
103-
prop_equal = True
104-
except AssertionError:
105-
prop_equal = False
106-
107100
return (
108101
self.species == other.species
109102
and np.allclose(self.coords, other.coords, atol=type(self).position_atol)
110-
and prop_equal
103+
and is_np_dict_equal(self.properties, other.properties)
111104
)
112105

113106
def __hash__(self) -> int:
@@ -368,17 +361,11 @@ def __eq__(self, other: object) -> bool:
368361
if not isinstance(other, type(self)):
369362
return NotImplemented
370363

371-
try:
372-
np.testing.assert_equal(self.properties, other.properties)
373-
prop_equal = True
374-
except AssertionError:
375-
prop_equal = False
376-
377364
return (
378365
self.species == other.species
379366
and self.lattice == other.lattice
380367
and np.allclose(self.coords, other.coords, atol=Site.position_atol)
381-
and prop_equal
368+
and is_np_dict_equal(self.properties, other.properties)
382369
)
383370

384371
def __repr__(self) -> str:

src/pymatgen/util/misc.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""The util package implements various utilities that are commonly used by various
2+
packages.
3+
"""
4+
5+
from __future__ import annotations
6+
7+
import numpy as np
8+
9+
10+
def is_np_dict_equal(dict1, dict2, /) -> bool:
11+
"""Compare two dict whose value could be np arrays.
12+
13+
Args:
14+
dict1 (dict): The first dict.
15+
dict2 (dict): The second dict.
16+
17+
Returns:
18+
bool: Whether these two dicts are equal.
19+
"""
20+
if dict1.keys() != dict2.keys():
21+
return False
22+
23+
return all(np.array_equal(dict1[key], dict2[key]) for key in dict1)

0 commit comments

Comments
 (0)