-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
94 lines (75 loc) · 2.82 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import tensorflow as tf
from trainer import train, predict, test_iterator
from preprocess import preprocess
flags = tf.flags
data_dir = "data"
label_dir = "LABELLED_DICOM"
mask_dir = "MASKS_DICOM"
mesh_dir = "MESHES_VTK"
patient_dir = "PATIENT_DICOM"
class_list = ["liver", "bone", "kidney"]
flags.DEFINE_string("label_dir", label_dir, "")
flags.DEFINE_string("mask_dir", mask_dir, "")
flags.DEFINE_string("mesh_dir", mesh_dir, "")
flags.DEFINE_string("patient_dir", patient_dir, "")
flags.DEFINE_list("class_list", class_list, "")
flags.DEFINE_integer("num_classes", len(class_list) + 1, "")
prepro_dir = "prepro"
log_dir = "log"
checkpoint_dir = os.path.join(log_dir, "checkpoints")
tensorboard_dir = os.path.join(log_dir, "tensorboard")
out_dir = "out"
train_record_file = os.path.join(prepro_dir, "train.tfrecords")
val_record_file = os.path.join(prepro_dir, "val.tfrecords")
data_csv = os.path.join(prepro_dir, "data.csv")
pretrained_vgg = os.path.join(log_dir, "vgg_16.ckpt")
# directory config
flags.DEFINE_string("data_dir", data_dir, "")
flags.DEFINE_string("log_dir", log_dir, "")
flags.DEFINE_string("checkpoint_dir", checkpoint_dir, "")
flags.DEFINE_string("tensorboard_dir", tensorboard_dir, "")
flags.DEFINE_string("out_dir", out_dir, "")
# file config
flags.DEFINE_string("train_record_file", train_record_file, "")
flags.DEFINE_string("val_record_file", val_record_file, "")
flags.DEFINE_string("data_csv", data_csv, "")
flags.DEFINE_string("pretrained_vgg", pretrained_vgg, "")
# mode config
flags.DEFINE_string("mode", "train", "train/preprocess")
flags.DEFINE_integer("seed", 2019, "")
# training config
flags.DEFINE_integer("image_size", 512, "")
flags.DEFINE_integer("batch_size", 32, "")
flags.DEFINE_float("lr", 3e-4, "")
flags.DEFINE_integer("shuffle_buffer", 100, "")
flags.DEFINE_integer("train_steps", 2258, "") # we have 2258 train samples in our split
flags.DEFINE_integer("val_steps", 565, "") # we have 565 test samples in our split
flags.DEFINE_integer("save_summary_period", 20, "")
flags.DEFINE_integer("validation_period", 500, "")
flags.DEFINE_integer("save_model_period", 500, "")
flags.DEFINE_bool("use_augment", False, "")
if not os.path.exists(prepro_dir):
os.makedirs(prepro_dir)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
def main(_):
config = flags.FLAGS
if config.mode == "train":
train(config)
elif config.mode == "debug":
config.train_steps = 3
config.val_steps = 1
config.batch_size = 2
config.save_summary_period = 1
config.validation_period = 1
config.save_model_period = 2
train(config)
elif config.mode == "predict":
predict(config)
elif config.mode == "iter":
test_iterator(config)
elif config.mode == "preprocess":
preprocess(config)
if __name__ == "__main__":
tf.app.run()