From be969eab865b66dee2a315723dabdb9c7bc38c0d Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 16 Sep 2022 15:59:34 -0400 Subject: [PATCH 1/3] fix e3nn.grad --- e3nn_jax/_src/grad.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/e3nn_jax/_src/grad.py b/e3nn_jax/_src/grad.py index 40a33dba..c3368469 100644 --- a/e3nn_jax/_src/grad.py +++ b/e3nn_jax/_src/grad.py @@ -31,11 +31,12 @@ def grad( def _grad(*args, **kwargs) -> e3nn.IrrepsArray: args = list(args) - x = args[argnums] + x: e3nn.IrrepsArray = args[argnums] if not isinstance(x, e3nn.IrrepsArray): raise TypeError(f"arg{argnums} must be an e3nn.IrrepsArray.") irreps_in = x.irreps leading_shape_in = x.shape[:-1] + x = x.replace_none_with_zeros() args[argnums] = x.list def naked_fun(*args, **kwargs) -> List[jnp.ndarray]: From 8a12a27c54c24a8400e8a8e413ef207837233c04 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 16 Sep 2022 16:31:30 -0400 Subject: [PATCH 2/3] add ref --- e3nn_jax/_src/reduced_tensor_product.py | 1 + 1 file changed, 1 insertion(+) diff --git a/e3nn_jax/_src/reduced_tensor_product.py b/e3nn_jax/_src/reduced_tensor_product.py index d1498f41..6dc36521 100644 --- a/e3nn_jax/_src/reduced_tensor_product.py +++ b/e3nn_jax/_src/reduced_tensor_product.py @@ -1,3 +1,4 @@ +# Partially based on https://github.com/songk42/ReducedTensorProduct.jl import functools import itertools from typing import FrozenSet, List, Optional, Tuple, Union From d4d18b79013472540bf62cea2736d978d1ecfde5 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 20 Sep 2022 17:06:02 -0400 Subject: [PATCH 3/3] fix algorithm tests in sh --- e3nn_jax/_src/spherical_harmonics.py | 9 ++-- e3nn_jax/_src/spherical_harmonics_test.py | 59 +++++++++++++---------- 2 files changed, 38 insertions(+), 30 deletions(-) diff --git a/e3nn_jax/_src/spherical_harmonics.py b/e3nn_jax/_src/spherical_harmonics.py index c27a6e7c..442456c9 100644 --- a/e3nn_jax/_src/spherical_harmonics.py +++ b/e3nn_jax/_src/spherical_harmonics.py @@ -444,14 +444,15 @@ def _legendre_spherical_harmonics(lmax: int, x: jnp.ndarray, normalize: bool, no sh_y = legendre(lmax, x[..., 1], 1.0) # [(lmax + 1) * (lmax + 2) // 2, ...] sh_y = jnp.moveaxis(sh_y, 0, -1) # [..., (lmax + 1) * (lmax + 2) // 2] - f = np.array( + sh_y = sh_y * np.array( [ math.sqrt(fractions.Fraction((2 * l + 1) * math.factorial(l - m), 4 * math.factorial(l + m)) / math.pi) for l in range(lmax + 1) for m in range(l + 1) ] ) - sh_y = f * sh_y + + sh = jnp.zeros(x.shape[:-1] + ((lmax + 1) ** 2,)) def f(l, sh): def g(m, sh): @@ -468,5 +469,5 @@ def g(m, sh): return jax.lax.fori_loop(-l, l + 1, g, sh) - sh = jnp.zeros(x.shape[:-1] + ((lmax + 1) ** 2,)) - return jax.lax.fori_loop(0, lmax + 1, f, sh) + sh = jax.lax.fori_loop(0, lmax + 1, f, sh) + return sh diff --git a/e3nn_jax/_src/spherical_harmonics_test.py b/e3nn_jax/_src/spherical_harmonics_test.py index 9e3dbe8c..fd801784 100644 --- a/e3nn_jax/_src/spherical_harmonics_test.py +++ b/e3nn_jax/_src/spherical_harmonics_test.py @@ -7,37 +7,34 @@ @pytest.fixture( - scope="module", - autouse=True, params=[ "recursive,dense", "recursive,sparse,custom_jvp", "legendre,dense,custom_jvp", - "legendre,sparse,custom_jvp", ], ) def algorithm(request): - e3nn.config("spherical_harmonics_algorithm", tuple(request.param.split(","))) + return tuple(request.param.split(",")) @pytest.mark.parametrize("l", [0, 1, 2, 3, 4, 5, 6, 7]) -def test_equivariance(keys, l): +def test_equivariance(keys, algorithm, l): input = e3nn.normal("1o", keys[0], (10,)) abc = e3nn.rand_angles(keys[1], ()) - output1 = e3nn.spherical_harmonics(l, input.transform_by_angles(*abc), False) - output2 = e3nn.spherical_harmonics(l, input, False).transform_by_angles(*abc) + output1 = e3nn.spherical_harmonics(l, input.transform_by_angles(*abc), False, algorithm=algorithm) + output2 = e3nn.spherical_harmonics(l, input, False, algorithm=algorithm).transform_by_angles(*abc) np.testing.assert_allclose(output1.array, output2.array, atol=1e-2, rtol=1e-2) -def test_closure(keys): +def test_closure(keys, algorithm): r""" integral of Ylm * Yjn = delta_lj delta_mn integral of 1 over the unit sphere = 4 pi """ x = jax.random.normal(keys[0], (1_000_000, 3)) - Ys = [e3nn.sh(l, x, True, "integral") for l in range(0, 3 + 1)] + Ys = [e3nn.sh(l, x, True, "integral", algorithm=algorithm) for l in range(0, 3 + 1)] for l1, Y1 in enumerate(Ys): for l2, Y2 in enumerate(Ys): m = Y1[:, :, None] * Y2[:, None, :] @@ -50,57 +47,64 @@ def test_closure(keys): @pytest.mark.parametrize("l", range(13 + 1)) -def test_normalization_integral(keys, l): +def test_normalization_integral(keys, algorithm, l): irreps = e3nn.Irreps([l]) n = jnp.mean( - e3nn.spherical_harmonics(irreps, jax.random.normal(keys[l + 0], (3,)), normalize=True, normalization="integral").array + e3nn.spherical_harmonics( + irreps, jax.random.normal(keys[l + 0], (3,)), normalize=True, normalization="integral", algorithm=algorithm + ).array ** 2 ) assert abs((4 * jnp.pi) * n - 1) < 7e-7 * max((l / 4) ** 8, 1) @pytest.mark.parametrize("l", range(13 + 1)) -def test_normalization_norm(keys, l): +def test_normalization_norm(keys, algorithm, l): irreps = e3nn.Irreps([l]) n = jnp.sum( - e3nn.spherical_harmonics(irreps, jax.random.normal(keys[l + 1], (3,)), normalize=True, normalization="norm").array ** 2 + e3nn.spherical_harmonics( + irreps, jax.random.normal(keys[l + 1], (3,)), normalize=True, normalization="norm", algorithm=algorithm + ).array + ** 2 ) assert abs(n - 1) < 6e-7 * max((l / 4) ** 8, 1) @pytest.mark.parametrize("l", range(13 + 1)) -def test_normalization_component(keys, l): +def test_normalization_component(keys, algorithm, l): irreps = e3nn.Irreps([l]) n = jnp.mean( - e3nn.spherical_harmonics(irreps, jax.random.normal(keys[l + 2], (3,)), normalize=True, normalization="component").array + e3nn.spherical_harmonics( + irreps, jax.random.normal(keys[l + 2], (3,)), normalize=True, normalization="component", algorithm=algorithm + ).array ** 2 ) assert abs(n - 1) < 6e-7 * max((l / 4) ** 8, 1) @pytest.mark.parametrize("l", range(8 + 1)) -def test_parity(keys, l): +def test_parity(keys, algorithm, l): irreps = e3nn.Irreps([l]) x = jax.random.normal(next(keys), (3,)) - y1 = (-1) ** l * e3nn.spherical_harmonics(irreps, x, normalize=True, normalization="integral") - y2 = e3nn.spherical_harmonics(irreps, -x, normalize=True, normalization="integral") + y1 = (-1) ** l * e3nn.spherical_harmonics(irreps, x, normalize=True, normalization="integral", algorithm=algorithm) + y2 = e3nn.spherical_harmonics(irreps, -x, normalize=True, normalization="integral", algorithm=algorithm) np.testing.assert_allclose(y1.array, y2.array, atol=1e-6, rtol=1e-6) @pytest.mark.parametrize("l", range(7 + 1)) -def test_recurrence_relation(keys, l): +def test_recurrence_relation(keys, algorithm, l): x = jax.random.normal(next(keys), (3,)) - y1 = e3nn.spherical_harmonics(e3nn.Irreps([l + 1]), x, normalize=True, normalization="integral").array + y1 = e3nn.spherical_harmonics(e3nn.Irreps([l + 1]), x, normalize=True, normalization="integral", algorithm=algorithm).array y2 = jnp.einsum( "ijk,i,j->k", e3nn.clebsch_gordan(1, l, l + 1), x, - e3nn.spherical_harmonics(e3nn.Irreps([l]), x, normalize=True, normalization="integral").array, + e3nn.spherical_harmonics(e3nn.Irreps([l]), x, normalize=True, normalization="integral", algorithm=algorithm).array, ) y1 = y1 / jnp.linalg.norm(y1) @@ -110,9 +114,9 @@ def test_recurrence_relation(keys, l): @pytest.mark.parametrize("normalization", ["integral", "norm", "component"]) @pytest.mark.parametrize("irreps", ["3x1o+2e+2x4e", "2x0e", "10e"]) -def test_check_grads(keys, irreps, normalization): +def test_check_grads(keys, algorithm, irreps, normalization): check_grads( - lambda x: e3nn.spherical_harmonics(irreps, x, normalize=False, normalization=normalization).array, + lambda x: e3nn.spherical_harmonics(irreps, x, normalize=False, normalization=normalization, algorithm=algorithm).array, (jax.random.normal(keys[0], (10, 3)),), 1, modes=["fwd", "rev"], @@ -122,8 +126,11 @@ def test_check_grads(keys, irreps, normalization): @pytest.mark.parametrize("l", range(7 + 1)) -def test_normalize(keys, l): +def test_normalize(keys, algorithm, l): x = jax.random.normal(keys[0], (10, 3)) - y1 = e3nn.spherical_harmonics(e3nn.Irreps([l]), x, normalize=True).array * jnp.linalg.norm(x, axis=1, keepdims=True) ** l - y2 = e3nn.spherical_harmonics(e3nn.Irreps([l]), x, normalize=False).array + y1 = ( + e3nn.spherical_harmonics(e3nn.Irreps([l]), x, normalize=True, algorithm=algorithm).array + * jnp.linalg.norm(x, axis=1, keepdims=True) ** l + ) + y2 = e3nn.spherical_harmonics(e3nn.Irreps([l]), x, normalize=False, algorithm=algorithm).array np.testing.assert_allclose(y1, y2, atol=1e-6, rtol=1e-5)