-
Notifications
You must be signed in to change notification settings - Fork 0
/
autoencoder.py
112 lines (91 loc) · 3.73 KB
/
autoencoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
'''AutoencoderKL model for Tensorflow.
Author: Emilio Morales (mil.mor.mor@gmail.com)
Jun 2023
'''
import tensorflow as tf
from tensorflow.keras import layers
class Sampling(layers.Layer):
def call(self, z_mean, z_log_var):
epsilon = tf.keras.backend.random_normal(shape=(z_mean.shape))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
def downBlock(filters, kernel_size=3, initializer='glorot_uniform'):
block = tf.keras.Sequential([
layers.Conv2D(
filters, kernel_size=kernel_size, strides=2,
padding='same', use_bias=False, kernel_initializer=initializer
),
layers.GroupNormalization(),
layers.Activation('swish'),
layers.Conv2D(
filters, kernel_size=kernel_size, strides=1,
padding='same', use_bias=False, kernel_initializer=initializer
),
layers.GroupNormalization(),
layers.Activation('swish'),
])
return block
class Encoder(tf.keras.models.Model):
def __init__(self, model_dim=[64, 128, 256], cuant_dim=4):
super(Encoder, self).__init__()
self.model_dim = model_dim
cuant_dim = cuant_dim * 2
self.ch_conv = layers.Conv2D(
model_dim[0], kernel_size=3, strides=1, padding='same'
)
self.encoder = [downBlock(i) for i in model_dim]
self.cuant_conv = layers.Conv2D(
cuant_dim, kernel_size=1, strides=1, padding='same'
)
self.sample = Sampling()
def call(self, x, training=True):
B, H, W, C = x.shape
x = self.ch_conv(x)
for i in range(len(self.model_dim)):
x = self.encoder[i](x)
x = self.cuant_conv(x)
z_mean, z_log_var = tf.split(x, 2, axis=-1)
z_log_var = tf.clip_by_value(z_log_var, -30.0, 20.0)
x = self.sample(z_mean, z_log_var)
return x, z_mean, z_log_var
def upBlock(filters, kernel_size=3, initializer='glorot_uniform'):
block = tf.keras.Sequential([
layers.UpSampling2D(2, interpolation='bilinear'),
layers.Conv2D(
filters, kernel_size=kernel_size,
padding='same', use_bias=False, kernel_initializer=initializer
),
layers.GroupNormalization(),
layers.Activation('swish'),
layers.Conv2D(
filters, kernel_size=kernel_size,
padding='same', use_bias=False, kernel_initializer=initializer
),
layers.GroupNormalization(),
layers.Activation('swish'),
])
return block
class Decoder(tf.keras.models.Model):
def __init__(self, model_dim=[256, 128, 64]):
super(Decoder, self).__init__()
self.model_dim = model_dim
self.post_quant_conv = layers.Conv2D(model_dim[0],
kernel_size=1,
strides=1, padding='same')
self.decoder = [upBlock(i) for i in model_dim]
self.ch_conv = layers.Conv2D(3, 3, strides=1, padding='same')
def call(self, x, training):
B, H, W, C = x.shape
x = self.post_quant_conv(x)
for i in range(len(self.model_dim)):
x = self.decoder[i](x)
x = self.ch_conv(x)
return x
class Autoencoder(tf.keras.models.Model):
def __init__(self, e_dim, d_dim, cuant_dim=4):
super().__init__()
self.encoder = Encoder(e_dim, cuant_dim)
self.decoder = Decoder(d_dim)
def call(self, x, training=True):
x, z_mean, z_log_var = self.encoder(x, training)
x = self.decoder(x, training)
return x, z_mean, z_log_var