-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path03-predict.py
27 lines (22 loc) · 1.07 KB
/
03-predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from tqdm import tqdm
import glob
from utils import *
from constants import *
from models.model_bce import ModelBCE
def test(path_to_images, path_output_maps, model_to_test=None):
list_img_files = [k.split('/')[-1].split('.')[0] for k in glob.glob(os.path.join(path_to_images, '*'))]
# Load Data
list_img_files.sort()
for curr_file in tqdm(list_img_files, ncols=20):
print os.path.join(path_to_images, curr_file + '.jpg')
img = cv2.cvtColor(cv2.imread(os.path.join(path_to_images, curr_file + '.jpg'), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
predict(model=model_to_test, image_stimuli=img, name=curr_file, path_output_maps=path_output_maps)
def main():
# Create network
model = ModelBCE(INPUT_SIZE[0], INPUT_SIZE[1], batch_size=8)
# Here need to specify the epoch of model sanpshot
load_weights(model.net['output'], path='gen_', epochtoload=90)
# Here need to specify the path to images and output path
test(path_to_images='../images/', path_output_maps='../saliency/', model_to_test=model)
if __name__ == "__main__":
main()