From 4c76b5845ebf9d7443b1f2966c109c355184da95 Mon Sep 17 00:00:00 2001 From: "Haoyu (Daniel) YANG" Date: Thu, 3 Oct 2024 03:41:22 +0800 Subject: [PATCH] Fix dict equality check with numpy array (#4086) * use np.testing to check dict equality * add unit tests * perhaps use a helper function instead? * add test for misc, thanks gpt * fix typo * add check if return type --- src/pymatgen/core/sites.py | 7 ++++--- src/pymatgen/util/misc.py | 21 +++++++++++++++++++ tests/core/test_sites.py | 14 +++++++++++++ tests/util/test_misc.py | 43 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 82 insertions(+), 3 deletions(-) create mode 100644 src/pymatgen/util/misc.py create mode 100644 tests/util/test_misc.py diff --git a/src/pymatgen/core/sites.py b/src/pymatgen/core/sites.py index e984c4296de..86054063f11 100644 --- a/src/pymatgen/core/sites.py +++ b/src/pymatgen/core/sites.py @@ -13,6 +13,7 @@ from pymatgen.core.lattice import Lattice from pymatgen.core.periodic_table import DummySpecies, Element, Species, get_el_sp from pymatgen.util.coord import pbc_diff +from pymatgen.util.misc import is_np_dict_equal if TYPE_CHECKING: from typing import Any @@ -90,7 +91,7 @@ def __getitem__(self, el: Element) -> float: def __eq__(self, other: object) -> bool: """Site is equal to another site if the species and occupancies are the - same, and the coordinates are the same to some tolerance. `numpy.allclose` + same, and the coordinates are the same to some tolerance. `np.allclose` is used to determine if coordinates are close. """ if not isinstance(other, type(self)): @@ -99,7 +100,7 @@ def __eq__(self, other: object) -> bool: return ( self.species == other.species and np.allclose(self.coords, other.coords, atol=type(self).position_atol) - and self.properties == other.properties + and is_np_dict_equal(self.properties, other.properties) ) def __hash__(self) -> int: @@ -364,7 +365,7 @@ def __eq__(self, other: object) -> bool: self.species == other.species and self.lattice == other.lattice and np.allclose(self.coords, other.coords, atol=Site.position_atol) - and self.properties == other.properties + and is_np_dict_equal(self.properties, other.properties) ) def __repr__(self) -> str: diff --git a/src/pymatgen/util/misc.py b/src/pymatgen/util/misc.py new file mode 100644 index 00000000000..bba8d862d2c --- /dev/null +++ b/src/pymatgen/util/misc.py @@ -0,0 +1,21 @@ +"""Other util functions.""" + +from __future__ import annotations + +import numpy as np + + +def is_np_dict_equal(dict1, dict2, /) -> bool: + """Compare two dict whose value could be np arrays. + + Args: + dict1 (dict): The first dict. + dict2 (dict): The second dict. + + Returns: + bool: Whether these two dicts are equal. + """ + if dict1.keys() != dict2.keys(): + return False + + return all(np.array_equal(dict1[key], dict2[key]) for key in dict1) diff --git a/tests/core/test_sites.py b/tests/core/test_sites.py index d81d47f373e..d530c8e8829 100644 --- a/tests/core/test_sites.py +++ b/tests/core/test_sites.py @@ -168,6 +168,20 @@ def test_equality_with_label(self): assert self.labeled_site.label != site.label assert self.labeled_site == site + def test_equality_prop_with_np_array(self): + """Some property (e.g. selective dynamics for POSCAR) could be numpy arrays, + use "==" for equality check might fail in these cases. + """ + site_0 = PeriodicSite( + "Fe", [0.25, 0.35, 0.45], self.lattice, properties={"selective_dynamics": np.array([True, True, False])} + ) + assert site_0 == site_0 + + site_1 = PeriodicSite( + "Fe", [0.25, 0.35, 0.45], self.lattice, properties={"selective_dynamics": np.array([True, False, False])} + ) + assert site_0 != site_1 + def test_as_from_dict(self): dct = self.site2.as_dict() site = PeriodicSite.from_dict(dct) diff --git a/tests/util/test_misc.py b/tests/util/test_misc.py new file mode 100644 index 00000000000..99c851e3968 --- /dev/null +++ b/tests/util/test_misc.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import numpy as np + +from pymatgen.util.misc import is_np_dict_equal + + +class TestIsNpDictEqual: + def test_different_keys(self): + """Test two dicts with different keys.""" + dict1 = {"a": np.array([1, 2, 3])} + dict2 = {"a": np.array([1, 2, 3]), "b": "hello"} + equal = is_np_dict_equal(dict1, dict2) + # make sure it's not a np.bool + assert isinstance(equal, bool) + assert not equal + + def test_both_list(self): + """Test two dicts where both have lists as values.""" + dict1 = {"a": [1, 2, 3]} + dict2 = {"a": [1, 2, 3]} + assert is_np_dict_equal(dict1, dict2) + + def test_both_np_array(self): + """Test two dicts where both have NumPy arrays as values.""" + dict1 = {"a": np.array([1, 2, 3])} + dict2 = {"a": np.array([1, 2, 3])} + assert is_np_dict_equal(dict1, dict2) + + def test_one_np_one_list(self): + """Test two dicts where one has a NumPy array and the other has a list.""" + dict1 = {"a": np.array([1, 2, 3])} + dict2 = {"a": [1, 2, 3]} + assert is_np_dict_equal(dict1, dict2) + + def test_nested_arrays(self): + """Test two dicts with deeper nested arrays.""" + dict1 = {"a": np.array([[1, 2], [3, 4]])} + dict2 = {"a": np.array([[1, 2], [3, 4]])} + assert is_np_dict_equal(dict1, dict2) + + dict3 = {"a": np.array([[1, 2], [3, 5]])} + assert not is_np_dict_equal(dict1, dict3)