From cde5b838848cd4ac6dce7487fae9b46b21ada4d2 Mon Sep 17 00:00:00 2001 From: ptran1203 Date: Thu, 1 Jul 2021 10:20:22 +0700 Subject: [PATCH] add weight path --- prediction.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/prediction.py b/prediction.py index e38df14..2a48d21 100644 --- a/prediction.py +++ b/prediction.py @@ -241,15 +241,13 @@ def detect_single_image(self, image, crop_sizes=[], show=False, tiling=False): return image, all_boxes, all_scores, all_classes -def get_inference_model(): +def get_inference_model(weight_path): num_of_classes = 7 model = m.RetinaNet(num_of_classes, backbone="densenet121") model.compile(optimizer="adam", loss=losses.RetinaNetLoss(num_of_classes)) - - # Trick: fit model first so the model can load the weight - model.fit(np.random.rand(1, 896, 2304, 3), np.random.rand(1, 386694, 5)) - image = tf.keras.Input(shape=[None, None, 3], name="image") - model.load_weights("./weight_dense.h5") + model.build((1, None, None, 3)) + image = tf.keras.InputLayer(shape=[None, None, 3], name="image") + model.load_weights(weight_path) predictions = model(image, training=False) detections = m.DecodePredictions(confidence_threshold=0.5, num_classes=num_of_classes, @@ -315,6 +313,8 @@ def combine_prediction( help="Path to input images") parser.add_argument("--output", dest="output_path", metavar="O", type=str, default="/data/result/submission.json", help="Output file path") + parser.add_argument("--weight", dest="weight_path", type=str, + default="pretrained_densenet121", help="Weight path") args = parser.parse_args() # Make prediction @@ -336,7 +336,7 @@ def combine_prediction( # Create submission.json submission = [] idx = 0 - predictor = Prediction(get_inference_model()) + predictor = Prediction(get_inference_model(args.weight)) print("Start predict...") start = datetime.datetime.now()