8
8
from PIL import Image
9
9
from Dataset .MattingDataset import MattingDataset
10
10
from model import EncoderDecoderNet , RefinementNet
11
- from dataset_transforms import RandomTrimapCrop , Resize , ToTensor , RandomHorizontalFlip , RandomRotation , RandomVerticalFlip , RandomBlur
12
- from loss import alpha_prediction_loss , compositional_loss
11
+ from dataset_transforms import RandomTrimapCrop , Resize , ToTensor , RandomHorizontalFlip , RandomRotation , RandomVerticalFlip , RandomBlur , RandomAffine
12
+ from loss import alpha_prediction_loss , compositional_loss , sum_absolute_difference , mean_squared_error
13
13
14
14
15
15
@@ -45,7 +45,7 @@ def batch_collate_fn(batch):
45
45
_TRAIN_ALPHA_DIR_ = "./Dataset/Training_set/CombinedAlpha"
46
46
_NETWORK_INPUT_ = (320 ,320 )
47
47
_COMPUTE_DEVICE_ = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
48
- _NUM_EPOCHS_ = 8 #200 #TODO: Remove 90
48
+ _NUM_EPOCHS_ = 30 #200 #TODO: Remove 90
49
49
_BATCH_SIZE_ = 8 #TODO: Increase this if using a GPU
50
50
_NUM_WORKERS_ = multiprocessing .cpu_count ()
51
51
_LOSS_WEIGHT_ = 0.6
@@ -68,7 +68,7 @@ def batch_collate_fn(batch):
68
68
69
69
trainingDataset = MattingDataset (
70
70
_TRAIN_FOREGROUND_DIR_ , _TRAIN_BACKGROUND_DIR_ , _TRAIN_ALPHA_DIR_ ,
71
- allTransform = tripleTransforms , imageTransforms = imageTransforms
71
+ allTransform = tripleTransforms , imageTransforms = None
72
72
)
73
73
trainDataloader = torch .utils .data .DataLoader (
74
74
trainingDataset , batch_size = _BATCH_SIZE_ , shuffle = True , num_workers = _NUM_WORKERS_ , collate_fn = batch_collate_fn )
@@ -115,10 +115,12 @@ def batch_collate_fn(batch):
115
115
modelAlphaLoss = alpha_prediction_loss (predictedMasks , groundTruthMasks )
116
116
refinedAlphaLoss = alpha_prediction_loss (refinedMasks , groundTruthMasks )
117
117
lossAlpha = modelAlphaLoss + refinedAlphaLoss
118
- # lossComposition = compositional_loss(predictedMasks, groundTruthMasks, compositeImages)
119
- lossComposition = compositional_loss (refinedMasks , groundTruthMasks , compositeImages )
118
+ lossComposition = compositional_loss (predictedMasks , groundTruthMasks , compositeImages )
120
119
totalLoss = _LOSS_WEIGHT_ * lossAlpha + (1 - _LOSS_WEIGHT_ ) * lossComposition
121
120
epochLoss += totalLoss .item ()
121
+ with torch .no_grad ():
122
+ sad = sum_absolute_difference (groundTruthMasks , refinedMasks )
123
+ mse = mean_squared_error (groundTruthMasks , refinedMasks , compositeImages )
122
124
123
125
if idx % 100 == 0 :
124
126
print (f"\t Iteration { idx + 1 } /{ len (trainingDataset )} " )
@@ -128,6 +130,11 @@ def batch_collate_fn(batch):
128
130
print (f"\t Alpha loss = { lossAlpha } " )
129
131
print (f"\t Composition loss = { lossComposition } " )
130
132
print (f"\t Total Loss = { totalLoss } " )
133
+ print (f"\t { '***' * 5 } " )
134
+ print (f"\t Metrics:" )
135
+ print (f"\t { '***' * 5 } " )
136
+ print (f"\t Sum absolute difference: { sad } " )
137
+ print (f"\t Mean Squared Error: { mse } " )
131
138
print ()
132
139
133
140
optimiser .zero_grad ()
0 commit comments