diff --git a/.gitignore b/.gitignore index 01ff62ca1..92dfba872 100644 --- a/.gitignore +++ b/.gitignore @@ -96,3 +96,6 @@ target/ # Mypy cache .mypy_cache/ + +# Jupyter notebooks +**/.last_checked diff --git a/diffdrr/renderers.py b/diffdrr/renderers.py index d26ade481..4d1b691fc 100644 --- a/diffdrr/renderers.py +++ b/diffdrr/renderers.py @@ -73,7 +73,9 @@ def forward( # https://stackoverflow.com/questions/78323859/broadcast-pytorch-array-across-channels-based-on-another-array/78324614#78324614 B, D, _ = img.shape C = int(mask.max().item() + 1) - channels = _get_voxel(mask, xyzs, align_corners=align_corners).long() + channels = _get_voxel( + mask, xyzs, self.mode, align_corners=align_corners + ).long() img = ( torch.zeros(B, C, D) .to(img) @@ -203,7 +205,9 @@ def forward( else: B, D, _ = img.shape C = int(mask.max().item() + 1) - channels = _get_voxel(mask, xyzs, align_corners=align_corners).long() + channels = _get_voxel( + mask, xyzs, self.mode, align_corners=align_corners + ).long() img = ( torch.zeros(B, C, D) .to(img) diff --git a/notebooks/api/01_renderers.ipynb b/notebooks/api/01_renderers.ipynb index af20b3e47..1947303f9 100644 --- a/notebooks/api/01_renderers.ipynb +++ b/notebooks/api/01_renderers.ipynb @@ -175,7 +175,7 @@ " # https://stackoverflow.com/questions/78323859/broadcast-pytorch-array-across-channels-based-on-another-array/78324614#78324614\n", " B, D, _ = img.shape\n", " C = int(mask.max().item() + 1)\n", - " channels = _get_voxel(mask, xyzs, align_corners=align_corners).long()\n", + " channels = _get_voxel(mask, xyzs, self.mode, align_corners=align_corners).long()\n", " img = (\n", " torch.zeros(B, C, D)\n", " .to(img)\n", @@ -334,7 +334,7 @@ " else:\n", " B, D, _ = img.shape\n", " C = int(mask.max().item() + 1)\n", - " channels = _get_voxel(mask, xyzs, align_corners=align_corners).long()\n", + " channels = _get_voxel(mask, xyzs, self.mode, align_corners=align_corners).long()\n", " img = (\n", " torch.zeros(B, C, D)\n", " .to(img)\n",