forked from UTSAVS26/PyVerse
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
35 lines (26 loc) · 921 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
"""
Trains a CNN model using tflearn wrapper for tensorflow
"""
import tflearn
import h5py
import numpy as np
from cnn_model import CNNModel
# Load HDF5 dataset
h5f = h5py.File('../data/train.h5', 'r')
X_train_images = h5f['X']
Y_train_labels = h5f['Y']
h5f2 = h5py.File('../data/val.h5', 'r')
X_val_images = h5f2['X']
Y_val_labels = h5f2['Y']
## Model definition
convnet = CNNModel()
network = convnet.define_network(X_train_images)
model = tflearn.DNN(network, tensorboard_verbose=0,\
checkpoint_path='nodule3-classifier.tfl.ckpt')
model.fit(X_train_images, Y_train_labels, n_epoch = 70, shuffle=True,\
validation_set = (X_val_images, Y_val_labels), show_metric = True,\
batch_size = 96, snapshot_epoch = True, run_id = 'nodule3-classifier')
model.save("nodule3-classifier.tfl")
print("Network trained and saved as nodule2-classifier.tfl!")
h5f.close()
h5f2.close()