Skip to content

Commit

Permalink
remove jax_enable_x64 from equivariance_test
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Dec 14, 2022
1 parent 7a93375 commit e62da34
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 34 deletions.
3 changes: 2 additions & 1 deletion e3nn_jax/_src/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ def normalize_function(phi):
k = jax.random.PRNGKey(0)
x = jax.random.normal(k, (1_000_000,), dtype=jnp.float64)
c = jnp.mean(phi(x) ** 2) ** 0.5
c = c.item()

if jnp.allclose(c, 1.0):
return phi
else:

def rho(x):
return phi(x) / c.astype(x.dtype)
return phi(x) / c

return rho

Expand Down
16 changes: 4 additions & 12 deletions e3nn_jax/_src/core_tensor_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import jax.numpy as jnp

from e3nn_jax import Instruction, Irreps, IrrepsArray, clebsch_gordan, config
from e3nn_jax._src.util.prod import prod

from e3nn_jax._src.einsum import einsum as opt_einsum
from e3nn_jax._src.util.dtype import get_pytree_dtype
from e3nn_jax._src.util.prod import prod


class FunctionalTensorProduct:
Expand Down Expand Up @@ -293,14 +293,6 @@ def update(instruction: Instruction) -> float:
return [update(instruction) for instruction in instructions]


def _get_dtype(*args):
leaves = jax.tree_util.tree_leaves(args)
if len(leaves) == 0:
return jnp.float32

return jax.eval_shape(lambda xs: sum(jnp.sum(x) for x in xs), leaves).dtype


@partial(jax.jit, static_argnums=(0,), static_argnames=("custom_einsum_jvp", "fused"))
@partial(jax.profiler.annotate_function, name="TensorProduct.left_right")
def _left_right(
Expand All @@ -312,7 +304,7 @@ def _left_right(
custom_einsum_jvp: bool = False,
fused: bool = False,
):
dtype = _get_dtype(weights, input1, input2)
dtype = get_pytree_dtype(weights, input1, input2)

if self.irreps_in1.dim == 0 or self.irreps_in2.dim == 0 or self.irreps_out.dim == 0:
return IrrepsArray.zeros(self.irreps_out, ())
Expand Down Expand Up @@ -570,7 +562,7 @@ def _right(
*,
custom_einsum_jvp: bool = False,
) -> jnp.ndarray:
dtype = _get_dtype(weights, input2)
dtype = get_pytree_dtype(weights, input2)

# = Short-circut for zero dimensional =
if self.irreps_in1.dim == 0 or self.irreps_in2.dim == 0 or self.irreps_out.dim == 0:
Expand Down
4 changes: 2 additions & 2 deletions e3nn_jax/_src/gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def _gate(input: IrrepsArray, even_act, odd_act, even_gate_act, odd_gate_act) ->
act = {1: even_act, -1: odd_act}
gate_act = {1: even_gate_act, -1: odd_gate_act}

scalars = input.filtered(["0e", "0o"])
vectors = input.filtered(lambda mul_ir: mul_ir.ir.l > 0)
scalars = input.filtered(keep=["0e", "0o"])
vectors = input.filtered(drop=["0e", "0o"])
del input

if vectors.shape[-1] == 0:
Expand Down
2 changes: 1 addition & 1 deletion e3nn_jax/_src/irreps_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
)
assert all(x.dtype == self.array.dtype for x in self._list if x is not None), (
f"IrrepsArray: List dtypes {[None if x is None else x.dtype for x in self._list]} "
"incompatible with array dtype {self.array.dtype}."
f"incompatible with array dtype {self.array.dtype}."
)

@staticmethod
Expand Down
24 changes: 12 additions & 12 deletions e3nn_jax/_src/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# matrix


def rand_matrix(key, shape):
def rand_matrix(key, shape, dtype=jnp.float32):
r"""Random rotation matrix.
Args:
Expand All @@ -16,13 +16,13 @@ def rand_matrix(key, shape):
Returns:
`jax.numpy.ndarray`: array of shape :math:`(..., 3, 3)`
"""
return angles_to_matrix(*rand_angles(key, shape))
return angles_to_matrix(*rand_angles(key, shape, dtype=dtype))


# angles


def identity_angles(shape):
def identity_angles(shape, dtype=jnp.float32):
r"""Angles of the identity rotation.
Args:
Expand All @@ -33,11 +33,11 @@ def identity_angles(shape):
beta (`jax.numpy.ndarray`): array of shape :math:`(...)`
gamma (`jax.numpy.ndarray`): array of shape :math:`(...)`
"""
return jnp.zeros(shape), jnp.zeros(shape), jnp.zeros(shape)
return jnp.zeros(shape, dtype), jnp.zeros(shape, dtype), jnp.zeros(shape, dtype)


@partial(jax.jit, static_argnums=(1,), inline=True)
def rand_angles(key, shape):
def rand_angles(key, shape, dtype=jnp.float32):
r"""Random rotation angles.
Args:
Expand All @@ -49,7 +49,7 @@ def rand_angles(key, shape):
beta (`jax.numpy.ndarray`): array of shape :math:`(...)`
gamma (`jax.numpy.ndarray`): array of shape :math:`(...)`
"""
x, y, z = jax.random.uniform(key, (3,) + shape)
x, y, z = jax.random.uniform(key, (3,) + shape, dtype=dtype)
return 2 * jnp.pi * x, jnp.arccos(2 * z - 1), 2 * jnp.pi * y


Expand Down Expand Up @@ -95,7 +95,7 @@ def inverse_angles(a, b, c):
# quaternions


def identity_quaternion(shape):
def identity_quaternion(shape, dtype=jnp.float32):
r"""Quaternion of identity rotation.
Args:
Expand All @@ -104,12 +104,12 @@ def identity_quaternion(shape):
Returns:
`jax.numpy.ndarray`: array of shape :math:`(..., 4)`
"""
q = jnp.zeros(shape, 4)
q = jnp.zeros(shape, 4, dtype=dtype)
return q.at[..., 0].set(1) # or -1...


@partial(jax.jit, static_argnums=(1,), inline=True)
def rand_quaternion(key, shape):
def rand_quaternion(key, shape, dtype=jnp.float32):
r"""Generate random quaternion.
Args:
Expand All @@ -119,7 +119,7 @@ def rand_quaternion(key, shape):
Returns:
`jax.numpy.ndarray`: array of shape :math:`(..., 4)`
"""
return angles_to_quaternion(*rand_angles(key, shape))
return angles_to_quaternion(*rand_angles(key, shape, dtype))


@partial(jax.jit, inline=True)
Expand Down Expand Up @@ -162,7 +162,7 @@ def inverse_quaternion(q):
# axis-angle


def rand_axis_angle(key, shape):
def rand_axis_angle(key, shape, dtype=jnp.float32):
r"""Generate random rotation as axis-angle.
Args:
Expand All @@ -173,7 +173,7 @@ def rand_axis_angle(key, shape):
axis (`jax.numpy.ndarray`): array of shape :math:`(..., 3)`
angle (`jax.numpy.ndarray`): array of shape :math:`(...)`
"""
return angles_to_axis_angle(*rand_angles(key, shape))
return angles_to_axis_angle(*rand_angles(key, shape, dtype))


def compose_axis_angle(axis1, angle1, axis2, angle2):
Expand Down
10 changes: 10 additions & 0 deletions e3nn_jax/_src/util/dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import jax
import jax.numpy as jnp


def get_pytree_dtype(*args, default_dtype=jnp.float32):
leaves = jax.tree_util.tree_leaves(args)
if len(leaves) == 0:
return default_dtype

return jax.eval_shape(lambda xs: sum(jnp.sum(x) for x in xs), leaves).dtype
11 changes: 5 additions & 6 deletions e3nn_jax/_src/util/test.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
from typing import Callable, Optional, Tuple

import e3nn_jax as e3nn
import jax
import jax.numpy as jnp
import numpy as np

import e3nn_jax as e3nn
from e3nn_jax._src.util.dtype import get_pytree_dtype


def equivariance_test(
fun: Callable[[e3nn.IrrepsArray], e3nn.IrrepsArray],
rng_key: jnp.ndarray,
*args,
):
jax_enable_x64 = jax.config.read("jax_enable_x64")
jax.config.update("jax_enable_x64", True)

assert all(isinstance(arg, e3nn.IrrepsArray) for arg in args)
dtype = get_pytree_dtype(args)

R = -e3nn.rand_matrix(rng_key, ()) # random rotation and inversion
R = -e3nn.rand_matrix(rng_key, (), dtype=dtype) # random rotation and inversion

out1 = fun(*[arg.transform_by_matrix(R) for arg in args])
out2 = fun(*args).transform_by_matrix(R)

jax.config.update("jax_enable_x64", jax_enable_x64)
return out1, out2


Expand Down

0 comments on commit e62da34

Please sign in to comment.