-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCLSWGAN.py
126 lines (111 loc) · 5.65 KB
/
CLSWGAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#implement WGANGP
#######################################################################
from __future__ import print_function, division
from LossFunctions import wasserstein_loss,gradient_penalty_loss
from keras.layers import Input
from keras.models import Model
from keras.optimizers import RMSprop
from functools import partial
import numpy as np
from RandomWeightedAverage import RandomWeightedAverage
from Critic import build_critic
from Generator import build_generator
from keras.utils import plot_model
from readData import readH5file,numberOfClass
#######################################################################
class CLSWGANGP():
def __init__(self):
##########################################
self.features_shape = (2048,)
self.latent_dim = 312
self.n_critic = 5
self.n_critic=10
optimizer = RMSprop(lr=0.00005)
self.generator = build_generator()
self.critic = build_critic()
self.batch_size=1024
self.losslog = []
self.nclasses=numberOfClass()
##########################################
# Freeze generator's layers while training critic
self.generator.trainable = False
# features input (real sample)
real_features = Input(shape=self.features_shape)
# Noise input
z_disc = Input(shape=(self.latent_dim,))
# Generate features based of noise (fake sample) and add label to the input
label = Input(shape=(85,))
fake_features = self.generator([z_disc, label])
# Discriminator determines validity of the real and fake images
fake = self.critic([fake_features, label])
valid = self.critic([real_features, label])
# Construct weighted average between real and fake images
interpolated_features = RandomWeightedAverage()([real_features, fake_features])
# Determine validity of weighted sample
validity_interpolated = self.critic([interpolated_features, label])
# Use Python partial to provide loss function with additional
# 'averaged_samples' argument
partial_gp_loss = partial(gradient_penalty_loss,averaged_samples=interpolated_features)
partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names
self.critic_model = Model(inputs=[real_features, label, z_disc], outputs=[valid, fake, validity_interpolated])
self.critic_model.compile(loss=[wasserstein_loss ,
wasserstein_loss ,
partial_gp_loss] ,
optimizer=optimizer ,
loss_weights=[1, 1, 10])
# -------------------------------
# Construct Computational Graph
# for Generator
# -------------------------------
from keras.models import load_model
classificationLayer = load_model('./models/classifierLayer.h5')
classificationLayer.name = 'modelClassifier'
# For the generator we freeze the critic's layers + classification Layers
self.critic.trainable = False
classificationLayer.trainable=False
self.generator.trainable = True
# Sampled noise for input to generator
z_gen = Input(shape=(self.latent_dim,))
# add label to the input
label = Input(shape=(85,))
# Generate images based of noise
features = self.generator([z_gen, label])
# Discriminator determines validity
valid = self.critic([features, label])
# Discriminator determines class
classx=classificationLayer(features)
self.generator_model = Model([z_gen, label],[valid,classx])
plot_model(self.generator_model,to_file="./models/model.pdf",show_shapes=True)
self.generator_model.compile(loss=[wasserstein_loss,'categorical_crossentropy'],optimizer=optimizer,loss_weights=[1, 1])
def train(self, epochs, batch_size, sample_interval=50):
(x_train, y_train, a_train), (x_test, y_test, a_test), (x_val, y_val, a_val) = readH5file()
# Adversarial ground truths
valid = -np.ones((batch_size, 1))
fake = np.ones((batch_size, 1))
dummy = np.zeros((batch_size, 1)) # Dummy gt for gradient penalty
for epoch in range(epochs):
for _ in range(self.n_critic):
# ---------------------
# Train Discriminator
# ---------------------
# Select a random batch of images
idx = np.random.randint(0, x_train.shape[0], batch_size)
features, labels = x_train[idx], a_train[idx]
# Sample generator input
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# Train the critic
d_loss = self.critic_model.train_on_batch([features, labels, noise], [valid, fake, dummy])
# ---------------------
# Train Generator
# ---------------------
idx = np.random.randint(0, x_train.shape[0], batch_size)
features, labels, attr = x_train[idx], y_train[idx],a_train[idx]
import keras
labels = keras.utils.to_categorical(labels, 50)
g_loss = self.generator_model.train_on_batch([noise, attr],[valid,labels])
# Plot the progress
print("%d [D loss: %f] [G loss: %f]" % (epoch, np.mean(d_loss), np.mean(g_loss)))
# If at save interval => save generated image samples
if epoch % sample_interval == 0:
self.generator.save_weights('generator', overwrite=True)
self.critic.save_weights('discriminator', overwrite=True)