Skip to content

Commit

Permalink
Merge pull request #185 from eigenvivek/refactor-se3
Browse files Browse the repository at this point in the history
Switch SE(3) backend to `diffdrr.pose.RigidTransform`
  • Loading branch information
eigenvivek authored Feb 8, 2024
2 parents d82ff4e + c8507bc commit f707cc0
Show file tree
Hide file tree
Showing 25 changed files with 1,925 additions and 3,739 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ plt.show()

On a single NVIDIA RTX 2080 Ti GPU, producing such an image takes

33.3 ms ± 6.78 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
33.6 ms ± 27.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

The full example is available at
[`introduction.ipynb`](https://vivekg.dev/DiffDRR/tutorials/introduction.html).
Expand Down
160 changes: 57 additions & 103 deletions diffdrr/_modidx.py

Large diffs are not rendered by default.

26 changes: 4 additions & 22 deletions diffdrr/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,38 +92,20 @@ def _initialize_carm(self: Detector):
return source, target

# %% ../notebooks/api/02_detector.ipynb 7
from .utils import Transform3d, convert
from .pose import RigidTransform


@patch
def forward(
self: Detector,
rotation: torch.Tensor, # Some (batched) representation of a rotation
translation: torch.Tensor, # Batch of C-arm translation (bx, by, bz)
parameterization: str, # Specifies the representation of the rotation
convention: str, # If parameterization is Euler angles, specify convention
pose: RigidTransform,
):
"""Create source and target points for X-rays to trace through the volume."""
if parameterization == "euler_angles" and convention is None:
raise ValueError(
"convention for Euler angles must be specified as a 3 letter combination of [X, Y, Z]"
)

# Convert rotation representation to a rotation matrix, R
# Transpose R to convert to right-handed convention for PyTorch3D
R = convert(rotation, parameterization, "matrix", input_convention=convention)
R = R.transpose(-1, -2)
t = Transform3d(device=rotation.device).rotate(R).translate(translation)
source, target = make_xrays(t, self.source, self.target)
source = pose(self.source)
target = pose(self.target)
return source, target

# %% ../notebooks/api/02_detector.ipynb 8
def make_xrays(t: Transform3d, source: torch.Tensor, target: torch.Tensor):
source = t.transform_points(source)
target = t.transform_points(target)
return source, target

# %% ../notebooks/api/02_detector.ipynb 9
def diffdrr_to_deepdrr(euler_angles):
alpha, beta, gamma = euler_angles.unbind(-1)
return torch.stack([beta, alpha, gamma], dim=1)
61 changes: 17 additions & 44 deletions diffdrr/drr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,14 @@ def __init__(
sdr: float, # Source-to-detector radius for the C-arm (half of the source-to-detector distance)
height: int, # Height of the rendered DRR
delx: float, # X-axis pixel size
width: int
| None = None, # Width of the rendered DRR (if not provided, set to `height`)
width: int | None = None, # Width of the rendered DRR (default to `height`)
dely: float | None = None, # Y-axis pixel size (if not provided, set to `delx`)
x0: float = 0.0, # Principal point X-offset
y0: float = 0.0, # Principal point Y-offset
p_subsample: float | None = None, # Proportion of pixels to randomly subsample
reshape: bool = True, # Return DRR with shape (b, 1, h, w)
reverse_x_axis: bool = False, # If pose includes reflection (in E(3) not SE(3)), reverse x-axis
patch_size: int
| None = None, # If the entire DRR can't fit in memory, render patches of the DRR in series
patch_size: int | None = None, # Render patches of the DRR in series
bone_attenuation_multiplier: float = 1.0, # Contrast ratio of bone to soft tissue
):
super().__init__()
Expand Down Expand Up @@ -92,18 +90,16 @@ def reshape_subsampled_drr(
return drr

# %% ../notebooks/api/00_drr.ipynb 10
from .detector import make_xrays
from .utils import Transform3d
# from diffdrr.se3 import RigidTransform, convert
from .pose import convert


@patch
def forward(
self: DRR,
rotation: torch.Tensor,
translation: torch.Tensor,
parameterization: str,
convention: str = None,
pose: Transform3d = None, # If you have a preformed pose, can pass it directly
*args, # Some batched representation of SE(3)
parameterization: str = None, # Specifies the representation of the rotation
convention: str = None, # If parameterization is Euler angles, specify convention
bone_attenuation_multiplier: float = None, # Contrast ratio of bone to soft tissue
):
"""Generate DRR with rotational and translational parameters."""
Expand All @@ -112,18 +108,11 @@ def forward(
if bone_attenuation_multiplier is not None:
self.set_bone_attenuation_multiplier(bone_attenuation_multiplier)

if pose is None:
assert len(rotation) == len(translation)
batch_size = len(rotation)
source, target = self.detector(
rotation=rotation,
translation=translation,
parameterization=parameterization,
convention=convention,
)
if parameterization is None:
pose = args[0]
else:
batch_size = len(pose)
source, target = make_xrays(pose, self.detector.source, self.detector.target)
pose = convert(*args, parameterization=parameterization, convention=convention)
source, target = self.detector(pose)

if self.patch_size is not None:
n_points = target.shape[1] // self.n_patches
Expand All @@ -135,7 +124,7 @@ def forward(
img = torch.cat(img, dim=1)
else:
img = siddon_raycast(source, target, self.density, self.spacing)
return self.reshape_transform(img, batch_size=batch_size)
return self.reshape_transform(img, batch_size=len(pose))

# %% ../notebooks/api/00_drr.ipynb 11
@patch
Expand Down Expand Up @@ -171,9 +160,6 @@ def set_intrinsics(
).to(self.volume)

# %% ../notebooks/api/00_drr.ipynb 14
from .utils import convert


class Registration(nn.Module):
"""Perform automatic 2D-to-3D registration using differentiable rendering."""

Expand All @@ -183,38 +169,25 @@ def __init__(
rotation: torch.Tensor,
translation: torch.Tensor,
parameterization: str,
input_convention: str = None,
output_convention: str = "ZYX",
convention: str = None,
):
super().__init__()
self.drr = drr
self.rotation = nn.Parameter(rotation)
self.translation = nn.Parameter(translation)
self.parameterization = parameterization
self.input_convention = input_convention
self.output_convention = output_convention
self.convention = convention

def forward(self):
return self.drr(
self.rotation,
self.translation,
self.parameterization,
self.input_convention,
parameterization=self.parameterization,
convention=self.convention,
)

def get_rotation(self):
return (
convert(
self.rotation,
input_parameterization=self.parameterization,
output_parameterization="euler_angles",
input_convention=self.input_convention,
output_convention=self.output_convention,
)
.clone()
.detach()
.cpu()
)
return self.rotation.clone().detach().cpu()

def get_translation(self):
return self.translation.clone().detach().cpu()
Loading

0 comments on commit f707cc0

Please sign in to comment.