From e62da34c0a853cfb332cf37264867479a1f06c68 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 14 Dec 2022 09:59:10 -0500 Subject: [PATCH] remove jax_enable_x64 from equivariance_test --- e3nn_jax/_src/activation.py | 3 ++- e3nn_jax/_src/core_tensor_product.py | 16 ++++------------ e3nn_jax/_src/gate.py | 4 ++-- e3nn_jax/_src/irreps_array.py | 2 +- e3nn_jax/_src/rotation.py | 24 ++++++++++++------------ e3nn_jax/_src/util/dtype.py | 10 ++++++++++ e3nn_jax/_src/util/test.py | 11 +++++------ 7 files changed, 36 insertions(+), 34 deletions(-) create mode 100644 e3nn_jax/_src/util/dtype.py diff --git a/e3nn_jax/_src/activation.py b/e3nn_jax/_src/activation.py index 5d39c5f1..a8d7cccc 100644 --- a/e3nn_jax/_src/activation.py +++ b/e3nn_jax/_src/activation.py @@ -11,13 +11,14 @@ def normalize_function(phi): k = jax.random.PRNGKey(0) x = jax.random.normal(k, (1_000_000,), dtype=jnp.float64) c = jnp.mean(phi(x) ** 2) ** 0.5 + c = c.item() if jnp.allclose(c, 1.0): return phi else: def rho(x): - return phi(x) / c.astype(x.dtype) + return phi(x) / c return rho diff --git a/e3nn_jax/_src/core_tensor_product.py b/e3nn_jax/_src/core_tensor_product.py index df8a469a..50de82df 100644 --- a/e3nn_jax/_src/core_tensor_product.py +++ b/e3nn_jax/_src/core_tensor_product.py @@ -10,9 +10,9 @@ import jax.numpy as jnp from e3nn_jax import Instruction, Irreps, IrrepsArray, clebsch_gordan, config -from e3nn_jax._src.util.prod import prod - from e3nn_jax._src.einsum import einsum as opt_einsum +from e3nn_jax._src.util.dtype import get_pytree_dtype +from e3nn_jax._src.util.prod import prod class FunctionalTensorProduct: @@ -293,14 +293,6 @@ def update(instruction: Instruction) -> float: return [update(instruction) for instruction in instructions] -def _get_dtype(*args): - leaves = jax.tree_util.tree_leaves(args) - if len(leaves) == 0: - return jnp.float32 - - return jax.eval_shape(lambda xs: sum(jnp.sum(x) for x in xs), leaves).dtype - - @partial(jax.jit, static_argnums=(0,), static_argnames=("custom_einsum_jvp", "fused")) @partial(jax.profiler.annotate_function, name="TensorProduct.left_right") def _left_right( @@ -312,7 +304,7 @@ def _left_right( custom_einsum_jvp: bool = False, fused: bool = False, ): - dtype = _get_dtype(weights, input1, input2) + dtype = get_pytree_dtype(weights, input1, input2) if self.irreps_in1.dim == 0 or self.irreps_in2.dim == 0 or self.irreps_out.dim == 0: return IrrepsArray.zeros(self.irreps_out, ()) @@ -570,7 +562,7 @@ def _right( *, custom_einsum_jvp: bool = False, ) -> jnp.ndarray: - dtype = _get_dtype(weights, input2) + dtype = get_pytree_dtype(weights, input2) # = Short-circut for zero dimensional = if self.irreps_in1.dim == 0 or self.irreps_in2.dim == 0 or self.irreps_out.dim == 0: diff --git a/e3nn_jax/_src/gate.py b/e3nn_jax/_src/gate.py index 8ea5365a..a30448a1 100644 --- a/e3nn_jax/_src/gate.py +++ b/e3nn_jax/_src/gate.py @@ -13,8 +13,8 @@ def _gate(input: IrrepsArray, even_act, odd_act, even_gate_act, odd_gate_act) -> act = {1: even_act, -1: odd_act} gate_act = {1: even_gate_act, -1: odd_gate_act} - scalars = input.filtered(["0e", "0o"]) - vectors = input.filtered(lambda mul_ir: mul_ir.ir.l > 0) + scalars = input.filtered(keep=["0e", "0o"]) + vectors = input.filtered(drop=["0e", "0o"]) del input if vectors.shape[-1] == 0: diff --git a/e3nn_jax/_src/irreps_array.py b/e3nn_jax/_src/irreps_array.py index 676403aa..a670aacc 100644 --- a/e3nn_jax/_src/irreps_array.py +++ b/e3nn_jax/_src/irreps_array.py @@ -99,7 +99,7 @@ def __init__( ) assert all(x.dtype == self.array.dtype for x in self._list if x is not None), ( f"IrrepsArray: List dtypes {[None if x is None else x.dtype for x in self._list]} " - "incompatible with array dtype {self.array.dtype}." + f"incompatible with array dtype {self.array.dtype}." ) @staticmethod diff --git a/e3nn_jax/_src/rotation.py b/e3nn_jax/_src/rotation.py index 4c1e2ffe..e1f33b81 100644 --- a/e3nn_jax/_src/rotation.py +++ b/e3nn_jax/_src/rotation.py @@ -6,7 +6,7 @@ # matrix -def rand_matrix(key, shape): +def rand_matrix(key, shape, dtype=jnp.float32): r"""Random rotation matrix. Args: @@ -16,13 +16,13 @@ def rand_matrix(key, shape): Returns: `jax.numpy.ndarray`: array of shape :math:`(..., 3, 3)` """ - return angles_to_matrix(*rand_angles(key, shape)) + return angles_to_matrix(*rand_angles(key, shape, dtype=dtype)) # angles -def identity_angles(shape): +def identity_angles(shape, dtype=jnp.float32): r"""Angles of the identity rotation. Args: @@ -33,11 +33,11 @@ def identity_angles(shape): beta (`jax.numpy.ndarray`): array of shape :math:`(...)` gamma (`jax.numpy.ndarray`): array of shape :math:`(...)` """ - return jnp.zeros(shape), jnp.zeros(shape), jnp.zeros(shape) + return jnp.zeros(shape, dtype), jnp.zeros(shape, dtype), jnp.zeros(shape, dtype) @partial(jax.jit, static_argnums=(1,), inline=True) -def rand_angles(key, shape): +def rand_angles(key, shape, dtype=jnp.float32): r"""Random rotation angles. Args: @@ -49,7 +49,7 @@ def rand_angles(key, shape): beta (`jax.numpy.ndarray`): array of shape :math:`(...)` gamma (`jax.numpy.ndarray`): array of shape :math:`(...)` """ - x, y, z = jax.random.uniform(key, (3,) + shape) + x, y, z = jax.random.uniform(key, (3,) + shape, dtype=dtype) return 2 * jnp.pi * x, jnp.arccos(2 * z - 1), 2 * jnp.pi * y @@ -95,7 +95,7 @@ def inverse_angles(a, b, c): # quaternions -def identity_quaternion(shape): +def identity_quaternion(shape, dtype=jnp.float32): r"""Quaternion of identity rotation. Args: @@ -104,12 +104,12 @@ def identity_quaternion(shape): Returns: `jax.numpy.ndarray`: array of shape :math:`(..., 4)` """ - q = jnp.zeros(shape, 4) + q = jnp.zeros(shape, 4, dtype=dtype) return q.at[..., 0].set(1) # or -1... @partial(jax.jit, static_argnums=(1,), inline=True) -def rand_quaternion(key, shape): +def rand_quaternion(key, shape, dtype=jnp.float32): r"""Generate random quaternion. Args: @@ -119,7 +119,7 @@ def rand_quaternion(key, shape): Returns: `jax.numpy.ndarray`: array of shape :math:`(..., 4)` """ - return angles_to_quaternion(*rand_angles(key, shape)) + return angles_to_quaternion(*rand_angles(key, shape, dtype)) @partial(jax.jit, inline=True) @@ -162,7 +162,7 @@ def inverse_quaternion(q): # axis-angle -def rand_axis_angle(key, shape): +def rand_axis_angle(key, shape, dtype=jnp.float32): r"""Generate random rotation as axis-angle. Args: @@ -173,7 +173,7 @@ def rand_axis_angle(key, shape): axis (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` angle (`jax.numpy.ndarray`): array of shape :math:`(...)` """ - return angles_to_axis_angle(*rand_angles(key, shape)) + return angles_to_axis_angle(*rand_angles(key, shape, dtype)) def compose_axis_angle(axis1, angle1, axis2, angle2): diff --git a/e3nn_jax/_src/util/dtype.py b/e3nn_jax/_src/util/dtype.py new file mode 100644 index 00000000..a26d3f2b --- /dev/null +++ b/e3nn_jax/_src/util/dtype.py @@ -0,0 +1,10 @@ +import jax +import jax.numpy as jnp + + +def get_pytree_dtype(*args, default_dtype=jnp.float32): + leaves = jax.tree_util.tree_leaves(args) + if len(leaves) == 0: + return default_dtype + + return jax.eval_shape(lambda xs: sum(jnp.sum(x) for x in xs), leaves).dtype diff --git a/e3nn_jax/_src/util/test.py b/e3nn_jax/_src/util/test.py index 964b7ec6..4da40c30 100644 --- a/e3nn_jax/_src/util/test.py +++ b/e3nn_jax/_src/util/test.py @@ -1,27 +1,26 @@ from typing import Callable, Optional, Tuple -import e3nn_jax as e3nn import jax import jax.numpy as jnp import numpy as np +import e3nn_jax as e3nn +from e3nn_jax._src.util.dtype import get_pytree_dtype + def equivariance_test( fun: Callable[[e3nn.IrrepsArray], e3nn.IrrepsArray], rng_key: jnp.ndarray, *args, ): - jax_enable_x64 = jax.config.read("jax_enable_x64") - jax.config.update("jax_enable_x64", True) - assert all(isinstance(arg, e3nn.IrrepsArray) for arg in args) + dtype = get_pytree_dtype(args) - R = -e3nn.rand_matrix(rng_key, ()) # random rotation and inversion + R = -e3nn.rand_matrix(rng_key, (), dtype=dtype) # random rotation and inversion out1 = fun(*[arg.transform_by_matrix(R) for arg in args]) out2 = fun(*args).transform_by_matrix(R) - jax.config.update("jax_enable_x64", jax_enable_x64) return out1, out2