-
Notifications
You must be signed in to change notification settings - Fork 558
/
visualize.py
149 lines (130 loc) · 5.82 KB
/
visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import os
import argparse
import sys
import torch
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import (
build_detection_test_loader,
MetadataCatalog
)
from detectron2.data import detection_utils
from detectron2.engine import default_setup, launch
from detectron2.utils.visualizer import ColorMode
sys.path.append('../../..')
import tllib.vision.models.object_detection.meta_arch as models
import utils
def visualize(cfg, args, model):
for dataset_name in args.test:
data_loader = build_detection_test_loader(cfg, dataset_name)
# create folder
dirname = os.path.join(args.save_path, dataset_name)
os.makedirs(dirname, exist_ok=True)
metadata = MetadataCatalog.get(dataset_name)
n_current = 0
# switch to eval mode
model.eval()
with torch.no_grad():
for batch in data_loader:
if n_current >= args.n_visualizations:
break
batch_predictions = model(batch)
for per_image, predictions in zip(batch, batch_predictions):
instances = predictions["instances"].to(torch.device("cpu"))
# only visualize boxes with highest confidence
instances = instances[0: args.n_bboxes]
# only visualize boxes with confidence exceeding the threshold
instances = instances[instances.scores > args.threshold]
# visualize in reverse order of confidence
index = [i for i in range(len(instances))]
index.reverse()
instances = instances[index]
img = per_image["image"].permute(1, 2, 0).cpu().detach().numpy()
img = detection_utils.convert_image_to_rgb(img, cfg.INPUT.FORMAT)
# scale pred_box to original resolution
ori_height, ori_width, _ = img.shape
height, width = instances.image_size
ratio = ori_width / width
for i in range(len(instances.pred_boxes)):
instances.pred_boxes[i].scale(ratio, ratio)
# save original image
visualizer = utils.VisualizerWithoutAreaSorting(img, metadata=metadata,
instance_mode=ColorMode.IMAGE)
output = visualizer.draw_instance_predictions(predictions=instances)
filepath = str(n_current) + ".png"
filepath = os.path.join(dirname, filepath)
output.save(filepath)
n_current += 1
if n_current >= args.n_visualizations:
break
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(
cfg, args
) # if you don't like any of the default setup, write your own setup code
return cfg
def main(args):
cfg = setup(args)
meta_arch = cfg.MODEL.META_ARCHITECTURE
model = models.__dict__[meta_arch](cfg, finetune=True)
model.to(torch.device(cfg.MODEL.DEVICE))
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=False
)
visualize(cfg, args, model)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
parser.add_argument(
"--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
)
# PyTorch still may leave orphan processes in multi-gpu training.
# Therefore we use a deterministic way to obtain port,
# so that users are aware of orphan processes by seeing the port occupied.
port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
parser.add_argument(
"--dist-url",
default="tcp://127.0.0.1:{}".format(port),
help="initialization URL for pytorch distributed backend. See "
"https://pytorch.org/docs/stable/distributed.html for details.",
)
parser.add_argument(
"opts",
help="Modify config options by adding 'KEY VALUE' pairs at the end of the command. "
"See config references at "
"https://detectron2.readthedocs.io/modules/config.html#config-references",
default=None,
nargs=argparse.REMAINDER,
)
parser.add_argument('--test', nargs='+', help='test domain(s)')
parser.add_argument('--save-path', type=str,
help='where to save visualization results ')
parser.add_argument('--n-visualizations', default=100, type=int,
help='maximum number of images to visualize (default: 100)')
parser.add_argument('--threshold', default=0.5, type=float,
help='confidence threshold of bounding boxes to visualize (default: 0.5)')
parser.add_argument('--n-bboxes', default=10, type=int,
help='maximum number of bounding boxes to visualize in a single image (default: 10)')
args = parser.parse_args()
print("Command Line Args:", args)
args.test = utils.build_dataset(args.test[::2], args.test[1::2])
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)