Skip to content

Commit e0e7215

Browse files
Nic-Mamonai-bot
andauthored
416 add determinism utility (Project-MONAI#422)
* [DLMED] implement API to set determinism * [MONAI] python code formatting * [DLMED] fix PyTorch known issue * [DLMED] update according to comments * [MONAI] python code formatting Co-authored-by: monai-bot <monai.miccai2019@gmail.com>
1 parent f278e4f commit e0e7215

15 files changed

+217
-166
lines changed

docs/source/highlights.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,10 @@ train_transforms = monai.transforms.Compose([
8888
])
8989
# set determinism for reproducibility
9090
train_transforms.set_random_state(seed=0)
91-
torch.manual_seed(0)
92-
torch.backends.cudnn.deterministic = True
93-
torch.backends.cudnn.benchmark = False
91+
```
92+
Users can also enable/disable deterministic training directly:
93+
```py
94+
monai.utils.set_determinism(seed=0, additional_settings=None)
9495
```
9596

9697
### 6. Cache IO and transforms data to accelerate training

docs/source/utils.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,8 @@ Aliases
1919
-------
2020
.. automodule:: monai.utils.aliases
2121
:members:
22+
23+
Misc
24+
-------
25+
.. automodule:: monai.utils.misc
26+
:members:

examples/notebooks/cache_dataset_speed.ipynb

Lines changed: 41 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@
4848
"from monai.networks.layers import Norm\n",
4949
"from monai.networks.nets import UNet\n",
5050
"from monai.losses import DiceLoss\n",
51-
"from monai.metrics import compute_meandice"
51+
"from monai.metrics import compute_meandice\n",
52+
"from monai.utils import set_determinism"
5253
]
5354
},
5455
{
@@ -99,48 +100,30 @@
99100
"metadata": {},
100101
"outputs": [],
101102
"source": [
102-
"train_transforms = Compose([\n",
103-
" LoadNiftid(keys=['image', 'label']),\n",
104-
" AddChanneld(keys=['image', 'label']),\n",
105-
" Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2.), interp_order=(3, 0)),\n",
106-
" Orientationd(keys=['image', 'label'], axcodes='RAS'),\n",
107-
" ScaleIntensityRanged(keys=['image'], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),\n",
108-
" CropForegroundd(keys=['image', 'label'], source_key='image'),\n",
109-
" # randomly crop out patch samples from big image based on pos / neg ratio\n",
110-
" # the image centers of negative samples must be in valid image area\n",
111-
" RandCropByPosNegLabeld(keys=['image', 'label'], label_key='label', size=(96, 96, 96), pos=1,\n",
112-
" neg=1, num_samples=4, image_key='image', image_threshold=0),\n",
113-
" ToTensord(keys=['image', 'label'])\n",
114-
"])\n",
115-
"val_transforms = Compose([\n",
116-
" LoadNiftid(keys=['image', 'label']),\n",
117-
" AddChanneld(keys=['image', 'label']),\n",
118-
" Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2.), interp_order=(3, 0)),\n",
119-
" Orientationd(keys=['image', 'label'], axcodes='RAS'),\n",
120-
" ScaleIntensityRanged(keys=['image'], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),\n",
121-
" CropForegroundd(keys=['image', 'label'], source_key='image'),\n",
122-
" ToTensord(keys=['image', 'label'])\n",
123-
"])"
124-
]
125-
},
126-
{
127-
"cell_type": "markdown",
128-
"metadata": {},
129-
"source": [
130-
"## Define deterministic training for reproducibility"
131-
]
132-
},
133-
{
134-
"cell_type": "code",
135-
"execution_count": 4,
136-
"metadata": {},
137-
"outputs": [],
138-
"source": [
139-
"def set_deterministic():\n",
140-
" train_transforms.set_random_state(seed=0)\n",
141-
" torch.manual_seed(0)\n",
142-
" torch.backends.cudnn.deterministic = True\n",
143-
" torch.backends.cudnn.benchmark = False"
103+
"def transformations():\n",
104+
" train_transforms = Compose([\n",
105+
" LoadNiftid(keys=['image', 'label']),\n",
106+
" AddChanneld(keys=['image', 'label']),\n",
107+
" Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2.), interp_order=(3, 0)),\n",
108+
" Orientationd(keys=['image', 'label'], axcodes='RAS'),\n",
109+
" ScaleIntensityRanged(keys=['image'], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),\n",
110+
" CropForegroundd(keys=['image', 'label'], source_key='image'),\n",
111+
" # randomly crop out patch samples from big image based on pos / neg ratio\n",
112+
" # the image centers of negative samples must be in valid image area\n",
113+
" RandCropByPosNegLabeld(keys=['image', 'label'], label_key='label', size=(96, 96, 96), pos=1,\n",
114+
" neg=1, num_samples=4, image_key='image', image_threshold=0),\n",
115+
" ToTensord(keys=['image', 'label'])\n",
116+
" ])\n",
117+
" val_transforms = Compose([\n",
118+
" LoadNiftid(keys=['image', 'label']),\n",
119+
" AddChanneld(keys=['image', 'label']),\n",
120+
" Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2.), interp_order=(3, 0)),\n",
121+
" Orientationd(keys=['image', 'label'], axcodes='RAS'),\n",
122+
" ScaleIntensityRanged(keys=['image'], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),\n",
123+
" CropForegroundd(keys=['image', 'label'], source_key='image'),\n",
124+
" ToTensord(keys=['image', 'label'])\n",
125+
" ])\n",
126+
" return train_transforms, val_transforms"
144127
]
145128
},
146129
{
@@ -152,13 +135,14 @@
152135
},
153136
{
154137
"cell_type": "code",
155-
"execution_count": 5,
138+
"execution_count": 4,
156139
"metadata": {
157140
"scrolled": true
158141
},
159142
"outputs": [],
160143
"source": [
161144
"def train_process(train_ds, val_ds):\n",
145+
" \n",
162146
" # use batch_size=2 to load images and use RandCropByPosNegLabeld\n",
163147
" # to generate 2 x 4 images for network training\n",
164148
" train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, collate_fn=list_data_collate)\n",
@@ -169,7 +153,7 @@
169153
" loss_function = DiceLoss(to_onehot_y=True, do_softmax=True)\n",
170154
" optimizer = torch.optim.Adam(model.parameters(), 1e-4)\n",
171155
"\n",
172-
" epoch_num = 600\n",
156+
" epoch_num = 2\n",
173157
" val_interval = 1 # do validation for every epoch\n",
174158
" best_metric = -1\n",
175159
" best_metric_epoch = -1\n",
@@ -234,18 +218,19 @@
234218
"cell_type": "markdown",
235219
"metadata": {},
236220
"source": [
237-
"## Define regular Dataset for training and validation"
221+
"## Enable deterministic training and define regular Datasets"
238222
]
239223
},
240224
{
241225
"cell_type": "code",
242-
"execution_count": 6,
226+
"execution_count": 5,
243227
"metadata": {},
244228
"outputs": [],
245229
"source": [
246-
"set_deterministic()\n",
247-
"train_ds = Dataset(data=train_files, transform=train_transforms)\n",
248-
"val_ds = Dataset(data=val_files, transform=val_transforms)"
230+
"set_determinism(seed=0)\n",
231+
"train_trans, val_trans = transformations()\n",
232+
"train_ds = Dataset(data=train_files, transform=train_trans)\n",
233+
"val_ds = Dataset(data=val_files, transform=val_trans)"
249234
]
250235
},
251236
{
@@ -269,12 +254,12 @@
269254
"cell_type": "markdown",
270255
"metadata": {},
271256
"source": [
272-
"## Define Cache Dataset for training and validation"
257+
"## Enable determinism training and define Cache Datasets"
273258
]
274259
},
275260
{
276261
"cell_type": "code",
277-
"execution_count": 8,
262+
"execution_count": 7,
278263
"metadata": {},
279264
"outputs": [
280265
{
@@ -289,10 +274,11 @@
289274
}
290275
],
291276
"source": [
292-
"set_deterministic()\n",
277+
"set_determinism(seed=0)\n",
278+
"train_trans, val_trans = transformations()\n",
293279
"cache_init_start = time.time()\n",
294-
"cache_train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)\n",
295-
"cache_val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4)\n",
280+
"cache_train_ds = CacheDataset(data=train_files, transform=train_trans, cache_rate=1.0, num_workers=4)\n",
281+
"cache_val_ds = CacheDataset(data=val_files, transform=val_trans, cache_rate=1.0, num_workers=4)\n",
296282
"cache_init_time = time.time() - cache_init_start"
297283
]
298284
},

examples/notebooks/persistent_dataset_speed.ipynb

Lines changed: 44 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@
5858
"from monai.networks.layers import Norm\n",
5959
"from monai.networks.nets import UNet\n",
6060
"from monai.losses import DiceLoss\n",
61-
"from monai.metrics import compute_meandice"
61+
"from monai.metrics import compute_meandice\n",
62+
"from monai.utils import set_determinism"
6263
]
6364
},
6465
{
@@ -146,27 +147,6 @@
146147
" return epoch_num, time.time() - total_start, epoch_loss_values, metric_values, epoch_times"
147148
]
148149
},
149-
{
150-
"cell_type": "markdown",
151-
"metadata": {},
152-
"source": [
153-
"## Define deterministic training for reproducibility"
154-
]
155-
},
156-
{
157-
"cell_type": "code",
158-
"execution_count": 3,
159-
"metadata": {},
160-
"outputs": [],
161-
"source": [
162-
"def set_deterministic(train_transforms):\n",
163-
" train_transforms.set_random_state(seed=0)\n",
164-
" torch.manual_seed(0)\n",
165-
" torch.backends.cudnn.deterministic = True\n",
166-
" torch.backends.cudnn.benchmark = False\n",
167-
" return train_transforms"
168-
]
169-
},
170150
{
171151
"cell_type": "markdown",
172152
"metadata": {},
@@ -185,7 +165,7 @@
185165
},
186166
{
187167
"cell_type": "code",
188-
"execution_count": 4,
168+
"execution_count": 3,
189169
"metadata": {},
190170
"outputs": [],
191171
"source": [
@@ -220,41 +200,43 @@
220200
},
221201
{
222202
"cell_type": "code",
223-
"execution_count": 5,
203+
"execution_count": 4,
224204
"metadata": {},
225205
"outputs": [],
226206
"source": [
227-
"train_transforms = Compose([\n",
228-
" LoadNiftid(keys=['image', 'label']),\n",
229-
" AddChanneld(keys=['image', 'label']),\n",
230-
" Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2.), interp_order=(3, 0)),\n",
231-
" Orientationd(keys=['image', 'label'], axcodes='RAS'),\n",
232-
" ScaleIntensityRanged(keys=['image'], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),\n",
233-
" CropForegroundd(keys=['image', 'label'], source_key='image'),\n",
234-
" # randomly crop out patch samples from big image based on pos / neg ratio\n",
235-
" # the image centers of negative samples must be in valid image area\n",
236-
" RandCropByPosNegLabeld(keys=['image', 'label'], label_key='label', size=(96, 96, 96), pos=1,\n",
237-
" neg=1, num_samples=4, image_key='image', image_threshold=0),\n",
238-
" ToTensord(keys=['image', 'label'])\n",
239-
"])\n",
207+
"def transformations():\n",
208+
" train_transforms = Compose([\n",
209+
" LoadNiftid(keys=['image', 'label']),\n",
210+
" AddChanneld(keys=['image', 'label']),\n",
211+
" Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2.), interp_order=(3, 0)),\n",
212+
" Orientationd(keys=['image', 'label'], axcodes='RAS'),\n",
213+
" ScaleIntensityRanged(keys=['image'], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),\n",
214+
" CropForegroundd(keys=['image', 'label'], source_key='image'),\n",
215+
" # randomly crop out patch samples from big image based on pos / neg ratio\n",
216+
" # the image centers of negative samples must be in valid image area\n",
217+
" RandCropByPosNegLabeld(keys=['image', 'label'], label_key='label', size=(96, 96, 96), pos=1,\n",
218+
" neg=1, num_samples=4, image_key='image', image_threshold=0),\n",
219+
" ToTensord(keys=['image', 'label'])\n",
220+
" ])\n",
240221
"\n",
241-
"# NOTE: No random cropping in the validation data, we will evaluate the entire image using a sliding window.\n",
242-
"val_transforms = Compose([\n",
243-
" LoadNiftid(keys=['image', 'label']),\n",
244-
" AddChanneld(keys=['image', 'label']),\n",
245-
" Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2.), interp_order=(3, 0)),\n",
246-
" Orientationd(keys=['image', 'label'], axcodes='RAS'),\n",
247-
" ScaleIntensityRanged(keys=['image'], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),\n",
248-
" CropForegroundd(keys=['image', 'label'], source_key='image'),\n",
249-
" ToTensord(keys=['image', 'label'])\n",
250-
"])"
222+
" # NOTE: No random cropping in the validation data, we will evaluate the entire image using a sliding window.\n",
223+
" val_transforms = Compose([\n",
224+
" LoadNiftid(keys=['image', 'label']),\n",
225+
" AddChanneld(keys=['image', 'label']),\n",
226+
" Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2.), interp_order=(3, 0)),\n",
227+
" Orientationd(keys=['image', 'label'], axcodes='RAS'),\n",
228+
" ScaleIntensityRanged(keys=['image'], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),\n",
229+
" CropForegroundd(keys=['image', 'label'], source_key='image'),\n",
230+
" ToTensord(keys=['image', 'label'])\n",
231+
" ])\n",
232+
" return train_transforms, val_transforms"
251233
]
252234
},
253235
{
254236
"cell_type": "markdown",
255237
"metadata": {},
256238
"source": [
257-
"## `Dataset`: Define, train and validate\n",
239+
"## Enable deterministic training and regular `Dataset`\n",
258240
"\n",
259241
"Load each original dataset and transform each time it is needed."
260242
]
@@ -265,9 +247,10 @@
265247
"metadata": {},
266248
"outputs": [],
267249
"source": [
268-
"train_transforms = set_deterministic(train_transforms)\n",
269-
"train_ds = Dataset(data=train_files, transform=train_transforms)\n",
270-
"val_ds = Dataset(data=val_files, transform=val_transforms)\n",
250+
"set_determinism(seed=0)\n",
251+
"train_trans, val_trans = transformations()\n",
252+
"train_ds = Dataset(data=train_files, transform=train_trans)\n",
253+
"val_ds = Dataset(data=val_files, transform=val_trans)\n",
271254
"\n",
272255
"epoch_num, total_time, epoch_loss_values, metric_values, epoch_times = train_process(train_ds, val_ds)\n",
273256
"print('Total training time of {} epochs with regular Dataset: {:.4f}'.format(epoch_num, total_time))"
@@ -277,7 +260,7 @@
277260
"cell_type": "markdown",
278261
"metadata": {},
279262
"source": [
280-
"## `PersistentDataset`: Define, train and validate\n",
263+
"## Enable deterministic training and `PersistentDataset`\n",
281264
"\n",
282265
"Use persistent storage of non-random transformed training and validation data computed once and stored in persistently across runs"
283266
]
@@ -291,9 +274,10 @@
291274
"persistent_cache: Path = Path(\"./persistent_cache\")\n",
292275
"persistent_cache.mkdir(parents=True, exist_ok=True)\n",
293276
"\n",
294-
"train_transforms = set_deterministic(train_transforms)\n",
295-
"train_persitence_ds = PersistentDataset(data=train_files, transform=train_transforms, cache_dir=persistent_cache)\n",
296-
"val_persitence_ds = PersistentDataset(data=val_files, transform=val_transforms, cache_dir=persistent_cache)\n",
277+
"set_determinism(seed=0)\n",
278+
"train_trans, val_trans = transformations()\n",
279+
"train_persitence_ds = PersistentDataset(data=train_files, transform=train_trans, cache_dir=persistent_cache)\n",
280+
"val_persitence_ds = PersistentDataset(data=val_files, transform=val_trans, cache_dir=persistent_cache)\n",
297281
"\n",
298282
"persistence_epoch_num, persistence_total_time, persistence_epoch_loss_values, \\\n",
299283
" persistence_metric_values, persistence_epoch_times = \\\n",
@@ -306,7 +290,7 @@
306290
"cell_type": "markdown",
307291
"metadata": {},
308292
"source": [
309-
"## `CacheDataset`: Define, train and validate\n",
293+
"## Enable deterministic training and `CacheDataset`\n",
310294
"\n",
311295
"Precompute all non-random transforms of original data and store in memory."
312296
]
@@ -317,10 +301,11 @@
317301
"metadata": {},
318302
"outputs": [],
319303
"source": [
320-
"train_transforms = set_deterministic(train_transforms)\n",
304+
"set_determinism(seed=0)\n",
305+
"train_trans, val_trans = transformations()\n",
321306
"cache_init_start = time.time()\n",
322-
"cache_train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)\n",
323-
"cache_val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4)\n",
307+
"cache_train_ds = CacheDataset(data=train_files, transform=train_trans, cache_rate=1.0, num_workers=4)\n",
308+
"cache_val_ds = CacheDataset(data=val_files, transform=val_trans, cache_rate=1.0, num_workers=4)\n",
324309
"cache_init_time = time.time() - cache_init_start\n",
325310
"\n",
326311
"cache_epoch_num, cache_total_time, cache_epoch_loss_values, cache_metric_values, cache_epoch_times = \\\n",

0 commit comments

Comments
 (0)