Skip to content

Commit cd274b0

Browse files
committed
Add convertion from SE(3) to Euclidean
1 parent e398237 commit cd274b0

File tree

3 files changed

+81
-18
lines changed

3 files changed

+81
-18
lines changed

diffdrr/_modidx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
'diffdrr.pose.RigidTransform.__init__': ('api/pose.html#rigidtransform.__init__', 'diffdrr/pose.py'),
5151
'diffdrr.pose.RigidTransform.__len__': ('api/pose.html#rigidtransform.__len__', 'diffdrr/pose.py'),
5252
'diffdrr.pose.RigidTransform.compose': ('api/pose.html#rigidtransform.compose', 'diffdrr/pose.py'),
53+
'diffdrr.pose.RigidTransform.convert': ('api/pose.html#rigidtransform.convert', 'diffdrr/pose.py'),
5354
'diffdrr.pose.RigidTransform.forward': ('api/pose.html#rigidtransform.forward', 'diffdrr/pose.py'),
5455
'diffdrr.pose.RigidTransform.get_se3_log': ('api/pose.html#rigidtransform.get_se3_log', 'diffdrr/pose.py'),
5556
'diffdrr.pose.RigidTransform.inverse': ('api/pose.html#rigidtransform.inverse', 'diffdrr/pose.py'),

diffdrr/pose.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/06_pose.ipynb.
22

33
# %% auto 0
4-
__all__ = ['RigidTransform', 'make_matrix', 'convert', 'rotation_10d_to_quaternion', 'quaternion_to_rotation_10d',
4+
__all__ = ['RigidTransform', 'convert', 'rotation_10d_to_quaternion', 'quaternion_to_rotation_10d',
55
'quaternion_adjugate_to_quaternion', 'quaternion_to_quaternion_adjugate']
66

77
# %% ../notebooks/api/06_pose.ipynb 4
@@ -44,10 +44,37 @@ def compose(self, T):
4444
matrix = torch.einsum("bij, bjk -> bik", T.matrix, self.matrix)
4545
return RigidTransform(matrix)
4646

47-
def get_se3_log(self):
48-
return se3_log_map(self.matrix.mT)
47+
def convert(self, parameterization, convention=None):
48+
translation = T.matrix[..., :3, 3]
49+
if parameterization == "axis_angle":
50+
rotation = matrix_to_axis_angle(T.matrix[..., :3, :3])
51+
elif parameterization == "euler_angles":
52+
rotation = matrix_to_euler_angles(T.matrix[..., :3, :3], convention)
53+
elif parameterization == "matrix":
54+
rotation = T.matrix[..., :3, :3]
55+
elif parameterization == "quaternion":
56+
rotation = matrix_to_quaternion(T.matrix[..., :3, :3])
57+
elif parameterization == "quaternion_adjugate":
58+
quaternion = matrix_to_quaternion(T.matrix[..., :3, :3])
59+
rotation = quaternion_to_quaternion_adjugate(quaternion)
60+
elif parameterization == "rotation_6d":
61+
rotation = matrix_to_rotation_6d(T.matrix[..., :3, :3])
62+
elif parameterization == "rotation_10d":
63+
quaternion = matrix_to_quaternion(T.matrix[..., :3, :3])
64+
rotation = quaternion_to_rotation_10d(quaternion)
65+
elif parameterization == "se3_log_map":
66+
rotation, translation = self.get_se3_log()
67+
else:
68+
raise ValueError(f"Must be in {PARAMETERIZATIONS}, not {parameterization}")
69+
return rotation, translation
4970

71+
def get_se3_log(self):
72+
params = se3_log_map(self.matrix.mT)
73+
rotation = params[..., 3:]
74+
translation = params[..., :3]
75+
return rotation, translation
5076

77+
# %% ../notebooks/api/06_pose.ipynb 5
5178
def make_matrix(R, t):
5279
assert (batch_size := len(R)) == len(t)
5380
matrix = torch.zeros(batch_size, 4, 4).to(R)
@@ -56,17 +83,17 @@ def make_matrix(R, t):
5683
matrix[..., -1, -1] = 1.0
5784
return matrix
5885

59-
# %% ../notebooks/api/06_pose.ipynb 5
86+
# %% ../notebooks/api/06_pose.ipynb 6
6087
from scipy.spatial.transform import Rotation
6188

6289

63-
def random_rigid_transform(batch_size):
90+
def random_rigid_transform(batch_size=1):
6491
"""Helper function for testing implementations."""
6592
R = torch.from_numpy(Rotation.random(batch_size).as_matrix()).to(torch.float32)
6693
t = 100 * torch.randn((batch_size, 3))
6794
return RigidTransform(make_matrix(R, t))
6895

69-
# %% ../notebooks/api/06_pose.ipynb 7
96+
# %% ../notebooks/api/06_pose.ipynb 8
7097
PARAMETERIZATIONS = [
7198
"axis_angle",
7299
"euler_angles",
@@ -78,7 +105,7 @@ def random_rigid_transform(batch_size):
78105
"se3_log_map",
79106
]
80107

81-
# %% ../notebooks/api/06_pose.ipynb 8
108+
# %% ../notebooks/api/06_pose.ipynb 9
82109
def convert(*args, parameterization, convention=None) -> RigidTransform:
83110
if parameterization == "euler_angles" and convention is None:
84111
raise ValueError(
@@ -119,7 +146,7 @@ def convert(*args, parameterization, convention=None) -> RigidTransform:
119146

120147
return convert(matrix, parameterization="matrix")
121148

122-
# %% ../notebooks/api/06_pose.ipynb 10
149+
# %% ../notebooks/api/06_pose.ipynb 11
123150
def _10vec_to_4x4symmetric(vec):
124151
"""Convert a 10-vector to a symmetric 4x4 matrix."""
125152
b = len(vec)
@@ -129,7 +156,7 @@ def _10vec_to_4x4symmetric(vec):
129156
A[..., jdx, idx] = vec
130157
return A
131158

132-
# %% ../notebooks/api/06_pose.ipynb 11
159+
# %% ../notebooks/api/06_pose.ipynb 12
133160
def rotation_10d_to_quaternion(rotation: torch.Tensor) -> torch.Tensor:
134161
"""
135162
Convert a 10-vector into a symmetric matrix, whose eigenvector corresponding
@@ -146,7 +173,7 @@ def quaternion_to_rotation_10d(q: torch.Tensor) -> torch.Tensor:
146173
idx, jdx = torch.triu_indices(4, 4)
147174
return A[..., idx, jdx]
148175

149-
# %% ../notebooks/api/06_pose.ipynb 12
176+
# %% ../notebooks/api/06_pose.ipynb 13
150177
def quaternion_adjugate_to_quaternion(rotation: torch.Tensor) -> torch.Tensor:
151178
"""
152179
Convert a 10-vector in the quaternion adjugate, a symmetric matrix whose
@@ -167,7 +194,7 @@ def quaternion_to_quaternion_adjugate(q: torch.Tensor) -> torch.Tensor:
167194
idx, jdx = torch.triu_indices(4, 4)
168195
return A[..., idx, jdx]
169196

170-
# %% ../notebooks/api/06_pose.ipynb 15
197+
# %% ../notebooks/api/06_pose.ipynb 16
171198
# pytorch3d/transforms/rotation_conversions.py
172199

173200
from typing import Optional, Union
@@ -700,7 +727,7 @@ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
700727
batch_dim = matrix.size()[:-2]
701728
return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
702729

703-
# %% ../notebooks/api/06_pose.ipynb 16
730+
# %% ../notebooks/api/06_pose.ipynb 17
704731
# pytorch3d/transforms/math.py
705732
from typing import Tuple
706733

@@ -778,7 +805,7 @@ def _dacos_dx(x: float) -> float:
778805
"""
779806
return (-1.0) / math.sqrt(1.0 - x * x)
780807

781-
# %% ../notebooks/api/06_pose.ipynb 17
808+
# %% ../notebooks/api/06_pose.ipynb 18
782809
# pytorch3d/transforms/so3.py
783810

784811
import warnings
@@ -1038,7 +1065,7 @@ def hat(v: torch.Tensor) -> torch.Tensor:
10381065

10391066
return h
10401067

1041-
# %% ../notebooks/api/06_pose.ipynb 18
1068+
# %% ../notebooks/api/06_pose.ipynb 19
10421069
# pytorch3d/transforms/se3.py
10431070

10441071

notebooks/api/06_pose.ipynb

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,45 @@
112112
" matrix = torch.einsum(\"bij, bjk -> bik\", T.matrix, self.matrix)\n",
113113
" return RigidTransform(matrix)\n",
114114
"\n",
115-
" def get_se3_log(self):\n",
116-
" return se3_log_map(self.matrix.mT)\n",
117-
"\n",
115+
" def convert(self, parameterization, convention=None):\n",
116+
" translation = T.matrix[..., :3, 3]\n",
117+
" if parameterization == \"axis_angle\":\n",
118+
" rotation = matrix_to_axis_angle(T.matrix[..., :3, :3])\n",
119+
" elif parameterization == \"euler_angles\":\n",
120+
" rotation = matrix_to_euler_angles(T.matrix[..., :3, :3], convention)\n",
121+
" elif parameterization == \"matrix\":\n",
122+
" rotation = T.matrix[..., :3, :3]\n",
123+
" elif parameterization == \"quaternion\":\n",
124+
" rotation = matrix_to_quaternion(T.matrix[..., :3, :3])\n",
125+
" elif parameterization == \"quaternion_adjugate\":\n",
126+
" quaternion = matrix_to_quaternion(T.matrix[..., :3, :3])\n",
127+
" rotation = quaternion_to_quaternion_adjugate(quaternion)\n",
128+
" elif parameterization == \"rotation_6d\":\n",
129+
" rotation = matrix_to_rotation_6d(T.matrix[..., :3, :3])\n",
130+
" elif parameterization == \"rotation_10d\":\n",
131+
" quaternion = matrix_to_quaternion(T.matrix[..., :3, :3])\n",
132+
" rotation = quaternion_to_rotation_10d(quaternion)\n",
133+
" elif parameterization == \"se3_log_map\":\n",
134+
" rotation, translation = self.get_se3_log()\n",
135+
" else:\n",
136+
" raise ValueError(f\"Must be in {PARAMETERIZATIONS}, not {parameterization}\")\n",
137+
" return rotation, translation\n",
118138
"\n",
139+
" def get_se3_log(self):\n",
140+
" params = se3_log_map(self.matrix.mT)\n",
141+
" rotation = params[..., 3:]\n",
142+
" translation = params[..., :3]\n",
143+
" return rotation, translation"
144+
]
145+
},
146+
{
147+
"cell_type": "code",
148+
"execution_count": null,
149+
"id": "5caffa4c-ed95-4a1e-ac56-3f940f21bbb5",
150+
"metadata": {},
151+
"outputs": [],
152+
"source": [
153+
"#| exporti\n",
119154
"def make_matrix(R, t):\n",
120155
" assert (batch_size := len(R)) == len(t)\n",
121156
" matrix = torch.zeros(batch_size, 4, 4).to(R)\n",
@@ -136,7 +171,7 @@
136171
"from scipy.spatial.transform import Rotation\n",
137172
"\n",
138173
"\n",
139-
"def random_rigid_transform(batch_size):\n",
174+
"def random_rigid_transform(batch_size=1):\n",
140175
" \"\"\"Helper function for testing implementations.\"\"\"\n",
141176
" R = torch.from_numpy(Rotation.random(batch_size).as_matrix()).to(torch.float32)\n",
142177
" t = 100 * torch.randn((batch_size, 3))\n",

0 commit comments

Comments
 (0)