diff --git a/e3nn_jax/_src/tensor_products.py b/e3nn_jax/_src/tensor_products.py index 36c4483c..29446a3d 100644 --- a/e3nn_jax/_src/tensor_products.py +++ b/e3nn_jax/_src/tensor_products.py @@ -14,6 +14,9 @@ def _prepare_inputs(input1, input2): input2 = e3nn.as_irreps_array(input2) dtype = get_pytree_dtype(input1, input2) + if dtype.kind == "i": + dtype = jnp.ones(()).dtype + input1 = input1.astype(dtype) input2 = input2.astype(dtype)