diff --git a/setup.py b/setup.py index ac49b6e..9714940 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ # Set this first for easier replacement -version = "2022.2.9.12.46.0" +version = "2022.2.9.19.55.57" if "win" in platform.lower() and not "darwin" in platform.lower(): extra_compile_args = ["/O2"] diff --git a/src/quaternion/__init__.py b/src/quaternion/__init__.py index 573d557..3e86dd7 100644 --- a/src/quaternion/__init__.py +++ b/src/quaternion/__init__.py @@ -3,7 +3,7 @@ # Copyright (c) 2021, Michael Boyle # See LICENSE file for details: -__version__ = "2022.2.9.12.46.0" +__version__ = "2022.2.9.19.55.57" __doc_title__ = "Quaternion dtype for NumPy" __doc__ = "Adds a quaternion dtype to NumPy." __all__ = ['quaternion', @@ -703,13 +703,12 @@ def rotate_vectors(R, v, axis=-1): if v.shape[axis] != 3: raise ValueError("Input `v` axis {0} has length {1}, not 3.".format(axis, v.shape[axis])) m = as_rotation_matrix(R) - m_axes = list(range(m.ndim)) - v_axes = list(range(m.ndim, m.ndim+v.ndim)) - mv_axes = list(v_axes) - mv_axes[axis] = m_axes[-2] - mv_axes = m_axes[:-2] + mv_axes - v_axes[axis] = m_axes[-1] - return np.einsum(m, m_axes, v, v_axes, mv_axes) + tensordot_axis = m.ndim-2 + final_axis = tensordot_axis + (axis % v.ndim) + return np.moveaxis( + np.tensordot(m, v, axes=(-1, axis)), + tensordot_axis, final_axis + ) def isclose(a, b, rtol=4*np.finfo(float).eps, atol=0.0, equal_nan=False): diff --git a/tests/test_quaternion.py b/tests/test_quaternion.py index eefe4f1..72a11a8 100644 --- a/tests/test_quaternion.py +++ b/tests/test_quaternion.py @@ -416,6 +416,20 @@ def test_rotate_vectors(Rs): rtol=1e-15, atol=1e-15) assert quats.shape + vecs.shape == vecsprime.shape, ("Out of shape!", quats.shape, vecs.shape, vecsprime.shape) + for Rshape in [(1,), (10,), (100,), (1000,), (5, 7), (5, 7, 23)]: + R = np.random.normal(size=Rshape+(4,)) + R = quaternion.from_float_array(R / np.linalg.norm(R, axis=-1)[..., np.newaxis]) + for vshape in [(1,), (2,), (3,), (4,), (20,), (200,), (2000,), (11, 13), (11, 13, 29)]: + v = np.random.normal(size=vshape+(3,)) + Rprime = quaternion.rotate_vectors(R, v) + expected_shape = Rshape + vshape + (3,) + assert Rprime.shape == expected_shape + for vshape, axis in [((7, 3, 5), 1), ((7, 3, 5), -2), ((7, 3, 5, 11), 1), ((7, 3, 5, 11), -3)]: + v = np.random.normal(size=vshape) + Rprime = quaternion.rotate_vectors(R, v, axis=axis) + expected_shape = Rshape + vshape + assert Rprime.shape == expected_shape + def test_allclose(Qs): for q in Qs[Qs_nonnan]: