Skip to content

Commit 17a8ea9

Browse files
authored
Merge pull request #185 from eigenvivek/refactor-se3
Switch SE(3) backend to `diffdrr.pose.RigidTransform`
2 parents 9daefba + 07397e6 commit 17a8ea9

25 files changed

+1925
-3739
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ plt.show()
7272

7373
On a single NVIDIA RTX 2080 Ti GPU, producing such an image takes
7474

75-
33.3 ms ± 6.78 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
75+
33.6 ms ± 27.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
7676

7777
The full example is available at
7878
[`introduction.ipynb`](https://vivekg.dev/DiffDRR/tutorials/introduction.html).

diffdrr/_modidx.py

Lines changed: 57 additions & 103 deletions
Large diffs are not rendered by default.

diffdrr/detector.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -92,38 +92,20 @@ def _initialize_carm(self: Detector):
9292
return source, target
9393

9494
# %% ../notebooks/api/02_detector.ipynb 7
95-
from .utils import Transform3d, convert
95+
from .pose import RigidTransform
9696

9797

9898
@patch
9999
def forward(
100100
self: Detector,
101-
rotation: torch.Tensor, # Some (batched) representation of a rotation
102-
translation: torch.Tensor, # Batch of C-arm translation (bx, by, bz)
103-
parameterization: str, # Specifies the representation of the rotation
104-
convention: str, # If parameterization is Euler angles, specify convention
101+
pose: RigidTransform,
105102
):
106103
"""Create source and target points for X-rays to trace through the volume."""
107-
if parameterization == "euler_angles" and convention is None:
108-
raise ValueError(
109-
"convention for Euler angles must be specified as a 3 letter combination of [X, Y, Z]"
110-
)
111-
112-
# Convert rotation representation to a rotation matrix, R
113-
# Transpose R to convert to right-handed convention for PyTorch3D
114-
R = convert(rotation, parameterization, "matrix", input_convention=convention)
115-
R = R.transpose(-1, -2)
116-
t = Transform3d(device=rotation.device).rotate(R).translate(translation)
117-
source, target = make_xrays(t, self.source, self.target)
104+
source = pose(self.source)
105+
target = pose(self.target)
118106
return source, target
119107

120108
# %% ../notebooks/api/02_detector.ipynb 8
121-
def make_xrays(t: Transform3d, source: torch.Tensor, target: torch.Tensor):
122-
source = t.transform_points(source)
123-
target = t.transform_points(target)
124-
return source, target
125-
126-
# %% ../notebooks/api/02_detector.ipynb 9
127109
def diffdrr_to_deepdrr(euler_angles):
128110
alpha, beta, gamma = euler_angles.unbind(-1)
129111
return torch.stack([beta, alpha, gamma], dim=1)

diffdrr/drr.py

Lines changed: 17 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,14 @@ def __init__(
2525
sdr: float, # Source-to-detector radius for the C-arm (half of the source-to-detector distance)
2626
height: int, # Height of the rendered DRR
2727
delx: float, # X-axis pixel size
28-
width: int
29-
| None = None, # Width of the rendered DRR (if not provided, set to `height`)
28+
width: int | None = None, # Width of the rendered DRR (default to `height`)
3029
dely: float | None = None, # Y-axis pixel size (if not provided, set to `delx`)
3130
x0: float = 0.0, # Principal point X-offset
3231
y0: float = 0.0, # Principal point Y-offset
3332
p_subsample: float | None = None, # Proportion of pixels to randomly subsample
3433
reshape: bool = True, # Return DRR with shape (b, 1, h, w)
3534
reverse_x_axis: bool = False, # If pose includes reflection (in E(3) not SE(3)), reverse x-axis
36-
patch_size: int
37-
| None = None, # If the entire DRR can't fit in memory, render patches of the DRR in series
35+
patch_size: int | None = None, # Render patches of the DRR in series
3836
bone_attenuation_multiplier: float = 1.0, # Contrast ratio of bone to soft tissue
3937
):
4038
super().__init__()
@@ -92,18 +90,16 @@ def reshape_subsampled_drr(
9290
return drr
9391

9492
# %% ../notebooks/api/00_drr.ipynb 10
95-
from .detector import make_xrays
96-
from .utils import Transform3d
93+
# from diffdrr.se3 import RigidTransform, convert
94+
from .pose import convert
9795

9896

9997
@patch
10098
def forward(
10199
self: DRR,
102-
rotation: torch.Tensor,
103-
translation: torch.Tensor,
104-
parameterization: str,
105-
convention: str = None,
106-
pose: Transform3d = None, # If you have a preformed pose, can pass it directly
100+
*args, # Some batched representation of SE(3)
101+
parameterization: str = None, # Specifies the representation of the rotation
102+
convention: str = None, # If parameterization is Euler angles, specify convention
107103
bone_attenuation_multiplier: float = None, # Contrast ratio of bone to soft tissue
108104
):
109105
"""Generate DRR with rotational and translational parameters."""
@@ -112,18 +108,11 @@ def forward(
112108
if bone_attenuation_multiplier is not None:
113109
self.set_bone_attenuation_multiplier(bone_attenuation_multiplier)
114110

115-
if pose is None:
116-
assert len(rotation) == len(translation)
117-
batch_size = len(rotation)
118-
source, target = self.detector(
119-
rotation=rotation,
120-
translation=translation,
121-
parameterization=parameterization,
122-
convention=convention,
123-
)
111+
if parameterization is None:
112+
pose = args[0]
124113
else:
125-
batch_size = len(pose)
126-
source, target = make_xrays(pose, self.detector.source, self.detector.target)
114+
pose = convert(*args, parameterization=parameterization, convention=convention)
115+
source, target = self.detector(pose)
127116

128117
if self.patch_size is not None:
129118
n_points = target.shape[1] // self.n_patches
@@ -135,7 +124,7 @@ def forward(
135124
img = torch.cat(img, dim=1)
136125
else:
137126
img = siddon_raycast(source, target, self.density, self.spacing)
138-
return self.reshape_transform(img, batch_size=batch_size)
127+
return self.reshape_transform(img, batch_size=len(pose))
139128

140129
# %% ../notebooks/api/00_drr.ipynb 11
141130
@patch
@@ -171,9 +160,6 @@ def set_intrinsics(
171160
).to(self.volume)
172161

173162
# %% ../notebooks/api/00_drr.ipynb 14
174-
from .utils import convert
175-
176-
177163
class Registration(nn.Module):
178164
"""Perform automatic 2D-to-3D registration using differentiable rendering."""
179165

@@ -183,38 +169,25 @@ def __init__(
183169
rotation: torch.Tensor,
184170
translation: torch.Tensor,
185171
parameterization: str,
186-
input_convention: str = None,
187-
output_convention: str = "ZYX",
172+
convention: str = None,
188173
):
189174
super().__init__()
190175
self.drr = drr
191176
self.rotation = nn.Parameter(rotation)
192177
self.translation = nn.Parameter(translation)
193178
self.parameterization = parameterization
194-
self.input_convention = input_convention
195-
self.output_convention = output_convention
179+
self.convention = convention
196180

197181
def forward(self):
198182
return self.drr(
199183
self.rotation,
200184
self.translation,
201-
self.parameterization,
202-
self.input_convention,
185+
parameterization=self.parameterization,
186+
convention=self.convention,
203187
)
204188

205189
def get_rotation(self):
206-
return (
207-
convert(
208-
self.rotation,
209-
input_parameterization=self.parameterization,
210-
output_parameterization="euler_angles",
211-
input_convention=self.input_convention,
212-
output_convention=self.output_convention,
213-
)
214-
.clone()
215-
.detach()
216-
.cpu()
217-
)
190+
return self.rotation.clone().detach().cpu()
218191

219192
def get_translation(self):
220193
return self.translation.clone().detach().cpu()

0 commit comments

Comments
 (0)