Skip to content

Commit

Permalink
add weight path
Browse files Browse the repository at this point in the history
  • Loading branch information
ptran1203 committed Jul 1, 2021
1 parent 36358bb commit cde5b83
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,15 +241,13 @@ def detect_single_image(self, image, crop_sizes=[], show=False, tiling=False):
return image, all_boxes, all_scores, all_classes


def get_inference_model():
def get_inference_model(weight_path):
num_of_classes = 7
model = m.RetinaNet(num_of_classes, backbone="densenet121")
model.compile(optimizer="adam", loss=losses.RetinaNetLoss(num_of_classes))

# Trick: fit model first so the model can load the weight
model.fit(np.random.rand(1, 896, 2304, 3), np.random.rand(1, 386694, 5))
image = tf.keras.Input(shape=[None, None, 3], name="image")
model.load_weights("./weight_dense.h5")
model.build((1, None, None, 3))
image = tf.keras.InputLayer(shape=[None, None, 3], name="image")
model.load_weights(weight_path)
predictions = model(image, training=False)
detections = m.DecodePredictions(confidence_threshold=0.5,
num_classes=num_of_classes,
Expand Down Expand Up @@ -315,6 +313,8 @@ def combine_prediction(
help="Path to input images")
parser.add_argument("--output", dest="output_path", metavar="O", type=str,
default="/data/result/submission.json", help="Output file path")
parser.add_argument("--weight", dest="weight_path", type=str,
default="pretrained_densenet121", help="Weight path")
args = parser.parse_args()

# Make prediction
Expand All @@ -336,7 +336,7 @@ def combine_prediction(
# Create submission.json
submission = []
idx = 0
predictor = Prediction(get_inference_model())
predictor = Prediction(get_inference_model(args.weight))

print("Start predict...")
start = datetime.datetime.now()
Expand Down

0 comments on commit cde5b83

Please sign in to comment.