Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EquiformerV2 fails rotational equivariance test #823

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

curtischong
Copy link
Contributor

@curtischong curtischong commented Aug 25, 2024

  • I did the test on test_equiformer_v2_deprecated.py rather than on test_equiformer_v2.py because it more closely followed the format that I saw in gemnet. Regardless, the test should still verify whether or not the model is equivariant.

  • I ran the equivariance tests on test_gemnet.py and test_gemnet_oc.py just to see if that equivariance test at least passes on those tests - which it does. Unfortunately, I couldn't get the entire test suite working on all of the tests due to invalid snapshots errors.

  • I think this equivariance test should explain the results of the model on these benchmarks: https://huggingface.co/spaces/atomind/mlip-arena

Here is the terminal output of pytest test_equiformer_v2_deprecated.py:

============================================================== FAILURES ===============================================================
______________________________________________ TestEquiformerV2.test_rotation_invariance ______________________________________________

self = <test_equiformer_v2_deprecated.TestEquiformerV2 object at 0x1432a5910>

    def test_rotation_invariance(self) -> None:
        random.seed(1)
        data = self.data

        # Sampling a random rotation within [-180, 180] for all axes.
        transform = RandomRotate([-180, 180], [0, 1, 2])
        data_rotated, rot, inv_rot = transform(data.clone())
        assert not np.array_equal(data.pos, data_rotated.pos)

        # Pass it through the model.
        batch = data_list_collater([data, data_rotated])
        out = self.model(batch)

        # Compare predicted energies and forces (after inv-rotation).
        energies = out["energy"].detach()
>       np.testing.assert_almost_equal(energies[0], energies[1], decimal=5)

test_equiformer_v2_deprecated.py:140:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

args = (tensor(-0.0383), tensor(-0.0354)), kwds = {'decimal': 5}

    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError:
E           Arrays are not almost equal to 5 decimals
E            ACTUAL: tensor(-0.0383)
E            DESIRED: tensor(-0.0354)

/opt/homebrew/Cellar/python@3.11/3.11.9_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/contextlib.py:81: AssertionError

@curtischong curtischong changed the title Test equiformer equivariance [DRAFT] Test equiformer equivariance Aug 25, 2024
@curtischong curtischong changed the title [DRAFT] Test equiformer equivariance EquiformerV2 fails rotational equivariance test Aug 25, 2024
@lbluque
Copy link
Collaborator

lbluque commented Aug 26, 2024

Hi @curtischong, thanks for flagging and adding a test for this!

Can you make sure to set the model to evaluation mode before calling forward model.eval()?

@curtischong
Copy link
Contributor Author

curtischong commented Aug 27, 2024

wow! I can't believe that the other tests don't explicitly put the models in eval mode using eval()!

what exactly does eval() do? Does it set it to 64bit precision (sorry don't have time to read the code rn)?

I made the changes and this was the output. Seems like the energy predictions are good enough but not the forces!

The forces assert statement passes if I make decimal = 2 in:

np.testing.assert_array_almost_equal(
      forces[: forces.shape[0] // 2],
      torch.matmul(forces[forces.shape[0] // 2 :], inv_rot),
      decimal=4,
)
============================================================== FAILURES ===============================================================
______________________________________________ TestEquiformerV2.test_rotation_invariance ______________________________________________

self = <test_equiformer_v2_deprecated.TestEquiformerV2 object at 0x1681bc7d0>

    def test_rotation_invariance(self) -> None:
        random.seed(1)
        data = self.data

        # Sampling a random rotation within [-180, 180] for all axes.
        transform = RandomRotate([-180, 180], [0, 1, 2])
        data_rotated, rot, inv_rot = transform(data.clone())
        assert not np.array_equal(data.pos, data_rotated.pos)

        # Pass it through the model.
        batch = data_list_collater([data, data_rotated])
        out = self.model(batch)

        # Compare predicted energies and forces (after inv-rotation).
        energies = out["energy"].detach()
        np.testing.assert_almost_equal(energies[0], energies[1], decimal=5)

        forces = out["forces"].detach()
        logging.info(forces)
>       np.testing.assert_array_almost_equal(
            forces[: forces.shape[0] // 2],
            torch.matmul(forces[forces.shape[0] // 2 :], inv_rot),
            decimal=4,
        )

test_equiformer_v2_deprecated.py:146:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/opt/homebrew/Cellar/python@3.11/3.11.9_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/contextlib.py:81: in inner
    return func(*args, **kwds)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

args = (<function assert_array_almost_equal.<locals>.compare at 0x16e038220>, tensor([[-1.4432e-01,  4.8881e-02,  3.0350e-01]...e-02,  7.1743e-03],
        [ 1.7721e-02,  7.4172e-02, -6.9789e-02],
        [ 1.3236e-02, -1.7110e-02, -1.8661e-02]]))
kwds = {'err_msg': '', 'header': 'Arrays are not almost equal to 4 decimals', 'precision': 4, 'verbose': True}

    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError:
E           Arrays are not almost equal to 4 decimals
E
E           Mismatched elements: 24 / 102 (23.5%)
E           Max absolute difference: 0.00195165
E           Max relative difference: 0.5172143
E            x: array([[-1.4432e-01,  4.8881e-02,  3.0350e-01],
E                  [ 2.0453e-03, -4.6955e-02, -2.3991e-01],
E                  [-1.3254e-02, -4.2705e-02, -2.5257e-02],...
E            y: array([[-1.4433e-01,  4.8874e-02,  3.0346e-01],
E                  [ 2.0608e-03, -4.6957e-02, -2.3986e-01],
E                  [-1.3226e-02, -4.2615e-02, -2.5217e-02],...

/opt/homebrew/Cellar/python@3.11/3.11.9_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/contextlib.py:81: AssertionError

@curtischong
Copy link
Contributor Author

Note. when I changed np.testing.assert_almost_equal(energies[0], energies[1], decimal=5) to have up to 6 decimal places of precision, the energy assert also fails. I find this strange since https://docs.e3nn.org/en/stable/guide/equivar_testing.html shows errors within 7-8 decimals of precision. Since we're only using one layer this error does seem a bit high. Although, the model is routing data through more functions, so the error could build.

@kyonofx
Copy link
Collaborator

kyonofx commented Sep 10, 2024

Hi Curtis, if you set enforce_max_neighbors_strictly=False (https://github.com/FAIR-Chem/fairchem/blob/main/src/fairchem/core/models/equiformer_v2/equiformer_v2_deprecated.py#L151), you should get a lower equivariance error. This issue is previously discussed in:

#428
#467

@curtischong
Copy link
Contributor Author

I added the flag and it does increase it! However, I'm still concerned since for forces, we can only get 4 decimal points of precision. (setting decimal=5 fails the test). I remember attending a lecture by Albert Musaelian and he mentioned that we should always use 64 bits of precision, especially for MD (since errors compound). However, 4 decimal points isn't even 32 bits of precision! Do you know what else is causing this error?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants