-
Notifications
You must be signed in to change notification settings - Fork 1
/
r2c_gan.py
145 lines (114 loc) · 6.49 KB
/
r2c_gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import tensorflow as tf
import module
import utils
class r2c_gan:
def __init__(self):
self.G_A2B = None
self.G_B2A = None
self.D_A = None
self.D_B = None
self.d_loss_fn, self.g_loss_fn = utils.lsgan_loss()
self.cycle_loss_fn = tf.losses.MeanAbsoluteError()
self.identity_loss_fn = tf.losses.MeanAbsoluteError()
self.class_loss_fn = tf.losses.CategoricalCrossentropy()
self.G_lr_scheduler = None
self.D_lr_scheduler = None
self.G_optimizer = None
self.D_optimizer = None
self.cycle_weights = None
self.identity_weight = None
self.filter = None
def init(self, args, len_dataset):
self.filter = args['method']
self.cycle_weights = args['cycle_loss_weight']
self.identity_weight = args['identity_loss_weight']
self.G_lr_scheduler = module.LinearDecay(args['lr'], args['epochs'] * len_dataset, args['epoch_decay'] * len_dataset)
self.D_lr_scheduler = module.LinearDecay(args['lr'], args['epochs'] * len_dataset, args['epoch_decay'] * len_dataset)
self.G_optimizer = tf.keras.optimizers.Adam(learning_rate=self.G_lr_scheduler, beta_1=args['beta_1'])
self.D_optimizer = tf.keras.optimizers.Adam(learning_rate=self.D_lr_scheduler, beta_1=args['beta_1'])
# Creating models.
self.set_G_A2B(input_shape=(args['crop_size'], args['crop_size'], 3), q = args['q'])
self.set_G_B2A(input_shape=(args['crop_size'], args['crop_size'], 3), q = args['q'])
self.set_D_A(input_shape=(args['crop_size'], args['crop_size'], 3), q = args['q'])
self.set_D_B(input_shape=(args['crop_size'], args['crop_size'], 3), q = args['q'])
def set_G_A2B(self, input_shape, q):
if self.filter == 'operational':
self.G_A2B = module.OpGenerator(input_shape = input_shape, q = q)
elif self.filter == 'convolutional':
self.G_A2B = module.ConvGenerator(input_shape = input_shape)
elif self.filter == 'convolutional-light':
self.G_A2B = module.ConvCompGenerator(input_shape = input_shape)
else: print('Undefined filtering method!')
def set_G_B2A(self, input_shape, q):
if self.filter == 'operational':
self.G_B2A = module.OpGenerator(input_shape = input_shape, q = q)
elif self.filter == 'convolutional':
self.G_B2A = module.ConvGenerator(input_shape = input_shape)
elif self.filter == 'convolutional-light':
self.G_B2A = module.ConvCompGenerator(input_shape = input_shape)
else: print('Undefined filtering method!')
def set_D_A(self, input_shape, q):
if self.filter == 'operational':
self.D_A = module.OpDiscriminator(input_shape = input_shape, q = q)
elif self.filter == 'convolutional':
self.D_A = module.ConvDiscriminator(input_shape = input_shape)
elif self.filter == 'convolutional-light':
self.D_A = module.ConvCompDiscriminator(input_shape = input_shape)
else: print('Undefined filtering method!')
def set_D_B(self, input_shape, q):
if self.filter == 'operational':
self.D_B = module.OpDiscriminator(input_shape = input_shape, q = q)
elif self.filter == 'convolutional':
self.D_B = module.ConvDiscriminator(input_shape = input_shape)
elif self.filter == 'convolutional-light':
self.D_B = module.ConvCompDiscriminator(input_shape = input_shape)
else: print('Undefined filtering method!')
@tf.function
def train_G(self, A, B):
with tf.GradientTape() as t:
A2B, y_A2B = self.G_A2B(A[0], training=True) # label_A
B2A, y_B2A = self.G_B2A(B[0], training=True) # label_B
A2B2A, y_A2B2A = self.G_B2A(A2B, training=True) # label_A
B2A2B, y_B2A2B = self.G_A2B(B2A, training=True) # label_B
A2A, y_A2A = self.G_B2A(A[0], training=True) # label_A
B2B, y_B2B = self.G_A2B(B[0], training=True) # label_B
A2B_d_logits = self.D_B(A2B, training=True)
B2A_d_logits = self.D_A(B2A, training=True)
A2B_g_loss = self.g_loss_fn(A2B_d_logits)
B2A_g_loss = self.g_loss_fn(B2A_d_logits)
A2B2A_cycle_loss = self.cycle_loss_fn(A[0], A2B2A)
B2A2B_cycle_loss = self.cycle_loss_fn(B[0], B2A2B)
A2A_id_loss = self.identity_loss_fn(A[0], A2A)
B2B_id_loss = self.identity_loss_fn(B[0], B2B)
# Classification losses.
A2B_c_loss = self.class_loss_fn(A[1], y_A2B) # label_A
A2B2A_c_loss = self.class_loss_fn(A[1], y_A2B2A) # label_A
A2A_c_loss = self.class_loss_fn(A[1], y_A2A) # label_A
B2A_c_loss = self.class_loss_fn(B[1], y_B2A) # label_B
B2A2B_c_loss = self.class_loss_fn(B[1], y_B2A2B) # label_B
B2B_c_loss = self.class_loss_fn(B[1], y_B2B) # label_B
G_loss = (A2B_g_loss + B2A_g_loss + (0.1 * (A2B_c_loss + B2A_c_loss))) + (
A2B2A_cycle_loss + B2A2B_cycle_loss + (0.01 * (A2B2A_c_loss + B2A2B_c_loss))) * self.cycle_weights + (
A2A_id_loss + B2B_id_loss + (0.02 * (A2A_c_loss + B2B_c_loss))) * self.identity_weight
G_grad = t.gradient(G_loss, self.G_A2B.trainable_variables + self.G_B2A.trainable_variables)
self.G_optimizer.apply_gradients(zip(G_grad, self.G_A2B.trainable_variables + self.G_B2A.trainable_variables))
return A2B, B2A
@tf.function
def train_D(self, A, B, A2B, B2A):
with tf.GradientTape() as t:
A_d_logits = self.D_A(A, training=True)
B2A_d_logits = self.D_A(B2A, training=True)
B_d_logits = self.D_B(B, training=True)
A2B_d_logits = self.D_B(A2B, training=True)
A_d_loss, B2A_d_loss = self.d_loss_fn(A_d_logits, B2A_d_logits)
B_d_loss, A2B_d_loss = self.d_loss_fn(B_d_logits, A2B_d_logits)
D_loss = (A_d_loss + B2A_d_loss) + (B_d_loss + A2B_d_loss)
D_grad = t.gradient(D_loss, self.D_A.trainable_variables + self.D_B.trainable_variables)
self.D_optimizer.apply_gradients(zip(D_grad, self.D_A.trainable_variables + self.D_B.trainable_variables))
@tf.function
def sample(self, A, B):
A2B, _ = self.G_A2B(A, training=False)
B2A, _ = self.G_B2A(B, training=False)
A2B2A, _ = self.G_B2A(A2B, training=False)
B2A2B, _ = self.G_A2B(B2A, training=False)
return A2B, B2A, A2B2A, B2A2B