forked from Project-MONAI/tutorials
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbrats_training_ddp.py
472 lines (404 loc) · 18.2 KB
/
brats_training_ddp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This example shows how to execute distributed training based on PyTorch native `DistributedDataParallel` module.
It can run on several nodes with multiple GPU devices on every node.
This example is a real-world task based on Decathlon challenge Task01: Brain Tumor segmentation.
So it's more complicated than other distributed training demo examples.
Main steps to set up the distributed training:
- Execute `torch.distributed.launch` to create processes on every node for every GPU.
It receives parameters as below:
`--nproc_per_node=NUM_GPUS_PER_NODE`
`--nnodes=NUM_NODES`
`--node_rank=INDEX_CURRENT_NODE`
`--master_addr="192.168.1.1"`
`--master_port=1234`
For more details, refer to https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py.
Alternatively, we can also use `torch.multiprocessing.spawn` to start program, but it that case, need to handle
all the above parameters and compute `rank` manually, then set to `init_process_group`, etc.
`torch.distributed.launch` is even more efficient than `torch.multiprocessing.spawn` during training.
- Use `init_process_group` to initialize every process, every GPU runs in a separate process with unique rank.
Here we use `NVIDIA NCCL` as the backend and must set `init_method="env://"` if use `torch.distributed.launch`.
- Wrap the model with `DistributedDataParallel` after moving to expected device.
- Partition dataset before training, so every rank process will only handle its own data partition.
Note:
`torch.distributed.launch` will launch `nnodes * nproc_per_node = world_size` processes in total.
Suggest setting exactly the same software environment for every node, especially `PyTorch`, `nccl`, etc.
A good practice is to use the same MONAI docker image for all nodes directly.
Example script to execute this program on every node:
python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_PER_NODE
--nnodes=NUM_NODES --node_rank=INDEX_CURRENT_NODE
--master_addr="192.168.1.1" --master_port=1234
brats_training_ddp.py -d DIR_OF_TESTDATA
This example was tested with [Ubuntu 16.04/20.04], [NCCL 2.6.3].
Referring to: https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
Some codes are taken from https://github.com/pytorch/examples/blob/master/imagenet/main.py
"""
import argparse
import os
import sys
import time
import warnings
import numpy as np
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.tensorboard import SummaryWriter
from monai.apps import DecathlonDataset
from monai.data import DataLoader, partition_dataset
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.networks.nets import SegResNet, UNet
from monai.transforms import (
Activations,
AsChannelFirstd,
AsDiscrete,
CenterSpatialCropd,
Compose,
LoadImaged,
MapTransform,
NormalizeIntensityd,
Orientationd,
RandFlipd,
RandScaleIntensityd,
RandShiftIntensityd,
RandSpatialCropd,
Spacingd,
ToTensord,
)
from monai.utils import set_determinism
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
"""
Convert labels to multi channels based on brats classes:
label 1 is the peritumoral edema
label 2 is the GD-enhancing tumor
label 3 is the necrotic and non-enhancing tumor core
The possible classes are TC (Tumor core), WC (Whole tumor)
and ET (Enhancing tumor).
"""
def __call__(self, data):
d = dict(data)
for key in self.keys:
result = list()
# merge label 2 and label 3 to construct TC
result.append(np.logical_or(d[key] == 2, d[key] == 3))
# merge labels 1, 2 and 3 to construct WC
result.append(np.logical_or(np.logical_or(d[key] == 2, d[key] == 3), d[key] == 1))
# label 2 is ET
result.append(d[key] == 2)
d[key] = np.stack(result, axis=0).astype(np.float32)
return d
class BratsCacheDataset(DecathlonDataset):
def __init__(
self,
root_dir,
section,
transform=LoadImaged(["image", "label"]),
cache_rate=1.0,
num_workers=0,
shuffle=False,
) -> None:
if not os.path.isdir(root_dir):
raise ValueError("Root directory root_dir must be a directory.")
self.section = section
self.shuffle = shuffle
self.val_frac = 0.2
self.set_random_state(seed=0)
dataset_dir = os.path.join(root_dir, "Task01_BrainTumour")
if not os.path.exists(dataset_dir):
raise RuntimeError(
f"Cannot find dataset directory: {dataset_dir}, please download it from Decathlon challenge."
)
data = self._generate_data_list(dataset_dir)
super(DecathlonDataset, self).__init__(data, transform, cache_rate=cache_rate, num_workers=num_workers)
def _generate_data_list(self, dataset_dir):
data = super()._generate_data_list(dataset_dir)
# partition dataset based on current rank number, every rank trains with its own data
# it can avoid duplicated caching content in each rank, but will not do global shuffle before every epoch
return partition_dataset(
data=data,
num_partitions=dist.get_world_size(),
shuffle=self.shuffle,
seed=0,
drop_last=False,
even_divisible=self.shuffle,
)[dist.get_rank()]
def main_worker(args):
# disable logging for processes except 0 on every node
if args.local_rank != 0:
f = open(os.devnull, "w")
sys.stdout = sys.stderr = f
if not os.path.exists(args.dir):
raise FileNotFoundError(f"Missing directory {args.dir}")
# initialize the distributed training process, every GPU runs in a process
dist.init_process_group(backend="nccl", init_method="env://")
total_start = time.time()
train_transforms = Compose(
[
# load 4 Nifti images and stack them together
LoadImaged(keys=["image", "label"]),
AsChannelFirstd(keys="image"),
ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
Orientationd(keys=["image", "label"], axcodes="RAS"),
RandSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 64], random_size=False),
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
ToTensord(keys=["image", "label"]),
]
)
# create a training data loader
train_ds = BratsCacheDataset(
root_dir=args.dir,
transform=train_transforms,
section="training",
num_workers=4,
cache_rate=args.cache_rate,
shuffle=True,
)
train_loader = DataLoader(
train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True
)
# validation transforms and dataset
val_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
AsChannelFirstd(keys="image"),
ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
Orientationd(keys=["image", "label"], axcodes="RAS"),
CenterSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 64]),
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
ToTensord(keys=["image", "label"]),
]
)
val_ds = BratsCacheDataset(
root_dir=args.dir,
transform=val_transforms,
section="validation",
num_workers=4,
cache_rate=args.cache_rate,
shuffle=False,
)
val_loader = DataLoader(
val_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
)
if dist.get_rank() == 0:
# Logging for TensorBoard
writer = SummaryWriter(log_dir=args.log_dir)
# create UNet, DiceLoss and Adam optimizer
device = torch.device(f"cuda:{args.local_rank}")
torch.cuda.set_device(device)
if args.network == "UNet":
model = UNet(
dimensions=3,
in_channels=4,
out_channels=3,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
).to(device)
else:
model = SegResNet(in_channels=4, out_channels=3, init_filters=16, dropout_prob=0.2).to(device)
loss_function = DiceLoss(to_onehot_y=False, sigmoid=True, squared_pred=True)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5, amsgrad=True)
# wrap the model with DistributedDataParallel module
model = DistributedDataParallel(model, device_ids=[device])
# start a typical PyTorch training
total_epoch = args.epochs
best_metric = -1000000
best_metric_epoch = -1
epoch_time = AverageMeter("Time", ":6.3f")
progress = ProgressMeter(total_epoch, [epoch_time], prefix="Epoch: ")
end = time.time()
print(f"Time elapsed before training: {end-total_start}")
for epoch in range(total_epoch):
train_loss = train(train_loader, model, loss_function, optimizer, epoch, args, device)
epoch_time.update(time.time() - end)
if epoch % args.print_freq == 0:
progress.display(epoch)
if dist.get_rank() == 0:
writer.add_scalar("Loss/train", train_loss, epoch)
if (epoch + 1) % args.val_interval == 0:
metric, metric_tc, metric_wt, metric_et = evaluate(model, val_loader, device)
if dist.get_rank() == 0:
writer.add_scalar("Mean Dice/val", metric, epoch)
writer.add_scalar("Mean Dice TC/val", metric_tc, epoch)
writer.add_scalar("Mean Dice WT/val", metric_wt, epoch)
writer.add_scalar("Mean Dice ET/val", metric_et, epoch)
if metric > best_metric:
best_metric = metric
best_metric_epoch = epoch + 1
print(
f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}"
f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
)
end = time.time()
print(f"Time elapsed after epoch {epoch + 1} is {end - total_start}")
if dist.get_rank() == 0:
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
# all processes should see same parameters as they all start from same
# random parameters and gradients are synchronized in backward passes,
# therefore, saving it in one process is sufficient
torch.save(model.state_dict(), "final_model.pth")
writer.flush()
dist.destroy_process_group()
def train(train_loader, model, criterion, optimizer, epoch, args, device):
batch_time = AverageMeter("Time", ":6.3f")
data_time = AverageMeter("Data", ":6.3f")
losses = AverageMeter("Loss", ":.4e")
progress = ProgressMeter(len(train_loader), [batch_time, data_time, losses], prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i, batch_data in enumerate(train_loader):
image = batch_data["image"].to(device, non_blocking=True)
target = batch_data["label"].to(device, non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# compute output
optimizer.zero_grad()
output = model(image)
loss = criterion(output, target)
# record loss
losses.update(loss.item(), image.size(0))
# compute gradient and do GD step
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 10 == 0:
progress.display(i)
return losses.avg
def evaluate(model, data_loader, device):
metric = torch.zeros(8, dtype=torch.float, device=device)
model.eval()
with torch.no_grad():
dice_metric = DiceMetric(include_background=True, reduction="mean")
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
for val_data in data_loader:
val_inputs, val_labels = (
val_data["image"].to(device, non_blocking=True),
val_data["label"].to(device, non_blocking=True),
)
val_outputs = model(val_inputs)
val_outputs = post_trans(val_outputs)
# compute overall mean dice
value, not_nans = dice_metric(y_pred=val_outputs, y=val_labels)
value = value.squeeze()
metric[0] += value * not_nans
metric[1] += not_nans
# compute mean dice for TC
value_tc, not_nans = dice_metric(y_pred=val_outputs[:, 0:1], y=val_labels[:, 0:1])
value_tc = value_tc.squeeze()
metric[2] += value_tc * not_nans
metric[3] += not_nans
# compute mean dice for WT
value_wt, not_nans = dice_metric(y_pred=val_outputs[:, 1:2], y=val_labels[:, 1:2])
value_wt = value_wt.squeeze()
metric[4] += value_wt * not_nans
metric[5] += not_nans
# compute mean dice for ET
value_et, not_nans = dice_metric(y_pred=val_outputs[:, 2:3], y=val_labels[:, 2:3])
value_et = value_et.squeeze()
metric[6] += value_et * not_nans
metric[7] += not_nans
# synchronizes all processes and reduce results
dist.barrier()
dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM)
metric = metric.tolist()
return metric[0] / metric[1], metric[2] / metric[3], metric[4] / metric[5], metric[6] / metric[7]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--dir", default="./testdata", type=str, help="directory of Brain Tumor dataset.")
# must parse the command-line argument: ``--local_rank=LOCAL_PROCESS_RANK``, which will be provided by DDP
parser.add_argument("--local_rank", type=int, help="node rank for distributed training")
parser.add_argument(
"-j", "--workers", default=1, type=int, metavar="N", help="number of data loading workers (default: 1)"
)
parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
parser.add_argument("--lr", default=1e-4, type=float, help="learning rate")
parser.add_argument(
"-b",
"--batch_size",
default=4,
type=int,
metavar="N",
help="mini-batch size (default: 256), this is the total "
"batch size of all GPUs on the current node when "
"using Data Parallel or Distributed Data Parallel",
)
parser.add_argument("-p", "--print_freq", default=10, type=int, metavar="N", help="print frequency (default: 10)")
parser.add_argument(
"-e", "--evaluate", dest="evaluate", action="store_true", help="evaluate model on validation set"
)
parser.add_argument("--seed", default=None, type=int, help="seed for initializing training.")
parser.add_argument("--cache_rate", type=float, default=1.0)
parser.add_argument("--val_interval", type=int, default=5)
parser.add_argument("--network", type=str, default="UNet", choices=["UNet", "SegResNet"])
parser.add_argument("--log_dir", type=str, default=None)
args = parser.parse_args()
if args.seed is not None:
set_determinism(seed=args.seed)
warnings.warn(
"You have chosen to seed training. "
"This will turn on the CUDNN deterministic setting, "
"which can slow down your training considerably! "
"You may see unexpected behavior when restarting "
"from checkpoints."
)
main_worker(args=args)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=":f"):
self.name = name
self.fmt = fmt
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print("\t".join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = "{:" + str(num_digits) + "d}"
return "[" + fmt + "/" + fmt.format(num_batches) + "]"
# usage example(refer to https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py):
# python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_PER_NODE
# --nnodes=NUM_NODES --node_rank=INDEX_CURRENT_NODE
# --master_addr="10.110.44.150" --master_port=1234
# brats_training_ddp.py -d DIR_OF_TESTDATA
if __name__ == "__main__":
main()