-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathmini.py
90 lines (71 loc) · 2.78 KB
/
mini.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import torch.distributed as dist
import torch.backends.cudnn as cudnn
import os
import random
import argparse
import numpy as np
import datetime
import warnings
warnings.filterwarnings('ignore',
'Argument interpolation should be of type InterpolationMode instead of int',
UserWarning)
warnings.filterwarnings('ignore',
'Leaking Caffe2 thread-pool after fork',
UserWarning)
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank,
timeout=datetime.timedelta(seconds=3600))
torch.distributed.barrier()
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def get_args_parser():
parser = argparse.ArgumentParser('training and evaluation script', add_help=False)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
parser.add_argument('--device', default='cuda', help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
return parser
def main(args):
os.environ['LOCAL_RANK'] = str(args.local_rank)
init_distributed_mode(args)
device = torch.device(args.device)
print(args)
# fix the seed for reproducibility
seed = args.seed + get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = True
print(" we reach the post init ")
if __name__ == '__main__':
parser = argparse.ArgumentParser('training and evaluation script', parents=[get_args_parser()])
args = parser.parse_args()
main(args)