Skip to content

Fix dict equality check with numpy array #4086

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 2, 2024

Conversation

DanielYang59
Copy link
Contributor

@DanielYang59 DanielYang59 commented Sep 30, 2024

Summary

Efficiency

Turns out np.testing.assert_equal is much slower than np.array_equal for each value. See the following benchmark:

Script (by GPT)
import numpy as np
import time

# 1. Comparison using `np.testing.assert_equal()`
def dicts_equal_np_test_assert(d1, d2):
    """Compare two dictionaries using np.testing.assert_equal."""
    try:
        np.testing.assert_equal(d1, d2)
        return True
    except AssertionError:
        return False

# 2. Comparison by using array_equal
def dicts_equal_array_equal(d1, d2):
    """Compare two dictionaries by using np.array_equal."""
    if d1.keys() != d2.keys():
        return False

    for key in d1:
        if not np.array_equal(d1[key], d2[key]):
            return False

    return True

# Helper function to average the time over 10 runs
def average_time(func, d1, d2, num_runs=10):
    total_time = 0
    for _ in range(num_runs):
        start = time.time()
        func(d1, d2)
        total_time += time.time() - start
    return total_time / num_runs

# 4. Running comparisons across array sizes
def run_comparisons(array_size):
    """Generate test dictionaries and run comparison methods."""
    # Create two identical dictionaries with numpy arrays
    arr1 = np.random.random(array_size)
    arr2 = np.copy(arr1)  # Create an identical array for comparison
    d1 = {"array": arr1}
    d2 = {"array": arr2}

    print(f"\nComparing dictionaries with array size: {array_size}")

    # 4.1. Using np.testing.assert_equal (averaged over 10 runs)
    avg_time_np_assert = average_time(dicts_equal_np_test_assert, d1, d2)
    print(f"np.testing.assert_equal: Average Time = {avg_time_np_assert * 1E6:.4f} microseconds")

    # 4.2. Using np.array_equal (averaged over 10 runs)
    avg_time_list_cast = average_time(dicts_equal_array_equal, d1, d2)
    print(f"np.array_equal: Average Time = {avg_time_list_cast * 1E6:.4f} microseconds")

if __name__ == "__main__":
    # Test with different array sizes
    for size in [1, 10, 100, 1000, 10_000]:
        run_comparisons(size)
Comparing dictionaries with array size: 1
np.testing.assert_equal: Average Time = 2419.4241 microseconds
np.array_equal: Average Time = 2.5749 microseconds

Comparing dictionaries with array size: 10
np.testing.assert_equal: Average Time = 19.1927 microseconds
np.array_equal: Average Time = 1.4067 microseconds

Comparing dictionaries with array size: 100
np.testing.assert_equal: Average Time = 19.1450 microseconds
np.array_equal: Average Time = 1.2159 microseconds

Comparing dictionaries with array size: 1000
np.testing.assert_equal: Average Time = 21.8868 microseconds
np.array_equal: Average Time = 1.5020 microseconds

Comparing dictionaries with array size: 10000
np.testing.assert_equal: Average Time = 43.8929 microseconds
np.array_equal: Average Time = 4.3869 microseconds

@DanielYang59 DanielYang59 marked this pull request as ready for review September 30, 2024 08:38
@DanielYang59 DanielYang59 marked this pull request as draft September 30, 2024 08:40
@DanielYang59 DanielYang59 marked this pull request as ready for review September 30, 2024 09:33
@shyuep shyuep merged commit 4c76b58 into materialsproject:master Oct 2, 2024
43 checks passed
@shyuep
Copy link
Member

shyuep commented Oct 2, 2024

Thanks.

@DanielYang59 DanielYang59 deleted the fix-4085-dict-compar branch October 3, 2024 01:52
@DanielYang59
Copy link
Contributor Author

DanielYang59 commented Oct 3, 2024

I realized I didn't really test other data types other than string, list and np.array (all are sequences of some sort). Would np.array_equal fail for any dtype? From the source code, looks like it would return False directly for any dtype that cannot be converted to an array: https://github.com/numpy/numpy/blob/2f7fe64b8b6d7591dd208942f1cc74473d5db4cb/numpy/_core/numeric.py#L2550-L2553

Or do we want to play safe, loop through each value, try to compare with ==, except any exception, use np.array_equal as a fallback option?

  • need speed benchmark

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

pymatgen.analysis.structure_matcher.StructureMatcher.fit() does not work properly on structures with selective_dynamics property
2 participants