-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
118 lines (93 loc) · 3.41 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
本檔案為「主程式」。
會對使用者給予的參數進行剖析,並呼叫指定之函式。
詳細功能與輸入、輸出請見各函式的 docstring。
"""
from argparse import ArgumentParser, Namespace, RawTextHelpFormatter
from configparser import ConfigParser
from sys import argv
from evalute import evalute
from inference import inference
from initialize import initialize
from my_openfe import openfe_inference, openfe_train
from save import save_results
from train import cv_train_with_optuna, train, train_with_optuna
from utils import get_logger
info = ' '.join(argv)
def parse_args() -> Namespace:
""" Parse arguments.
Returns:
Namespace: The arguments.
"""
parser = ArgumentParser(formatter_class=RawTextHelpFormatter)
parser.add_argument("-c",
"--config",
default="configs/configs.ini",
type=str,
help="Path to the configs file.")
parser.add_argument(
"-m",
"--mode",
type=str,
choices=[
"cv_train_with_optuna", "evalute", "inference", "train",
"train_with_optuna", "openfe-inference", "openfe-train"
],
required=True,
help=
"Mode to run the program. \nNotice that: \n1. \"evalute\", \"inference\" and \"openfe-inference\" modes must load a checkpoint. \n2. Checkpoints only can be used in \"evalute\", \"inference\" and \"openfe-inference\" modes. There may be some unexpected errors if you use checkpoints in other modes."
)
parser.add_argument(
"-uofed",
"--use_openfe_data",
action="store_true",
help=
"Add this argument if you want to use the training data generated by OpenFE."
)
return parser.parse_args()
def parse_configs(configs_path: str) -> ConfigParser:
""" Parse configs.
Args:
configs_path (str): The path of configs file.
Returns:
ConfigParser: The configs.
"""
parser = ConfigParser()
parser.read(filenames=configs_path)
return parser
def main():
""" The main function.
Raises:
ValueError: If the given mode is unknown.
"""
args = parse_args()
configs_path = args.config
mode = args.mode
use_openfe_data = args.use_openfe_data
configs = parse_configs(configs_path=configs_path)
log_name = (configs.get(section="GENERAL", option="version") + ".log")
logger = get_logger(log_name=log_name)
logger.info(msg=f"Command: python3 {info}")
parameters = initialize(configs=configs,
mode=mode,
use_openfe_data=use_openfe_data)
if mode == "cv_train_with_optuna":
cv_train_with_optuna(params=parameters)
elif mode == "evalute":
evalute(parameters=parameters)
elif mode == "inference":
inference(parameters=parameters)
elif mode == "train":
train(params=parameters)
elif mode == "train_with_optuna":
train_with_optuna(params=parameters)
elif mode == "openfe-inference":
openfe_inference(params=parameters)
elif mode == "openfe-train":
openfe_train(params=parameters)
else:
raise ValueError(f"Unknown mode: {mode}")
if mode != "evalute" and mode != "openfe-train" and mode != "openfe-inference":
save_results(configs=configs, mode=mode, params=parameters)
if __name__ == "__main__":
main()