-
Notifications
You must be signed in to change notification settings - Fork 1
/
3d_unet
58 lines (50 loc) · 2.4 KB
/
3d_unet
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def conv_block(x, size, dropout):
# Convolutional layer.
conv = layers.Conv3D(size, (3, 3, 3), kernel_initializer='he_uniform', padding="same")(x)
conv = layers.Activation("relu")(conv)
conv = layers.Conv3D(size, (3, 3, 3), kernel_initializer='he_uniform', padding="same")(conv)
conv = layers.Activation("relu")(conv)
if dropout > 0:
conv = layers.Dropout(dropout)(conv)
return conv
def UNet_3D_Model(input_shape):
# network structure
filter_numb = 16 # number of filters for the first layer
inputs = layers.Input(input_shape, dtype=tf.float32)
# Contraction path:
# DownRes 1, convolution + pooling
conv_64 = conv_block(inputs, filter_numb, dropout=0.10)
pool_32 = layers.MaxPooling3D((2, 2, 2), padding="same")(conv_64)
# DownRes 2
conv_32 = conv_block(pool_32, 2 * filter_numb, dropout=0.15)
pool_16 = layers.MaxPooling3D((2, 2, 2), padding="same")(conv_32)
# DownRes 3
conv_16 = conv_block(pool_16, 4 * filter_numb, dropout=0.20)
pool_8 = layers.MaxPooling3D((2, 2, 2), padding="same")(conv_16)
# DownRes 4
conv_8 = conv_block(pool_8, 8 * filter_numb, dropout=0.25)
pool_4 = layers.MaxPooling3D((2, 2, 2), padding="same")(conv_8)
# DownRes 5, convolution only
conv_4 = conv_block(pool_4, 16 * filter_numb, dropout=0.30)
# Upsampling layers
up_8 = layers.UpSampling3D((2, 2, 2), data_format="channels_last")(conv_4)
up_8 = layers.concatenate([up_8, conv_8])
up_conv_8 = conv_block(up_8, 8 * filter_numb, dropout=0.25)
# UpRes 7
up_16 = layers.UpSampling3D((2, 2, 2), data_format="channels_last")(up_conv_8)
up_16 = layers.concatenate([up_16, conv_16])
up_conv_16 = conv_block(up_16, 4 * filter_numb, dropout=0.20)
# UpRes 8
up_32 = layers.UpSampling3D((2, 2, 2), data_format="channels_last")(up_conv_16)
up_32 = layers.concatenate([up_32, conv_32])
up_conv_32 = conv_block(up_32, 2 * filter_numb, dropout=0.15)
# UpRes 9
up_64 = layers.UpSampling3D((2, 2, 2), data_format="channels_last")(up_conv_32)
up_64 = layers.concatenate([up_64, conv_64])
up_conv_64 = conv_block(up_64, filter_numb, dropout=0.10)
# final convolutional layer
conv_final = layers.Conv3D(1, (1, 1, 1))(up_conv_64)
conv_final = layers.Activation('linear')(conv_final)
model = models.Model(inputs=[inputs], outputs=[conv_final], name="UNet_3D_Model")
model.summary()
return model