Skip to content

Commit 90781c4

Browse files
committed
Add SO(3) conversion
1 parent e50a510 commit 90781c4

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

diffdrr/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"quaternion_adjugate",
1818
"rotation_6d",
1919
"rotation_10d",
20+
"so3_log_map",
2021
]
2122

2223
# %% ../notebooks/api/06_utils.ipynb 6
@@ -97,10 +98,11 @@ def convert(
9798
euler_angles_to_matrix,
9899
quaternion_to_matrix,
99100
rotation_6d_to_matrix,
101+
so3_exp_map,
100102
)
101103

102104

103-
def _convert_to_rotation_matrix(rotation, parameterization, convention):
105+
def _convert_to_rotation_matrix(rotation, parameterization, convention, **kwargs):
104106
"""Convert any parameterization of a rotation to a matrix representation."""
105107
if parameterization == "axis_angle":
106108
R = axis_angle_to_matrix(rotation)
@@ -116,6 +118,8 @@ def _convert_to_rotation_matrix(rotation, parameterization, convention):
116118
R = quaternion_to_matrix(rotation_10d_to_quaternion(rotation))
117119
elif parameterization == "quaternion_adjugate":
118120
R = quaternion_to_matrix(quaternion_adjugate_to_quaternion(rotation))
121+
elif parameterization == "so3_log_map":
122+
R = so3_exp_map(R, **kwargs)
119123
else:
120124
raise ValueError(
121125
f"parameterization must be in {PARAMETERIZATIONS}, not {parameterization}"
@@ -128,10 +132,11 @@ def _convert_to_rotation_matrix(rotation, parameterization, convention):
128132
matrix_to_euler_angles,
129133
matrix_to_quaternion,
130134
matrix_to_rotation_6d,
135+
so3_log_map,
131136
)
132137

133138

134-
def _convert_from_rotation_matrix(matrix, parameterization, convention=None):
139+
def _convert_from_rotation_matrix(matrix, parameterization, convention=None, **kwargs):
135140
"Convert a rotation matrix to any allowed parameterization."
136141
if parameterization == "axis_angle":
137142
rotation = matrix_to_axis_angle(matrix)
@@ -149,6 +154,8 @@ def _convert_from_rotation_matrix(matrix, parameterization, convention=None):
149154
elif parameterization == "quaternion_adjugate":
150155
q = _convert_from_rotation_matrix(matrix, "quaternion")
151156
rotation = quaternion_to_quaternion_adjugate(q)
157+
elif parameterization == "so3_log_map":
158+
rotation = so3_log_map(R, **kwargs)
152159
else:
153160
raise ValueError(
154161
f"parameterization must be in {PARAMETERIZATIONS}, not {parameterization}"

notebooks/api/06_utils.ipynb

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
" \"quaternion_adjugate\",\n",
7171
" \"rotation_6d\",\n",
7272
" \"rotation_10d\",\n",
73+
" \"so3_log_map\",\n",
7374
"]"
7475
]
7576
},
@@ -189,10 +190,11 @@
189190
" euler_angles_to_matrix,\n",
190191
" quaternion_to_matrix,\n",
191192
" rotation_6d_to_matrix,\n",
193+
" so3_exp_map,\n",
192194
")\n",
193195
"\n",
194196
"\n",
195-
"def _convert_to_rotation_matrix(rotation, parameterization, convention):\n",
197+
"def _convert_to_rotation_matrix(rotation, parameterization, convention, **kwargs):\n",
196198
" \"\"\"Convert any parameterization of a rotation to a matrix representation.\"\"\"\n",
197199
" if parameterization == \"axis_angle\":\n",
198200
" R = axis_angle_to_matrix(rotation)\n",
@@ -208,6 +210,8 @@
208210
" R = quaternion_to_matrix(rotation_10d_to_quaternion(rotation))\n",
209211
" elif parameterization == \"quaternion_adjugate\":\n",
210212
" R = quaternion_to_matrix(quaternion_adjugate_to_quaternion(rotation))\n",
213+
" elif parameterization == \"so3_log_map\":\n",
214+
" R = so3_exp_map(R, **kwargs)\n",
211215
" else:\n",
212216
" raise ValueError(\n",
213217
" f\"parameterization must be in {PARAMETERIZATIONS}, not {parameterization}\"\n",
@@ -228,10 +232,11 @@
228232
" matrix_to_euler_angles,\n",
229233
" matrix_to_quaternion,\n",
230234
" matrix_to_rotation_6d,\n",
235+
" so3_log_map,\n",
231236
")\n",
232237
"\n",
233238
"\n",
234-
"def _convert_from_rotation_matrix(matrix, parameterization, convention=None):\n",
239+
"def _convert_from_rotation_matrix(matrix, parameterization, convention=None, **kwargs):\n",
235240
" \"Convert a rotation matrix to any allowed parameterization.\"\n",
236241
" if parameterization == \"axis_angle\":\n",
237242
" rotation = matrix_to_axis_angle(matrix)\n",
@@ -249,6 +254,8 @@
249254
" elif parameterization == \"quaternion_adjugate\":\n",
250255
" q = _convert_from_rotation_matrix(matrix, \"quaternion\")\n",
251256
" rotation = quaternion_to_quaternion_adjugate(q)\n",
257+
" elif parameterization == \"so3_log_map\":\n",
258+
" rotation = so3_log_map(R, **kwargs)\n",
252259
" else:\n",
253260
" raise ValueError(\n",
254261
" f\"parameterization must be in {PARAMETERIZATIONS}, not {parameterization}\"\n",

0 commit comments

Comments
 (0)