Skip to content

Commit

Permalink
Add optional alphaminimax inputs (#312)
Browse files Browse the repository at this point in the history
* Make alphaminimax independent for each pose in a batch

* Add optional alphaminmax flags
  • Loading branch information
eigenvivek authored Jul 16, 2024
1 parent 77628db commit 2dcae81
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
9 changes: 6 additions & 3 deletions diffdrr/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,16 @@ def forward(
n_points=500,
align_corners=False,
mask=None,
alphamin=None,
alphamax=None,
):
dims = self.dims(volume)

# Sample points along the rays and rescale to [-1, 1]
alphamin, alphamax = _get_alpha_minmax(source, target, dims, self.eps)
alphamin = alphamin.min()
alphamax = alphamax.max()
if alphamin is None or alphamax is None:
alphamin, alphamax = _get_alpha_minmax(source, target, dims, self.eps)
alphamin = alphamin.min()
alphamax = alphamax.max()
alphas = torch.linspace(0, 1, n_points)[None, None].to(volume)
alphas = alphas * (alphamax - alphamin) + alphamin

Expand Down
17 changes: 12 additions & 5 deletions notebooks/api/01_renderers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@
" # https://stackoverflow.com/questions/78323859/broadcast-pytorch-array-across-channels-based-on-another-array/78324614#78324614\n",
" B, D, _ = img.shape\n",
" C = int(mask.max().item() + 1)\n",
" channels = _get_voxel(mask, xyzs, self.mode, align_corners=align_corners).long()\n",
" channels = _get_voxel(\n",
" mask, xyzs, self.mode, align_corners=align_corners\n",
" ).long()\n",
" img = (\n",
" torch.zeros(B, C, D)\n",
" .to(img)\n",
Expand Down Expand Up @@ -311,13 +313,16 @@
" n_points=500,\n",
" align_corners=False,\n",
" mask=None,\n",
" alphamin=None,\n",
" alphamax=None,\n",
" ):\n",
" dims = self.dims(volume)\n",
"\n",
" # Sample points along the rays and rescale to [-1, 1]\n",
" alphamin, alphamax = _get_alpha_minmax(source, target, dims, self.eps)\n",
" alphamin = alphamin.min()\n",
" alphamax = alphamax.max()\n",
" if alphamin is None or alphamax is None:\n",
" alphamin, alphamax = _get_alpha_minmax(source, target, dims, self.eps)\n",
" alphamin = alphamin.min()\n",
" alphamax = alphamax.max()\n",
" alphas = torch.linspace(0, 1, n_points)[None, None].to(volume)\n",
" alphas = alphas * (alphamax - alphamin) + alphamin\n",
"\n",
Expand All @@ -334,7 +339,9 @@
" else:\n",
" B, D, _ = img.shape\n",
" C = int(mask.max().item() + 1)\n",
" channels = _get_voxel(mask, xyzs, self.mode, align_corners=align_corners).long()\n",
" channels = _get_voxel(\n",
" mask, xyzs, self.mode, align_corners=align_corners\n",
" ).long()\n",
" img = (\n",
" torch.zeros(B, C, D)\n",
" .to(img)\n",
Expand Down

0 comments on commit 2dcae81

Please sign in to comment.