Skip to content

Commit

Permalink
[spatial] Add EulerRotation.matrix_() setter
Browse files Browse the repository at this point in the history
  • Loading branch information
aschuh-hf committed Aug 3, 2023
1 parent 03d763f commit b2ac443
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/deepali/spatial/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,18 @@ def angles_(self: EulerRotation, arg: Tensor) -> EulerRotation:
self.data_(params)
return self

def matrix_(self: EulerRotation, arg: Tensor) -> EulerRotation:
r"""Set rotation angles from rotation matrix."""
if not isinstance(arg, Tensor):
raise TypeError("EulerRotation.matrix() 'arg' must be tensor")
if arg.ndim != 3:
raise ValueError("EulerRotation.matrix() 'arg' must be 3-dimensional tensor")
shape = (arg.shape[0], 3, 3)
if arg.shape != shape:
raise ValueError(f"Rotation matrix must have shape {shape!r}")
angles = U.euler_rotation_angles(arg, order=self.order)
return self.angles_(angles)

def tensor(self: EulerRotation) -> Tensor:
r"""Get tensor representation of this transformation
Expand Down

0 comments on commit b2ac443

Please sign in to comment.