-
Notifications
You must be signed in to change notification settings - Fork 875
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix dict equality check with numpy array (#4086)
* use np.testing to check dict equality * add unit tests * perhaps use a helper function instead? * add test for misc, thanks gpt * fix typo * add check if return type
- Loading branch information
1 parent
47b1b42
commit 4c76b58
Showing
4 changed files
with
82 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
"""Other util functions.""" | ||
|
||
from __future__ import annotations | ||
|
||
import numpy as np | ||
|
||
|
||
def is_np_dict_equal(dict1, dict2, /) -> bool: | ||
"""Compare two dict whose value could be np arrays. | ||
Args: | ||
dict1 (dict): The first dict. | ||
dict2 (dict): The second dict. | ||
Returns: | ||
bool: Whether these two dicts are equal. | ||
""" | ||
if dict1.keys() != dict2.keys(): | ||
return False | ||
|
||
return all(np.array_equal(dict1[key], dict2[key]) for key in dict1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from __future__ import annotations | ||
|
||
import numpy as np | ||
|
||
from pymatgen.util.misc import is_np_dict_equal | ||
|
||
|
||
class TestIsNpDictEqual: | ||
def test_different_keys(self): | ||
"""Test two dicts with different keys.""" | ||
dict1 = {"a": np.array([1, 2, 3])} | ||
dict2 = {"a": np.array([1, 2, 3]), "b": "hello"} | ||
equal = is_np_dict_equal(dict1, dict2) | ||
# make sure it's not a np.bool | ||
assert isinstance(equal, bool) | ||
assert not equal | ||
|
||
def test_both_list(self): | ||
"""Test two dicts where both have lists as values.""" | ||
dict1 = {"a": [1, 2, 3]} | ||
dict2 = {"a": [1, 2, 3]} | ||
assert is_np_dict_equal(dict1, dict2) | ||
|
||
def test_both_np_array(self): | ||
"""Test two dicts where both have NumPy arrays as values.""" | ||
dict1 = {"a": np.array([1, 2, 3])} | ||
dict2 = {"a": np.array([1, 2, 3])} | ||
assert is_np_dict_equal(dict1, dict2) | ||
|
||
def test_one_np_one_list(self): | ||
"""Test two dicts where one has a NumPy array and the other has a list.""" | ||
dict1 = {"a": np.array([1, 2, 3])} | ||
dict2 = {"a": [1, 2, 3]} | ||
assert is_np_dict_equal(dict1, dict2) | ||
|
||
def test_nested_arrays(self): | ||
"""Test two dicts with deeper nested arrays.""" | ||
dict1 = {"a": np.array([[1, 2], [3, 4]])} | ||
dict2 = {"a": np.array([[1, 2], [3, 4]])} | ||
assert is_np_dict_equal(dict1, dict2) | ||
|
||
dict3 = {"a": np.array([[1, 2], [3, 5]])} | ||
assert not is_np_dict_equal(dict1, dict3) |