Skip to content

Commit 7b95524

Browse files
adding validation script
1 parent 1b7c139 commit 7b95524

File tree

3 files changed

+114
-20
lines changed

3 files changed

+114
-20
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ opencv-python==3.4.8.29
66
pillow
77
wandb
88
protobuf==3.20.*
9-
elasticdeform==0.5.1
9+
elasticdeform==0.5.1

train.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,26 +23,12 @@
2323
from data import create_dataset
2424
from models import create_model
2525
from util.visualizer import Visualizer
26-
#from torchmetrics import StructuralSimilarityIndexMeasure
26+
from util.validation import validation
2727
import torch
2828
import wandb
2929
import copy
3030

3131

32-
def validate(val_set, model):
33-
#ssim = StructuralSimilarityIndexMeasure(data_range=1.0)
34-
metric = torch.nn.L1Loss()
35-
model.eval()
36-
errors = 0
37-
for i, data in enumerate(val_set): # inner loop within one epoch
38-
model.set_input(data) # unpack data from dataset and apply preprocessing
39-
model.test() # calculate loss functions, get gradients, update network weights
40-
visuals = model.get_current_visuals()
41-
real = visuals['real_B']
42-
pred = visuals['fake_B']
43-
errors += metric(pred.cpu(), real.cpu())
44-
return (errors/len(val_set)).item()
45-
4632
if __name__ == '__main__':
4733
opt = TrainOptions().parse() # get training options
4834
wandb.init(project="testing-maskgan", name=opt.name)
@@ -97,14 +83,18 @@ def validate(val_set, model):
9783
iter_data_time = time.time()
9884

9985
if epoch % opt.save_epoch_freq == 0: # cache our model every <save_epoch_freq> epochs
100-
perf = validate(val_dataset, model)
101-
print('saving the model at the end of epoch %d, iters %d, MAE %d' % (epoch, total_iters, perf))
86+
perf = validation(val_dataset, model, val_opt)
87+
if not opt.wdb_disabled:
88+
metrics_val = {"val/MAE_fake": perf[0], "val/MSE_fake": perf[1], "val/SSIM_fake" : perf[2], "val/PSNR_fake": perf[3]}
89+
# Send metrics to WANDB
90+
wandb.log(metrics_val)
91+
print('saving the model at the end of epoch %d, iters %d, MAE %d' % (epoch, total_iters, perf[0]))
10292
model.save_networks('latest')
10393
model.save_networks(epoch)
104-
if best > perf:
94+
if best > perf[0]:
10595
print(f"Best Model with MAE={best}")
10696
model.save_networks('best')
107-
best = perf
97+
best = perf[0]
10898

10999
print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
110100
model.update_learning_rate() # update learning rates at the end of every epoch.

util/validation.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
2+
import PIL
3+
import torch
4+
import numpy as np
5+
import wandb
6+
from torchmetrics.image import PeakSignalNoiseRatio
7+
import monai
8+
9+
def val_visualizations_over_batches(real_a,
10+
real_b,
11+
fake_b):
12+
"""We define a function to visualizate val images of sets
13+
real_A, real_B and fake_B
14+
"""
15+
results = []
16+
# Save real A
17+
real_a = real_a.cpu().numpy()
18+
real_a = real_a[4:8,:,:,:]
19+
real_a = np.concatenate(real_a, axis=1)
20+
# Save real B
21+
real_b = real_b.cpu().numpy()
22+
real_b = real_b[4:8,:,:,:]
23+
real_b = np.concatenate(real_b, axis=1)
24+
# Save prediction
25+
fake_b = fake_b.cpu().numpy()
26+
fake_b = fake_b[4:8,:,:,:]
27+
fake_b = np.concatenate(fake_b, axis=1)
28+
29+
# Append results to a final output
30+
results.append(real_a)
31+
results.append(real_b)
32+
results.append(fake_b)
33+
34+
# Transforming results to tuple
35+
return tuple(results)
36+
37+
def validation(val_set, model, opt):
38+
#ssim = StructuralSimilarityIndexMeasure(data_range=1.0)
39+
# Getting MAE
40+
metric_mae = torch.nn.L1Loss()
41+
# Getting MSE
42+
metric_mse = torch.nn.MSELoss()
43+
# Getting SSIM
44+
metric_ssim = monai.metrics.SSIMMetric(spatial_dims=2, reduction = 'mean')
45+
# Getting PSNR
46+
metric_psnr = PeakSignalNoiseRatio()
47+
48+
model.eval()
49+
# Set zero
50+
mae_fake = 0.0
51+
mse_fake = 0.0
52+
ssim_fake = 0.0
53+
psnr_fake = 0.0
54+
55+
# Check number of total batches for validation set
56+
batches = len(val_set.dataloader)
57+
# Select values which represent 20th, 40th, 60th, and 80th percentiles
58+
percentile_values = np.percentile(np.arange(0,batches + 1), [20,40,60,80]).astype(int)
59+
60+
for i, data in enumerate(val_set): # inner loop within one epoch
61+
62+
model.set_input(data) # unpack data from dataset and apply preprocessing
63+
model.test() # calculate loss functions, get gradients, update network weights
64+
visuals = model.get_current_visuals()
65+
real_A = visuals['real_A']
66+
real_B = visuals['real_B']
67+
fake_B = visuals['fake_B']
68+
69+
# Get metrics comparing fake B with real B
70+
mae_fake += metric_mae(fake_B.cpu(), real_B.cpu())
71+
mse_fake += metric_mse(fake_B.cpu(), real_B.cpu())
72+
ssim_fake += metric_ssim(fake_B.cpu(), real_B.cpu()).mean()
73+
psnr_fake += metric_psnr(fake_B.cpu(), real_B.cpu())
74+
75+
# Create visualizations
76+
if not opt.wdb_disabled:
77+
if i == percentile_values[0]:
78+
imgA_wb, imgB_wb, fakeB_wb = val_visualizations_over_batches(real_A,real_B,fake_B)
79+
elif i == percentile_values[1] or i == percentile_values[2]:
80+
imgA2_wb, imgB2_wb, fakeB2_wb = val_visualizations_over_batches(real_A,real_B,fake_B)
81+
imgA_wb = np.concatenate((imgA_wb, imgA2_wb), axis=2)
82+
imgB_wb = np.concatenate((imgB_wb, imgB2_wb), axis=2)
83+
fakeB_wb = np.concatenate((fakeB_wb, fakeB2_wb), axis=2)
84+
elif i == percentile_values[3]:
85+
imgA2_wb, imgB2_wb, fakeB2_wb = val_visualizations_over_batches(real_A,real_B,fake_B)
86+
imgA_wb = np.concatenate((imgA_wb, imgA2_wb), axis=2)
87+
imgA_wb = ((imgA_wb + 1) * 127.5).astype(np.uint8)
88+
imgA_wb = PIL.Image.fromarray(np.squeeze(imgA_wb))
89+
imgA_wb = imgA_wb.convert("L")
90+
imgB_wb = np.concatenate((imgB_wb, imgB2_wb), axis=2)
91+
imgB_wb = ((imgB_wb + 1) * 127.5).astype(np.uint8)
92+
imgB_wb = PIL.Image.fromarray(np.squeeze(imgB_wb))
93+
imgB_wb = imgB_wb.convert("L")
94+
fakeB_wb = np.concatenate((fakeB_wb, fakeB2_wb), axis=2)
95+
fakeB_wb = ((fakeB_wb + 1) * 127.5).astype(np.uint8)
96+
fakeB_wb = PIL.Image.fromarray(np.squeeze(fakeB_wb))
97+
fakeB_wb = fakeB_wb.convert("L")
98+
99+
# Send data to Wandb
100+
if not opt.wdb_disabled:
101+
wandb.log({"val/examples": [wandb.Image(imgA_wb, caption="realA"),wandb.Image(imgB_wb, caption="realB"),wandb.Image(fakeB_wb, caption="fakeB")]})
102+
103+
# Return metrics for comparison B to A and B to B_hat
104+
return (mae_fake/batches).cpu().numpy(), (mse_fake/batches).cpu().numpy(), (ssim_fake/batches).cpu().numpy(), (psnr_fake/batches).cpu().numpy()

0 commit comments

Comments
 (0)