-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsal360iqa_main.py
114 lines (86 loc) · 3.23 KB
/
sal360iqa_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import numpy as np
import tensorflow as tf
import random as python_random
from scipy.stats import pearsonr, spearmanr
import os
import argparse
import pandas as pd
import csv
import model
import dataset
from sklearn.model_selection import train_test_split
os.environ['PYTHONHASHSEED'] = '0'
os.environ['TF_DETERMINISTIC_OPS'] = '1'
print(tf.__version__)
SEED = 123
np.random.seed(SEED)
python_random.seed(SEED)
tf.random.set_seed(SEED)
NBR_PATCHES = 64
OIQA_IMG = 320
def select_gpu(id_gpu):
gpus = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[int(id_gpu)], 'GPU')
if __name__ == '__main__':
EPOCHS = 100 # could be more or less
parser = argparse.ArgumentParser()
parser.add_argument("-bs", "--bs",
help="batch size.", type=int)
parser.add_argument("-db", "--db",
help="database name.")
parser.add_argument("-gpu", "--gpu_id",
help="GPU ID to be used.")
parser.add_argument("-val", "--val", type=int,
help="Validation percentage.")
parser.add_argument("-norm", "--norm", type=int,
help="Wether to use normalization or not.")
parser.add_argument("-loss", "--ls",
help="The loss function")
args = parser.parse_args()
batch_size = args.bs
database = args.db
gpu_id = args.gpu_id
val = float(args.val/10)
normalization = args.norm
loss = args.ls
select_gpu(gpu_id)
out_p = os.path.join('Results/', database)
os.makedirs(out_p, exist_ok=True)
opt = tf.keras.optimizers.Adam(learning_rate=1e-3, decay=1e-4 / EPOCHS)
if loss == 'huber':
ls_func = tf.keras.losses.Huber()
if loss == 'mse':
ls_func = tf.keras.losses.MeanSquaredError()
if loss == 'mae':
ls_func = tf.keras.losses.MeanAbsoluteError()
if normalization == 1:
inp_s = (128, 128, 1)
pre_norm = 'LCN'
norm = True
elif normalization == 0:
inp_s = (256, 256, 3)
pre_norm = 'RGB'
norm = False
# Read your data
data = dataset.Dataset()
patches, mos = data.get_input(
patches_data, patches_path, OIQA_IMG, NBR_PATCHES, norm)
# Split into training and testing sets (train_x, train_y) (test_x, test_y)
# For the sake of illustration, you can use the train_test_split from the sklearn lib
train_x, test_x, train_y, test_y = train_test_split(
patches, mos, test_size=0.2, random_state=42)
out_dim = 1
sal360iqa = model.Sal360Model()
iqa_model = Sal360Model.build_model(inp_s, out_dim)
print('[INFO] Compiling the model...')
iqa_model.compile(loss=ls_func,
optimizer=opt, metrics=tf.keras.metrics.RootMeanSquaredError(name='rmse'))
cb = model.create_callbacks_fun(
1, out_p, batch_size, pre_norm)
print('[INFO] Training the model...')
iqa_model.fit(x=train_x, y=train_y, validation_split=val,
epochs=EPOCHS, batch_size=batch_size, callbacks=cb, shuffle=True)
preds = iqa_model.predict(test_x, batch_size=batch_size)
plcc = pearsonr(preds, test_y)
srocc = spearmanr(preds, test_y)
print(f'PLCC = {plcc[0]}, SRCC = {srocc[0]}')