From 9d7f205e253790ae382e15f445d8672f6ee403df Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Sat, 24 Jun 2023 02:05:41 -0400 Subject: [PATCH] fix --- e3nn_jax/_src/tensor_products.py | 31 ++++++++++++++++----------- e3nn_jax/_src/tensor_products_test.py | 25 +++++++++++++++++++++ 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/e3nn_jax/_src/tensor_products.py b/e3nn_jax/_src/tensor_products.py index a8883925..36c4483c 100644 --- a/e3nn_jax/_src/tensor_products.py +++ b/e3nn_jax/_src/tensor_products.py @@ -114,6 +114,7 @@ def tensor_product( f"irrep_normalization={irrep_normalization} not supported" ) + cg = cg.astype(x1.dtype) chunk = jnp.einsum("...ui , ...vj , ijk -> ...uvk", x1, x2, cg) chunk = jnp.reshape( chunk, chunk.shape[:-3] + (mul_1 * mul_2, ir_out.dim) @@ -183,20 +184,26 @@ def elementwise_tensor_product( continue irreps_out.append((mul, ir_out)) - cg = e3nn.clebsch_gordan(ir_1.l, ir_2.l, ir_out.l) - - if irrep_normalization == "component": - cg = cg * jnp.sqrt(ir_out.dim) - elif irrep_normalization == "norm": - cg = cg * jnp.sqrt(ir_1.dim * ir_2.dim) - elif irrep_normalization == "none": - pass + + if x1 is not None and x2 is not None: + cg = e3nn.clebsch_gordan(ir_1.l, ir_2.l, ir_out.l) + + if irrep_normalization == "component": + cg = cg * jnp.sqrt(ir_out.dim) + elif irrep_normalization == "norm": + cg = cg * jnp.sqrt(ir_1.dim * ir_2.dim) + elif irrep_normalization == "none": + pass + else: + raise ValueError( + f"irrep_normalization={irrep_normalization} not supported" + ) + + cg = cg.astype(x1.dtype) + chunk = jnp.einsum("...ui , ...uj , ijk -> ...uk", x1, x2, cg) else: - raise ValueError( - f"irrep_normalization={irrep_normalization} not supported" - ) + chunk = None - chunk = jnp.einsum("...ui , ...uj , ijk -> ...uk", x1, x2, cg) chunks.append(chunk) return e3nn.from_chunks(irreps_out, chunks, leading_shape, input1.dtype) diff --git a/e3nn_jax/_src/tensor_products_test.py b/e3nn_jax/_src/tensor_products_test.py index 3ff12d15..a6feb13b 100644 --- a/e3nn_jax/_src/tensor_products_test.py +++ b/e3nn_jax/_src/tensor_products_test.py @@ -1,4 +1,5 @@ import haiku as hk +import jax import jax.numpy as jnp import numpy as np @@ -21,6 +22,14 @@ def test_tensor_product_with_zeros(): assert x3.zero_flags == (True, True, True) +def test_elementwise_with_zeros(): + x1 = e3nn.from_chunks("1o", [None], (), jnp.float32) + x2 = e3nn.IrrepsArray("1o", jnp.array([0.0, 1.0, 0.0])) + x3 = e3nn.elementwise_tensor_product(x1, x2) + assert x3.irreps == "0e + 1e + 2e" + assert x3.zero_flags == (True, True, True) + + def test_tensor_product_irreps(): irreps = e3nn.tensor_product("1o", "1o", filter_ir_out=("1e",)) assert irreps == e3nn.Irreps("1e") @@ -106,3 +115,19 @@ def test_tensor_square_and_spherical_harmonics(keys): )["2e"] y2 = e3nn.spherical_harmonics("2e", x, normalize=True, normalization="component") np.testing.assert_allclose(y1.array, y2.array, atol=1e-5) + + +def test_tensor_product_dtype(): + jax.config.update("jax_enable_x64", True) + x1 = e3nn.IrrepsArray("1o", jnp.array([1.0, 0.0, 0.0])) + x2 = e3nn.IrrepsArray("1o", jnp.array([0.0, 1.0, 0.0])) + e3nn.utils.assert_output_dtype_matches_input_dtype(e3nn.tensor_product, x1, x2) + + +def test_elementwise_dtype(): + jax.config.update("jax_enable_x64", True) + x1 = e3nn.IrrepsArray("1o", jnp.array([1.0, 0.0, 0.0])) + x2 = e3nn.IrrepsArray("1o", jnp.array([0.0, 1.0, 0.0])) + e3nn.utils.assert_output_dtype_matches_input_dtype( + e3nn.elementwise_tensor_product, x1, x2 + )