From 18d9d53c21350a060767597c19d9c78beac211a3 Mon Sep 17 00:00:00 2001 From: "Haoyu (Daniel)" Date: Thu, 3 Oct 2024 11:11:37 +0800 Subject: [PATCH] add unit test for more data types --- tests/util/test_misc.py | 61 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/util/test_misc.py b/tests/util/test_misc.py index 99c851e3968..4bcf881dcad 100644 --- a/tests/util/test_misc.py +++ b/tests/util/test_misc.py @@ -1,5 +1,7 @@ from __future__ import annotations +from dataclasses import dataclass + import numpy as np from pymatgen.util.misc import is_np_dict_equal @@ -41,3 +43,62 @@ def test_nested_arrays(self): dict3 = {"a": np.array([[1, 2], [3, 5]])} assert not is_np_dict_equal(dict1, dict3) + + def test_diff_dtype(self): + """Make sure it also works for other data types as value.""" + + @dataclass + class CustomClass: + name: str + value: int + + # Test with bool values + dict1 = {"a": True} + dict2 = {"a": True} + assert is_np_dict_equal(dict1, dict2) + + dict3 = {"a": False} + assert not is_np_dict_equal(dict1, dict3) + + # Test with string values + dict4 = {"a": "hello"} + dict5 = {"a": "hello"} + assert is_np_dict_equal(dict4, dict5) + + dict6 = {"a": "world"} + assert not is_np_dict_equal(dict4, dict6) + + # Test with a custom data class + dict7 = {"a": CustomClass(name="test", value=1)} + dict8 = {"a": CustomClass(name="test", value=1)} + assert is_np_dict_equal(dict7, dict8) + + dict9 = {"a": CustomClass(name="test", value=2)} + assert not is_np_dict_equal(dict7, dict9) + + # Test with None + dict10 = {"a": None} + dict11 = {"a": None} + assert is_np_dict_equal(dict10, dict11) + + dict12 = {"a": None} + dict13 = {"a": "non-none"} + assert not is_np_dict_equal(dict12, dict13) + + # Test with nested complex lists + dict14 = {"a": [[1, 2], ["hello", 3.0]]} + dict15 = {"a": [np.array([1, 2]), ["hello", 3.0]]} + assert is_np_dict_equal(dict14, dict15) + + dict16 = {"a": [[1, 2], ["world", 3.0]]} + assert not is_np_dict_equal(dict14, dict16) + + # Test with unhashable dicts + dict17 = {"a": {"key1": "val1", "key2": "val2"}} + dict18 = {"a": {"key1": "val1", "key2": "val2"}} + assert is_np_dict_equal(dict17, dict18) + + # Test with unhashable sets + dict19 = {"a": {1, 2, 3}} + dict20 = {"a": {1, 2, 3}} + assert is_np_dict_equal(dict19, dict20)