-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathunet_def.py
100 lines (80 loc) · 3.75 KB
/
unet_def.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import keras
from keras import Model
from keras.layers import (Activation, Add, AveragePooling2D,
BatchNormalization, Conv2D, Conv2DTranspose, Flatten,
Input, MaxPool2D, Reshape, UpSampling2D,
ZeroPadding2D, concatenate)
import keras.backend as backend
import keras.utils as keras_utils
'''
def get_segmentation_model(input, output):
img_input = input
o = output
o_shape = Model(img_input, o).output_shape
i_shape = Model(img_input, o).input_shape
if IMAGE_ORDERING == 'channels_first':
output_height = o_shape[2]
output_width = o_shape[3]
input_height = i_shape[2]
input_width = i_shape[3]
n_classes = o_shape[1]
o = (Reshape((-1, output_height*output_width)))(o)
o = (Permute((2, 1)))(o)
elif IMAGE_ORDERING == 'channels_last':
output_height = o_shape[1]
output_width = o_shape[2]
input_height = i_shape[1]
input_width = i_shape[2]
n_classes = o_shape[3]
o = (Reshape((output_height*output_width, -1)))(o)
o = (Activation('softmax'))(o)
model = Model(img_input, o)
model.output_width = output_width
model.output_height = output_height
model.n_classes = n_classes
model.input_height = input_height
model.input_width = input_width
model.model_name = ""
model.train = MethodType(train, model)
model.predict_segmentation = MethodType(predict, model)
model.predict_multiple = MethodType(predict_multiple, model)
model.evaluate_segmentation = MethodType(evaluate, model)
return model
'''
# Decoder for UNet is adapted from keras-segmentation
# https://github.com/divamgupta/image-segmentation-keras/blob/master/keras_segmentation/models/unet.py
def UNet(f4, f3, f2, f1, output_height, output_width, l1_skip_conn=True, n_classes=2):
o = f4
IMAGE_ORDERING = 'channels_last'
if IMAGE_ORDERING == 'channels_first':
MERGE_AXIS = 1
elif IMAGE_ORDERING == 'channels_last':
MERGE_AXIS = -1
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(512, (3, 3), padding='valid' , activation='relu' , data_format=IMAGE_ORDERING))(o)
o = (BatchNormalization())(o)
o = (Conv2DTranspose(512, (2, 2), strides=(2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, f3], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(256, (3, 3), padding='valid', activation='relu' , data_format=IMAGE_ORDERING))(o)
o = (BatchNormalization())(o)
o = (Conv2DTranspose(256, (2, 2), strides=(2, 2), data_format=IMAGE_ORDERING))(o)
o = (concatenate([o, f2], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(128, (3, 3), padding='valid' , activation='relu' , data_format=IMAGE_ORDERING))(o)
o = (BatchNormalization())(o)
o = (Conv2DTranspose(128, (2, 2), strides=(2, 2), data_format=IMAGE_ORDERING))(o)
if l1_skip_conn:
o = (concatenate([o, f1], axis=MERGE_AXIS))
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(64, (3, 3), padding='valid', activation='relu', data_format=IMAGE_ORDERING))(o)
o = (BatchNormalization())(o)
o = Conv2D(n_classes, (3, 3), padding='same', data_format=IMAGE_ORDERING)(o)
o = (Conv2DTranspose(64, (2, 2), strides=(2, 2), data_format=IMAGE_ORDERING))(o)
o = (ZeroPadding2D((1, 1), data_format=IMAGE_ORDERING))(o)
o = (Conv2D(32, (3, 3), padding='valid' , activation='relu' , data_format=IMAGE_ORDERING))(o)
o = (BatchNormalization())(o)
o = (Conv2DTranspose(2, (2, 2), strides=(2, 2), data_format=IMAGE_ORDERING))(o)
# o = (Reshape((output_height*output_width, -1)))(o)
o = (Activation('softmax'))(o)
return o