Skip to content

Commit 4029c42

Browse files
surajpaibpre-commit-ci[bot]KumoLiuatbenmurray
authored
Refactor Dataset to use Compose for transforms (Project-MONAI#7784)
Fixes Project-MONAI#7646 ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Suraj Pai <b.pai@maastrichtuniversity.nl> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Ben Murray <ben.murray@gmail.com>
1 parent 0d7f772 commit 4029c42

File tree

4 files changed

+86
-40
lines changed

4 files changed

+86
-40
lines changed

monai/data/dataset.py

+15-37
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,7 @@
3636

3737
from monai.data.meta_tensor import MetaTensor
3838
from monai.data.utils import SUPPORTED_PICKLE_MOD, convert_tables_to_dicts, pickle_hashing
39-
from monai.transforms import (
40-
Compose,
41-
Randomizable,
42-
RandomizableTrait,
43-
Transform,
44-
apply_transform,
45-
convert_to_contiguous,
46-
reset_ops_id,
47-
)
39+
from monai.transforms import Compose, Randomizable, RandomizableTrait, Transform, convert_to_contiguous, reset_ops_id
4840
from monai.utils import MAX_SEED, convert_to_tensor, get_seed, look_up_option, min_version, optional_import
4941
from monai.utils.misc import first
5042

@@ -77,15 +69,19 @@ class Dataset(_TorchDataset):
7769
}, }, }]
7870
"""
7971

80-
def __init__(self, data: Sequence, transform: Callable | None = None) -> None:
72+
def __init__(self, data: Sequence, transform: Sequence[Callable] | Callable | None = None) -> None:
8173
"""
8274
Args:
8375
data: input data to load and transform to generate dataset for model.
84-
transform: a callable data transform on input data.
85-
76+
transform: a callable, sequence of callables or None. If transform is not
77+
a `Compose` instance, it will be wrapped in a `Compose` instance. Sequences
78+
of callables are applied in order and if `None` is passed, the data is returned as is.
8679
"""
8780
self.data = data
88-
self.transform: Any = transform
81+
try:
82+
self.transform = Compose(transform) if not isinstance(transform, Compose) else transform
83+
except Exception as e:
84+
raise ValueError("`transform` must be a callable or a list of callables that is Composable") from e
8985

9086
def __len__(self) -> int:
9187
return len(self.data)
@@ -95,7 +91,7 @@ def _transform(self, index: int):
9591
Fetch single data item from `self.data`.
9692
"""
9793
data_i = self.data[index]
98-
return apply_transform(self.transform, data_i) if self.transform is not None else data_i
94+
return self.transform(data_i)
9995

10096
def __getitem__(self, index: int | slice | Sequence[int]):
10197
"""
@@ -264,8 +260,6 @@ def __init__(
264260
using the cached content and with re-created transform instances.
265261
266262
"""
267-
if not isinstance(transform, Compose):
268-
transform = Compose(transform)
269263
super().__init__(data=data, transform=transform)
270264
self.cache_dir = Path(cache_dir) if cache_dir is not None else None
271265
self.hash_func = hash_func
@@ -323,9 +317,6 @@ def _pre_transform(self, item_transformed):
323317
random transform object
324318
325319
"""
326-
if not isinstance(self.transform, Compose):
327-
raise ValueError("transform must be an instance of monai.transforms.Compose.")
328-
329320
first_random = self.transform.get_index_of_first(
330321
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
331322
)
@@ -346,9 +337,6 @@ def _post_transform(self, item_transformed):
346337
the transformed element through the random transforms
347338
348339
"""
349-
if not isinstance(self.transform, Compose):
350-
raise ValueError("transform must be an instance of monai.transforms.Compose.")
351-
352340
first_random = self.transform.get_index_of_first(
353341
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
354342
)
@@ -501,9 +489,6 @@ def _pre_transform(self, item_transformed):
501489
Returns:
502490
the transformed element up to the N transform object
503491
"""
504-
if not isinstance(self.transform, Compose):
505-
raise ValueError("transform must be an instance of monai.transforms.Compose.")
506-
507492
item_transformed = self.transform(item_transformed, end=self.cache_n_trans, threading=True)
508493

509494
reset_ops_id(item_transformed)
@@ -519,9 +504,6 @@ def _post_transform(self, item_transformed):
519504
Returns:
520505
the final transformed result
521506
"""
522-
if not isinstance(self.transform, Compose):
523-
raise ValueError("transform must be an instance of monai.transforms.Compose.")
524-
525507
return self.transform(item_transformed, start=self.cache_n_trans)
526508

527509

@@ -809,8 +791,6 @@ def __init__(
809791
Not following these recommendations may lead to runtime errors or duplicated cache across processes.
810792
811793
"""
812-
if not isinstance(transform, Compose):
813-
transform = Compose(transform)
814794
super().__init__(data=data, transform=transform)
815795
self.set_num = cache_num # tracking the user-provided `cache_num` option
816796
self.set_rate = cache_rate # tracking the user-provided `cache_rate` option
@@ -1282,8 +1262,10 @@ def to_list(x):
12821262
data = []
12831263
for dataset in self.data:
12841264
data.extend(to_list(dataset[index]))
1265+
12851266
if self.transform is not None:
1286-
data = apply_transform(self.transform, data, map_items=False) # transform the list data
1267+
self.transform.map_items = False # Compose object map_items to false so transform is applied to list
1268+
data = self.transform(data)
12871269
# use tuple instead of list as the default collate_fn callback of MONAI DataLoader flattens nested lists
12881270
return tuple(data)
12891271

@@ -1432,15 +1414,11 @@ def __len__(self):
14321414

14331415
def _transform(self, index: int):
14341416
data = {k: v[index] for k, v in self.arrays.items()}
1435-
1436-
if not self.transform:
1437-
return data
1438-
1439-
result = apply_transform(self.transform, data)
1417+
result = self.transform(data) if self.transform is not None else data
14401418

14411419
if isinstance(result, dict) or (isinstance(result, list) and isinstance(result[0], dict)):
14421420
return result
1443-
raise AssertionError("With a dict supplied to apply_transform, should return a dict or a list of dicts.")
1421+
raise AssertionError("With a dict supplied to Compose, should return a dict or a list of dicts.")
14441422

14451423

14461424
class CSVDataset(Dataset):

tests/test_arraydataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
class TestCompose(Compose):
4343

44-
def __call__(self, input_, lazy):
44+
def __call__(self, input_, lazy=False):
4545
img = self.transforms[0](input_)
4646
metadata = img.meta
4747
img = self.transforms[1](img)

tests/test_dataset.py

+67-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from parameterized import parameterized
2424

2525
from monai.data import Dataset
26-
from monai.transforms import Compose, LoadImaged, SimulateDelayd
26+
from monai.transforms import Compose, Lambda, LoadImage, LoadImaged, SimulateDelay, SimulateDelayd
2727
from tests.test_compose import TEST_COMPOSE_LAZY_ON_CALL_LOGGING_TEST_CASES, data_from_keys
2828

2929
TEST_CASE_1 = [(128, 128, 128)]
@@ -99,6 +99,72 @@ def test_dataset_lazy_on_call(self):
9999
data[0, 0:2, 0:2] = 1
100100

101101

102+
class TestTupleDataset(unittest.TestCase):
103+
104+
@parameterized.expand([TEST_CASE_1])
105+
def test_shape(self, expected_shape):
106+
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))
107+
with tempfile.TemporaryDirectory() as tempdir:
108+
nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz"))
109+
nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz"))
110+
nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz"))
111+
nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz"))
112+
test_data = [
113+
(os.path.join(tempdir, "test_image1.nii.gz"), os.path.join(tempdir, "test_label1.nii.gz")),
114+
(os.path.join(tempdir, "test_image2.nii.gz"), os.path.join(tempdir, "test_label2.nii.gz")),
115+
]
116+
117+
test_transform = Compose([LoadImage(), SimulateDelay(delay_time=1e-5)])
118+
119+
# Here test_transform is applied element by element for the tuple.
120+
dataset = Dataset(data=test_data, transform=test_transform)
121+
data1 = dataset[0]
122+
data2 = dataset[1]
123+
124+
# Output is a list/tuple
125+
self.assertTrue(isinstance(data1, (list, tuple)))
126+
self.assertTrue(isinstance(data2, (list, tuple)))
127+
128+
# Number of elements are 2
129+
self.assertEqual(len(data1), 2)
130+
self.assertEqual(len(data2), 2)
131+
132+
# Output shapes are as expected
133+
self.assertTupleEqual(data1[0].shape, expected_shape)
134+
self.assertTupleEqual(data1[1].shape, expected_shape)
135+
self.assertTupleEqual(data2[0].shape, expected_shape)
136+
self.assertTupleEqual(data2[1].shape, expected_shape)
137+
138+
# Here test_transform is applied to the tuple as a whole.
139+
test_transform = Compose(
140+
[
141+
# LoadImage creates a channel-stacked image when applied to a tuple
142+
LoadImage(),
143+
# Get the channel-stacked image and the label
144+
Lambda(func=lambda x: (x[0].permute(2, 1, 0), x[1])),
145+
],
146+
map_items=False,
147+
)
148+
149+
dataset = Dataset(data=test_data, transform=test_transform)
150+
data1 = dataset[0]
151+
data2 = dataset[1]
152+
153+
# Output is a list/tuple
154+
self.assertTrue(isinstance(data1, (list, tuple)))
155+
self.assertTrue(isinstance(data2, (list, tuple)))
156+
157+
# Number of elements are 2
158+
self.assertEqual(len(data1), 2)
159+
self.assertEqual(len(data2), 2)
160+
161+
# Output shapes are as expected
162+
self.assertTupleEqual(data1[0].shape, expected_shape)
163+
self.assertTupleEqual(data1[1].shape, expected_shape)
164+
self.assertTupleEqual(data2[0].shape, expected_shape)
165+
self.assertTupleEqual(data2[1].shape, expected_shape)
166+
167+
102168
class TestDatsesetWithLazy(unittest.TestCase):
103169
LOGGER_NAME = "a_logger_name"
104170

tests/test_profiling.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def setUp(self):
3535

3636
self.scale = mt.ScaleIntensity()
3737
self.scale_call_name = "ScaleIntensity.__call__"
38+
self.compose_call_name = "Compose.__call__"
3839
self.test_comp = mt.Compose([mt.ScaleIntensity(), mt.RandAxisFlip(0.5)])
3940
self.test_image = torch.rand(1, 16, 16, 16)
4041
self.pid = os.getpid()
@@ -82,7 +83,7 @@ def test_profile_multithread(self):
8283
self.assertSequenceEqual(batch.shape, (4, 1, 16, 16, 16))
8384

8485
results = wp.get_results()
85-
self.assertSequenceEqual(list(results), [self.scale_call_name])
86+
self.assertSequenceEqual(list(results), [self.scale_call_name, self.compose_call_name])
8687

8788
prs = results[self.scale_call_name]
8889

@@ -98,6 +99,7 @@ def test_profile_context(self):
9899
self.scale(self.test_image)
99100

100101
results = wp.get_results()
102+
101103
self.assertSequenceEqual(set(results), {"ScaleIntensity.__call__", "context"})
102104

103105
prs = results["context"]

0 commit comments

Comments
 (0)