-
Notifications
You must be signed in to change notification settings - Fork 6
/
train.py
33 lines (27 loc) · 1.19 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# Implementation to train the CNN as detailed in:
# 'Segmentation of histological images and fibrosis identification with a convolutional neural network'
# https://doi.org/10.1016/j.compbiomed.2018.05.015
# https://arxiv.org/abs/1803.07301
import os
import logging
import sys
import network
logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
level=logging.INFO,
stream=sys.stdout)
# os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
batch_size = 128
n_train_data = 59904 # Number of RGB images
n_epochs = 100 # Number of epochs to train for
restore = False # Option to continue training from saved model
save = True # Save model every epoch
h = 48 # Image height
w = 48 # Image width
keep_rate = 1.0 # 1 - dropout rate
if not os.path.exists("predictions training"):
os.makedirs("predictions training")
# Train neural network
logging.info("Training network")
convnet = network.CNN(keep_rate=keep_rate, train_mode=True)
t_net = network.TRAIN_CNN(convnet, batch_size, h, w)
t_net.train_network(n_train_data, batch_size, n_epochs, restore=restore, save=save)