Skip to content

Commit 7d4b77c

Browse files
committed
More transform functions added
1 parent a154953 commit 7d4b77c

7 files changed

+74
-10
lines changed

Dataset/MattingDataset.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,18 @@
88
from PIL import Image, ImageFilter, ImageChops
99

1010
class MattingDataset(data.Dataset):
11-
def __init__(self, fgDir, bgDir, alphaDir, allTransform):
11+
def __init__(self, fgDir, bgDir, alphaDir, allTransform, imageTransforms):
1212
self.fgDir = fgDir
1313
self.bgDir = bgDir
1414
self.alphaDir = alphaDir
1515

1616
self.foregroundImageNames = os.listdir(self.fgDir)
1717
self.backgroundImageNames = os.listdir(self.bgDir)
1818
random.shuffle(self.backgroundImageNames) #TODO: Remove
19-
self.backgroundImageNames = self.backgroundImageNames[:10] #TODO: Remove
19+
self.backgroundImageNames = self.backgroundImageNames[:12] #TODO: Remove
2020
self.alphaImageNames = os.listdir(self.alphaDir)
2121
random.shuffle(self.alphaImageNames) #TODO: Remove
22-
self.alphaImageNames = self.alphaImageNames[:20] #TODO:Remove
22+
self.alphaImageNames = self.alphaImageNames[:22] #TODO:Remove
2323

2424
self.numForeground = len(self.foregroundImageNames)
2525
self.numBackground = len(self.backgroundImageNames)
@@ -31,6 +31,7 @@ def __init__(self, fgDir, bgDir, alphaDir, allTransform):
3131
self.imageBackgroundPair = sorted(self.imageBackgroundPair, key=lambda x: x[0])
3232

3333
self.allTransform = allTransform
34+
self.imageTransform = imageTransforms
3435

3536
# assert len(self.imageBackgroundPair) == len(self) #TODO: Remove
3637

@@ -50,6 +51,9 @@ def __getitem__(self, idx):
5051
compositeImage = self.composite_image(foregroundImage, backgroundImage, alphaMask)
5152

5253
assert compositeImage.size == trimap.size, f"composite size = {compositeImage.size} and trimap = {trimap.size} and foreground size = {foregroundImage.size}"
54+
55+
if self.imageTransform:
56+
compositeImage = self.imageTransform(compositeImage)
5357

5458
if self.allTransform:
5559
compositeImage, trimap, alphaMask = self.allTransform((compositeImage, trimap, alphaMask))

TrainLoss200Items.png

405 Bytes
Loading

TrainLoss231Items.png

21.2 KB
Loading

TrainLoss264Items.png

23.5 KB
Loading

dataset_transforms.py

+47-1
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,54 @@
22
import random
33
import torch
44
import torchvision.transforms.functional as TF
5-
from PIL import Image
5+
from PIL import Image, ImageFilter
66

7+
class RandomBlur(object):
8+
def __init__(self, probability=0.5):
9+
self.p = probability
10+
11+
def __call__(self, image):
12+
if random.random() < self.p:
13+
return image.filter(ImageFilter.GaussianBlur(radius=2))
14+
return image
15+
16+
class RandomRotation(object):
17+
def __init__(self, probability=0.5, angle=45):
18+
self.p = probability
19+
self.angle = angle
20+
21+
def __call__(self, items):
22+
image, trimap, mask = items
23+
angle = random.randint(-self.angle, self.angle)
24+
if random.random() < self.p:
25+
image = TF.rotate(image, angle)
26+
trimap = TF.rotate(trimap, angle)
27+
mask = TF.rotate(mask, angle)
28+
return image, trimap, mask
29+
30+
class RandomVerticalFlip(object):
31+
def __init__(self, probability=0.5):
32+
self.p = probability
33+
34+
def __call__(self, items):
35+
image, trimap, mask = items
36+
if random.random() < self.p:
37+
image = TF.vflip(image)
38+
trimap = TF.vflip(trimap)
39+
mask = TF.vflip(mask)
40+
return image, trimap, mask
41+
42+
class RandomHorizontalFlip(object):
43+
def __init__(self, probability=0.5):
44+
self.p = probability
45+
46+
def __call__(self, items):
47+
image, trimap, mask = items
48+
if random.random() < self.p:
49+
image = TF.hflip(image)
50+
trimap = TF.hflip(trimap)
51+
mask = TF.hflip(mask)
52+
return image, trimap, mask
753

854
class ToTensor(object):
955
def __call__(self, items):

model.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def forward(self, x):
3838
"""
3939
x = self.encoder(x)
4040
x = self.decoder(x)
41-
4241
return x
4342

4443

@@ -51,9 +50,13 @@ def __init__(self):
5150
"""
5251
self.encoderBlocks = nn.Sequential(
5352
convBatchNormReLU(4, 64, 3),
53+
convBatchNormReLU(64, 64, 1, pad=0, stride=1),
5454
convBatchNormReLU(64, 128, 3),
55+
convBatchNormReLU(128, 128, 1, pad=0, stride=1),
5556
convBatchNormReLU(128, 256, 3),
56-
convBatchNormReLU(256, 512, 3)
57+
convBatchNormReLU(256, 256, 1, pad=0, stride=1),
58+
convBatchNormReLU(256, 512, 3),
59+
convBatchNormReLU(512, 512, 1, pad=0, stride=1),
5760
)
5861

5962
def forward(self, x):

training.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from PIL import Image
99
from Dataset.MattingDataset import MattingDataset
1010
from model import EncoderDecoderNet, RefinementNet
11-
from dataset_transforms import RandomTrimapCrop, Resize, ToTensor
11+
from dataset_transforms import RandomTrimapCrop, Resize, ToTensor, RandomHorizontalFlip, RandomRotation, RandomVerticalFlip, RandomBlur
1212
from loss import alpha_prediction_loss, compositional_loss
1313

1414

@@ -45,21 +45,30 @@ def batch_collate_fn(batch):
4545
_TRAIN_ALPHA_DIR_ = "./Dataset/Training_set/CombinedAlpha"
4646
_NETWORK_INPUT_ = (320,320)
4747
_COMPUTE_DEVICE_ = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48-
_NUM_EPOCHS_ = 30 #200 #TODO: Remove
48+
_NUM_EPOCHS_ = 60 #200 #TODO: Remove
4949
_BATCH_SIZE_ = 8 #TODO: Increase this if using a GPU
5050
_NUM_WORKERS_ = multiprocessing.cpu_count()
5151
_LOSS_WEIGHT_ = 0.4 #0.5
5252
_GRADIENT_CLIP_ = 2.5
5353

5454
tripleTransforms = transforms.Compose([
55+
RandomRotation(probability=0.5, angle=180),
56+
RandomVerticalFlip(probability=0.5),
57+
RandomHorizontalFlip(probability=0.5),
5558
RandomTrimapCrop([(320, 320), (480, 480), (640, 640)], probability=0.7),
5659
Resize(_NETWORK_INPUT_),
5760
ToTensor()
5861
])
5962

63+
imageTransforms = transforms.Compose([
64+
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25),
65+
transforms.RandomGrayscale(p=0.3),
66+
RandomBlur(probability=0.1)
67+
])
68+
6069
trainingDataset = MattingDataset(
6170
_TRAIN_FOREGROUND_DIR_, _TRAIN_BACKGROUND_DIR_, _TRAIN_ALPHA_DIR_,
62-
allTransform=tripleTransforms
71+
allTransform=tripleTransforms, imageTransforms=imageTransforms
6372
)
6473
trainDataloader = torch.utils.data.DataLoader(
6574
trainingDataset, batch_size=_BATCH_SIZE_, shuffle=True, num_workers=_NUM_WORKERS_, collate_fn=batch_collate_fn)
@@ -142,9 +151,11 @@ def batch_collate_fn(batch):
142151
plt.title(f"Training loss using a dataset of {len(trainingDataset)} images")
143152
plt.savefig(f"TrainLoss{len(trainingDataset)}Items.png")
144153

154+
trainingElapsed = time.time() - trainStart
155+
print(f"\nTotal training time is {trainingElapsed//60:.0f}m {trainingElapsed%60:.0f}s")
145156
#Make a sample prediction
146157
idx = random.choice(range(0, len(trainingDataset)))
147-
img_, trimap, gMasks = trainingDataset[0]
158+
img_, trimap, gMasks = trainingDataset[idx]
148159
trimap = trimap.unsqueeze(0)
149160
gMasks = gMasks.unsqueeze(0)
150161
img = torch.cat([img_, trimap], 0).unsqueeze(0)

0 commit comments

Comments
 (0)