Skip to content

Commit

Permalink
Speed up rotate_vectors and test (#193)
Browse files Browse the repository at this point in the history
Closes #191
  • Loading branch information
moble authored Feb 10, 2022
1 parent 62cd142 commit 8cf5abe
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 9 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


# Set this first for easier replacement
version = "2022.2.9.12.46.0"
version = "2022.2.9.19.55.57"

if "win" in platform.lower() and not "darwin" in platform.lower():
extra_compile_args = ["/O2"]
Expand Down
15 changes: 7 additions & 8 deletions src/quaternion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Copyright (c) 2021, Michael Boyle
# See LICENSE file for details: <https://github.com/moble/quaternion/blob/main/LICENSE>

__version__ = "2022.2.9.12.46.0"
__version__ = "2022.2.9.19.55.57"
__doc_title__ = "Quaternion dtype for NumPy"
__doc__ = "Adds a quaternion dtype to NumPy."
__all__ = ['quaternion',
Expand Down Expand Up @@ -703,13 +703,12 @@ def rotate_vectors(R, v, axis=-1):
if v.shape[axis] != 3:
raise ValueError("Input `v` axis {0} has length {1}, not 3.".format(axis, v.shape[axis]))
m = as_rotation_matrix(R)
m_axes = list(range(m.ndim))
v_axes = list(range(m.ndim, m.ndim+v.ndim))
mv_axes = list(v_axes)
mv_axes[axis] = m_axes[-2]
mv_axes = m_axes[:-2] + mv_axes
v_axes[axis] = m_axes[-1]
return np.einsum(m, m_axes, v, v_axes, mv_axes)
tensordot_axis = m.ndim-2
final_axis = tensordot_axis + (axis % v.ndim)
return np.moveaxis(
np.tensordot(m, v, axes=(-1, axis)),
tensordot_axis, final_axis
)


def isclose(a, b, rtol=4*np.finfo(float).eps, atol=0.0, equal_nan=False):
Expand Down
14 changes: 14 additions & 0 deletions tests/test_quaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,20 @@ def test_rotate_vectors(Rs):
rtol=1e-15, atol=1e-15)
assert quats.shape + vecs.shape == vecsprime.shape, ("Out of shape!", quats.shape, vecs.shape, vecsprime.shape)

for Rshape in [(1,), (10,), (100,), (1000,), (5, 7), (5, 7, 23)]:
R = np.random.normal(size=Rshape+(4,))
R = quaternion.from_float_array(R / np.linalg.norm(R, axis=-1)[..., np.newaxis])
for vshape in [(1,), (2,), (3,), (4,), (20,), (200,), (2000,), (11, 13), (11, 13, 29)]:
v = np.random.normal(size=vshape+(3,))
Rprime = quaternion.rotate_vectors(R, v)
expected_shape = Rshape + vshape + (3,)
assert Rprime.shape == expected_shape
for vshape, axis in [((7, 3, 5), 1), ((7, 3, 5), -2), ((7, 3, 5, 11), 1), ((7, 3, 5, 11), -3)]:
v = np.random.normal(size=vshape)
Rprime = quaternion.rotate_vectors(R, v, axis=axis)
expected_shape = Rshape + vshape
assert Rprime.shape == expected_shape


def test_allclose(Qs):
for q in Qs[Qs_nonnan]:
Expand Down

0 comments on commit 8cf5abe

Please sign in to comment.