|
| 1 | + |
| 2 | +import PIL |
| 3 | +import torch |
| 4 | +import numpy as np |
| 5 | +import wandb |
| 6 | +from torchmetrics.image import PeakSignalNoiseRatio |
| 7 | +import monai |
| 8 | + |
| 9 | +def val_visualizations_over_batches(real_a, |
| 10 | + real_b, |
| 11 | + fake_b): |
| 12 | + """We define a function to visualizate val images of sets |
| 13 | + real_A, real_B and fake_B |
| 14 | + """ |
| 15 | + results = [] |
| 16 | + # Save real A |
| 17 | + real_a = real_a.cpu().numpy() |
| 18 | + real_a = real_a[4:8,:,:,:] |
| 19 | + real_a = np.concatenate(real_a, axis=1) |
| 20 | + # Save real B |
| 21 | + real_b = real_b.cpu().numpy() |
| 22 | + real_b = real_b[4:8,:,:,:] |
| 23 | + real_b = np.concatenate(real_b, axis=1) |
| 24 | + # Save prediction |
| 25 | + fake_b = fake_b.cpu().numpy() |
| 26 | + fake_b = fake_b[4:8,:,:,:] |
| 27 | + fake_b = np.concatenate(fake_b, axis=1) |
| 28 | + |
| 29 | + # Append results to a final output |
| 30 | + results.append(real_a) |
| 31 | + results.append(real_b) |
| 32 | + results.append(fake_b) |
| 33 | + |
| 34 | + # Transforming results to tuple |
| 35 | + return tuple(results) |
| 36 | + |
| 37 | +def validation(val_set, model, opt): |
| 38 | + #ssim = StructuralSimilarityIndexMeasure(data_range=1.0) |
| 39 | + # Getting MAE |
| 40 | + metric_mae = torch.nn.L1Loss() |
| 41 | + # Getting MSE |
| 42 | + metric_mse = torch.nn.MSELoss() |
| 43 | + # Getting SSIM |
| 44 | + metric_ssim = monai.metrics.SSIMMetric(spatial_dims=2, reduction = 'mean') |
| 45 | + # Getting PSNR |
| 46 | + metric_psnr = PeakSignalNoiseRatio() |
| 47 | + |
| 48 | + model.eval() |
| 49 | + # Set zero |
| 50 | + mae_fake = 0.0 |
| 51 | + mse_fake = 0.0 |
| 52 | + ssim_fake = 0.0 |
| 53 | + psnr_fake = 0.0 |
| 54 | + |
| 55 | + # Check number of total batches for validation set |
| 56 | + batches = len(val_set.dataloader) |
| 57 | + # Select values which represent 20th, 40th, 60th, and 80th percentiles |
| 58 | + percentile_values = np.percentile(np.arange(0,batches + 1), [20,40,60,80]).astype(int) |
| 59 | + |
| 60 | + for i, data in enumerate(val_set): # inner loop within one epoch |
| 61 | + |
| 62 | + model.set_input(data) # unpack data from dataset and apply preprocessing |
| 63 | + model.test() # calculate loss functions, get gradients, update network weights |
| 64 | + visuals = model.get_current_visuals() |
| 65 | + real_A = visuals['real_A'] |
| 66 | + real_B = visuals['real_B'] |
| 67 | + fake_B = visuals['fake_B'] |
| 68 | + |
| 69 | + # Get metrics comparing fake B with real B |
| 70 | + mae_fake += metric_mae(fake_B.cpu(), real_B.cpu()) |
| 71 | + mse_fake += metric_mse(fake_B.cpu(), real_B.cpu()) |
| 72 | + ssim_fake += metric_ssim(fake_B.cpu(), real_B.cpu()).mean() |
| 73 | + psnr_fake += metric_psnr(fake_B.cpu(), real_B.cpu()) |
| 74 | + |
| 75 | + # Create visualizations |
| 76 | + if not opt.wdb_disabled: |
| 77 | + if i == percentile_values[0]: |
| 78 | + imgA_wb, imgB_wb, fakeB_wb = val_visualizations_over_batches(real_A,real_B,fake_B) |
| 79 | + elif i == percentile_values[1] or i == percentile_values[2]: |
| 80 | + imgA2_wb, imgB2_wb, fakeB2_wb = val_visualizations_over_batches(real_A,real_B,fake_B) |
| 81 | + imgA_wb = np.concatenate((imgA_wb, imgA2_wb), axis=2) |
| 82 | + imgB_wb = np.concatenate((imgB_wb, imgB2_wb), axis=2) |
| 83 | + fakeB_wb = np.concatenate((fakeB_wb, fakeB2_wb), axis=2) |
| 84 | + elif i == percentile_values[3]: |
| 85 | + imgA2_wb, imgB2_wb, fakeB2_wb = val_visualizations_over_batches(real_A,real_B,fake_B) |
| 86 | + imgA_wb = np.concatenate((imgA_wb, imgA2_wb), axis=2) |
| 87 | + imgA_wb = ((imgA_wb + 1) * 127.5).astype(np.uint8) |
| 88 | + imgA_wb = PIL.Image.fromarray(np.squeeze(imgA_wb)) |
| 89 | + imgA_wb = imgA_wb.convert("L") |
| 90 | + imgB_wb = np.concatenate((imgB_wb, imgB2_wb), axis=2) |
| 91 | + imgB_wb = ((imgB_wb + 1) * 127.5).astype(np.uint8) |
| 92 | + imgB_wb = PIL.Image.fromarray(np.squeeze(imgB_wb)) |
| 93 | + imgB_wb = imgB_wb.convert("L") |
| 94 | + fakeB_wb = np.concatenate((fakeB_wb, fakeB2_wb), axis=2) |
| 95 | + fakeB_wb = ((fakeB_wb + 1) * 127.5).astype(np.uint8) |
| 96 | + fakeB_wb = PIL.Image.fromarray(np.squeeze(fakeB_wb)) |
| 97 | + fakeB_wb = fakeB_wb.convert("L") |
| 98 | + |
| 99 | + # Send data to Wandb |
| 100 | + if not opt.wdb_disabled: |
| 101 | + wandb.log({"val/examples": [wandb.Image(imgA_wb, caption="realA"),wandb.Image(imgB_wb, caption="realB"),wandb.Image(fakeB_wb, caption="fakeB")]}) |
| 102 | + |
| 103 | + # Return metrics for comparison B to A and B to B_hat |
| 104 | + return (mae_fake/batches).cpu().numpy(), (mse_fake/batches).cpu().numpy(), (ssim_fake/batches).cpu().numpy(), (psnr_fake/batches).cpu().numpy() |
0 commit comments