Skip to content

Commit

Permalink
Add missing parameter mode (#300)
Browse files Browse the repository at this point in the history
  • Loading branch information
haarisr authored Jul 8, 2024
1 parent 4782d31 commit fdec965
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,6 @@ target/

# Mypy cache
.mypy_cache/

# Jupyter notebooks
**/.last_checked
8 changes: 6 additions & 2 deletions diffdrr/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def forward(
# https://stackoverflow.com/questions/78323859/broadcast-pytorch-array-across-channels-based-on-another-array/78324614#78324614
B, D, _ = img.shape
C = int(mask.max().item() + 1)
channels = _get_voxel(mask, xyzs, align_corners=align_corners).long()
channels = _get_voxel(
mask, xyzs, self.mode, align_corners=align_corners
).long()
img = (
torch.zeros(B, C, D)
.to(img)
Expand Down Expand Up @@ -203,7 +205,9 @@ def forward(
else:
B, D, _ = img.shape
C = int(mask.max().item() + 1)
channels = _get_voxel(mask, xyzs, align_corners=align_corners).long()
channels = _get_voxel(
mask, xyzs, self.mode, align_corners=align_corners
).long()
img = (
torch.zeros(B, C, D)
.to(img)
Expand Down
4 changes: 2 additions & 2 deletions notebooks/api/01_renderers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@
" # https://stackoverflow.com/questions/78323859/broadcast-pytorch-array-across-channels-based-on-another-array/78324614#78324614\n",
" B, D, _ = img.shape\n",
" C = int(mask.max().item() + 1)\n",
" channels = _get_voxel(mask, xyzs, align_corners=align_corners).long()\n",
" channels = _get_voxel(mask, xyzs, self.mode, align_corners=align_corners).long()\n",
" img = (\n",
" torch.zeros(B, C, D)\n",
" .to(img)\n",
Expand Down Expand Up @@ -334,7 +334,7 @@
" else:\n",
" B, D, _ = img.shape\n",
" C = int(mask.max().item() + 1)\n",
" channels = _get_voxel(mask, xyzs, align_corners=align_corners).long()\n",
" channels = _get_voxel(mask, xyzs, self.mode, align_corners=align_corners).long()\n",
" img = (\n",
" torch.zeros(B, C, D)\n",
" .to(img)\n",
Expand Down

0 comments on commit fdec965

Please sign in to comment.