From d699678f2801d438da5483221e77c1ab2c1ce799 Mon Sep 17 00:00:00 2001 From: "Haoyu (Daniel) YANG" Date: Tue, 22 Oct 2024 04:59:51 +0800 Subject: [PATCH] Enhance test for util function `is_np_dict_equal` (#4092) * add unit test for more data types * Add experimental is_np_dict_equal_try_except function * remove experimental implementation for now --- src/pymatgen/util/misc.py | 2 +- tests/util/test_misc.py | 61 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/src/pymatgen/util/misc.py b/src/pymatgen/util/misc.py index bba8d862d2c..0dab2eccb8b 100644 --- a/src/pymatgen/util/misc.py +++ b/src/pymatgen/util/misc.py @@ -6,7 +6,7 @@ def is_np_dict_equal(dict1, dict2, /) -> bool: - """Compare two dict whose value could be np arrays. + """Compare two dict whose value could be NumPy arrays. Args: dict1 (dict): The first dict. diff --git a/tests/util/test_misc.py b/tests/util/test_misc.py index 99c851e3968..4bcf881dcad 100644 --- a/tests/util/test_misc.py +++ b/tests/util/test_misc.py @@ -1,5 +1,7 @@ from __future__ import annotations +from dataclasses import dataclass + import numpy as np from pymatgen.util.misc import is_np_dict_equal @@ -41,3 +43,62 @@ def test_nested_arrays(self): dict3 = {"a": np.array([[1, 2], [3, 5]])} assert not is_np_dict_equal(dict1, dict3) + + def test_diff_dtype(self): + """Make sure it also works for other data types as value.""" + + @dataclass + class CustomClass: + name: str + value: int + + # Test with bool values + dict1 = {"a": True} + dict2 = {"a": True} + assert is_np_dict_equal(dict1, dict2) + + dict3 = {"a": False} + assert not is_np_dict_equal(dict1, dict3) + + # Test with string values + dict4 = {"a": "hello"} + dict5 = {"a": "hello"} + assert is_np_dict_equal(dict4, dict5) + + dict6 = {"a": "world"} + assert not is_np_dict_equal(dict4, dict6) + + # Test with a custom data class + dict7 = {"a": CustomClass(name="test", value=1)} + dict8 = {"a": CustomClass(name="test", value=1)} + assert is_np_dict_equal(dict7, dict8) + + dict9 = {"a": CustomClass(name="test", value=2)} + assert not is_np_dict_equal(dict7, dict9) + + # Test with None + dict10 = {"a": None} + dict11 = {"a": None} + assert is_np_dict_equal(dict10, dict11) + + dict12 = {"a": None} + dict13 = {"a": "non-none"} + assert not is_np_dict_equal(dict12, dict13) + + # Test with nested complex lists + dict14 = {"a": [[1, 2], ["hello", 3.0]]} + dict15 = {"a": [np.array([1, 2]), ["hello", 3.0]]} + assert is_np_dict_equal(dict14, dict15) + + dict16 = {"a": [[1, 2], ["world", 3.0]]} + assert not is_np_dict_equal(dict14, dict16) + + # Test with unhashable dicts + dict17 = {"a": {"key1": "val1", "key2": "val2"}} + dict18 = {"a": {"key1": "val1", "key2": "val2"}} + assert is_np_dict_equal(dict17, dict18) + + # Test with unhashable sets + dict19 = {"a": {1, 2, 3}} + dict20 = {"a": {1, 2, 3}} + assert is_np_dict_equal(dict19, dict20)