diff --git a/src/pymatgen/util/misc.py b/src/pymatgen/util/misc.py index bba8d862d2c..5ea2766afa4 100644 --- a/src/pymatgen/util/misc.py +++ b/src/pymatgen/util/misc.py @@ -6,7 +6,7 @@ def is_np_dict_equal(dict1, dict2, /) -> bool: - """Compare two dict whose value could be np arrays. + """Compare two dict whose value could be NumPy arrays. Args: dict1 (dict): The first dict. @@ -19,3 +19,25 @@ def is_np_dict_equal(dict1, dict2, /) -> bool: return False return all(np.array_equal(dict1[key], dict2[key]) for key in dict1) + + +def is_np_dict_equal_try_except(dict1, dict2, /) -> bool: + """Another implementation with try-except. + + TODO: need speed test. + """ + if dict1.keys() != dict2.keys(): + return False + + for key in dict1: + value1 = dict1[key] + value2 = dict2[key] + + try: + if value1 != value2: + return False + except ValueError: + if not np.array_equal(value1, value2): + return False + + return True