forked from mattdutson/xview2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
72 lines (52 loc) · 2.21 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env python3
import argparse
from data_generator import TestDataGenerator
from unet import create_model
from util import *
def postprocess(pred):
pred = tf.argmax(pred, axis=-1)
pred = tf.cast(pred, tf.uint8)
pred = tf.expand_dims(pred, axis=-1)
return pred
def test(args):
model = create_model(n_classes=5)
model.load_weights(args.model)
localization_model = None
if args.localization_model is not None:
localization_model = create_model(n_classes=2)
localization_model.load_weights(args.localization_model)
if not os.path.exists(args.prediction_dir):
os.makedirs(args.prediction_dir)
test_gen = TestDataGenerator(args.test_dir)
progress = 0.0
progress_step = 100.0 / len(test_gen)
print("Performing test inference...")
print("Progress: {:3.1f}%\r".format(progress), end="")
for i in range(len(test_gen)):
pre_post, index = test_gen[i]
pred = model.predict(pre_post)[0, :, :, :]
pred = postprocess(pred)
if localization_model is not None:
pred_localization = localization_model.predict(pre_post)[0, :, :, :]
pred_localization = postprocess(pred_localization)
pred = pred * pred_localization
write_png(pred, os.path.join(args.prediction_dir, "test_damage_{:05d}_prediction.png".format(index)))
write_png(pred, os.path.join(args.prediction_dir, "test_localization_{:05d}_prediction.png".format(index)))
progress += progress_step
print("Progress: {:3.1f}%\r".format(progress), end="")
print("\nDone.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-l", "--localization_model", type=str, default=None,
help="path for a separate localization model")
parser.add_argument(
"-m", "--model", type=str, default="model.json",
help="path for loading model weights")
parser.add_argument(
"-o", "--prediction_dir", type=str, default="predictions",
help="path for saving predictions")
parser.add_argument(
"-t", "--test_dir", default=os.path.join("dataset", "test"),
help="folder containing test data")
test(parser.parse_args())