diff --git a/e3nn_jax/_src/rotation.py b/e3nn_jax/_src/rotation.py index e1f33b81..a8750ae5 100644 --- a/e3nn_jax/_src/rotation.py +++ b/e3nn_jax/_src/rotation.py @@ -1,5 +1,3 @@ -from functools import partial - import jax import jax.numpy as jnp @@ -36,7 +34,6 @@ def identity_angles(shape, dtype=jnp.float32): return jnp.zeros(shape, dtype), jnp.zeros(shape, dtype), jnp.zeros(shape, dtype) -@partial(jax.jit, static_argnums=(1,), inline=True) def rand_angles(key, shape, dtype=jnp.float32): r"""Random rotation angles. @@ -108,7 +105,6 @@ def identity_quaternion(shape, dtype=jnp.float32): return q.at[..., 0].set(1) # or -1... -@partial(jax.jit, static_argnums=(1,), inline=True) def rand_quaternion(key, shape, dtype=jnp.float32): r"""Generate random quaternion. @@ -122,7 +118,6 @@ def rand_quaternion(key, shape, dtype=jnp.float32): return angles_to_quaternion(*rand_angles(key, shape, dtype)) -@partial(jax.jit, inline=True) def compose_quaternion(q1, q2): r"""Compose two quaternions: :math:`q_1 \circ q_2`. @@ -197,7 +192,6 @@ def compose_axis_angle(axis1, angle1, axis2, angle2): # conversions -@partial(jax.jit, inline=True) def matrix_x(angle): r"""Matrix of rotation around X axis. @@ -221,7 +215,6 @@ def matrix_x(angle): ) -@partial(jax.jit, inline=True) def matrix_y(angle): r"""Matrix of rotation around Y axis. @@ -245,7 +238,6 @@ def matrix_y(angle): ) -@partial(jax.jit, inline=True) def matrix_z(angle): r"""Matrix of rotation around Z axis. @@ -277,7 +269,6 @@ def angles_to_matrix(alpha, beta, gamma): return matrix_y(alpha) @ matrix_x(beta) @ matrix_y(gamma) -@partial(jax.jit, inline=True) def matrix_to_angles(R): r"""Conversion from matrix to angles. @@ -290,14 +281,13 @@ def matrix_to_angles(R): gamma (`jax.numpy.ndarray`): array of shape :math:`(...)` """ # assert jnp.allclose(jnp.linalg.det(R), 1) - x = R @ jnp.array([0.0, 1.0, 0.0]) + x = R @ jnp.array([0.0, 1.0, 0.0], dtype=R.dtype) a, b = xyz_to_angles(x) R = jnp.swapaxes(angles_to_matrix(a, b, 0.0), -1, -2) @ R c = jnp.arctan2(R[..., 0, 2], R[..., 0, 0]) return a, b, c -@partial(jax.jit, inline=True) def angles_to_quaternion(alpha, beta, gamma): r"""Conversion from angles to quaternion. @@ -310,9 +300,9 @@ def angles_to_quaternion(alpha, beta, gamma): q (`jax.numpy.ndarray`): array of shape :math:`(..., 4)` """ alpha, beta, gamma = jnp.broadcast_arrays(alpha, beta, gamma) - qa = axis_angle_to_quaternion(jnp.array([0.0, 1.0, 0.0]), alpha) - qb = axis_angle_to_quaternion(jnp.array([1.0, 0.0, 0.0]), beta) - qc = axis_angle_to_quaternion(jnp.array([0.0, 1.0, 0.0]), gamma) + qa = axis_angle_to_quaternion(jnp.array([0.0, 1.0, 0.0], alpha.dtype), alpha) + qb = axis_angle_to_quaternion(jnp.array([1.0, 0.0, 0.0], beta.dtype), beta) + qc = axis_angle_to_quaternion(jnp.array([0.0, 1.0, 0.0], gamma.dtype), gamma) return compose_quaternion(qa, compose_quaternion(qb, qc)) @@ -328,7 +318,6 @@ def matrix_to_quaternion(R): return axis_angle_to_quaternion(*matrix_to_axis_angle(R)) -@partial(jax.jit, inline=True) def axis_angle_to_quaternion(xyz, angle): r"""Conversion from axis-angle to quaternion. @@ -346,7 +335,6 @@ def axis_angle_to_quaternion(xyz, angle): return jnp.concatenate([c, xyz * s], axis=-1) -@partial(jax.jit, inline=True) def quaternion_to_axis_angle(q): r"""Conversion from quaternion to axis-angle. @@ -367,7 +355,6 @@ def _normalize(x): return x / jnp.where(n > 0, n, 1.0) -@partial(jax.jit, inline=True) def matrix_to_axis_angle(R): r"""Conversion from matrix to axis-angle. @@ -408,7 +395,6 @@ def angles_to_axis_angle(alpha, beta, gamma): return matrix_to_axis_angle(angles_to_matrix(alpha, beta, gamma)) -@partial(jax.jit, inline=True) def axis_angle_to_matrix(axis, angle): r"""Conversion from axis-angle to matrix. @@ -470,7 +456,6 @@ def axis_angle_to_angles(axis, angle): # point on the sphere -@partial(jax.jit, inline=True) def angles_to_xyz(alpha, beta): r"""Convert :math:`(\alpha, \beta)` into a point :math:`(x, y, z)` on the sphere. @@ -492,7 +477,6 @@ def angles_to_xyz(alpha, beta): return jnp.stack([x, y, z], axis=-1) -@partial(jax.jit, inline=True) def xyz_to_angles(xyz): r"""Convert a point :math:`\vec r = (x, y, z)` on the sphere into angles :math:`(\alpha, \beta)`. diff --git a/e3nn_jax/_src/rotation_test.py b/e3nn_jax/_src/rotation_test.py index 2750b883..e0567ffb 100644 --- a/e3nn_jax/_src/rotation_test.py +++ b/e3nn_jax/_src/rotation_test.py @@ -50,19 +50,23 @@ def identity(x): g = conv[i][j](g) return g - R1 = e3nn.rand_matrix(next(keys), (100,)) + R1 = e3nn.rand_matrix(next(keys), (100,), dtype=jnp.float64) R2 = f(R1) + assert R2.dtype == jnp.float64 np.testing.assert_allclose(R1, R2, rtol=0, atol=1e-10) + R1 = e3nn.rand_matrix(next(keys), (2,), dtype=jnp.float32) + assert f(R1).dtype == jnp.float32 + jax.config.update("jax_enable_x64", False) def test_compose(keys): jax.config.update("jax_enable_x64", True) - q1 = e3nn.rand_quaternion(keys[1], (10,)) - q2 = e3nn.rand_quaternion(keys[2], (10,)) + q1 = e3nn.rand_quaternion(keys[1], (10,), dtype=jnp.float64) + q2 = e3nn.rand_quaternion(keys[2], (10,), dtype=jnp.float64) q = e3nn.compose_quaternion(q1, q2)