Skip to content

Commit

Permalink
Fix dict equality check with numpy array (#4086)
Browse files Browse the repository at this point in the history
* use np.testing to check dict equality

* add unit tests

* perhaps use a helper function instead?

* add test for misc, thanks gpt

* fix typo

* add check if return type
  • Loading branch information
DanielYang59 authored Oct 2, 2024
1 parent 47b1b42 commit 4c76b58
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/pymatgen/core/sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pymatgen.core.lattice import Lattice
from pymatgen.core.periodic_table import DummySpecies, Element, Species, get_el_sp
from pymatgen.util.coord import pbc_diff
from pymatgen.util.misc import is_np_dict_equal

if TYPE_CHECKING:
from typing import Any
Expand Down Expand Up @@ -90,7 +91,7 @@ def __getitem__(self, el: Element) -> float:

def __eq__(self, other: object) -> bool:
"""Site is equal to another site if the species and occupancies are the
same, and the coordinates are the same to some tolerance. `numpy.allclose`
same, and the coordinates are the same to some tolerance. `np.allclose`
is used to determine if coordinates are close.
"""
if not isinstance(other, type(self)):
Expand All @@ -99,7 +100,7 @@ def __eq__(self, other: object) -> bool:
return (
self.species == other.species
and np.allclose(self.coords, other.coords, atol=type(self).position_atol)
and self.properties == other.properties
and is_np_dict_equal(self.properties, other.properties)
)

def __hash__(self) -> int:
Expand Down Expand Up @@ -364,7 +365,7 @@ def __eq__(self, other: object) -> bool:
self.species == other.species
and self.lattice == other.lattice
and np.allclose(self.coords, other.coords, atol=Site.position_atol)
and self.properties == other.properties
and is_np_dict_equal(self.properties, other.properties)
)

def __repr__(self) -> str:
Expand Down
21 changes: 21 additions & 0 deletions src/pymatgen/util/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Other util functions."""

from __future__ import annotations

import numpy as np


def is_np_dict_equal(dict1, dict2, /) -> bool:
"""Compare two dict whose value could be np arrays.
Args:
dict1 (dict): The first dict.
dict2 (dict): The second dict.
Returns:
bool: Whether these two dicts are equal.
"""
if dict1.keys() != dict2.keys():
return False

return all(np.array_equal(dict1[key], dict2[key]) for key in dict1)
14 changes: 14 additions & 0 deletions tests/core/test_sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,20 @@ def test_equality_with_label(self):
assert self.labeled_site.label != site.label
assert self.labeled_site == site

def test_equality_prop_with_np_array(self):
"""Some property (e.g. selective dynamics for POSCAR) could be numpy arrays,
use "==" for equality check might fail in these cases.
"""
site_0 = PeriodicSite(
"Fe", [0.25, 0.35, 0.45], self.lattice, properties={"selective_dynamics": np.array([True, True, False])}
)
assert site_0 == site_0

site_1 = PeriodicSite(
"Fe", [0.25, 0.35, 0.45], self.lattice, properties={"selective_dynamics": np.array([True, False, False])}
)
assert site_0 != site_1

def test_as_from_dict(self):
dct = self.site2.as_dict()
site = PeriodicSite.from_dict(dct)
Expand Down
43 changes: 43 additions & 0 deletions tests/util/test_misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

import numpy as np

from pymatgen.util.misc import is_np_dict_equal


class TestIsNpDictEqual:
def test_different_keys(self):
"""Test two dicts with different keys."""
dict1 = {"a": np.array([1, 2, 3])}
dict2 = {"a": np.array([1, 2, 3]), "b": "hello"}
equal = is_np_dict_equal(dict1, dict2)
# make sure it's not a np.bool
assert isinstance(equal, bool)
assert not equal

def test_both_list(self):
"""Test two dicts where both have lists as values."""
dict1 = {"a": [1, 2, 3]}
dict2 = {"a": [1, 2, 3]}
assert is_np_dict_equal(dict1, dict2)

def test_both_np_array(self):
"""Test two dicts where both have NumPy arrays as values."""
dict1 = {"a": np.array([1, 2, 3])}
dict2 = {"a": np.array([1, 2, 3])}
assert is_np_dict_equal(dict1, dict2)

def test_one_np_one_list(self):
"""Test two dicts where one has a NumPy array and the other has a list."""
dict1 = {"a": np.array([1, 2, 3])}
dict2 = {"a": [1, 2, 3]}
assert is_np_dict_equal(dict1, dict2)

def test_nested_arrays(self):
"""Test two dicts with deeper nested arrays."""
dict1 = {"a": np.array([[1, 2], [3, 4]])}
dict2 = {"a": np.array([[1, 2], [3, 4]])}
assert is_np_dict_equal(dict1, dict2)

dict3 = {"a": np.array([[1, 2], [3, 5]])}
assert not is_np_dict_equal(dict1, dict3)

0 comments on commit 4c76b58

Please sign in to comment.