Skip to content

Commit

Permalink
Merge branch 'main' into pre-commit-ci-update-config
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar authored Sep 20, 2024
2 parents 8b8e710 + 1c6a6f6 commit d675de7
Show file tree
Hide file tree
Showing 30 changed files with 139 additions and 61 deletions.
9 changes: 9 additions & 0 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,15 @@
"contributions": [
"bug"
]
},
{
"login": "rickymwalsh",
"name": "Ricky Walsh",
"avatar_url": "https://avatars.githubusercontent.com/u/70853488?v=4",
"profile": "https://www.linkedin.com/in/ricky-walsh/",
"contributions": [
"code"
]
}
],
"contributorsPerLine": 7,
Expand Down
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.19.9
current_version = 0.20.0
commit = True
tag = True

Expand Down
3 changes: 0 additions & 3 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ jobs:
- name: Publish package to TestPyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.TEST_PYPI_API_TOKEN }}
repository-url: https://test.pypi.org/legacy/
verbose: true
skip-existing: true
Expand All @@ -51,5 +49,4 @@ jobs:
if: startsWith(github.ref, 'refs/tags')
uses: pypa/gh-action-pypi-publish@release/v1
with:
password: ${{ secrets.PYPI_API_TOKEN }}
verbose: true
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ Thanks goes to all these people ([emoji key](https://allcontributors.org/docs/en
<td align="center" valign="top" width="14.28%"><a href="https://github.com/marius-sm"><img src="https://avatars.githubusercontent.com/u/40166021?v=4?s=100" width="100px;" alt="marius-sm"/><br /><sub><b>marius-sm</b></sub></a><br /><a href="#ideas-marius-sm" title="Ideas, Planning, & Feedback">🤔</a></td>
<td align="center" valign="top" width="14.28%"><a href="https://github.com/haarisr"><img src="https://avatars.githubusercontent.com/u/122410226?v=4?s=100" width="100px;" alt="haarisr"/><br /><sub><b>haarisr</b></sub></a><br /><a href="https://github.com/fepegar/torchio/commits?author=haarisr" title="Code">💻</a></td>
<td align="center" valign="top" width="14.28%"><a href="https://github.com/c-winder"><img src="https://avatars.githubusercontent.com/u/50587864?v=4?s=100" width="100px;" alt="Chris Winder"/><br /><sub><b>Chris Winder</b></sub></a><br /><a href="https://github.com/fepegar/torchio/issues?q=author%3Ac-winder" title="Bug reports">🐛</a></td>
<td align="center" valign="top" width="14.28%"><a href="https://www.linkedin.com/in/ricky-walsh/"><img src="https://avatars.githubusercontent.com/u/70853488?v=4?s=100" width="100px;" alt="Ricky Walsh"/><br /><sub><b>Ricky Walsh</b></sub></a><br /> <a href="https://github.com/fepegar/torchio/commits?author=rickymwalsh" title="Code">💻</a></td>
</tr>
</tbody>
</table>
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/plot_3d_to_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

def plot_batch(sampler):
queue = tio.Queue(dataset, max_queue_length, patches_per_volume, sampler)
loader = torch.utils.data.DataLoader(queue, batch_size=16)
loader = tio.SubjectsLoader(queue, batch_size=16)
batch = tio.utils.get_first_item(loader)

fig, axes = plt.subplots(4, 4, figsize=(12, 10))
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/plot_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
) # noqa: T201, B950
print(transformed.get_inverse_transform(ignore_intensity=False)) # noqa: T201

loader = torch.utils.data.DataLoader(
loader = tio.SubjectsLoader(
dataset,
batch_size=batch_size,
collate_fn=tio.utils.history_collate,
Expand Down
12 changes: 6 additions & 6 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ MedMNIST
from einops import rearrange
rows, cols = 16, 28
dataset = tio.datasets.OrganMNIST3D('train')
loader = torch.utils.data.DataLoader(dataset, batch_size=rows * cols)
loader = tio.SubjectsLoader(dataset, batch_size=rows * cols)
batch = tio.utils.get_first_item(loader)
tensor = batch['image'][tio.DATA]
pattern = '(b1 b2) c x y z -> c x (b1 y) (b2 z)'
Expand All @@ -224,7 +224,7 @@ MedMNIST
from einops import rearrange
rows, cols = 16, 28
dataset = tio.datasets.NoduleMNIST3D('train')
loader = torch.utils.data.DataLoader(dataset, batch_size=rows * cols)
loader = tio.SubjectsLoader(dataset, batch_size=rows * cols)
batch = tio.utils.get_first_item(loader)
tensor = batch['image'][tio.DATA]
pattern = '(b1 b2) c x y z -> c x (b1 y) (b2 z)'
Expand All @@ -242,7 +242,7 @@ MedMNIST
from einops import rearrange
rows, cols = 16, 28
dataset = tio.datasets.AdrenalMNIST3D('train')
loader = torch.utils.data.DataLoader(dataset, batch_size=rows * cols)
loader = tio.SubjectsLoader(dataset, batch_size=rows * cols)
batch = tio.utils.get_first_item(loader)
tensor = batch['image'][tio.DATA]
pattern = '(b1 b2) c x y z -> c x (b1 y) (b2 z)'
Expand All @@ -260,7 +260,7 @@ MedMNIST
from einops import rearrange
rows, cols = 16, 28
dataset = tio.datasets.FractureMNIST3D('train')
loader = torch.utils.data.DataLoader(dataset, batch_size=rows * cols)
loader = tio.SubjectsLoader(dataset, batch_size=rows * cols)
batch = tio.utils.get_first_item(loader)
tensor = batch['image'][tio.DATA]
pattern = '(b1 b2) c x y z -> c x (b1 y) (b2 z)'
Expand All @@ -278,7 +278,7 @@ MedMNIST
from einops import rearrange
rows, cols = 16, 28
dataset = tio.datasets.VesselMNIST3D('train')
loader = torch.utils.data.DataLoader(dataset, batch_size=rows * cols)
loader = tio.SubjectsLoader(dataset, batch_size=rows * cols)
batch = tio.utils.get_first_item(loader)
tensor = batch['image'][tio.DATA]
pattern = '(b1 b2) c x y z -> c x (b1 y) (b2 z)'
Expand All @@ -296,7 +296,7 @@ MedMNIST
from einops import rearrange
rows, cols = 16, 28
dataset = tio.datasets.SynapseMNIST3D('train')
loader = torch.utils.data.DataLoader(dataset, batch_size=rows * cols)
loader = tio.SubjectsLoader(dataset, batch_size=rows * cols)
batch = tio.utils.get_first_item(loader)
tensor = batch['image'][tio.DATA]
pattern = '(b1 b2) c x y z -> c x (b1 y) (b2 z)'
Expand Down
2 changes: 1 addition & 1 deletion docs/source/patches/patch_inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ inference across a 3D image using patches::
... patch_size,
... patch_overlap,
... )
>>> patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=4)
>>> patch_loader = tio.SubjectsLoader(grid_sampler, batch_size=4)
>>> aggregator = tio.inference.GridAggregator(grid_sampler)
>>> model = nn.Identity().eval()
>>> with torch.no_grad():
Expand Down
13 changes: 9 additions & 4 deletions docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,12 @@ Hello, World!

This example shows the basic usage of TorchIO, where an instance of
:class:`~torchio.SubjectsDataset` is passed to
a PyTorch :class:`~torch.utils.data.DataLoader` to generate training batches
a PyTorch :class:`~torch.SubjectsLoader` to generate training batches
of 3D images that are loaded, preprocessed and augmented on the fly,
in parallel::

import torch
import torchio as tio
from torch.utils.data import DataLoader

# Each instance of tio.Subject is passed arbitrary keyword arguments.
# Typically, these arguments will be instances of tio.Image
Expand Down Expand Up @@ -91,8 +90,14 @@ in parallel::
# SubjectsDataset is a subclass of torch.data.utils.Dataset
subjects_dataset = tio.SubjectsDataset(subjects_list, transform=transform)

# Images are processed in parallel thanks to a PyTorch DataLoader
training_loader = DataLoader(subjects_dataset, batch_size=4, num_workers=4)
# Images are processed in parallel thanks to a SubjectsLoader
# (which inherits from torch.utils.data.DataLoader)
training_loader = tio.SubjectsLoader(
subjects_dataset,
batch_size=4,
num_workers=4,
shuffle=True,
)

# Training epoch
for subjects_batch in training_loader:
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = torchio
version = 0.19.9
version = 0.20.0
description = Tools for medical image processing with PyTorch
long_description = file: README.md
long_description_content_type = text/markdown
Expand Down
4 changes: 3 additions & 1 deletion src/torchio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__author__ = """Fernando Perez-Garcia"""
__email__ = 'fepegar@gmail.com'
__version__ = '0.19.9'
__version__ = '0.20.0'


from . import utils
Expand All @@ -13,6 +13,7 @@
sampler,
inference,
SubjectsDataset,
SubjectsLoader,
Image,
ScalarImage,
LabelMap,
Expand All @@ -34,6 +35,7 @@
'sampler',
'inference',
'SubjectsDataset',
'SubjectsLoader',
'Image',
'ScalarImage',
'LabelMap',
Expand Down
2 changes: 2 additions & 0 deletions src/torchio/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .image import LabelMap
from .image import ScalarImage
from .inference import GridAggregator
from .loader import SubjectsLoader
from .queue import Queue
from .sampler import GridSampler
from .sampler import LabelSampler
Expand All @@ -16,6 +17,7 @@
'Queue',
'Subject',
'SubjectsDataset',
'SubjectsLoader',
'Image',
'ScalarImage',
'LabelMap',
Expand Down
4 changes: 2 additions & 2 deletions src/torchio/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ class SubjectsDataset(Dataset):
"""Base TorchIO dataset.
Reader of 3D medical images that directly inherits from the PyTorch
:class:`~torch.utils.data.Dataset`. It can be used with a PyTorch
:class:`~torch.utils.data.DataLoader` for efficient loading and
:class:`~torch.utils.data.Dataset`. It can be used with a
:class:`~tio.SubjectsLoader` for efficient loading and
augmentation. It receives a list of instances of :class:`~torchio.Subject`
and an optional transform applied to the volumes after loading.
Expand Down
6 changes: 3 additions & 3 deletions src/torchio/data/inference/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def add_batch(
extracted using ``batch[torchio.LOCATION]``.
"""
batch = batch_tensor.cpu()
locations = locations.cpu().numpy()
patch_sizes = locations[:, 3:] - locations[:, :3]
locations_array = locations.cpu().numpy()
patch_sizes = locations_array[:, 3:] - locations_array[:, :3]
# There should be only one patch size
assert len(np.unique(patch_sizes, axis=0)) == 1
input_spatial_shape = tuple(batch.shape[-3:])
Expand All @@ -155,7 +155,7 @@ def add_batch(
self._initialize_output_tensor(batch)
assert isinstance(self._output_tensor, torch.Tensor)
if self.overlap_mode == 'crop':
for patch, location in zip(batch, locations):
for patch, location in zip(batch, locations_array):
cropped_patch, new_location = self._crop_patch(
patch,
location,
Expand Down
5 changes: 3 additions & 2 deletions src/torchio/data/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,18 +181,19 @@ def _write_sitk(
) -> None:
assert tensor.ndim == 4
path = Path(path)
array = tensor.numpy()
if path.suffix in ('.png', '.jpg', '.jpeg', '.bmp'):
warnings.warn(
f'Casting to uint 8 before saving to {path}',
RuntimeWarning,
stacklevel=2,
)
tensor = tensor.numpy().astype(np.uint8)
array = array.astype(np.uint8)
if squeeze is None:
force_3d = path.suffix not in IMAGE_2D_FORMATS
else:
force_3d = not squeeze
image = nib_to_sitk(tensor, affine, force_3d=force_3d)
image = nib_to_sitk(array, affine, force_3d=force_3d)
sitk.WriteImage(image, str(path), use_compression)


Expand Down
64 changes: 64 additions & 0 deletions src/torchio/data/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import TypeVar

import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from .subject import Subject


T = TypeVar('T')


class SubjectsLoader(DataLoader):
def __init__(
self,
dataset: Dataset,
collate_fn: Optional[Callable[[List[T]], Any]] = None,
**kwargs,
):
if collate_fn is None:
collate_fn = self._collate # type: ignore[assignment]
super().__init__(
dataset=dataset,
collate_fn=collate_fn,
**kwargs,
)

@staticmethod
def _collate(subjects: List[Subject]) -> Dict[str, Any]:
first_subject = subjects[0]
batch_dict = {}
for key in first_subject.keys():
collated_value = _stack([subject[key] for subject in subjects])
batch_dict[key] = collated_value
return batch_dict


def _stack(x):
"""Determine the type of the input and stack it accordingly.
Args:
x: List of elements to stack.
Returns:
Stacked elements, as either a torch.Tensor, np.ndarray, dict or list.
"""
first_element = x[0]
if isinstance(first_element, torch.Tensor):
return torch.stack(x, dim=0)
elif isinstance(first_element, np.ndarray):
return np.stack(x, axis=0)
elif isinstance(first_element, dict):
# Assume that all elements have the same keys
collated_dict = {}
for key in first_element.keys():
collated_dict[key] = _stack([element[key] for element in x])
return collated_dict
else:
return x
5 changes: 2 additions & 3 deletions src/torchio/data/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ class Queue(Dataset):
>>> import torch
>>> import torchio as tio
>>> from torch.utils.data import DataLoader
>>> patch_size = 96
>>> queue_length = 300
>>> samples_per_volume = 10
Expand All @@ -129,7 +128,7 @@ class Queue(Dataset):
... sampler,
... num_workers=4,
... )
>>> patches_loader = DataLoader(
>>> patches_loader = tio.SubjectsLoader(
... patches_queue,
... batch_size=16,
... num_workers=0, # this must be 0
Expand Down Expand Up @@ -168,7 +167,7 @@ class Queue(Dataset):
... num_workers=4,
... subject_sampler=subject_sampler,
... )
>>> patches_loader = DataLoader(
>>> patches_loader = tio.SubjectsLoader(
... patches_queue,
... batch_size=16,
... num_workers=0, # this must be 0
Expand Down
2 changes: 1 addition & 1 deletion src/torchio/data/subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def _subject_copy_helper(
value = copy.deepcopy(value)
result_dict[key] = value

new = new_subj_cls(**result_dict)
new = new_subj_cls(**result_dict) # type: ignore[call-arg]
new.applied_transforms = old_obj.applied_transforms[:]
return new

Expand Down
9 changes: 5 additions & 4 deletions src/torchio/transforms/preprocessing/intensity/rescale.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import torch

from ....data.image import Image
from ....data.subject import Subject
from ....typing import TypeDoubleFloat
from .normalization_transform import NormalizationTransform
Expand Down Expand Up @@ -84,7 +85,7 @@ def apply_normalization(
image_name: str,
mask: torch.Tensor,
) -> None:
image = subject[image_name]
image: Image = subject[image_name]
image.set_data(self.rescale(image.data, mask, image_name))

def rescale(
Expand All @@ -95,16 +96,16 @@ def rescale(
) -> torch.Tensor:
# The tensor is cloned as in-place operations will be used
array = tensor.clone().float().numpy()
mask = mask.numpy()
if not mask.any():
mask_array = mask.numpy()
if not mask_array.any():
message = (
f'Rescaling image "{image_name}" not possible'
' because the mask to compute the statistics is empty'
)
warnings.warn(message, RuntimeWarning, stacklevel=2)
return tensor

values = array[mask]
values = array[mask_array]
cutoff = np.percentile(values, self.percentiles)
np.clip(array, *cutoff, out=array) # type: ignore[call-overload]

Expand Down
Loading

0 comments on commit d675de7

Please sign in to comment.