-
Notifications
You must be signed in to change notification settings - Fork 0
/
discriminator.py
66 lines (58 loc) · 2.11 KB
/
discriminator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
'''PatchGAN discriminator model for Tensorflow.
Author: Emilio Morales (mil.mor.mor@gmail.com)
Mar 2023
'''
import tensorflow as tf
from tensorflow.keras import layers
def convBlock(filters, kernel_size=3, initializer='glorot_uniform'):
block = tf.keras.Sequential([
layers.Conv2D(
filters, kernel_size=kernel_size, padding='same',
kernel_initializer=initializer, strides=2,
use_bias=False
),
layers.BatchNormalization(),
layers.LeakyReLU(0.2),
])
return block
class Discriminator(tf.keras.models.Model):
def __init__(self, model_dim=[32, 64, 128, 256, 512],
initializer='glorot_uniform'):
super(Discriminator, self).__init__()
self.down_big = tf.keras.Sequential([
layers.Conv2D(
model_dim[0], kernel_size=3, strides=2, use_bias=False,
kernel_initializer=initializer, padding='same'
),
layers.LeakyReLU(0.2),
convBlock(
model_dim[1], kernel_size=3, initializer=initializer
),
convBlock(
model_dim[2], kernel_size=3, initializer=initializer
),
convBlock(
model_dim[3], kernel_size=3, initializer=initializer
),
])
self.down_small = tf.keras.Sequential([
layers.Conv2D(
model_dim[4], kernel_size=1, strides=1, use_bias=False,
kernel_initializer=initializer, padding='valid'
),
layers.BatchNormalization(),
layers.LeakyReLU(0.2),
layers.Conv2D(
1, kernel_size=4, strides=1, use_bias=False,
kernel_initializer=initializer, padding='valid'
)
])
'''Logits'''
self.logits = tf.keras.Sequential([
layers.Flatten(),
layers.Activation('linear', dtype='float32')
])
def call(self, img):
x = self.down_big(img)
x = self.down_small(x)
return [self.logits(x)]