Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix memory leak by removing custom __copy__ logic #1227

Merged
merged 4 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions src/torchio/data/subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pprint
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
Expand Down Expand Up @@ -72,9 +71,6 @@ def __repr__(self):
)
return string

def __copy__(self):
return _subject_copy_helper(self, type(self))

def __len__(self):
return len(self.get_images(intensity_only=False))

Expand Down Expand Up @@ -429,25 +425,3 @@ def plot(self, **kwargs) -> None:
from ..visualization import plot_subject # avoid circular import

plot_subject(self, **kwargs)


def _subject_copy_helper(
old_obj: Subject,
new_subj_cls: Callable[[Dict[str, Any]], Subject],
):
result_dict = {}
for key, value in old_obj.items():
if isinstance(value, Image):
value = copy.copy(value)
else:
value = copy.deepcopy(value)
result_dict[key] = value

new = new_subj_cls(**result_dict) # type: ignore[call-arg]
new.applied_transforms = old_obj.applied_transforms[:]
return new


class _RawSubjectCopySubject(Subject):
def __copy__(self):
return _subject_copy_helper(self, Subject)
4 changes: 2 additions & 2 deletions src/torchio/datasets/fpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from ..data import LabelMap
from ..data import ScalarImage
from ..data.io import read_matrix
from ..data.subject import _RawSubjectCopySubject
from ..data.subject import Subject
from ..download import download_url
from ..utils import get_torchio_cache_dir


class FPG(_RawSubjectCopySubject):
class FPG(Subject):
"""3T :math:`T_1`-weighted brain MRI and corresponding parcellation.

Args:
Expand Down
4 changes: 2 additions & 2 deletions src/torchio/datasets/itk_snap/itk_snap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from ...data import LabelMap
from ...data import ScalarImage
from ...data.subject import _RawSubjectCopySubject
from ...data.subject import Subject
from ...download import download_and_extract_archive
from ...utils import get_torchio_cache_dir


class SubjectITKSNAP(_RawSubjectCopySubject):
class SubjectITKSNAP(Subject):
"""ITK-SNAP Image Data Downloads.

See `the ITK-SNAP website`_ for more information.
Expand Down
4 changes: 2 additions & 2 deletions src/torchio/datasets/mni/mni.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from ...data.subject import _RawSubjectCopySubject
from ...data.subject import Subject
from ...utils import get_torchio_cache_dir


class SubjectMNI(_RawSubjectCopySubject):
class SubjectMNI(Subject):
"""Atlases from the Montreal Neurological Institute (MNI).

See `the website <https://nist.mni.mcgill.ca/?page_id=714>`_ for more
Expand Down
4 changes: 2 additions & 2 deletions src/torchio/datasets/slicer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import urllib.parse

from ..data import ScalarImage
from ..data.subject import _RawSubjectCopySubject
from ..data.subject import Subject
from ..download import download_url
from ..utils import get_torchio_cache_dir

Expand Down Expand Up @@ -36,7 +36,7 @@
}


class Slicer(_RawSubjectCopySubject):
class Slicer(Subject):
fepegar marked this conversation as resolved.
Show resolved Hide resolved
"""Sample data provided by `3D Slicer <https://www.slicer.org/>`_.

See `the Slicer wiki <https://www.slicer.org/wiki/SampleData>`_
Expand Down
4 changes: 2 additions & 2 deletions src/torchio/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ def __call__(self, data: InputType) -> InputType:
if self.keep is not None:
images_to_keep = {}
for name, new_name in self.keep.items():
images_to_keep[new_name] = copy.copy(subject[name])
images_to_keep[new_name] = copy.deepcopy(subject[name])
if self.copy:
subject = copy.copy(subject)
subject = copy.deepcopy(subject)
with np.errstate(all='raise', under='ignore'):
transformed = self.apply_transform(subject)
if self.keep is not None:
Expand Down
27 changes: 27 additions & 0 deletions tests/data/test_subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,30 @@ def test_subjects_batch(self):
loader = tio.SubjectsLoader(subjects, batch_size=4)
batch = next(iter(loader))
assert batch.__class__ is dict

def test_deep_copy_subject(self):
sub_copy = copy.deepcopy(self.sample_subject)
assert isinstance(sub_copy, tio.data.Subject)

new_tensor = torch.ones_like(sub_copy['t1'].data)
sub_copy['t1'].set_data(new_tensor)
# The data of the original subject should not be modified
assert not torch.allclose(sub_copy['t1'].data, self.sample_subject['t1'].data)

def test_shallow_copy_subject(self):
# We are creating a deep copy of the original subject first to not modify the original subject
copy_original_subj = copy.deepcopy(self.sample_subject)
sub_copy = copy.copy(copy_original_subj)
assert isinstance(sub_copy, tio.data.Subject)

new_tensor = torch.ones_like(sub_copy['t1'].data)
sub_copy['t1'].set_data(new_tensor)

# The data of both copies needs to be the same as we are using a shallow copy
assert torch.allclose(sub_copy['t1'].data, copy_original_subj['t1'].data)
# The data of the original subject should not be modified
assert not torch.allclose(
sub_copy['t1'].data, self.sample_subject['t1'].data
) and not torch.allclose(
copy_original_subj['t1'].data, self.sample_subject['t1'].data
)
nicoloesch marked this conversation as resolved.
Show resolved Hide resolved
Loading