-
Notifications
You must be signed in to change notification settings - Fork 41
/
run.py
86 lines (69 loc) · 2.75 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import numpy as np
import cv2
from glob import glob
import tensorflow as tf
from tensorflow.keras.metrics import Precision, Recall, MeanIoU
from tensorflow.keras.optimizers import Adam, Nadam, SGD
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, CSVLogger
from data_generator import DataGen
from unet import Unet
from resunet import ResUnet
from m_resunet import ResUnetPlusPlus
from metrics import dice_coef, dice_loss
if __name__ == "__main__":
## Path
file_path = "files/"
model_path = "files/resunetplusplus.h5"
## Create files folder
try:
os.mkdir("files")
except:
pass
train_path = "new_data/kvasir_segmentation_dataset/train/"
valid_path = "new_data/kvasir_segmentation_dataset/valid/"
## Training
train_image_paths = glob(os.path.join(train_path, "images", "*"))
train_mask_paths = glob(os.path.join(train_path, "masks", "*"))
train_image_paths.sort()
train_mask_paths.sort()
# train_image_paths = train_image_paths[:2000]
# train_mask_paths = train_mask_paths[:2000]
## Validation
valid_image_paths = glob(os.path.join(valid_path, "images", "*"))
valid_mask_paths = glob(os.path.join(valid_path, "masks", "*"))
valid_image_paths.sort()
valid_mask_paths.sort()
## Parameters
image_size = 256
batch_size = 8
lr = 1e-4
epochs = 200
train_steps = len(train_image_paths)//batch_size
valid_steps = len(valid_image_paths)//batch_size
## Generator
train_gen = DataGen(image_size, train_image_paths, train_mask_paths, batch_size=batch_size)
valid_gen = DataGen(image_size, valid_image_paths, valid_mask_paths, batch_size=batch_size)
## Unet
#arch = Unet(input_size=image_size)
#model = arch.build_model()
## ResUnet
#arch = ResUnet(input_size=image_size)
#model = arch.build_model()
## ResUnet++
arch = ResUnetPlusPlus(input_size=image_size)
model = arch.build_model()
optimizer = Nadam(lr)
metrics = [Recall(), Precision(), dice_coef, MeanIoU(num_classes=2)]
model.compile(loss=dice_loss, optimizer=optimizer, metrics=metrics)
csv_logger = CSVLogger(f"{file_path}unet_{batch_size}.csv", append=False)
checkpoint = ModelCheckpoint(model_path, verbose=1, save_best_only=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=1e-6, verbose=1)
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=False)
callbacks = [csv_logger, checkpoint, reduce_lr, early_stopping]
model.fit_generator(train_gen,
validation_data=valid_gen,
steps_per_epoch=train_steps,
validation_steps=valid_steps,
epochs=epochs,
callbacks=callbacks)