Skip to content

Commit f147f43

Browse files
committed
Switch SE(3) backend to diffdrr.pose.RigidTransform
1 parent d82ff4e commit f147f43

File tree

12 files changed

+491
-3769
lines changed

12 files changed

+491
-3769
lines changed

diffdrr/_modidx.py

Lines changed: 62 additions & 109 deletions
Large diffs are not rendered by default.

diffdrr/detector.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -92,38 +92,20 @@ def _initialize_carm(self: Detector):
9292
return source, target
9393

9494
# %% ../notebooks/api/02_detector.ipynb 7
95-
from .utils import Transform3d, convert
95+
from .pose import RigidTransform
9696

9797

9898
@patch
9999
def forward(
100100
self: Detector,
101-
rotation: torch.Tensor, # Some (batched) representation of a rotation
102-
translation: torch.Tensor, # Batch of C-arm translation (bx, by, bz)
103-
parameterization: str, # Specifies the representation of the rotation
104-
convention: str, # If parameterization is Euler angles, specify convention
101+
pose: RigidTransform,
105102
):
106103
"""Create source and target points for X-rays to trace through the volume."""
107-
if parameterization == "euler_angles" and convention is None:
108-
raise ValueError(
109-
"convention for Euler angles must be specified as a 3 letter combination of [X, Y, Z]"
110-
)
111-
112-
# Convert rotation representation to a rotation matrix, R
113-
# Transpose R to convert to right-handed convention for PyTorch3D
114-
R = convert(rotation, parameterization, "matrix", input_convention=convention)
115-
R = R.transpose(-1, -2)
116-
t = Transform3d(device=rotation.device).rotate(R).translate(translation)
117-
source, target = make_xrays(t, self.source, self.target)
104+
source = pose(self.source)
105+
target = pose(self.target)
118106
return source, target
119107

120108
# %% ../notebooks/api/02_detector.ipynb 8
121-
def make_xrays(t: Transform3d, source: torch.Tensor, target: torch.Tensor):
122-
source = t.transform_points(source)
123-
target = t.transform_points(target)
124-
return source, target
125-
126-
# %% ../notebooks/api/02_detector.ipynb 9
127109
def diffdrr_to_deepdrr(euler_angles):
128110
alpha, beta, gamma = euler_angles.unbind(-1)
129111
return torch.stack([beta, alpha, gamma], dim=1)

diffdrr/drr.py

Lines changed: 11 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .siddon import siddon_raycast
1313

1414
# %% auto 0
15-
__all__ = ['DRR', 'Registration']
15+
__all__ = ['DRR']
1616

1717
# %% ../notebooks/api/00_drr.ipynb 7
1818
class DRR(nn.Module):
@@ -92,18 +92,16 @@ def reshape_subsampled_drr(
9292
return drr
9393

9494
# %% ../notebooks/api/00_drr.ipynb 10
95-
from .detector import make_xrays
96-
from .utils import Transform3d
95+
# from diffdrr.se3 import RigidTransform, convert
96+
from .pose import convert
9797

9898

9999
@patch
100100
def forward(
101101
self: DRR,
102-
rotation: torch.Tensor,
103-
translation: torch.Tensor,
104-
parameterization: str,
105-
convention: str = None,
106-
pose: Transform3d = None, # If you have a preformed pose, can pass it directly
102+
*args, # Some batched representation of SE(3)
103+
parameterization: str = None, # Specifies the representation of the rotation
104+
convention: str = None, # If parameterization is Euler angles, specify convention
107105
bone_attenuation_multiplier: float = None, # Contrast ratio of bone to soft tissue
108106
):
109107
"""Generate DRR with rotational and translational parameters."""
@@ -112,18 +110,11 @@ def forward(
112110
if bone_attenuation_multiplier is not None:
113111
self.set_bone_attenuation_multiplier(bone_attenuation_multiplier)
114112

115-
if pose is None:
116-
assert len(rotation) == len(translation)
117-
batch_size = len(rotation)
118-
source, target = self.detector(
119-
rotation=rotation,
120-
translation=translation,
121-
parameterization=parameterization,
122-
convention=convention,
123-
)
113+
if parameterization is None:
114+
pose = args[0]
124115
else:
125-
batch_size = len(pose)
126-
source, target = make_xrays(pose, self.detector.source, self.detector.target)
116+
pose = convert(*args, parameterization=parameterization, convention=convention)
117+
source, target = self.detector(pose)
127118

128119
if self.patch_size is not None:
129120
n_points = target.shape[1] // self.n_patches
@@ -135,7 +126,7 @@ def forward(
135126
img = torch.cat(img, dim=1)
136127
else:
137128
img = siddon_raycast(source, target, self.density, self.spacing)
138-
return self.reshape_transform(img, batch_size=batch_size)
129+
return self.reshape_transform(img, batch_size=len(pose))
139130

140131
# %% ../notebooks/api/00_drr.ipynb 11
141132
@patch
@@ -169,52 +160,3 @@ def set_intrinsics(
169160
n_subsample=self.detector.n_subsample,
170161
reverse_x_axis=self.detector.reverse_x_axis,
171162
).to(self.volume)
172-
173-
# %% ../notebooks/api/00_drr.ipynb 14
174-
from .utils import convert
175-
176-
177-
class Registration(nn.Module):
178-
"""Perform automatic 2D-to-3D registration using differentiable rendering."""
179-
180-
def __init__(
181-
self,
182-
drr: DRR,
183-
rotation: torch.Tensor,
184-
translation: torch.Tensor,
185-
parameterization: str,
186-
input_convention: str = None,
187-
output_convention: str = "ZYX",
188-
):
189-
super().__init__()
190-
self.drr = drr
191-
self.rotation = nn.Parameter(rotation)
192-
self.translation = nn.Parameter(translation)
193-
self.parameterization = parameterization
194-
self.input_convention = input_convention
195-
self.output_convention = output_convention
196-
197-
def forward(self):
198-
return self.drr(
199-
self.rotation,
200-
self.translation,
201-
self.parameterization,
202-
self.input_convention,
203-
)
204-
205-
def get_rotation(self):
206-
return (
207-
convert(
208-
self.rotation,
209-
input_parameterization=self.parameterization,
210-
output_parameterization="euler_angles",
211-
input_convention=self.input_convention,
212-
output_convention=self.output_convention,
213-
)
214-
.clone()
215-
.detach()
216-
.cpu()
217-
)
218-
219-
def get_translation(self):
220-
return self.translation.clone().detach().cpu()

diffdrr/pose.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/06_pypose.ipynb.
2+
3+
# %% auto 0
4+
__all__ = ['convert']
5+
6+
# %% ../notebooks/api/06_pypose.ipynb 3
7+
import torch
8+
import pypose as pp
9+
10+
# %% ../notebooks/api/06_pypose.ipynb 4
11+
def convert(*args, parameterization, convention=None, **kwargs) -> pp.SE3:
12+
if parameterization == "euler_angles" and convention is None:
13+
raise ValueError(
14+
"convention for Euler angles must be specified as a 3 letter combination of [X, Y, Z]"
15+
)
16+
17+
if parameterization == "axis_angle":
18+
rotation, translation = args
19+
quaternion = pp.so3(rotation).Exp().tensor()
20+
return convert(quaternion, translation, parameterization="quaternion")
21+
elif parameterization == "euler_angles":
22+
rotation, translation = args
23+
rotmat = euler_angles_to_matrix(rotation, convention)
24+
matrix = torch.concat([rotmat, translation.unsqueeze(-1)], axis=-1)
25+
return convert(matrix, parameterization="matrix", check=False)
26+
elif parameterization == "matrix":
27+
return pp.from_matrix(*args, ltype=pp.SE3_type, **kwargs)
28+
elif parameterization == "quaternion":
29+
rotation, translation = args
30+
return pp.SE3(torch.concat([translation, rotation], axis=-1))
31+
elif parameterization == "quaternion_adjugate":
32+
rotation, translation = args
33+
quaternion = quaternion_adjugate_to_quaternion(rotation)
34+
return convert(quaternion, translation, parameterization="quaternion")
35+
elif parameterization == "rotation_6d":
36+
rotation, translation = args
37+
rotmat = rotation_6d_to_matrix(rotation)
38+
matrix = torch.concat([rotmat, translation.unsqueeze(-1)], axis=-1)
39+
return convert(matrix, parameterization="matrix", check=False)
40+
elif parameterization in ["rotation_10d"]:
41+
rotation, translation = args
42+
quaternion = rotation_10d_to_quaternion(rotation)
43+
return convert(quaternion, translation, parameterization="quaternion")
44+
elif parameterization == "se3":
45+
rotation, translation = args
46+
return pp.se3(torch.concat([translation, rotation], axis=-1)).Exp()
47+
elif parameterization == "SE3":
48+
return args[0]
49+
else:
50+
raise ValueError
51+
52+
# %% ../notebooks/api/06_pypose.ipynb 5
53+
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
54+
"""Source: http://arxiv.org/abs/1812.07035"""
55+
a1, a2 = d6[..., :3], d6[..., 3:]
56+
b1 = F.normalize(a1, dim=-1)
57+
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
58+
b2 = F.normalize(b2, dim=-1)
59+
b3 = torch.cross(b1, b2, dim=-1)
60+
return torch.stack((b1, b2, b3), dim=-2)
61+
62+
63+
def rotation_10d_to_quaternion(rotation: torch.Tensor) -> torch.Tensor:
64+
"""
65+
Convert a 10-vector into a symmetric matrix, whose eigenvector corresponding
66+
to the eigenvalue of minimum modulus is the resulting quaternion.
67+
68+
Source: https://arxiv.org/abs/2006.01031
69+
"""
70+
A = _10vec_to_4x4symmetric(rotation) # A is a symmetric data matrix
71+
return torch.linalg.eigh(A).eigenvectors[..., 0]
72+
73+
74+
def quaternion_adjugate_to_quaternion(rotation: torch.Tensor) -> torch.Tensor:
75+
"""
76+
Convert a 10-vector in the quaternion adjugate, a symmetric matrix whose
77+
eigenvector corresponding to the eigenvalue of maximum modulus is the
78+
(unnormalized) quaternion. Uses a fast method to solve for the eigenvector
79+
without explicity computing the eigendecomposition.
80+
81+
Source: https://arxiv.org/abs/2205.09116
82+
"""
83+
A = _10vec_to_4x4symmetric(rotation) # A is the quaternion adjugate
84+
norms = A.norm(dim=1).amax(dim=1, keepdim=True)
85+
max_eigenvectors = torch.argmax(A.norm(dim=1), dim=1)
86+
return A[range(len(A)), max_eigenvectors] / norms
87+
88+
89+
def _10vec_to_4x4symmetric(vec):
90+
"""Convert a 10-vector to a symmetric 4x4 matrix."""
91+
b = len(vec)
92+
A = torch.zeros(b, 4, 4, device=vec.device)
93+
idx, jdx = torch.triu_indices(4, 4)
94+
A[..., idx, jdx] = vec
95+
A[..., jdx, idx] = vec
96+
return A
97+
98+
# %% ../notebooks/api/06_pypose.ipynb 6
99+
def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
100+
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
101+
raise ValueError("Invalid input euler angles.")
102+
if len(convention) != 3:
103+
raise ValueError("Convention must have 3 letters.")
104+
if convention[1] in (convention[0], convention[2]):
105+
raise ValueError(f"Invalid convention {convention}.")
106+
for letter in convention:
107+
if letter not in ("X", "Y", "Z"):
108+
raise ValueError(f"Invalid letter {letter} in convention string.")
109+
matrices = [
110+
_axis_angle_rotation(c, e)
111+
for c, e in zip(convention, torch.unbind(euler_angles, -1))
112+
]
113+
return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
114+
115+
116+
def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
117+
cos = torch.cos(angle)
118+
sin = torch.sin(angle)
119+
one = torch.ones_like(angle)
120+
zero = torch.zeros_like(angle)
121+
122+
if axis == "X":
123+
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
124+
elif axis == "Y":
125+
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
126+
elif axis == "Z":
127+
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
128+
else:
129+
raise ValueError("letter must be either X, Y or Z.")
130+
131+
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))

0 commit comments

Comments
 (0)