Skip to content

Commit

Permalink
Merge branch 'main' of github.com:e3nn/e3nn-jax into main
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Sep 21, 2022
2 parents c353ec1 + d4d18b7 commit 84bbe98
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 31 deletions.
3 changes: 2 additions & 1 deletion e3nn_jax/_src/grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ def grad(

def _grad(*args, **kwargs) -> e3nn.IrrepsArray:
args = list(args)
x = args[argnums]
x: e3nn.IrrepsArray = args[argnums]
if not isinstance(x, e3nn.IrrepsArray):
raise TypeError(f"arg{argnums} must be an e3nn.IrrepsArray.")
irreps_in = x.irreps
leading_shape_in = x.shape[:-1]
x = x.replace_none_with_zeros()
args[argnums] = x.list

def naked_fun(*args, **kwargs) -> List[jnp.ndarray]:
Expand Down
1 change: 1 addition & 0 deletions e3nn_jax/_src/reduced_tensor_product.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Partially based on https://github.com/songk42/ReducedTensorProduct.jl
import functools
import itertools
from typing import FrozenSet, List, Optional, Tuple, Union
Expand Down
9 changes: 5 additions & 4 deletions e3nn_jax/_src/spherical_harmonics.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,14 +444,15 @@ def _legendre_spherical_harmonics(lmax: int, x: jnp.ndarray, normalize: bool, no
sh_y = legendre(lmax, x[..., 1], 1.0) # [(lmax + 1) * (lmax + 2) // 2, ...]
sh_y = jnp.moveaxis(sh_y, 0, -1) # [..., (lmax + 1) * (lmax + 2) // 2]

f = np.array(
sh_y = sh_y * np.array(
[
math.sqrt(fractions.Fraction((2 * l + 1) * math.factorial(l - m), 4 * math.factorial(l + m)) / math.pi)
for l in range(lmax + 1)
for m in range(l + 1)
]
)
sh_y = f * sh_y

sh = jnp.zeros(x.shape[:-1] + ((lmax + 1) ** 2,))

def f(l, sh):
def g(m, sh):
Expand All @@ -468,5 +469,5 @@ def g(m, sh):

return jax.lax.fori_loop(-l, l + 1, g, sh)

sh = jnp.zeros(x.shape[:-1] + ((lmax + 1) ** 2,))
return jax.lax.fori_loop(0, lmax + 1, f, sh)
sh = jax.lax.fori_loop(0, lmax + 1, f, sh)
return sh
59 changes: 33 additions & 26 deletions e3nn_jax/_src/spherical_harmonics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,37 +7,34 @@


@pytest.fixture(
scope="module",
autouse=True,
params=[
"recursive,dense",
"recursive,sparse,custom_jvp",
"legendre,dense,custom_jvp",
"legendre,sparse,custom_jvp",
],
)
def algorithm(request):
e3nn.config("spherical_harmonics_algorithm", tuple(request.param.split(",")))
return tuple(request.param.split(","))


@pytest.mark.parametrize("l", [0, 1, 2, 3, 4, 5, 6, 7])
def test_equivariance(keys, l):
def test_equivariance(keys, algorithm, l):
input = e3nn.normal("1o", keys[0], (10,))

abc = e3nn.rand_angles(keys[1], ())
output1 = e3nn.spherical_harmonics(l, input.transform_by_angles(*abc), False)
output2 = e3nn.spherical_harmonics(l, input, False).transform_by_angles(*abc)
output1 = e3nn.spherical_harmonics(l, input.transform_by_angles(*abc), False, algorithm=algorithm)
output2 = e3nn.spherical_harmonics(l, input, False, algorithm=algorithm).transform_by_angles(*abc)

np.testing.assert_allclose(output1.array, output2.array, atol=1e-2, rtol=1e-2)


def test_closure(keys):
def test_closure(keys, algorithm):
r"""
integral of Ylm * Yjn = delta_lj delta_mn
integral of 1 over the unit sphere = 4 pi
"""
x = jax.random.normal(keys[0], (1_000_000, 3))
Ys = [e3nn.sh(l, x, True, "integral") for l in range(0, 3 + 1)]
Ys = [e3nn.sh(l, x, True, "integral", algorithm=algorithm) for l in range(0, 3 + 1)]
for l1, Y1 in enumerate(Ys):
for l2, Y2 in enumerate(Ys):
m = Y1[:, :, None] * Y2[:, None, :]
Expand All @@ -50,57 +47,64 @@ def test_closure(keys):


@pytest.mark.parametrize("l", range(13 + 1))
def test_normalization_integral(keys, l):
def test_normalization_integral(keys, algorithm, l):
irreps = e3nn.Irreps([l])

n = jnp.mean(
e3nn.spherical_harmonics(irreps, jax.random.normal(keys[l + 0], (3,)), normalize=True, normalization="integral").array
e3nn.spherical_harmonics(
irreps, jax.random.normal(keys[l + 0], (3,)), normalize=True, normalization="integral", algorithm=algorithm
).array
** 2
)
assert abs((4 * jnp.pi) * n - 1) < 7e-7 * max((l / 4) ** 8, 1)


@pytest.mark.parametrize("l", range(13 + 1))
def test_normalization_norm(keys, l):
def test_normalization_norm(keys, algorithm, l):
irreps = e3nn.Irreps([l])

n = jnp.sum(
e3nn.spherical_harmonics(irreps, jax.random.normal(keys[l + 1], (3,)), normalize=True, normalization="norm").array ** 2
e3nn.spherical_harmonics(
irreps, jax.random.normal(keys[l + 1], (3,)), normalize=True, normalization="norm", algorithm=algorithm
).array
** 2
)
assert abs(n - 1) < 6e-7 * max((l / 4) ** 8, 1)


@pytest.mark.parametrize("l", range(13 + 1))
def test_normalization_component(keys, l):
def test_normalization_component(keys, algorithm, l):
irreps = e3nn.Irreps([l])

n = jnp.mean(
e3nn.spherical_harmonics(irreps, jax.random.normal(keys[l + 2], (3,)), normalize=True, normalization="component").array
e3nn.spherical_harmonics(
irreps, jax.random.normal(keys[l + 2], (3,)), normalize=True, normalization="component", algorithm=algorithm
).array
** 2
)
assert abs(n - 1) < 6e-7 * max((l / 4) ** 8, 1)


@pytest.mark.parametrize("l", range(8 + 1))
def test_parity(keys, l):
def test_parity(keys, algorithm, l):
irreps = e3nn.Irreps([l])
x = jax.random.normal(next(keys), (3,))

y1 = (-1) ** l * e3nn.spherical_harmonics(irreps, x, normalize=True, normalization="integral")
y2 = e3nn.spherical_harmonics(irreps, -x, normalize=True, normalization="integral")
y1 = (-1) ** l * e3nn.spherical_harmonics(irreps, x, normalize=True, normalization="integral", algorithm=algorithm)
y2 = e3nn.spherical_harmonics(irreps, -x, normalize=True, normalization="integral", algorithm=algorithm)
np.testing.assert_allclose(y1.array, y2.array, atol=1e-6, rtol=1e-6)


@pytest.mark.parametrize("l", range(7 + 1))
def test_recurrence_relation(keys, l):
def test_recurrence_relation(keys, algorithm, l):
x = jax.random.normal(next(keys), (3,))

y1 = e3nn.spherical_harmonics(e3nn.Irreps([l + 1]), x, normalize=True, normalization="integral").array
y1 = e3nn.spherical_harmonics(e3nn.Irreps([l + 1]), x, normalize=True, normalization="integral", algorithm=algorithm).array
y2 = jnp.einsum(
"ijk,i,j->k",
e3nn.clebsch_gordan(1, l, l + 1),
x,
e3nn.spherical_harmonics(e3nn.Irreps([l]), x, normalize=True, normalization="integral").array,
e3nn.spherical_harmonics(e3nn.Irreps([l]), x, normalize=True, normalization="integral", algorithm=algorithm).array,
)

y1 = y1 / jnp.linalg.norm(y1)
Expand All @@ -110,9 +114,9 @@ def test_recurrence_relation(keys, l):

@pytest.mark.parametrize("normalization", ["integral", "norm", "component"])
@pytest.mark.parametrize("irreps", ["3x1o+2e+2x4e", "2x0e", "10e"])
def test_check_grads(keys, irreps, normalization):
def test_check_grads(keys, algorithm, irreps, normalization):
check_grads(
lambda x: e3nn.spherical_harmonics(irreps, x, normalize=False, normalization=normalization).array,
lambda x: e3nn.spherical_harmonics(irreps, x, normalize=False, normalization=normalization, algorithm=algorithm).array,
(jax.random.normal(keys[0], (10, 3)),),
1,
modes=["fwd", "rev"],
Expand All @@ -122,8 +126,11 @@ def test_check_grads(keys, irreps, normalization):


@pytest.mark.parametrize("l", range(7 + 1))
def test_normalize(keys, l):
def test_normalize(keys, algorithm, l):
x = jax.random.normal(keys[0], (10, 3))
y1 = e3nn.spherical_harmonics(e3nn.Irreps([l]), x, normalize=True).array * jnp.linalg.norm(x, axis=1, keepdims=True) ** l
y2 = e3nn.spherical_harmonics(e3nn.Irreps([l]), x, normalize=False).array
y1 = (
e3nn.spherical_harmonics(e3nn.Irreps([l]), x, normalize=True, algorithm=algorithm).array
* jnp.linalg.norm(x, axis=1, keepdims=True) ** l
)
y2 = e3nn.spherical_harmonics(e3nn.Irreps([l]), x, normalize=False, algorithm=algorithm).array
np.testing.assert_allclose(y1, y2, atol=1e-6, rtol=1e-5)

0 comments on commit 84bbe98

Please sign in to comment.