36
36
37
37
from monai .data .meta_tensor import MetaTensor
38
38
from monai .data .utils import SUPPORTED_PICKLE_MOD , convert_tables_to_dicts , pickle_hashing
39
- from monai .transforms import (
40
- Compose ,
41
- Randomizable ,
42
- RandomizableTrait ,
43
- Transform ,
44
- apply_transform ,
45
- convert_to_contiguous ,
46
- reset_ops_id ,
47
- )
39
+ from monai .transforms import Compose , Randomizable , RandomizableTrait , Transform , convert_to_contiguous , reset_ops_id
48
40
from monai .utils import MAX_SEED , convert_to_tensor , get_seed , look_up_option , min_version , optional_import
49
41
from monai .utils .misc import first
50
42
@@ -77,15 +69,19 @@ class Dataset(_TorchDataset):
77
69
}, }, }]
78
70
"""
79
71
80
- def __init__ (self , data : Sequence , transform : Callable | None = None ) -> None :
72
+ def __init__ (self , data : Sequence , transform : Sequence [ Callable ] | Callable | None = None ) -> None :
81
73
"""
82
74
Args:
83
75
data: input data to load and transform to generate dataset for model.
84
- transform: a callable data transform on input data.
85
-
76
+ transform: a callable, sequence of callables or None. If transform is not
77
+ a `Compose` instance, it will be wrapped in a `Compose` instance. Sequences
78
+ of callables are applied in order and if `None` is passed, the data is returned as is.
86
79
"""
87
80
self .data = data
88
- self .transform : Any = transform
81
+ try :
82
+ self .transform = Compose (transform ) if not isinstance (transform , Compose ) else transform
83
+ except Exception as e :
84
+ raise ValueError ("`transform` must be a callable or a list of callables that is Composable" ) from e
89
85
90
86
def __len__ (self ) -> int :
91
87
return len (self .data )
@@ -95,7 +91,7 @@ def _transform(self, index: int):
95
91
Fetch single data item from `self.data`.
96
92
"""
97
93
data_i = self .data [index ]
98
- return apply_transform ( self .transform , data_i ) if self . transform is not None else data_i
94
+ return self .transform ( data_i )
99
95
100
96
def __getitem__ (self , index : int | slice | Sequence [int ]):
101
97
"""
@@ -264,8 +260,6 @@ def __init__(
264
260
using the cached content and with re-created transform instances.
265
261
266
262
"""
267
- if not isinstance (transform , Compose ):
268
- transform = Compose (transform )
269
263
super ().__init__ (data = data , transform = transform )
270
264
self .cache_dir = Path (cache_dir ) if cache_dir is not None else None
271
265
self .hash_func = hash_func
@@ -323,9 +317,6 @@ def _pre_transform(self, item_transformed):
323
317
random transform object
324
318
325
319
"""
326
- if not isinstance (self .transform , Compose ):
327
- raise ValueError ("transform must be an instance of monai.transforms.Compose." )
328
-
329
320
first_random = self .transform .get_index_of_first (
330
321
lambda t : isinstance (t , RandomizableTrait ) or not isinstance (t , Transform )
331
322
)
@@ -346,9 +337,6 @@ def _post_transform(self, item_transformed):
346
337
the transformed element through the random transforms
347
338
348
339
"""
349
- if not isinstance (self .transform , Compose ):
350
- raise ValueError ("transform must be an instance of monai.transforms.Compose." )
351
-
352
340
first_random = self .transform .get_index_of_first (
353
341
lambda t : isinstance (t , RandomizableTrait ) or not isinstance (t , Transform )
354
342
)
@@ -501,9 +489,6 @@ def _pre_transform(self, item_transformed):
501
489
Returns:
502
490
the transformed element up to the N transform object
503
491
"""
504
- if not isinstance (self .transform , Compose ):
505
- raise ValueError ("transform must be an instance of monai.transforms.Compose." )
506
-
507
492
item_transformed = self .transform (item_transformed , end = self .cache_n_trans , threading = True )
508
493
509
494
reset_ops_id (item_transformed )
@@ -519,9 +504,6 @@ def _post_transform(self, item_transformed):
519
504
Returns:
520
505
the final transformed result
521
506
"""
522
- if not isinstance (self .transform , Compose ):
523
- raise ValueError ("transform must be an instance of monai.transforms.Compose." )
524
-
525
507
return self .transform (item_transformed , start = self .cache_n_trans )
526
508
527
509
@@ -809,8 +791,6 @@ def __init__(
809
791
Not following these recommendations may lead to runtime errors or duplicated cache across processes.
810
792
811
793
"""
812
- if not isinstance (transform , Compose ):
813
- transform = Compose (transform )
814
794
super ().__init__ (data = data , transform = transform )
815
795
self .set_num = cache_num # tracking the user-provided `cache_num` option
816
796
self .set_rate = cache_rate # tracking the user-provided `cache_rate` option
@@ -1282,8 +1262,10 @@ def to_list(x):
1282
1262
data = []
1283
1263
for dataset in self .data :
1284
1264
data .extend (to_list (dataset [index ]))
1265
+
1285
1266
if self .transform is not None :
1286
- data = apply_transform (self .transform , data , map_items = False ) # transform the list data
1267
+ self .transform .map_items = False # Compose object map_items to false so transform is applied to list
1268
+ data = self .transform (data )
1287
1269
# use tuple instead of list as the default collate_fn callback of MONAI DataLoader flattens nested lists
1288
1270
return tuple (data )
1289
1271
@@ -1432,15 +1414,11 @@ def __len__(self):
1432
1414
1433
1415
def _transform (self , index : int ):
1434
1416
data = {k : v [index ] for k , v in self .arrays .items ()}
1435
-
1436
- if not self .transform :
1437
- return data
1438
-
1439
- result = apply_transform (self .transform , data )
1417
+ result = self .transform (data ) if self .transform is not None else data
1440
1418
1441
1419
if isinstance (result , dict ) or (isinstance (result , list ) and isinstance (result [0 ], dict )):
1442
1420
return result
1443
- raise AssertionError ("With a dict supplied to apply_transform , should return a dict or a list of dicts." )
1421
+ raise AssertionError ("With a dict supplied to Compose , should return a dict or a list of dicts." )
1444
1422
1445
1423
1446
1424
class CSVDataset (Dataset ):
0 commit comments