diff --git a/diffdrr/detector.py b/diffdrr/detector.py index 1e491bb11..dc44a09be 100644 --- a/diffdrr/detector.py +++ b/diffdrr/detector.py @@ -20,7 +20,7 @@ class Detector(torch.nn.Module): def __init__( self, - sdr: float, # Source-to-detector radius (half of the source-to-detector distance) + sdd: float, # Source-to-detector distance (i.e., focal length) height: int, # Height of the X-ray detector width: int, # Width of the X-ray detector delx: float, # Pixel spacing in the X-direction @@ -31,7 +31,7 @@ def __init__( reverse_x_axis: bool = False, # If pose includes reflection (in E(3) not SE(3)), reverse x-axis ): super().__init__() - self.sdr = sdr + self.sdd = sdd self.height = height self.width = width self.delx = delx @@ -48,30 +48,30 @@ def __init__( self.register_buffer("source", source) self.register_buffer("target", target) - # Anatomy to world coordinates - flip_xz = torch.tensor( - [ - [0.0, 0.0, -1.0, 0.0], - [0.0, 1.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0], - ] - ) - translate = torch.tensor( - [ - [1.0, 0.0, 0.0, -self.sdr], - [0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 1.0], - ] - ) - self.register_buffer("_flip_xz", flip_xz) - self.register_buffer("_translate", translate) + # # Anatomy to world coordinates + # flip_xz = torch.tensor( + # [ + # [0.0, 0.0, -1.0, 0.0], + # [0.0, 1.0, 0.0, 0.0], + # [1.0, 0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0, 1.0], + # ] + # ) + # translate = torch.tensor( + # [ + # [1.0, 0.0, 0.0, -self.sdr], + # [0.0, 1.0, 0.0, 0.0], + # [0.0, 0.0, 1.0, 0.0], + # [0.0, 0.0, 0.0, 1.0], + # ] + # ) + # self.register_buffer("_flip_xz", flip_xz) + # self.register_buffer("_translate", translate) @property def intrinsic(self): return make_intrinsic_matrix( - self.sdr, + self.sdd, self.delx, self.dely, self.height, @@ -93,16 +93,16 @@ def translate(self): def _initialize_carm(self: Detector): """Initialize the default position for the source and detector plane.""" try: - device = self.sdr.device + device = self.sdd.device except AttributeError: device = torch.device("cpu") - # Initialize the source on the x-axis and the center of the detector plane on the negative x-axis - source = torch.tensor([[1.0, 0.0, 0.0]], device=device) * self.sdr - center = torch.tensor([[-1.0, 0.0, 0.0]], device=device) * self.sdr + # Initialize the source at the origin and the center of the detector plane on the positive z-axis + source = torch.tensor([[0.0, 0.0, 0.0]], device=device) + center = torch.tensor([[0.0, 0.0, 1.0]], device=device) * self.sdd # Use the standard basis for the detector plane - basis = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], device=device) + basis = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device) # Construct the detector plane with different offsets for even or odd heights h_off = 1.0 if self.height % 2 else 0.5 @@ -126,7 +126,7 @@ def _initialize_carm(self: Detector): target = target.unsqueeze(0) # Apply principal point offset - target[..., 2] -= self.x0 + target[..., 0] -= self.x0 target[..., 1] -= self.y0 if self.n_subsample is not None: diff --git a/diffdrr/drr.py b/diffdrr/drr.py index ac182120c..94237ecdc 100644 --- a/diffdrr/drr.py +++ b/diffdrr/drr.py @@ -23,7 +23,7 @@ def __init__( volume: np.ndarray, # CT volume origin: tuple, # Origin of the CT volume in world coordinates spacing: tuple, # Dimensions in the CT volume in world coordinates - sdr: float, # Source-to-detector radius for the C-arm (half of the source-to-detector distance) + sdd: float, # Source-to-detector distance (i.e., the C-arm's focal length) height: int, # Height of the rendered DRR delx: float, # X-axis pixel size width: int | None = None, # Width of the rendered DRR (default to `height`) @@ -48,7 +48,7 @@ def __init__( int(height * width * p_subsample) if p_subsample is not None else None ) self.detector = Detector( - sdr, + sdd, height, width, delx, @@ -173,14 +173,14 @@ def set_bone_attenuation_multiplier(self: DRR, bone_attenuation_multiplier: floa @patch def set_intrinsics( self: DRR, - sdr: float = None, + sdd: float = None, delx: float = None, dely: float = None, x0: float = None, y0: float = None, ): self.detector = Detector( - sdr if sdr is not None else self.detector.sdr, + sdd if sdd is not None else self.detector.sdd, self.detector.height, self.detector.width, delx if delx is not None else self.detector.delx, @@ -201,9 +201,7 @@ def perspective_projection( pose: RigidTransform, pts: torch.Tensor, ): - extrinsic = ( - pose.inverse().compose(self.detector.translate).compose(self.detector.flip_xz) - ) + extrinsic = pose.inverse().compose(self.detector.flip_z) x = extrinsic(pts) x = torch.einsum("ij, bnj -> bni", self.detector.intrinsic, x) z = x[..., -1].unsqueeze(-1).clone() @@ -225,14 +223,10 @@ def inverse_projection( .compose(self.detector.translate.inverse()) .compose(pose) ) - x = ( - -2 - * self.detector.sdr - * torch.einsum( - "ij, bnj -> bni", - self.detector.intrinsic.inverse(), - pad(pts, (0, 1), value=1), # Convert to homogenous coordinates - ) + x = -self.detector.sdd * torch.einsum( + "ij, bnj -> bni", + self.detector.intrinsic.inverse(), + pad(pts, (0, 1), value=1), # Convert to homogenous coordinates ) return extrinsic(x) diff --git a/notebooks/api/00_drr.ipynb b/notebooks/api/00_drr.ipynb index 99e3236e3..3e485d748 100644 --- a/notebooks/api/00_drr.ipynb +++ b/notebooks/api/00_drr.ipynb @@ -60,7 +60,7 @@ "## DRR\n", "`DRR` is a PyTorch module that compues differentiable digitally reconstructed radiographs. The viewing angle for the DRR (known generally in computer graphics as the *camera pose*) is parameterized by the following parameters:\n", "\n", - "- SDR : Source-to-Detector radius (half of the source-to-detector distance)\n", + "- SDD : source-to-detector distance (i.e., the focal length of the C-arm)\n", "- $\\mathbf R \\in \\mathrm{SO}(3)$ : a rotation\n", "- $\\mathbf t \\in \\mathbb R^3$ : a translation" ] @@ -118,7 +118,7 @@ " volume: np.ndarray, # CT volume\n", " origin: tuple, # Origin of the CT volume in world coordinates\n", " spacing: tuple, # Dimensions in the CT volume in world coordinates\n", - " sdr: float, # Source-to-detector radius for the C-arm (half of the source-to-detector distance)\n", + " sdd: float, # Source-to-detector distance (i.e., the C-arm's focal length)\n", " height: int, # Height of the rendered DRR\n", " delx: float, # X-axis pixel size\n", " width: int | None = None, # Width of the rendered DRR (default to `height`)\n", @@ -143,7 +143,7 @@ " int(height * width * p_subsample) if p_subsample is not None else None\n", " )\n", " self.detector = Detector(\n", - " sdr,\n", + " sdd,\n", " height,\n", " width,\n", " delx,\n", @@ -308,14 +308,14 @@ "@patch\n", "def set_intrinsics(\n", " self: DRR,\n", - " sdr: float = None,\n", + " sdd: float = None,\n", " delx: float = None,\n", " dely: float = None,\n", " x0: float = None,\n", " y0: float = None,\n", "):\n", " self.detector = Detector(\n", - " sdr if sdr is not None else self.detector.sdr,\n", + " sdd if sdd is not None else self.detector.sdd,\n", " self.detector.height,\n", " self.detector.width,\n", " delx if delx is not None else self.detector.delx,\n", @@ -344,9 +344,7 @@ " pose: RigidTransform,\n", " pts: torch.Tensor,\n", "):\n", - " extrinsic = (\n", - " pose.inverse().compose(self.detector.translate).compose(self.detector.flip_xz)\n", - " )\n", + " extrinsic = pose.inverse().compose(self.detector.flip_z)\n", " x = extrinsic(pts)\n", " x = torch.einsum(\"ij, bnj -> bni\", self.detector.intrinsic, x)\n", " z = x[..., -1].unsqueeze(-1).clone()\n", @@ -376,14 +374,10 @@ " .compose(self.detector.translate.inverse())\n", " .compose(pose)\n", " )\n", - " x = (\n", - " -2\n", - " * self.detector.sdr\n", - " * torch.einsum(\n", - " \"ij, bnj -> bni\",\n", - " self.detector.intrinsic.inverse(),\n", - " pad(pts, (0, 1), value=1), # Convert to homogenous coordinates\n", - " )\n", + " x = -self.detector.sdd * torch.einsum(\n", + " \"ij, bnj -> bni\",\n", + " self.detector.intrinsic.inverse(),\n", + " pad(pts, (0, 1), value=1), # Convert to homogenous coordinates\n", " )\n", " return extrinsic(x)" ] diff --git a/notebooks/api/02_detector.ipynb b/notebooks/api/02_detector.ipynb index 1d3522c2f..5a5394249 100644 --- a/notebooks/api/02_detector.ipynb +++ b/notebooks/api/02_detector.ipynb @@ -75,7 +75,7 @@ "\n", " def __init__(\n", " self,\n", - " sdr: float, # Source-to-detector radius (half of the source-to-detector distance)\n", + " sdd: float, # Source-to-detector distance (i.e., focal length)\n", " height: int, # Height of the X-ray detector\n", " width: int, # Width of the X-ray detector\n", " delx: float, # Pixel spacing in the X-direction\n", @@ -86,7 +86,7 @@ " reverse_x_axis: bool = False, # If pose includes reflection (in E(3) not SE(3)), reverse x-axis\n", " ):\n", " super().__init__()\n", - " self.sdr = sdr\n", + " self.sdd = sdd\n", " self.height = height\n", " self.width = width\n", " self.delx = delx\n", @@ -103,30 +103,30 @@ " self.register_buffer(\"source\", source)\n", " self.register_buffer(\"target\", target)\n", "\n", - " # Anatomy to world coordinates\n", - " flip_xz = torch.tensor(\n", - " [\n", - " [0.0, 0.0, -1.0, 0.0],\n", - " [0.0, 1.0, 0.0, 0.0],\n", - " [1.0, 0.0, 0.0, 0.0],\n", - " [0.0, 0.0, 0.0, 1.0],\n", - " ]\n", - " )\n", - " translate = torch.tensor(\n", - " [\n", - " [1.0, 0.0, 0.0, -self.sdr],\n", - " [0.0, 1.0, 0.0, 0.0],\n", - " [0.0, 0.0, 1.0, 0.0],\n", - " [0.0, 0.0, 0.0, 1.0],\n", - " ]\n", - " )\n", - " self.register_buffer(\"_flip_xz\", flip_xz)\n", - " self.register_buffer(\"_translate\", translate)\n", + " # # Anatomy to world coordinates\n", + " # flip_xz = torch.tensor(\n", + " # [\n", + " # [0.0, 0.0, -1.0, 0.0],\n", + " # [0.0, 1.0, 0.0, 0.0],\n", + " # [1.0, 0.0, 0.0, 0.0],\n", + " # [0.0, 0.0, 0.0, 1.0],\n", + " # ]\n", + " # )\n", + " # translate = torch.tensor(\n", + " # [\n", + " # [1.0, 0.0, 0.0, -self.sdr],\n", + " # [0.0, 1.0, 0.0, 0.0],\n", + " # [0.0, 0.0, 1.0, 0.0],\n", + " # [0.0, 0.0, 0.0, 1.0],\n", + " # ]\n", + " # )\n", + " # self.register_buffer(\"_flip_xz\", flip_xz)\n", + " # self.register_buffer(\"_translate\", translate)\n", "\n", " @property\n", " def intrinsic(self):\n", " return make_intrinsic_matrix(\n", - " self.sdr,\n", + " self.sdd,\n", " self.delx,\n", " self.dely,\n", " self.height,\n", @@ -156,16 +156,16 @@ "def _initialize_carm(self: Detector):\n", " \"\"\"Initialize the default position for the source and detector plane.\"\"\"\n", " try:\n", - " device = self.sdr.device\n", + " device = self.sdd.device\n", " except AttributeError:\n", " device = torch.device(\"cpu\")\n", "\n", - " # Initialize the source on the x-axis and the center of the detector plane on the negative x-axis\n", - " source = torch.tensor([[1.0, 0.0, 0.0]], device=device) * self.sdr\n", - " center = torch.tensor([[-1.0, 0.0, 0.0]], device=device) * self.sdr\n", + " # Initialize the source at the origin and the center of the detector plane on the positive z-axis\n", + " source = torch.tensor([[0.0, 0.0, 0.0]], device=device)\n", + " center = torch.tensor([[0.0, 0.0, 1.0]], device=device) * self.sdd\n", "\n", " # Use the standard basis for the detector plane\n", - " basis = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], device=device)\n", + " basis = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device)\n", "\n", " # Construct the detector plane with different offsets for even or odd heights\n", " h_off = 1.0 if self.height % 2 else 0.5\n", @@ -189,7 +189,7 @@ " target = target.unsqueeze(0)\n", "\n", " # Apply principal point offset\n", - " target[..., 2] -= self.x0\n", + " target[..., 0] -= self.x0\n", " target[..., 1] -= self.y0\n", "\n", " if self.n_subsample is not None:\n",