Skip to content

Commit 58822f0

Browse files
committed
Move files to src folder
1 parent 8a780c4 commit 58822f0

File tree

3 files changed

+169
-0
lines changed

3 files changed

+169
-0
lines changed

src/losses.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from metrics import dice_coefficient, tversky
2+
from tensorflow.keras.losses import binary_crossentropy
3+
4+
5+
def dice_loss(y_true, y_pred):
6+
loss = 1 - dice_coefficient(y_true, y_pred)
7+
return loss
8+
9+
10+
def bce_dice_loss(y_true, y_pred):
11+
loss = binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)
12+
return loss
13+
14+
15+
def tversky_loss(y_true, y_pred):
16+
return 1 - tversky(y_true,y_pred)
17+
18+
19+
def focal_tversky(y_true, y_pred, gamma=0.75):
20+
pt_1 = tversky(y_true, y_pred)
21+
return K.pow((1-pt_1), gamma)

src/metrics.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from tensorflow.keras import backend as K
2+
3+
smooth = 1.
4+
5+
def dice_coefficient(y_true, y_pred):
6+
y_true_f = K.flatten(y_true)
7+
y_pred_f = K.flatten(y_pred)
8+
intersection = K.sum(y_true_f * y_pred_f)
9+
score = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
10+
return score
11+
12+
13+
def confusion(y_true, y_pred):
14+
y_pred_pos = K.clip(y_pred, 0, 1)
15+
y_pred_neg = 1 - y_pred_pos
16+
y_pos = K.clip(y_true, 0, 1)
17+
y_neg = 1 - y_pos
18+
tp = K.sum(y_pos * y_pred_pos)
19+
fp = K.sum(y_neg * y_pred_pos)
20+
fn = K.sum(y_pos * y_pred_neg)
21+
prec = (tp + smooth)/(tp+fp+smooth)
22+
rec = (tp+smooth)/(tp+fn+smooth)
23+
return prec, rec
24+
25+
26+
def recall(y_true, y_pred):
27+
y_pred_pos = K.clip(y_pred, 0, 1)
28+
y_pred_neg = 1 - y_pred_pos
29+
y_pos = K.clip(y_true, 0, 1)
30+
y_neg = 1 - y_pos
31+
tp = K.sum(y_pos * y_pred_pos)
32+
fp = K.sum(y_neg * y_pred_pos)
33+
fn = K.sum(y_pos * y_pred_neg)
34+
rec = (tp+smooth)/(tp+fn+smooth)
35+
return rec
36+
37+
38+
def precision(y_true, y_pred):
39+
y_pred_pos = K.clip(y_pred, 0, 1)
40+
y_pred_neg = 1 - y_pred_pos
41+
y_pos = K.clip(y_true, 0, 1)
42+
y_neg = 1 - y_pos
43+
tp = K.sum(y_pos * y_pred_pos)
44+
fp = K.sum(y_neg * y_pred_pos)
45+
fn = K.sum(y_pos * y_pred_neg)
46+
prec = (tp + smooth)/(tp+fp+smooth)
47+
return prec
48+
49+
50+
def tversky(y_true, y_pred, alpha=0.7):
51+
y_true_pos = K.flatten(y_true)
52+
y_pred_pos = K.flatten(y_pred)
53+
true_pos = K.sum(y_true_pos * y_pred_pos)
54+
false_neg = K.sum(y_true_pos * (1-y_pred_pos))
55+
false_pos = K.sum((1-y_true_pos)*y_pred_pos)
56+
return (true_pos + smooth)/(true_pos + alpha*false_neg + (1-alpha)*false_pos + smooth)

src/unet.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from tensorflow.keras import Model
2+
from tensorflow.keras.layers import Conv3D, Conv3DTranspose, Input, BatchNormalization, Activation, MaxPool3D, SpatialDropout3D, Concatenate
3+
4+
5+
class Unet3D:
6+
def __init__(self,
7+
n_classes,
8+
input_shape,
9+
activation="relu",
10+
n_base_filters=8,
11+
batchnorm=False,
12+
spatial_dropout=False,
13+
batch_size=None,
14+
model_depth=5,
15+
name="3DUnet"):
16+
self.n_classes = n_classes
17+
self.input_shape = input_shape
18+
self.activation = activation
19+
self.n_base_filters = n_base_filters
20+
self.batchnorm = batchnorm
21+
self.spatial_dropout = spatial_dropout
22+
self.batch_size = batch_size
23+
self.model_depth = model_depth
24+
self.name = name
25+
26+
self.skips = []
27+
28+
self.conv_kwds = {
29+
"kernel_size": (3, 3, 3),
30+
"activation": None,
31+
"padding": "same",
32+
"kernel_initializer": "he_normal",
33+
# 'kernel_regularizer': tf.keras.regularizers.l2(0.001),
34+
}
35+
36+
self.conv_transpose_kwds = {
37+
"kernel_size": (2, 2, 2),
38+
"strides": 2,
39+
"padding": "same",
40+
"kernel_initializer": "he_normal",
41+
# 'kernel_regularizer': tf.keras.regularizers.l2(0.001),
42+
}
43+
44+
def encoder(self, inputs):
45+
x = inputs
46+
for depth in range(self.model_depth):
47+
x = Conv3D(self.n_base_filters * (2**depth), **self.conv_kwds)(x)
48+
if self.batchnorm:
49+
x = BatchNormalization()(x)
50+
x = Activation(self.activation)(x)
51+
x = Conv3D(self.n_base_filters * (2**(depth+1)), **self.conv_kwds)(x)
52+
if self.batchnorm:
53+
x = BatchNormalization()(x)
54+
x = Activation(self.activation)(x)
55+
if depth < self.model_depth - 1:
56+
self.skips.append(x)
57+
x = MaxPool3D(2)(x)
58+
if self.spatial_dropout:
59+
x = SpatialDropout3D(0.5)(x)
60+
61+
return x
62+
63+
def decoder(self, x):
64+
for depth in range(self.model_depth-1, 0, -1):
65+
x = Conv3DTranspose(self.n_base_filters * (2**(depth+1)), **self.conv_transpose_kwds)(x)
66+
67+
x = Concatenate(axis=-1)([self.skips[depth-1], x])
68+
if self.spatial_dropout:
69+
x = SpatialDropout3D(0.5)(x)
70+
x = Conv3D(self.n_base_filters * (2**depth), **self.conv_kwds)(x)
71+
if self.batchnorm:
72+
x = BatchNormalization()(x)
73+
x = Activation(self.activation)(x)
74+
x = Conv3D(self.n_base_filters * (2**depth), **self.conv_kwds)(x)
75+
if self.batchnorm:
76+
x = BatchNormalization()(x)
77+
x = Activation(self.activation)(x)
78+
79+
x = Conv3D(filters=self.n_classes, kernel_size=1)(x)
80+
return x
81+
82+
83+
def build_model(self):
84+
inputs = Input(shape=self.input_shape, batch_size=self.batch_size)
85+
x = self.encoder(inputs)
86+
x = self.decoder(x)
87+
88+
final_activation = "sigmoid" if self.n_classes == 1 else "softmax"
89+
x = Activation(final_activation)(x)
90+
91+
model = Model(inputs=inputs, outputs=x, name=self.name)
92+
return model

0 commit comments

Comments
 (0)