-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest.py
74 lines (63 loc) · 2.66 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import argparse
from networks.attention_unet import AttentionUNet
from networks.segnet import Segnet
from networks.unet import UNet
from networks.squeeze_unet import SqueezeUNet
from networks.att_squeeze_unet import AttSqueezeUNet
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import mixed_precision
from tensorflow.keras.layers import Input
from utils import *
from loss import *
import random
import cv2
import cv2 as cv
from os.path import exists
import numpy as np
tf.config.run_functions_eagerly(True)
print(tf.executing_eagerly())
physical_devices = tf.config.experimental.list_physical_devices('GPU')
assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
tf.config.experimental.set_memory_growth(physical_devices[0], True)
policy = mixed_precision.Policy('float32')
mixed_precision.set_global_policy(policy)
def argparser():
# command line argments
parser = argparse.ArgumentParser(description="Attention Squeeze U-Net")
parser.add_argument('--network', dest='network', type=str, default="attention_squeeze_unet", help='Select network: attention_squeeze_unet, squeeze_unet, attention_unet, unet, segnet')
parser.add_argument("--test_dir", help="train test list path")
parser.add_argument("--resume", help="path to the model to resume")
args = parser.parse_args()
return args
def main(args):
from glob import glob
list_images = sorted(glob(args.test_dir+"/*.jpg"))
list_maps = sorted(glob(args.test_dir+"/*.png"))
assert len(list_images) != 0, "Error the testing image array is empty!"
assert len(list_images) == len(list_maps), "Error the testing image number differs from the number of masks"
size = (384, 512)
test_gen = test_generator(list_images, list_maps, size=size)
model = None
if args.network == "attention_unet":
model = AttentionUNet(size=size)
elif args.network == "attention_squeeze_unet":
model = AttSqueezeUNet()
elif args.network == "squeeze_unet":
model = SqueezeUNet()
elif args.network == "segnet":
model = Segnet(size=size)
elif args.network == "unet":
model = UNet(size=size)
else:
raise ValueError("Network " + args.network + " unknown!")
model.build(input_shape=(1, size[1], size[0], 3))
model.compile(loss=focal_tversky_loss, optimizer=Adam(lr=0.001), metrics=["acc", dice_coef, jaccard_coef])
if exists(args.resume):
model.load_weights(args.resume)
else:
raise ValueError("File {file} does not exist!".format(file=args.resume))
model.evaluate(test_gen)
if __name__ == "__main__":
args = argparser()
main(args)