Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Jun 24, 2023
1 parent 9d7f205 commit ab2a223
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions e3nn_jax/_src/tensor_products.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ def _prepare_inputs(input1, input2):
input2 = e3nn.as_irreps_array(input2)

dtype = get_pytree_dtype(input1, input2)
if dtype.kind == "i":
dtype = jnp.ones(()).dtype

input1 = input1.astype(dtype)
input2 = input2.astype(dtype)

Expand Down

0 comments on commit ab2a223

Please sign in to comment.