Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Jun 24, 2023
1 parent 4525492 commit 9d7f205
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 12 deletions.
31 changes: 19 additions & 12 deletions e3nn_jax/_src/tensor_products.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def tensor_product(
f"irrep_normalization={irrep_normalization} not supported"
)

cg = cg.astype(x1.dtype)
chunk = jnp.einsum("...ui , ...vj , ijk -> ...uvk", x1, x2, cg)
chunk = jnp.reshape(
chunk, chunk.shape[:-3] + (mul_1 * mul_2, ir_out.dim)
Expand Down Expand Up @@ -183,20 +184,26 @@ def elementwise_tensor_product(
continue

irreps_out.append((mul, ir_out))
cg = e3nn.clebsch_gordan(ir_1.l, ir_2.l, ir_out.l)

if irrep_normalization == "component":
cg = cg * jnp.sqrt(ir_out.dim)
elif irrep_normalization == "norm":
cg = cg * jnp.sqrt(ir_1.dim * ir_2.dim)
elif irrep_normalization == "none":
pass

if x1 is not None and x2 is not None:
cg = e3nn.clebsch_gordan(ir_1.l, ir_2.l, ir_out.l)

if irrep_normalization == "component":
cg = cg * jnp.sqrt(ir_out.dim)
elif irrep_normalization == "norm":
cg = cg * jnp.sqrt(ir_1.dim * ir_2.dim)
elif irrep_normalization == "none":
pass
else:
raise ValueError(
f"irrep_normalization={irrep_normalization} not supported"
)

cg = cg.astype(x1.dtype)
chunk = jnp.einsum("...ui , ...uj , ijk -> ...uk", x1, x2, cg)
else:
raise ValueError(
f"irrep_normalization={irrep_normalization} not supported"
)
chunk = None

chunk = jnp.einsum("...ui , ...uj , ijk -> ...uk", x1, x2, cg)
chunks.append(chunk)

return e3nn.from_chunks(irreps_out, chunks, leading_shape, input1.dtype)
Expand Down
25 changes: 25 additions & 0 deletions e3nn_jax/_src/tensor_products_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np

Expand All @@ -21,6 +22,14 @@ def test_tensor_product_with_zeros():
assert x3.zero_flags == (True, True, True)


def test_elementwise_with_zeros():
x1 = e3nn.from_chunks("1o", [None], (), jnp.float32)
x2 = e3nn.IrrepsArray("1o", jnp.array([0.0, 1.0, 0.0]))
x3 = e3nn.elementwise_tensor_product(x1, x2)
assert x3.irreps == "0e + 1e + 2e"
assert x3.zero_flags == (True, True, True)


def test_tensor_product_irreps():
irreps = e3nn.tensor_product("1o", "1o", filter_ir_out=("1e",))
assert irreps == e3nn.Irreps("1e")
Expand Down Expand Up @@ -106,3 +115,19 @@ def test_tensor_square_and_spherical_harmonics(keys):
)["2e"]
y2 = e3nn.spherical_harmonics("2e", x, normalize=True, normalization="component")
np.testing.assert_allclose(y1.array, y2.array, atol=1e-5)


def test_tensor_product_dtype():
jax.config.update("jax_enable_x64", True)
x1 = e3nn.IrrepsArray("1o", jnp.array([1.0, 0.0, 0.0]))
x2 = e3nn.IrrepsArray("1o", jnp.array([0.0, 1.0, 0.0]))
e3nn.utils.assert_output_dtype_matches_input_dtype(e3nn.tensor_product, x1, x2)


def test_elementwise_dtype():
jax.config.update("jax_enable_x64", True)
x1 = e3nn.IrrepsArray("1o", jnp.array([1.0, 0.0, 0.0]))
x2 = e3nn.IrrepsArray("1o", jnp.array([0.0, 1.0, 0.0]))
e3nn.utils.assert_output_dtype_matches_input_dtype(
e3nn.elementwise_tensor_product, x1, x2
)

0 comments on commit 9d7f205

Please sign in to comment.