This is an attempt to try to implement the Re-Aging GANs (the face filter type).
The main framework for this implementation is PyTorch
.
- Lifespan Age Transformation Synthesis (LATS)
- PFA-GAN: Progressive Face Aging with Generative Adversarial Network
- Re-Aging GAN: Toward Personalized Face Age Transformation
- Age Gap Reducer-GAN for Recognizing Age-Separated Faces
- The discriminator is basically a Pix2Pix disciminator, which follows a
PatchGAN
type structure. - Inorder to make the model be aware of the input and the output ages, what we do is build an
Embedding
layer for the input and output ages where we encode each age group of the input dataset.- Here, we do a smaller version of the same, so we have 3 age categories. Although, we could extend that to a larger categorical size as well.
- If the image dimensions are
(BATCH_SIZE, 3, 224, 224)
, we build an Embedding layer with theembedding_dim
as$224^2$ . So that, we could eventually reshape the embedding to a(BATCH_SIZE, 1, 224, 224)
and concatenate to the image at the$1^{\text{st}}$ axis, leading to an image dimension of(BATCH_SIZE, 4, 224, 224)
. - We do that same for the output image and then do the same as we would do for the Pix2Pix, concatenate both the images and pass it through the model.
Discriminator(
(inital): Sequential(
(0): Conv2d(8, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): LeakyReLU(negative_slope=0.2)
)
(model): Sequential(
(0): ConvBlock(
(conv): Sequential(
(0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
)
(1): ConvBlock(
(conv): Sequential(
(0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
)
(2): ConvBlock(
(conv): Sequential(
(0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
)
(3): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
)
(input_embed): Embedding(3, 65536)
(output_embed): Embedding(3, 65536)
)
- The generator is also inspired from the Pix2Pix, which inspired from the U-Net: Convolutional Networks for Biomedical Image Segmentation.
- I do the exact same thing as what I did to the discriminator. But, there is a small difference. I concatenate both the input and the output age with the input image. Which would result in an image shape of
(BATCH_SIZE, 5, 224, 224)
. - Then I simply pass it through the U-Net architecture.
Generator(
(input_embed): Embedding(3, 65536)
(output_embed): Embedding(3, 65536)
(init_down): ConvBlock(
(conv): Sequential(
(0): Conv2d(5, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
)
(down1): ConvBlock(
(conv): Sequential(
(0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
)
(down2): ConvBlock(
(conv): Sequential(
(0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
)
(down3): ConvBlock(
(conv): Sequential(
(0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
)
(down4): ConvBlock(
(conv): Sequential(
(0): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
)
(down5): ConvBlock(
(conv): Sequential(
(0): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
)
(down6): ConvBlock(
(conv): Sequential(
(0): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
)
(bottle_neck): Sequential(
(0): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
(1): LeakyReLU(negative_slope=0.2)
)
(up1): TransposeConvBlock(
(tran_conv): Sequential(
(0): ConvTranspose2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
)
(up2): TransposeConvBlock(
(tran_conv): Sequential(
(0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
)
(up3): TransposeConvBlock(
(tran_conv): Sequential(
(0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
)
(up4): TransposeConvBlock(
(tran_conv): Sequential(
(0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
)
(up5): TransposeConvBlock(
(tran_conv): Sequential(
(0): ConvTranspose2d(1024, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
)
(up6): TransposeConvBlock(
(tran_conv): Sequential(
(0): ConvTranspose2d(512, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
)
(up7): TransposeConvBlock(
(tran_conv): Sequential(
(0): ConvTranspose2d(256, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
)
(final_up): Sequential(
(0): ConvTranspose2d(128, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): Tanh()
)
)
The framework operates on three input information, an input image
This loss monitors the case where the input and output ages are the same. Ideally, we expect the model to give the same image as the output which is monitored by the L1-Loss.
$$\mathcal{L}{rec}(G) = |x-x{rec}|_1$$
We have the age transformed image and let's say that the new input age is the ouptut age and the new output age is the input age. In such a case as well, we would want the output image to be as similar to the input image. This is also determined by the L1-Loss.
$$\mathcal{L}{cyc}(G) = |x-x{cycle}|_1$$
In general, GANs follow a Zero Sum Min-Max problem, so we use the standard GAN loss as well.
$$\mathcal{L}{adv}(G,D)=\mathbb{E}{x,y}[\log D_y(x)] + \mathbb{E}{x,y'}[\log(1-D{y'}(x'))]$$