diff --git a/src/deepali/spatial/linear.py b/src/deepali/spatial/linear.py index 3322e51..03e190c 100644 --- a/src/deepali/spatial/linear.py +++ b/src/deepali/spatial/linear.py @@ -213,6 +213,18 @@ def angles_(self: EulerRotation, arg: Tensor) -> EulerRotation: self.data_(params) return self + def matrix_(self: EulerRotation, arg: Tensor) -> EulerRotation: + r"""Set rotation angles from rotation matrix.""" + if not isinstance(arg, Tensor): + raise TypeError("EulerRotation.matrix() 'arg' must be tensor") + if arg.ndim != 3: + raise ValueError("EulerRotation.matrix() 'arg' must be 3-dimensional tensor") + shape = (arg.shape[0], 3, 3) + if arg.shape != shape: + raise ValueError(f"Rotation matrix must have shape {shape!r}") + angles = U.euler_rotation_angles(arg, order=self.order) + return self.angles_(angles) + def tensor(self: EulerRotation) -> Tensor: r"""Get tensor representation of this transformation