-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
146 lines (106 loc) · 3.83 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
# import some common libraries
import numpy as np
import os, json, cv2, random
import warnings
warnings.filterwarnings("ignore")
# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
from .exceptions import handle_exception
import matplotlib.pyplot as plt
from detectron2.utils.visualizer import ColorMode, GenericMask
import time
import uuid
def format_output(output, custom_metadata):
items = []
if "instances" in output.keys():
pred_classes = output["instances"].pred_classes
pred_scores = output["instances"].scores.tolist()
pred_classes = np.array([custom_metadata.thing_classes[i] for i in pred_classes])
pred_classes_counts = np.unique(pred_classes, return_counts=True)
for i, pred in enumerate(pred_classes_counts[0]):
obj = {}
obj["name"] = pred
obj["count"] = pred_classes_counts[1][i]
obj["location"] = []
idxs = np.where(pred_classes == pred)
for idx in idxs[0]:
generic_mask = GenericMask(np.array(output["instances"].pred_masks[idx], dtype=np.uint8),output["height"], output["width"])
mask = generic_mask.polygons
mask = ",".join(str(v) for v in mask[0])
obj["location"].append({
"coordinates": mask,
"accuracy_pct": round(pred_scores[idx]*100)
})
items.append(obj)
output["items"] = items
return output
def prepare_model(data_repo, model_repo):
basic_config = f"{data_repo}/basic_config.yaml"
adv_config = f"{data_repo}/advanced_config.yaml"
# Load annotations json
with open(f"{data_repo}/metadata.json") as f:
categories = json.load(f)
ids = [category["id"] for category in categories]
categories = [category["name"] for category in categories]
#add dummy class if 0 not in categories
if 0 not in ids:
categories.insert(0, "N/A")
#Register metadata
MetadataCatalog.get("custom_test").set(thing_classes=categories)
#Get metadata
custom_metadata = MetadataCatalog.get("custom_test")
model_to_use = "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"
cfg = get_cfg()
#add custom key "INFER_DEBUG" and then merge
cfg.INFER_DEBUG = None
if os.path.exists(adv_config):
cfg.merge_from_file(adv_config)
else:
print("Merging custom configurations")
cfg.merge_from_file(basic_config)
cfg.INPUT.MASK_FORMAT="polygon"
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
#prediction
cfg.MODEL.WEIGHTS = f"{model_repo}/model_final.pth"
predictor = DefaultPredictor(cfg)
return predictor, cfg.INFER_DEBUG, custom_metadata
def infer(debug_path, predictor, debug, meta, files):
os.makedirs(debug_path, exist_ok=True)
for file_ in files:
# Get file BytesIO object
byte_img = file_.file._file
# Converting to NumPy array to make it compatible for cv2 decoding
file_bytes = np.asarray(bytearray(byte_img.read()), dtype=np.uint8)
output = {}
output["scale_factor"] = 0.0
output["image_name"] = ""
output["height"], output["width"] = 0, 0
output["error"] = None
try:
# Loading image through bytes
im = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
scale = 1
output = predictor(im)
if debug:
v = Visualizer(im[:, :, ::-1],
metadata=meta,
scale=scale,
instance_mode=ColorMode.IMAGE_BW
)
out = v.draw_instance_predictions(output["instances"].to("cpu"))
im = out.get_image()[:, :, ::-1]
file_name = str(uuid.uuid4())+".jpg"
cv2.imwrite(f"{debug_path}/{file_name}", im)
output["scale_factor"] = scale
output["image_name"] = file_name
output["height"], output["width"] = im.shape[:2]
except Exception as e:
output["error"] = handle_exception(e)
return output
output = format_output(output, meta)
return output