Skip to content

Commit

Permalink
Fix memory leak by removing custom __copy__ logic (#1227)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicoloesch authored Jan 27, 2025
1 parent c7f3b82 commit 7f4b5a8
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 38 deletions.
26 changes: 0 additions & 26 deletions src/torchio/data/subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Optional

import numpy as np
Expand Down Expand Up @@ -69,9 +68,6 @@ def __repr__(self):
)
return string

def __copy__(self):
return _subject_copy_helper(self, type(self))

def __len__(self):
return len(self.get_images(intensity_only=False))

Expand Down Expand Up @@ -426,25 +422,3 @@ def plot(self, **kwargs) -> None:
from ..visualization import plot_subject # avoid circular import

plot_subject(self, **kwargs)


def _subject_copy_helper(
old_obj: Subject,
new_subj_cls: Callable[[dict[str, Any]], Subject],
):
result_dict = {}
for key, value in old_obj.items():
if isinstance(value, Image):
value = copy.copy(value)
else:
value = copy.deepcopy(value)
result_dict[key] = value

new = new_subj_cls(**result_dict) # type: ignore[call-arg]
new.applied_transforms = old_obj.applied_transforms[:]
return new


class _RawSubjectCopySubject(Subject):
def __copy__(self):
return _subject_copy_helper(self, Subject)
4 changes: 2 additions & 2 deletions src/torchio/datasets/fpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from ..data import LabelMap
from ..data import ScalarImage
from ..data.io import read_matrix
from ..data.subject import _RawSubjectCopySubject
from ..data.subject import Subject
from ..download import download_url
from ..utils import get_torchio_cache_dir


class FPG(_RawSubjectCopySubject):
class FPG(Subject):
"""3T :math:`T_1`-weighted brain MRI and corresponding parcellation.
Args:
Expand Down
4 changes: 2 additions & 2 deletions src/torchio/datasets/itk_snap/itk_snap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from ...data import LabelMap
from ...data import ScalarImage
from ...data.subject import _RawSubjectCopySubject
from ...data.subject import Subject
from ...download import download_and_extract_archive
from ...utils import get_torchio_cache_dir


class SubjectITKSNAP(_RawSubjectCopySubject):
class SubjectITKSNAP(Subject):
"""ITK-SNAP Image Data Downloads.
See `the ITK-SNAP website`_ for more information.
Expand Down
4 changes: 2 additions & 2 deletions src/torchio/datasets/mni/mni.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from ...data.subject import _RawSubjectCopySubject
from ...data.subject import Subject
from ...utils import get_torchio_cache_dir


class SubjectMNI(_RawSubjectCopySubject):
class SubjectMNI(Subject):
"""Atlases from the Montreal Neurological Institute (MNI).
See `the website <https://nist.mni.mcgill.ca/?page_id=714>`_ for more
Expand Down
4 changes: 2 additions & 2 deletions src/torchio/datasets/slicer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import urllib.parse

from ..data import ScalarImage
from ..data.subject import _RawSubjectCopySubject
from ..data.subject import Subject
from ..download import download_url
from ..utils import get_torchio_cache_dir

Expand Down Expand Up @@ -36,7 +36,7 @@
}


class Slicer(_RawSubjectCopySubject):
class Slicer(Subject):
"""Sample data provided by `3D Slicer <https://www.slicer.org/>`_.
See `the Slicer wiki <https://www.slicer.org/wiki/SampleData>`_
Expand Down
4 changes: 2 additions & 2 deletions src/torchio/transforms/augmentation/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __getitem__(self, index) -> Transform:
def __repr__(self) -> str:
return f'{self.name}({self.transforms})'

def get_base_args(self) -> Dict:
def get_base_args(self) -> dict:
init_args = super().get_base_args()
if 'parse_input' in init_args:
init_args.pop('parse_input')
Expand Down Expand Up @@ -109,7 +109,7 @@ def __init__(self, transforms: TypeTransformsDict, **kwargs):
super().__init__(parse_input=False, **kwargs)
self.transforms_dict = self._get_transforms_dict(transforms)

def get_base_args(self) -> Dict:
def get_base_args(self) -> dict:
init_args = super().get_base_args()
if 'parse_input' in init_args:
init_args.pop('parse_input')
Expand Down
4 changes: 2 additions & 2 deletions src/torchio/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ def __call__(self, data: InputType) -> InputType:
if self.keep is not None:
images_to_keep = {}
for name, new_name in self.keep.items():
images_to_keep[new_name] = copy.copy(subject[name])
images_to_keep[new_name] = copy.deepcopy(subject[name])
if self.copy:
subject = copy.copy(subject)
subject = copy.deepcopy(subject)
with np.errstate(all='raise', under='ignore'):
transformed = self.apply_transform(subject)
if self.keep is not None:
Expand Down
26 changes: 26 additions & 0 deletions tests/data/test_subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,29 @@ def test_subjects_batch(self):
loader = tio.SubjectsLoader(subjects, batch_size=4)
batch = next(iter(loader))
assert batch.__class__ is dict

def test_deep_copy_subject(self):
sub_copy = copy.deepcopy(self.sample_subject)
assert isinstance(sub_copy, tio.data.Subject)

new_tensor = torch.ones_like(sub_copy['t1'].data)
sub_copy['t1'].set_data(new_tensor)
# The data of the original subject should not be modified
assert not torch.allclose(sub_copy['t1'].data, self.sample_subject['t1'].data)

def test_shallow_copy_subject(self):
# We are creating a deep copy of the original subject first to not modify the original subject
copy_original_subj = copy.deepcopy(self.sample_subject)
sub_copy = copy.copy(copy_original_subj)
assert isinstance(sub_copy, tio.data.Subject)

new_tensor = torch.ones_like(sub_copy['t1'].data)
sub_copy['t1'].set_data(new_tensor)

# The data of both copies needs to be the same as we are using a shallow copy
assert torch.allclose(sub_copy['t1'].data, copy_original_subj['t1'].data)
# The data of the original subject should not be modified
assert not torch.allclose(sub_copy['t1'].data, self.sample_subject['t1'].data)
assert not torch.allclose(
copy_original_subj['t1'].data, self.sample_subject['t1'].data
)

0 comments on commit 7f4b5a8

Please sign in to comment.