Skip to content

Commit c4e7d39

Browse files
committed
Sum absolute difference and mean squared error metrics added
1 parent 0eba0ff commit c4e7d39

6 files changed

+48
-9
lines changed

Dataset/MattingDataset.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ def __init__(self, fgDir, bgDir, alphaDir, allTransform, imageTransforms):
1616
self.foregroundImageNames = os.listdir(self.fgDir)
1717
self.backgroundImageNames = os.listdir(self.bgDir)
1818
random.shuffle(self.backgroundImageNames) #TODO: Remove
19-
self.backgroundImageNames = self.backgroundImageNames[:3] #TODO: Remove 13
19+
self.backgroundImageNames = self.backgroundImageNames[:4] #TODO: Remove 13
2020
self.alphaImageNames = os.listdir(self.alphaDir)
2121
random.shuffle(self.alphaImageNames) #TODO: Remove
22-
self.alphaImageNames = self.alphaImageNames[:3] #TODO:Remove 23
22+
self.alphaImageNames = self.alphaImageNames[:4] #TODO:Remove 23
2323

2424
self.numForeground = len(self.foregroundImageNames)
2525
self.numBackground = len(self.backgroundImageNames)

TrainLoss16Items.png

19.4 KB
Loading

TrainLoss25Items.png

21.5 KB
Loading

dataset_transforms.py

+14
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,20 @@
44
import torchvision.transforms.functional as TF
55
from PIL import Image, ImageFilter
66

7+
class RandomAffine(object):
8+
def __init__(self, probability=0.5):
9+
self.p = probability
10+
11+
def __call__(self, items):
12+
image, trimap, mask = items
13+
if random.random() < self.p:
14+
angle = random.randint(-180, 180)
15+
image = TF.affine(image, angle, translate=[0,0], scale=1.0, shear=0, resample=Image.BICUBIC)
16+
# use nearest so the values of the trimap and alpha mask are not changed
17+
trimap = TF.affine(trimap, angle, translate=[0,0], scale=1.0, shear=0, resample=Image.NEAREST)
18+
mask = TF.affine(mask, angle, translate=[0,0], scale=1.0, shear=0, resample=Image.NEAREST)
19+
return image, trimap, mask
20+
721
class RandomBlur(object):
822
def __init__(self, probability=0.5):
923
self.p = probability

loss.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,22 @@ def show(xf):
6161
sumTrueForeground = trueForeground.sum(dim=[2,3]) + eps
6262
totalLoss = rootDiff.sum(dim=[2,3]) / sumTrueForeground
6363
avgLoss = totalLoss.mean().mean() # average over the RGB channels and also across the batch
64-
return avgLoss
64+
return avgLoss
65+
66+
67+
def sum_absolute_difference(trueAlpha, predAlpha):
68+
"""
69+
calculates the sum of absolute differences between images and predictions in batches
70+
As the calculation is done over a batch, the mean is used to reduce the results
71+
"""
72+
difference = predAlpha - trueAlpha
73+
avgDiff = difference.sum(dim=[1,2]).mean()
74+
return avgDiff
75+
76+
def mean_squared_error(trueAlpha, predAlpha, compositeImage):
77+
trimaps = compositeImage[:,3,:] * 255
78+
blackMask = torch.zeros_like(trueAlpha)
79+
unknownRegions = torch.where(trimaps == 127, trueAlpha, blackMask)
80+
mse = torch.pow(predAlpha - trueAlpha, 2).sum() / unknownRegions.sum()
81+
82+
return mse

training.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from PIL import Image
99
from Dataset.MattingDataset import MattingDataset
1010
from model import EncoderDecoderNet, RefinementNet
11-
from dataset_transforms import RandomTrimapCrop, Resize, ToTensor, RandomHorizontalFlip, RandomRotation, RandomVerticalFlip, RandomBlur
12-
from loss import alpha_prediction_loss, compositional_loss
11+
from dataset_transforms import RandomTrimapCrop, Resize, ToTensor, RandomHorizontalFlip, RandomRotation, RandomVerticalFlip, RandomBlur, RandomAffine
12+
from loss import alpha_prediction_loss, compositional_loss, sum_absolute_difference, mean_squared_error
1313

1414

1515

@@ -45,7 +45,7 @@ def batch_collate_fn(batch):
4545
_TRAIN_ALPHA_DIR_ = "./Dataset/Training_set/CombinedAlpha"
4646
_NETWORK_INPUT_ = (320,320)
4747
_COMPUTE_DEVICE_ = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48-
_NUM_EPOCHS_ = 8 #200 #TODO: Remove 90
48+
_NUM_EPOCHS_ = 30 #200 #TODO: Remove 90
4949
_BATCH_SIZE_ = 8 #TODO: Increase this if using a GPU
5050
_NUM_WORKERS_ = multiprocessing.cpu_count()
5151
_LOSS_WEIGHT_ = 0.6
@@ -68,7 +68,7 @@ def batch_collate_fn(batch):
6868

6969
trainingDataset = MattingDataset(
7070
_TRAIN_FOREGROUND_DIR_, _TRAIN_BACKGROUND_DIR_, _TRAIN_ALPHA_DIR_,
71-
allTransform=tripleTransforms, imageTransforms=imageTransforms
71+
allTransform=tripleTransforms, imageTransforms=None
7272
)
7373
trainDataloader = torch.utils.data.DataLoader(
7474
trainingDataset, batch_size=_BATCH_SIZE_, shuffle=True, num_workers=_NUM_WORKERS_, collate_fn=batch_collate_fn)
@@ -115,10 +115,12 @@ def batch_collate_fn(batch):
115115
modelAlphaLoss = alpha_prediction_loss(predictedMasks, groundTruthMasks)
116116
refinedAlphaLoss = alpha_prediction_loss(refinedMasks, groundTruthMasks)
117117
lossAlpha = modelAlphaLoss + refinedAlphaLoss
118-
# lossComposition = compositional_loss(predictedMasks, groundTruthMasks, compositeImages)
119-
lossComposition = compositional_loss(refinedMasks, groundTruthMasks, compositeImages)
118+
lossComposition = compositional_loss(predictedMasks, groundTruthMasks, compositeImages)
120119
totalLoss = _LOSS_WEIGHT_ * lossAlpha + (1 - _LOSS_WEIGHT_) * lossComposition
121120
epochLoss += totalLoss.item()
121+
with torch.no_grad():
122+
sad = sum_absolute_difference(groundTruthMasks, refinedMasks)
123+
mse = mean_squared_error(groundTruthMasks, refinedMasks, compositeImages)
122124

123125
if idx % 100 == 0:
124126
print(f"\tIteration {idx+1}/{len(trainingDataset)}")
@@ -128,6 +130,11 @@ def batch_collate_fn(batch):
128130
print(f"\t Alpha loss = {lossAlpha}")
129131
print(f"\t Composition loss = {lossComposition}")
130132
print(f"\t Total Loss = {totalLoss}")
133+
print(f"\t {'***' * 5}")
134+
print(f"\t Metrics:")
135+
print(f"\t {'***' * 5}")
136+
print(f"\t Sum absolute difference: {sad}")
137+
print(f"\t Mean Squared Error: {mse}")
131138
print()
132139

133140
optimiser.zero_grad()

0 commit comments

Comments
 (0)