forked from facebookresearch/Detic
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
97 lines (89 loc) · 3.83 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import sys
import cv2
import tempfile
from pathlib import Path
import cog
import time
# import some common detectron2 utilities
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
# Detic libraries
sys.path.insert(0, 'third_party/CenterNet2/')
from centernet.config import add_centernet_config
from detic.config import add_detic_config
from detic.modeling.utils import reset_cls_test
from detic.modeling.text.text_encoder import build_text_encoder
class Predictor(cog.Predictor):
def setup(self):
cfg = get_cfg()
add_centernet_config(cfg)
add_detic_config(cfg)
cfg.merge_from_file("configs/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml")
cfg.MODEL.WEIGHTS = 'Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth'
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = 'rand'
cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = True
self.predictor = DefaultPredictor(cfg)
self.BUILDIN_CLASSIFIER = {
'lvis': 'datasets/metadata/lvis_v1_clip_a+cname.npy',
'objects365': 'datasets/metadata/o365_clip_a+cnamefix.npy',
'openimages': 'datasets/metadata/oid_clip_a+cname.npy',
'coco': 'datasets/metadata/coco_clip_a+cname.npy',
}
self.BUILDIN_METADATA_PATH = {
'lvis': 'lvis_v1_val',
'objects365': 'objects365_v2_val',
'openimages': 'oid_val_expanded',
'coco': 'coco_2017_val',
}
@cog.input(
"image",
type=Path,
help="input image",
)
@cog.input(
"vocabulary",
type=str,
default='lvis',
options=['lvis', 'objects365', 'openimages', 'coco', 'custom'],
help="Choose vocabulary",
)
@cog.input(
"custom_vocabulary",
type=str,
default=None,
help="Type your own vocabularies, separated by coma ','",
)
def predict(self, image, vocabulary, custom_vocabulary):
image = cv2.imread(str(image))
if not vocabulary == 'custom':
metadata = MetadataCatalog.get(self.BUILDIN_METADATA_PATH[vocabulary])
classifier = self.BUILDIN_CLASSIFIER[vocabulary]
num_classes = len(metadata.thing_classes)
reset_cls_test(self.predictor.model, classifier, num_classes)
else:
assert custom_vocabulary is not None and len(custom_vocabulary.split(',')) > 0, \
"Please provide your own vocabularies when vocabulary is set to 'custom'."
metadata = MetadataCatalog.get(str(time.time()))
metadata.thing_classes = custom_vocabulary.split(',')
classifier = get_clip_embeddings(metadata.thing_classes)
num_classes = len(metadata.thing_classes)
reset_cls_test(self.predictor.model, classifier, num_classes)
# Reset visualization threshold
output_score_threshold = 0.3
for cascade_stages in range(len(self.predictor.model.roi_heads.box_predictor)):
self.predictor.model.roi_heads.box_predictor[cascade_stages].test_score_thresh = output_score_threshold
outputs = self.predictor(image)
v = Visualizer(image[:, :, ::-1], metadata)
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
out_path = Path(tempfile.mkdtemp()) / "out.png"
cv2.imwrite(str(out_path), out.get_image()[:, :, ::-1])
return out_path
def get_clip_embeddings(vocabulary, prompt='a '):
text_encoder = build_text_encoder(pretrain=True)
text_encoder.eval()
texts = [prompt + x for x in vocabulary]
emb = text_encoder(texts).detach().permute(1, 0).contiguous().cpu()
return emb