Skip to content

Commit

Permalink
Remove old functions
Browse files Browse the repository at this point in the history
  • Loading branch information
eigenvivek committed Mar 18, 2024
1 parent 089631f commit 0a0b8d3
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 95 deletions.
44 changes: 4 additions & 40 deletions diffdrr/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.nn.functional import normalize

# %% auto 0
__all__ = ['Detector', 'diffdrr_to_deepdrr']
__all__ = ['Detector']

# %% ../notebooks/api/02_detector.ipynb 5
from .pose import RigidTransform
Expand All @@ -28,7 +28,7 @@ def __init__(
x0: float, # Principal point X-offset
y0: float, # Principal point Y-offset
n_subsample: int | None = None, # Number of target points to randomly sample
reverse_x_axis: bool = False, # If pose includes reflection (in E(3) not SE(3)), reverse x-axis
reverse_x_axis: bool = True, # If pose includes reflection (in E(3) not SE(3)), reverse x-axis
):
super().__init__()
self.sdd = sdd
Expand All @@ -48,26 +48,6 @@ def __init__(
self.register_buffer("source", source)
self.register_buffer("target", target)

# # Anatomy to world coordinates
# flip_xz = torch.tensor(
# [
# [0.0, 0.0, -1.0, 0.0],
# [0.0, 1.0, 0.0, 0.0],
# [1.0, 0.0, 0.0, 0.0],
# [0.0, 0.0, 0.0, 1.0],
# ]
# )
# translate = torch.tensor(
# [
# [1.0, 0.0, 0.0, -self.sdr],
# [0.0, 1.0, 0.0, 0.0],
# [0.0, 0.0, 1.0, 0.0],
# [0.0, 0.0, 0.0, 1.0],
# ]
# )
# self.register_buffer("_flip_xz", flip_xz)
# self.register_buffer("_translate", translate)

@property
def intrinsic(self):
return make_intrinsic_matrix(
Expand All @@ -78,15 +58,7 @@ def intrinsic(self):
self.width,
self.x0,
self.y0,
).to(self._flip_xz)

@property
def flip_xz(self):
return RigidTransform(self._flip_xz)

@property
def translate(self):
return RigidTransform(self._translate)
).to(self.source)

# %% ../notebooks/api/02_detector.ipynb 6
@patch
Expand Down Expand Up @@ -140,16 +112,8 @@ def _initialize_carm(self: Detector):


@patch
def forward(
self: Detector,
pose: RigidTransform,
):
def forward(self: Detector, pose: RigidTransform):
"""Create source and target points for X-rays to trace through the volume."""
source = pose(self.source)
target = pose(self.target)
return source, target

# %% ../notebooks/api/02_detector.ipynb 8
def diffdrr_to_deepdrr(euler_angles):
alpha, beta, gamma = euler_angles.unbind(-1)
return torch.stack([beta, alpha, gamma], dim=1)
7 changes: 3 additions & 4 deletions diffdrr/drr.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ def __init__(
# Initialize the X-ray detector
width = height if width is None else width
dely = delx if dely is None else dely
n_subsample = (
int(height * width * p_subsample) if p_subsample is not None else None
)
self.detector = Detector(
sdd,
height,
Expand All @@ -55,8 +52,10 @@ def __init__(
dely,
x0,
y0,
n_subsample=n_subsample,
reverse_x_axis=reverse_x_axis,
n_subsample=int(height * width * p_subsample)
if p_subsample is not None
else None,
)

# Initialize the volume
Expand Down
7 changes: 3 additions & 4 deletions notebooks/api/00_drr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,6 @@
" # Initialize the X-ray detector\n",
" width = height if width is None else width\n",
" dely = delx if dely is None else dely\n",
" n_subsample = (\n",
" int(height * width * p_subsample) if p_subsample is not None else None\n",
" )\n",
" self.detector = Detector(\n",
" sdd,\n",
" height,\n",
Expand All @@ -150,8 +147,10 @@
" dely,\n",
" x0,\n",
" y0,\n",
" n_subsample=n_subsample,\n",
" reverse_x_axis=reverse_x_axis,\n",
" n_subsample=int(height * width * p_subsample)\n",
" if p_subsample is not None\n",
" else None,\n",
" )\n",
"\n",
" # Initialize the volume\n",
Expand Down
50 changes: 3 additions & 47 deletions notebooks/api/02_detector.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
" x0: float, # Principal point X-offset\n",
" y0: float, # Principal point Y-offset\n",
" n_subsample: int | None = None, # Number of target points to randomly sample\n",
" reverse_x_axis: bool = False, # If pose includes reflection (in E(3) not SE(3)), reverse x-axis\n",
" reverse_x_axis: bool = True, # If pose includes reflection (in E(3) not SE(3)), reverse x-axis\n",
" ):\n",
" super().__init__()\n",
" self.sdd = sdd\n",
Expand All @@ -103,26 +103,6 @@
" self.register_buffer(\"source\", source)\n",
" self.register_buffer(\"target\", target)\n",
"\n",
" # # Anatomy to world coordinates\n",
" # flip_xz = torch.tensor(\n",
" # [\n",
" # [0.0, 0.0, -1.0, 0.0],\n",
" # [0.0, 1.0, 0.0, 0.0],\n",
" # [1.0, 0.0, 0.0, 0.0],\n",
" # [0.0, 0.0, 0.0, 1.0],\n",
" # ]\n",
" # )\n",
" # translate = torch.tensor(\n",
" # [\n",
" # [1.0, 0.0, 0.0, -self.sdr],\n",
" # [0.0, 1.0, 0.0, 0.0],\n",
" # [0.0, 0.0, 1.0, 0.0],\n",
" # [0.0, 0.0, 0.0, 1.0],\n",
" # ]\n",
" # )\n",
" # self.register_buffer(\"_flip_xz\", flip_xz)\n",
" # self.register_buffer(\"_translate\", translate)\n",
"\n",
" @property\n",
" def intrinsic(self):\n",
" return make_intrinsic_matrix(\n",
Expand All @@ -133,15 +113,7 @@
" self.width,\n",
" self.x0,\n",
" self.y0,\n",
" ).to(self._flip_xz)\n",
"\n",
" @property\n",
" def flip_xz(self):\n",
" return RigidTransform(self._flip_xz)\n",
"\n",
" @property\n",
" def translate(self):\n",
" return RigidTransform(self._translate)"
" ).to(self.source)"
]
},
{
Expand Down Expand Up @@ -211,29 +183,13 @@
"\n",
"\n",
"@patch\n",
"def forward(\n",
" self: Detector,\n",
" pose: RigidTransform,\n",
"):\n",
"def forward(self: Detector, pose: RigidTransform):\n",
" \"\"\"Create source and target points for X-rays to trace through the volume.\"\"\"\n",
" source = pose(self.source)\n",
" target = pose(self.target)\n",
" return source, target"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "13e35b1b-13d1-4067-b96c-ecf0b2045d94",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def diffdrr_to_deepdrr(euler_angles):\n",
" alpha, beta, gamma = euler_angles.unbind(-1)\n",
" return torch.stack([beta, alpha, gamma], dim=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit 0a0b8d3

Please sign in to comment.