-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_transformer.py
90 lines (77 loc) · 3.28 KB
/
run_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import argparse
import os
from dywsss.tool import pyutils
from mmcv import Config, DictAction
from dywsss.pipeline.train_cam_transformer import train
from dywsss.pipeline.infer_multi_scale_transformer import infer_multi_scale
from dywsss.pipeline.infer_multi_scale_swin_transformer import infer_multi_scale_swin
from dywsss.tool.torch_utils import set_seed
import logging
def parse_args():
parser = argparse.ArgumentParser(description='Train a models')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
parser.add_argument('--tag', help='the tag')
parser.add_argument(
'--no-validate',
action='store_true',
help='whether not to evaluate the checkpoint during training')
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument('--device', help='device used for training')
group_gpus.add_argument(
'--gpus',
type=int,
help='number of gpus to use '
'(only applicable to non-distributed training)')
group_gpus.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='ids of gpus to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--weights', type=str, default='best.pth')
parser.add_argument(
'--options', nargs='+', action=DictAction, help='arguments in dict')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
cfg = Config.fromfile(args.config)
if args.options is not None:
cfg.merge_from_dict(args.options)
args = cfg
set_seed(cfg)
# Output Path
args.model_dir = os.path.join('work_dirs', args.session_name, "model")
args.test_dir = os.path.join('work_dirs', args.session_name, "test")
args.log_dir = os.path.join('work_dirs', args.session_name, "log")
args.tensorboard_dir = os.path.join(
'work_dirs', args.session_name, "tensorboard")
os.makedirs("work_dirs", exist_ok=True)
os.makedirs(args.model_dir, exist_ok=True)
os.makedirs(args.test_dir, exist_ok=True)
os.makedirs(args.log_dir, exist_ok=True)
os.makedirs(args.tensorboard_dir, exist_ok=True)
print(vars(args))
if args.train_multi_scale:
timer = pyutils.Timer('train in multi-scale strategy:')
train(args)
if args.gen_mask_for_multi_crop:
timer = pyutils.Timer('infer multi_scale cam and make rough mask:')
args.weights = os.path.join(args.model_dir, args.weights)
args.infer_list = os.path.join('voc12', args.infer_list)
args.out_cam = os.path.join(args.test_dir,
f'cam_{args.infer_list.split("/")[-1].split(".")[0]}_{args.weights.split("/")[-1].split(".")[0]}')
args.out_crf = os.path.join(args.test_dir, 'train_mask')
if 'swin' in args.network:
infer_multi_scale_swin(args)
else:
infer_multi_scale(args)