-
Notifications
You must be signed in to change notification settings - Fork 8
/
generator.py
121 lines (100 loc) · 3.72 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
###########################
# Generator: Resnet
###########################
# To control feature map in generator
ngf = 64
class GeneratorResnet(nn.Module):
def __init__(self, inception = False):
'''
:param inception: if True crop layer will be added to go from 3x300x300 t0 3x299x299.
'''
super(GeneratorResnet, self).__init__()
# Input_size = 3, n, n
self.inception = inception
self.block1 = nn.Sequential(
nn.ReflectionPad2d(3),
nn.Conv2d(3, ngf, kernel_size=7, padding=0, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True)
)
# Input size = 3, n, n
self.block2 = nn.Sequential(
nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True)
)
# Input size = 3, n/2, n/2
self.block3 = nn.Sequential(
nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True)
)
# Input size = 3, n/4, n/4
# Residual Blocks: 6
self.resblock1 = ResidualBlock(ngf * 4)
self.resblock2 = ResidualBlock(ngf * 4)
self.resblock3 = ResidualBlock(ngf * 4)
self.resblock4 = ResidualBlock(ngf * 4)
self.resblock5 = ResidualBlock(ngf * 4)
self.resblock6 = ResidualBlock(ngf * 4)
# Input size = 3, n/4, n/4
self.upsampl1 = nn.Sequential(
nn.ConvTranspose2d(ngf * 4, ngf * 2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True)
)
# Input size = 3, n/2, n/2
self.upsampl2 = nn.Sequential(
nn.ConvTranspose2d(ngf * 2, ngf, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True)
)
# Input size = 3, n, n
self.blockf = nn.Sequential(
nn.ReflectionPad2d(3),
nn.Conv2d(ngf, 3, kernel_size=7, padding=0)
)
self.crop = nn.ConstantPad2d((0, -1, -1, 0), 0)
def forward(self, input):
x = self.block1(input)
x = self.block2(x)
x = self.block3(x)
x = self.resblock1(x)
x = self.resblock2(x)
x = self.resblock3(x)
x = self.resblock4(x)
x = self.resblock5(x)
x = self.resblock6(x)
x = self.upsampl1(x)
x = self.upsampl2(x)
x = self.blockf(x)
if self.inception:
x = self.crop(x)
return (torch.tanh(x) + 1) / 2 # Output range [0 1]
class ResidualBlock(nn.Module):
def __init__(self, num_filters):
super(ResidualBlock, self).__init__()
self.block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(in_channels=num_filters, out_channels=num_filters, kernel_size=3, stride=1, padding=0,
bias=False),
nn.BatchNorm2d(num_filters),
nn.ReLU(True),
nn.Dropout(0.5),
nn.ReflectionPad2d(1),
nn.Conv2d(in_channels=num_filters, out_channels=num_filters, kernel_size=3, stride=1, padding=0,
bias=False),
nn.BatchNorm2d(num_filters)
)
def forward(self, x):
residual = self.block(x)
return x + residual
if __name__ == '__main__':
netG = GeneratorResnet(data_dim='low')
test_sample = torch.rand(1, 3, 32, 32)
print('Generator output:', netG(test_sample).size())
print('Generator parameters:', sum(p.numel() for p in netG.parameters() if p.requires_grad))