From 36358bb057b88178f9712b604c2e5b14f0f6a0a4 Mon Sep 17 00:00:00 2001 From: ptran1203 Date: Thu, 1 Jul 2021 10:12:53 +0700 Subject: [PATCH] improve inference --- prediction.py | 52 +++++++++++++++++---------------------------------- 1 file changed, 17 insertions(+), 35 deletions(-) diff --git a/prediction.py b/prediction.py index 7303c7c..e38df14 100644 --- a/prediction.py +++ b/prediction.py @@ -9,6 +9,8 @@ import json import argparse import datetime +import glob +import cv2 try: from google.colab.patches import cv2_imshow @@ -84,10 +86,8 @@ def get_slice_indices(self, full_size): return slices - def get_input_img(self, sample, crop=False, crop_size=512): - sample = tf.io.parse_single_example(sample, data_processing.image_feature_description) - - image = tf.image.decode_png(sample["image"]) + def get_input_img(self, image, crop=False, crop_size=512): + image = tf.convert_to_tensor(image) if self.dynamic_size: shape = image.shape @@ -150,7 +150,7 @@ def big_box_filter(image, boxes, scores, classes, threshold=.12): return tf.stack(fboxes), tf.stack(fscores), tf.stack(fclasses) - def detect_single_image(self, sample, crop_sizes=[], show=False, tiling=False): + def detect_single_image(self, image, crop_sizes=[], show=False, tiling=False): all_boxes = [] all_scores = [] all_classes = [] @@ -162,7 +162,7 @@ def detect_single_image(self, sample, crop_sizes=[], show=False, tiling=False): detected = False if tiling: - input_img, image, ratio = self.get_input_img(sample, crop=True, crop_size=self.tiling_size) + input_img, image, ratio = self.get_input_img(image, crop=True, crop_size=self.tiling_size) detections = self.inference_model.predict_on_batch(tf.concat(input_img, 0)) @@ -187,7 +187,7 @@ def detect_single_image(self, sample, crop_sizes=[], show=False, tiling=False): show and print(f"Found {small_detections} objects in small parts - {sscores}") for crop_size in crop_sizes: - input_img, image, ratio = self.get_input_img(sample, crop=False, crop_size=crop_size) + input_img, image, ratio = self.get_input_img(image, crop=False, crop_size=crop_size) detections = self.inference_model.predict(input_img) num_detections = detections.valid_detections[0] @@ -307,28 +307,12 @@ def combine_prediction( tf.gather(scores / highest, selected_indices), tf.gather(classes, selected_indices)) -def get_test_data_info(input_path): - id_list = os.listdir(input_path) - id_list = sorted([int(c.split(".")[0]) for c in id_list]) - - data_info = [ - { - "bbox": [[0, 0, 0, 0]], - "label": [0], - "id": x, - } for x in id_list - ] - - return data_info - if __name__ == "__main__": parser = argparse.ArgumentParser(description='Traffic sign detection') parser.add_argument("--input", dest="input_path", metavar="I", type=str, default="/data/images", help="Path to input images") - parser.add_argument("--test_file", dest="test_file", default="./images_private_test.tfrecords", - metavar="F", type=str, help="Tfrecords test file",) parser.add_argument("--output", dest="output_path", metavar="O", type=str, default="/data/result/submission.json", help="Output file path") args = parser.parse_args() @@ -338,19 +322,16 @@ def get_test_data_info(input_path): output_path = args.output_path if output_path.split(".")[-1] != "json": - raise("Output file should be json format") - - TFRECORDS_FILE_PRIVATE_TEST = args.test_file + raise ValueError("Output file should be json format") # Get list of test images - data_info = get_test_data_info(input_path) - - print("Test on {} images".format(len(data_info))) - - print("Create tfrecords dataset") - data_processing.write_tfrecords(data_info, TFRECORDS_FILE_PRIVATE_TEST, input_path) + if os.path.isdir(input_path): + image_files = glob.glob(os.path.join(input_path, '*')) + else: + # it's file + image_files = [input_path] - test_dataset = tf.data.TFRecordDataset(TFRECORDS_FILE_PRIVATE_TEST) + print(f"Test on {len(image_files)} images") # Create submission.json submission = [] @@ -359,8 +340,9 @@ def get_test_data_info(input_path): print("Start predict...") start = datetime.datetime.now() - for sample in test_dataset: - image, boxes, scores, classes = predictor.detect_single_image(sample) + for file_path in image_files: + image = cv2.imread(file_path)[..., ::-1] + image, boxes, scores, classes = predictor.detect_single_image(image) if not isinstance(boxes, list): boxes = boxes.numpy() scores = scores.numpy()