Skip to content

Commit

Permalink
Try updating C-arm intial pose
Browse files Browse the repository at this point in the history
  • Loading branch information
eigenvivek committed Mar 14, 2024
1 parent 2ff6aa2 commit 4c9bd35
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 87 deletions.
56 changes: 28 additions & 28 deletions diffdrr/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Detector(torch.nn.Module):

def __init__(
self,
sdr: float, # Source-to-detector radius (half of the source-to-detector distance)
sdd: float, # Source-to-detector distance (i.e., focal length)
height: int, # Height of the X-ray detector
width: int, # Width of the X-ray detector
delx: float, # Pixel spacing in the X-direction
Expand All @@ -31,7 +31,7 @@ def __init__(
reverse_x_axis: bool = False, # If pose includes reflection (in E(3) not SE(3)), reverse x-axis
):
super().__init__()
self.sdr = sdr
self.sdd = sdd
self.height = height
self.width = width
self.delx = delx
Expand All @@ -48,30 +48,30 @@ def __init__(
self.register_buffer("source", source)
self.register_buffer("target", target)

# Anatomy to world coordinates
flip_xz = torch.tensor(
[
[0.0, 0.0, -1.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
]
)
translate = torch.tensor(
[
[1.0, 0.0, 0.0, -self.sdr],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
]
)
self.register_buffer("_flip_xz", flip_xz)
self.register_buffer("_translate", translate)
# # Anatomy to world coordinates
# flip_xz = torch.tensor(
# [
# [0.0, 0.0, -1.0, 0.0],
# [0.0, 1.0, 0.0, 0.0],
# [1.0, 0.0, 0.0, 0.0],
# [0.0, 0.0, 0.0, 1.0],
# ]
# )
# translate = torch.tensor(
# [
# [1.0, 0.0, 0.0, -self.sdr],
# [0.0, 1.0, 0.0, 0.0],
# [0.0, 0.0, 1.0, 0.0],
# [0.0, 0.0, 0.0, 1.0],
# ]
# )
# self.register_buffer("_flip_xz", flip_xz)
# self.register_buffer("_translate", translate)

@property
def intrinsic(self):
return make_intrinsic_matrix(
self.sdr,
self.sdd,
self.delx,
self.dely,
self.height,
Expand All @@ -93,16 +93,16 @@ def translate(self):
def _initialize_carm(self: Detector):
"""Initialize the default position for the source and detector plane."""
try:
device = self.sdr.device
device = self.sdd.device
except AttributeError:
device = torch.device("cpu")

# Initialize the source on the x-axis and the center of the detector plane on the negative x-axis
source = torch.tensor([[1.0, 0.0, 0.0]], device=device) * self.sdr
center = torch.tensor([[-1.0, 0.0, 0.0]], device=device) * self.sdr
# Initialize the source at the origin and the center of the detector plane on the positive z-axis
source = torch.tensor([[0.0, 0.0, 0.0]], device=device)
center = torch.tensor([[0.0, 0.0, 1.0]], device=device) * self.sdd

# Use the standard basis for the detector plane
basis = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], device=device)
basis = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device)

# Construct the detector plane with different offsets for even or odd heights
h_off = 1.0 if self.height % 2 else 0.5
Expand All @@ -126,7 +126,7 @@ def _initialize_carm(self: Detector):
target = target.unsqueeze(0)

# Apply principal point offset
target[..., 2] -= self.x0
target[..., 0] -= self.x0
target[..., 1] -= self.y0

if self.n_subsample is not None:
Expand Down
24 changes: 9 additions & 15 deletions diffdrr/drr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
volume: np.ndarray, # CT volume
origin: tuple, # Origin of the CT volume in world coordinates
spacing: tuple, # Dimensions in the CT volume in world coordinates
sdr: float, # Source-to-detector radius for the C-arm (half of the source-to-detector distance)
sdd: float, # Source-to-detector distance (i.e., the C-arm's focal length)
height: int, # Height of the rendered DRR
delx: float, # X-axis pixel size
width: int | None = None, # Width of the rendered DRR (default to `height`)
Expand All @@ -48,7 +48,7 @@ def __init__(
int(height * width * p_subsample) if p_subsample is not None else None
)
self.detector = Detector(
sdr,
sdd,
height,
width,
delx,
Expand Down Expand Up @@ -173,14 +173,14 @@ def set_bone_attenuation_multiplier(self: DRR, bone_attenuation_multiplier: floa
@patch
def set_intrinsics(
self: DRR,
sdr: float = None,
sdd: float = None,
delx: float = None,
dely: float = None,
x0: float = None,
y0: float = None,
):
self.detector = Detector(
sdr if sdr is not None else self.detector.sdr,
sdd if sdd is not None else self.detector.sdd,
self.detector.height,
self.detector.width,
delx if delx is not None else self.detector.delx,
Expand All @@ -201,9 +201,7 @@ def perspective_projection(
pose: RigidTransform,
pts: torch.Tensor,
):
extrinsic = (
pose.inverse().compose(self.detector.translate).compose(self.detector.flip_xz)
)
extrinsic = pose.inverse().compose(self.detector.flip_z)
x = extrinsic(pts)
x = torch.einsum("ij, bnj -> bni", self.detector.intrinsic, x)
z = x[..., -1].unsqueeze(-1).clone()
Expand All @@ -225,14 +223,10 @@ def inverse_projection(
.compose(self.detector.translate.inverse())
.compose(pose)
)
x = (
-2
* self.detector.sdr
* torch.einsum(
"ij, bnj -> bni",
self.detector.intrinsic.inverse(),
pad(pts, (0, 1), value=1), # Convert to homogenous coordinates
)
x = -self.detector.sdd * torch.einsum(
"ij, bnj -> bni",
self.detector.intrinsic.inverse(),
pad(pts, (0, 1), value=1), # Convert to homogenous coordinates
)
return extrinsic(x)

Expand Down
26 changes: 10 additions & 16 deletions notebooks/api/00_drr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"## DRR\n",
"`DRR` is a PyTorch module that compues differentiable digitally reconstructed radiographs. The viewing angle for the DRR (known generally in computer graphics as the *camera pose*) is parameterized by the following parameters:\n",
"\n",
"- SDR : Source-to-Detector radius (half of the source-to-detector distance)\n",
"- SDD : source-to-detector distance (i.e., the focal length of the C-arm)\n",
"- $\\mathbf R \\in \\mathrm{SO}(3)$ : a rotation\n",
"- $\\mathbf t \\in \\mathbb R^3$ : a translation"
]
Expand Down Expand Up @@ -118,7 +118,7 @@
" volume: np.ndarray, # CT volume\n",
" origin: tuple, # Origin of the CT volume in world coordinates\n",
" spacing: tuple, # Dimensions in the CT volume in world coordinates\n",
" sdr: float, # Source-to-detector radius for the C-arm (half of the source-to-detector distance)\n",
" sdd: float, # Source-to-detector distance (i.e., the C-arm's focal length)\n",
" height: int, # Height of the rendered DRR\n",
" delx: float, # X-axis pixel size\n",
" width: int | None = None, # Width of the rendered DRR (default to `height`)\n",
Expand All @@ -143,7 +143,7 @@
" int(height * width * p_subsample) if p_subsample is not None else None\n",
" )\n",
" self.detector = Detector(\n",
" sdr,\n",
" sdd,\n",
" height,\n",
" width,\n",
" delx,\n",
Expand Down Expand Up @@ -308,14 +308,14 @@
"@patch\n",
"def set_intrinsics(\n",
" self: DRR,\n",
" sdr: float = None,\n",
" sdd: float = None,\n",
" delx: float = None,\n",
" dely: float = None,\n",
" x0: float = None,\n",
" y0: float = None,\n",
"):\n",
" self.detector = Detector(\n",
" sdr if sdr is not None else self.detector.sdr,\n",
" sdd if sdd is not None else self.detector.sdd,\n",
" self.detector.height,\n",
" self.detector.width,\n",
" delx if delx is not None else self.detector.delx,\n",
Expand Down Expand Up @@ -344,9 +344,7 @@
" pose: RigidTransform,\n",
" pts: torch.Tensor,\n",
"):\n",
" extrinsic = (\n",
" pose.inverse().compose(self.detector.translate).compose(self.detector.flip_xz)\n",
" )\n",
" extrinsic = pose.inverse().compose(self.detector.flip_z)\n",
" x = extrinsic(pts)\n",
" x = torch.einsum(\"ij, bnj -> bni\", self.detector.intrinsic, x)\n",
" z = x[..., -1].unsqueeze(-1).clone()\n",
Expand Down Expand Up @@ -376,14 +374,10 @@
" .compose(self.detector.translate.inverse())\n",
" .compose(pose)\n",
" )\n",
" x = (\n",
" -2\n",
" * self.detector.sdr\n",
" * torch.einsum(\n",
" \"ij, bnj -> bni\",\n",
" self.detector.intrinsic.inverse(),\n",
" pad(pts, (0, 1), value=1), # Convert to homogenous coordinates\n",
" )\n",
" x = -self.detector.sdd * torch.einsum(\n",
" \"ij, bnj -> bni\",\n",
" self.detector.intrinsic.inverse(),\n",
" pad(pts, (0, 1), value=1), # Convert to homogenous coordinates\n",
" )\n",
" return extrinsic(x)"
]
Expand Down
56 changes: 28 additions & 28 deletions notebooks/api/02_detector.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
"\n",
" def __init__(\n",
" self,\n",
" sdr: float, # Source-to-detector radius (half of the source-to-detector distance)\n",
" sdd: float, # Source-to-detector distance (i.e., focal length)\n",
" height: int, # Height of the X-ray detector\n",
" width: int, # Width of the X-ray detector\n",
" delx: float, # Pixel spacing in the X-direction\n",
Expand All @@ -86,7 +86,7 @@
" reverse_x_axis: bool = False, # If pose includes reflection (in E(3) not SE(3)), reverse x-axis\n",
" ):\n",
" super().__init__()\n",
" self.sdr = sdr\n",
" self.sdd = sdd\n",
" self.height = height\n",
" self.width = width\n",
" self.delx = delx\n",
Expand All @@ -103,30 +103,30 @@
" self.register_buffer(\"source\", source)\n",
" self.register_buffer(\"target\", target)\n",
"\n",
" # Anatomy to world coordinates\n",
" flip_xz = torch.tensor(\n",
" [\n",
" [0.0, 0.0, -1.0, 0.0],\n",
" [0.0, 1.0, 0.0, 0.0],\n",
" [1.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 1.0],\n",
" ]\n",
" )\n",
" translate = torch.tensor(\n",
" [\n",
" [1.0, 0.0, 0.0, -self.sdr],\n",
" [0.0, 1.0, 0.0, 0.0],\n",
" [0.0, 0.0, 1.0, 0.0],\n",
" [0.0, 0.0, 0.0, 1.0],\n",
" ]\n",
" )\n",
" self.register_buffer(\"_flip_xz\", flip_xz)\n",
" self.register_buffer(\"_translate\", translate)\n",
" # # Anatomy to world coordinates\n",
" # flip_xz = torch.tensor(\n",
" # [\n",
" # [0.0, 0.0, -1.0, 0.0],\n",
" # [0.0, 1.0, 0.0, 0.0],\n",
" # [1.0, 0.0, 0.0, 0.0],\n",
" # [0.0, 0.0, 0.0, 1.0],\n",
" # ]\n",
" # )\n",
" # translate = torch.tensor(\n",
" # [\n",
" # [1.0, 0.0, 0.0, -self.sdr],\n",
" # [0.0, 1.0, 0.0, 0.0],\n",
" # [0.0, 0.0, 1.0, 0.0],\n",
" # [0.0, 0.0, 0.0, 1.0],\n",
" # ]\n",
" # )\n",
" # self.register_buffer(\"_flip_xz\", flip_xz)\n",
" # self.register_buffer(\"_translate\", translate)\n",
"\n",
" @property\n",
" def intrinsic(self):\n",
" return make_intrinsic_matrix(\n",
" self.sdr,\n",
" self.sdd,\n",
" self.delx,\n",
" self.dely,\n",
" self.height,\n",
Expand Down Expand Up @@ -156,16 +156,16 @@
"def _initialize_carm(self: Detector):\n",
" \"\"\"Initialize the default position for the source and detector plane.\"\"\"\n",
" try:\n",
" device = self.sdr.device\n",
" device = self.sdd.device\n",
" except AttributeError:\n",
" device = torch.device(\"cpu\")\n",
"\n",
" # Initialize the source on the x-axis and the center of the detector plane on the negative x-axis\n",
" source = torch.tensor([[1.0, 0.0, 0.0]], device=device) * self.sdr\n",
" center = torch.tensor([[-1.0, 0.0, 0.0]], device=device) * self.sdr\n",
" # Initialize the source at the origin and the center of the detector plane on the positive z-axis\n",
" source = torch.tensor([[0.0, 0.0, 0.0]], device=device)\n",
" center = torch.tensor([[0.0, 0.0, 1.0]], device=device) * self.sdd\n",
"\n",
" # Use the standard basis for the detector plane\n",
" basis = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], device=device)\n",
" basis = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device)\n",
"\n",
" # Construct the detector plane with different offsets for even or odd heights\n",
" h_off = 1.0 if self.height % 2 else 0.5\n",
Expand All @@ -189,7 +189,7 @@
" target = target.unsqueeze(0)\n",
"\n",
" # Apply principal point offset\n",
" target[..., 2] -= self.x0\n",
" target[..., 0] -= self.x0\n",
" target[..., 1] -= self.y0\n",
"\n",
" if self.n_subsample is not None:\n",
Expand Down

0 comments on commit 4c9bd35

Please sign in to comment.