Skip to content

Commit

Permalink
fix rotation dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Dec 14, 2022
1 parent e62da34 commit a658419
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 23 deletions.
24 changes: 4 additions & 20 deletions e3nn_jax/_src/rotation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from functools import partial

import jax
import jax.numpy as jnp

Expand Down Expand Up @@ -36,7 +34,6 @@ def identity_angles(shape, dtype=jnp.float32):
return jnp.zeros(shape, dtype), jnp.zeros(shape, dtype), jnp.zeros(shape, dtype)


@partial(jax.jit, static_argnums=(1,), inline=True)
def rand_angles(key, shape, dtype=jnp.float32):
r"""Random rotation angles.
Expand Down Expand Up @@ -108,7 +105,6 @@ def identity_quaternion(shape, dtype=jnp.float32):
return q.at[..., 0].set(1) # or -1...


@partial(jax.jit, static_argnums=(1,), inline=True)
def rand_quaternion(key, shape, dtype=jnp.float32):
r"""Generate random quaternion.
Expand All @@ -122,7 +118,6 @@ def rand_quaternion(key, shape, dtype=jnp.float32):
return angles_to_quaternion(*rand_angles(key, shape, dtype))


@partial(jax.jit, inline=True)
def compose_quaternion(q1, q2):
r"""Compose two quaternions: :math:`q_1 \circ q_2`.
Expand Down Expand Up @@ -197,7 +192,6 @@ def compose_axis_angle(axis1, angle1, axis2, angle2):
# conversions


@partial(jax.jit, inline=True)
def matrix_x(angle):
r"""Matrix of rotation around X axis.
Expand All @@ -221,7 +215,6 @@ def matrix_x(angle):
)


@partial(jax.jit, inline=True)
def matrix_y(angle):
r"""Matrix of rotation around Y axis.
Expand All @@ -245,7 +238,6 @@ def matrix_y(angle):
)


@partial(jax.jit, inline=True)
def matrix_z(angle):
r"""Matrix of rotation around Z axis.
Expand Down Expand Up @@ -277,7 +269,6 @@ def angles_to_matrix(alpha, beta, gamma):
return matrix_y(alpha) @ matrix_x(beta) @ matrix_y(gamma)


@partial(jax.jit, inline=True)
def matrix_to_angles(R):
r"""Conversion from matrix to angles.
Expand All @@ -290,14 +281,13 @@ def matrix_to_angles(R):
gamma (`jax.numpy.ndarray`): array of shape :math:`(...)`
"""
# assert jnp.allclose(jnp.linalg.det(R), 1)
x = R @ jnp.array([0.0, 1.0, 0.0])
x = R @ jnp.array([0.0, 1.0, 0.0], dtype=R.dtype)
a, b = xyz_to_angles(x)
R = jnp.swapaxes(angles_to_matrix(a, b, 0.0), -1, -2) @ R
c = jnp.arctan2(R[..., 0, 2], R[..., 0, 0])
return a, b, c


@partial(jax.jit, inline=True)
def angles_to_quaternion(alpha, beta, gamma):
r"""Conversion from angles to quaternion.
Expand All @@ -310,9 +300,9 @@ def angles_to_quaternion(alpha, beta, gamma):
q (`jax.numpy.ndarray`): array of shape :math:`(..., 4)`
"""
alpha, beta, gamma = jnp.broadcast_arrays(alpha, beta, gamma)
qa = axis_angle_to_quaternion(jnp.array([0.0, 1.0, 0.0]), alpha)
qb = axis_angle_to_quaternion(jnp.array([1.0, 0.0, 0.0]), beta)
qc = axis_angle_to_quaternion(jnp.array([0.0, 1.0, 0.0]), gamma)
qa = axis_angle_to_quaternion(jnp.array([0.0, 1.0, 0.0], alpha.dtype), alpha)
qb = axis_angle_to_quaternion(jnp.array([1.0, 0.0, 0.0], beta.dtype), beta)
qc = axis_angle_to_quaternion(jnp.array([0.0, 1.0, 0.0], gamma.dtype), gamma)
return compose_quaternion(qa, compose_quaternion(qb, qc))


Expand All @@ -328,7 +318,6 @@ def matrix_to_quaternion(R):
return axis_angle_to_quaternion(*matrix_to_axis_angle(R))


@partial(jax.jit, inline=True)
def axis_angle_to_quaternion(xyz, angle):
r"""Conversion from axis-angle to quaternion.
Expand All @@ -346,7 +335,6 @@ def axis_angle_to_quaternion(xyz, angle):
return jnp.concatenate([c, xyz * s], axis=-1)


@partial(jax.jit, inline=True)
def quaternion_to_axis_angle(q):
r"""Conversion from quaternion to axis-angle.
Expand All @@ -367,7 +355,6 @@ def _normalize(x):
return x / jnp.where(n > 0, n, 1.0)


@partial(jax.jit, inline=True)
def matrix_to_axis_angle(R):
r"""Conversion from matrix to axis-angle.
Expand Down Expand Up @@ -408,7 +395,6 @@ def angles_to_axis_angle(alpha, beta, gamma):
return matrix_to_axis_angle(angles_to_matrix(alpha, beta, gamma))


@partial(jax.jit, inline=True)
def axis_angle_to_matrix(axis, angle):
r"""Conversion from axis-angle to matrix.
Expand Down Expand Up @@ -470,7 +456,6 @@ def axis_angle_to_angles(axis, angle):
# point on the sphere


@partial(jax.jit, inline=True)
def angles_to_xyz(alpha, beta):
r"""Convert :math:`(\alpha, \beta)` into a point :math:`(x, y, z)` on the sphere.
Expand All @@ -492,7 +477,6 @@ def angles_to_xyz(alpha, beta):
return jnp.stack([x, y, z], axis=-1)


@partial(jax.jit, inline=True)
def xyz_to_angles(xyz):
r"""Convert a point :math:`\vec r = (x, y, z)` on the sphere into angles :math:`(\alpha, \beta)`.
Expand Down
10 changes: 7 additions & 3 deletions e3nn_jax/_src/rotation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,23 @@ def identity(x):
g = conv[i][j](g)
return g

R1 = e3nn.rand_matrix(next(keys), (100,))
R1 = e3nn.rand_matrix(next(keys), (100,), dtype=jnp.float64)
R2 = f(R1)

assert R2.dtype == jnp.float64
np.testing.assert_allclose(R1, R2, rtol=0, atol=1e-10)

R1 = e3nn.rand_matrix(next(keys), (2,), dtype=jnp.float32)
assert f(R1).dtype == jnp.float32

jax.config.update("jax_enable_x64", False)


def test_compose(keys):
jax.config.update("jax_enable_x64", True)

q1 = e3nn.rand_quaternion(keys[1], (10,))
q2 = e3nn.rand_quaternion(keys[2], (10,))
q1 = e3nn.rand_quaternion(keys[1], (10,), dtype=jnp.float64)
q2 = e3nn.rand_quaternion(keys[2], (10,), dtype=jnp.float64)

q = e3nn.compose_quaternion(q1, q2)

Expand Down

0 comments on commit a658419

Please sign in to comment.