Skip to content

Commit

Permalink
improve inference
Browse files Browse the repository at this point in the history
  • Loading branch information
ptran1203 committed Jul 1, 2021
1 parent 9a3998b commit 36358bb
Showing 1 changed file with 17 additions and 35 deletions.
52 changes: 17 additions & 35 deletions prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import json
import argparse
import datetime
import glob
import cv2

try:
from google.colab.patches import cv2_imshow
Expand Down Expand Up @@ -84,10 +86,8 @@ def get_slice_indices(self, full_size):

return slices

def get_input_img(self, sample, crop=False, crop_size=512):
sample = tf.io.parse_single_example(sample, data_processing.image_feature_description)

image = tf.image.decode_png(sample["image"])
def get_input_img(self, image, crop=False, crop_size=512):
image = tf.convert_to_tensor(image)

if self.dynamic_size:
shape = image.shape
Expand Down Expand Up @@ -150,7 +150,7 @@ def big_box_filter(image, boxes, scores, classes, threshold=.12):
return tf.stack(fboxes), tf.stack(fscores), tf.stack(fclasses)


def detect_single_image(self, sample, crop_sizes=[], show=False, tiling=False):
def detect_single_image(self, image, crop_sizes=[], show=False, tiling=False):
all_boxes = []
all_scores = []
all_classes = []
Expand All @@ -162,7 +162,7 @@ def detect_single_image(self, sample, crop_sizes=[], show=False, tiling=False):

detected = False
if tiling:
input_img, image, ratio = self.get_input_img(sample, crop=True, crop_size=self.tiling_size)
input_img, image, ratio = self.get_input_img(image, crop=True, crop_size=self.tiling_size)

detections = self.inference_model.predict_on_batch(tf.concat(input_img, 0))

Expand All @@ -187,7 +187,7 @@ def detect_single_image(self, sample, crop_sizes=[], show=False, tiling=False):
show and print(f"Found {small_detections} objects in small parts - {sscores}")

for crop_size in crop_sizes:
input_img, image, ratio = self.get_input_img(sample, crop=False, crop_size=crop_size)
input_img, image, ratio = self.get_input_img(image, crop=False, crop_size=crop_size)
detections = self.inference_model.predict(input_img)
num_detections = detections.valid_detections[0]

Expand Down Expand Up @@ -307,28 +307,12 @@ def combine_prediction(
tf.gather(scores / highest, selected_indices),
tf.gather(classes, selected_indices))

def get_test_data_info(input_path):
id_list = os.listdir(input_path)
id_list = sorted([int(c.split(".")[0]) for c in id_list])

data_info = [
{
"bbox": [[0, 0, 0, 0]],
"label": [0],
"id": x,
} for x in id_list
]

return data_info


if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Traffic sign detection')
parser.add_argument("--input", dest="input_path",
metavar="I", type=str, default="/data/images",
help="Path to input images")
parser.add_argument("--test_file", dest="test_file", default="./images_private_test.tfrecords",
metavar="F", type=str, help="Tfrecords test file",)
parser.add_argument("--output", dest="output_path", metavar="O", type=str,
default="/data/result/submission.json", help="Output file path")
args = parser.parse_args()
Expand All @@ -338,19 +322,16 @@ def get_test_data_info(input_path):
output_path = args.output_path

if output_path.split(".")[-1] != "json":
raise("Output file should be json format")

TFRECORDS_FILE_PRIVATE_TEST = args.test_file
raise ValueError("Output file should be json format")

# Get list of test images
data_info = get_test_data_info(input_path)

print("Test on {} images".format(len(data_info)))

print("Create tfrecords dataset")
data_processing.write_tfrecords(data_info, TFRECORDS_FILE_PRIVATE_TEST, input_path)
if os.path.isdir(input_path):
image_files = glob.glob(os.path.join(input_path, '*'))
else:
# it's file
image_files = [input_path]

test_dataset = tf.data.TFRecordDataset(TFRECORDS_FILE_PRIVATE_TEST)
print(f"Test on {len(image_files)} images")

# Create submission.json
submission = []
Expand All @@ -359,8 +340,9 @@ def get_test_data_info(input_path):

print("Start predict...")
start = datetime.datetime.now()
for sample in test_dataset:
image, boxes, scores, classes = predictor.detect_single_image(sample)
for file_path in image_files:
image = cv2.imread(file_path)[..., ::-1]
image, boxes, scores, classes = predictor.detect_single_image(image)
if not isinstance(boxes, list):
boxes = boxes.numpy()
scores = scores.numpy()
Expand Down

0 comments on commit 36358bb

Please sign in to comment.