-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathNetworks.py
75 lines (60 loc) · 4.04 KB
/
Networks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D, UpSampling2D, Flatten, Reshape, BatchNormalization, Dropout
from tensorflow.keras import regularizers
def Net_Encoder(input_img, weight_decay = 0.0001, add_dense=True):
net = Conv2D(32, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay), name='conv_1')(input_img)
net = BatchNormalization()(net)
net = Conv2D(32, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay), name='conv_1_2')(net)
net = BatchNormalization()(net)
net = MaxPooling2D((2, 2), strides=(2, 2), name='pool_1')(net)
net = Dropout(0.1)(net)
net = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay), name='conv_2')(net)
net = BatchNormalization()(net)
net = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay), name='conv_2_2')(net)
net = BatchNormalization()(net)
net = MaxPooling2D((2, 2), strides=(2, 2), name='pool_2')(net)
net = Dropout(0.1)(net)
net = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay), name='conv_3')(net)
net = BatchNormalization()(net)
net = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay), name='conv_3_2')(net)
net = BatchNormalization()(net)
net = MaxPooling2D((2, 2), strides=(2, 2), name='pool_3')(net)
net = Dropout(0.1)(net)
net = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay), name='conv_4')(net)
net = BatchNormalization()(net)
net = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay), name='conv_4_2')(net)
net = BatchNormalization()(net)
net = MaxPooling2D((2, 2), strides=(2, 2), name='pool_4')(net)
net = Dropout(0.1)(net)
if add_dense:
net = Flatten()(net)
net = Dense(1024, activation='relu', kernel_regularizer=regularizers.l2(weight_decay), name='latent_feats')(net)
return net
def Net_Decoder(encoder, weight_decay = 0.0001, dense_added=True):
net = encoder
if dense_added:
net = Reshape((2,2,256))(net)
net = UpSampling2D((2, 2), name='upsample_4')(net)
net = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay), name='upconv_4')(net)
net = BatchNormalization()(net)
net = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay), name='upconv_4_2')(net)
net = BatchNormalization()(net)
net = UpSampling2D((2, 2), name='upsample_3')(net)
net = Dropout(0.1)(net)
net = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay), name='upconv_3')(net)
net = BatchNormalization()(net)
net = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay), name='upconv_3_2')(net)
net = BatchNormalization()(net)
net = UpSampling2D((2, 2), name='upsample_2')(net)
net = Dropout(0.1)(net)
net = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay), name='upconv_2')(net)
net = BatchNormalization()(net)
net = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay), name='upconv_2_2')(net)
net = BatchNormalization()(net)
net = UpSampling2D((2, 2), name='upsample_1')(net)
net = Dropout(0.1)(net)
net = Conv2D(32, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay), name='upconv_1')(net)
net = BatchNormalization()(net)
net = Conv2D(32, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(weight_decay), name='upconv_1_2')(net)
net = BatchNormalization()(net)
net = Conv2D(3, (3, 3), activation='sigmoid', padding='same', name='upconv_final')(net) #sigmoid
return net