forked from megvii-research/mdistiller
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_ray_dist.py
83 lines (69 loc) · 2.35 KB
/
train_ray_dist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import argparse
import time
import ray
import os
from tools.train import main as train
from mdistiller.engine.cfg import CFG as cfg
from mdistiller.engine.utils import log_msg
@ray.remote
def run(cfg, resume, opts, worker_id, gpu_id):
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
# torch.cuda.set_device()
try:
train(cfg, resume, opts, group_flag=True)
except (Exception, KeyboardInterrupt) as e:
print(log_msg(f"worker {worker_id} fail: {e}", "ERROR"))
if cfg.LOG.WANDB:
try:
import wandb
wandb.finish(exit_code=1)
except Exception as e:
print(
log_msg(f"worker {worker_id} failed to exit wandb: {e}", "ERROR"))
else:
if cfg.LOG.WANDB:
try:
import wandb
wandb.finish(exit_code=0)
except Exception as e:
print(
log_msg(f"worker {worker_id} failed to exit wandb: {e}", "ERROR"))
if __name__ == "__main__":
parser = argparse.ArgumentParser("training for knowledge distillation.")
parser.add_argument("--cfg", type=str, default="")
parser.add_argument("--num_tests", type=int, default=1)
parser.add_argument("--resume", action="store_true")
parser.add_argument("opts", nargs="*")
args = parser.parse_args()
cfg.merge_from_file(args.cfg)
cfg.merge_from_list(args.opts)
cfg.freeze()
gpu_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")
gpu_ids = [int(i) for i in gpu_ids]
gpu_cnt = 0
print("num_tests:", args.num_tests)
# manually set gpus(for reusing single gpu)
ray.init(num_cpus=args.num_tests)
try:
tasks = []
for i in range(args.num_tests):
print(f"Start test {i}, use GPU {gpu_ids[gpu_cnt]}")
tasks.append(
run.remote(
cfg=cfg,
resume=args.resume,
opts=args.opts,
worker_id=i,
gpu_id=gpu_ids[gpu_cnt]
)
)
gpu_cnt = (gpu_cnt+1) % len(gpu_ids)
# join
ray.wait(tasks, num_returns=len(tasks))
except:
print(log_msg("Training failed", "ERROR"))
finally:
for task in tasks:
ray.cancel(task)
time.sleep(30)
ray.shutdown()