-
Notifications
You must be signed in to change notification settings - Fork 0
/
Train_A03_Grad.py
41 lines (31 loc) · 1.23 KB
/
Train_A03_Grad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from General_A03 import *
import A03
out_dir = base_dir + "/" + "output_rbc"
###############################################################################
# MAIN
###############################################################################
def main():
# Load datasets
train_data, test_data = load_and_prepare_BCCD_data()
# Do directory exist?
if os.path.exists(out_dir):
check_overwrite = input("Output folder exists; do you wish to overwrite it? (y/n) ")
if check_overwrite == "y":
shutil.rmtree(out_dir)
else:
print("Exiting...")
exit(1)
# Create output directory
os.makedirs(out_dir)
# Predict for training
train_metrics = predict_dataset(train_data, "TRAIN", out_dir,
BCCD_TYPES.RBC.value, A03.find_RBC)
# Predict for testing
test_metrics = predict_dataset(test_data, "TEST", out_dir,
BCCD_TYPES.RBC.value, A03.find_RBC)
# Save metrics
print_metrics(train_metrics, test_metrics)
with open(out_dir + "/RESULTS_RBC.txt", "w") as f:
print_metrics(train_metrics, test_metrics, f)
if __name__ == "__main__":
main()