-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocess.py
46 lines (38 loc) · 1.26 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from text_gan.data.squad1_ca_q import Squad1_CA_Q
from text_gan.data.squad1_ca_qc import SQuAD_CA_QC
from text_gan.data.squad_ca_preqc import SQuAD_CA_PreQC
from text_gan import cfg_from_file, cfg
import argparse
import multiprocessing
import logging
DATA = [
"squadca-q",
"squadca-qc",
"squadca-preqc",
]
def parse_args():
parser = argparse.ArgumentParser(prog="Preprocess SQuAD 1 data")
parser.add_argument(
"--dataset", "-d", choices=DATA,
required=True, dest="dataset",
help="Select dataset to preprocess")
parser.add_argument(
"--cfg", dest="cfg", type=str, help="Config YAML filepath",
required=False, default=None)
return parser.parse_args()
def main():
args = parse_args()
if args.cfg is not None:
cfg_from_file(args.cfg)
logging.basicConfig(
level=cfg.LOG_LVL,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
if args.dataset == "squadca-q":
_ = Squad1_CA_Q(prepare=True)
elif args.dataset == "squadca-qc":
_ = SQuAD_CA_QC(prepare=True)
elif args.dataset == "squadca-preqc":
_ = SQuAD_CA_PreQC(prepare=True)
if __name__ == "__main__":
multiprocessing.set_start_method('spawn') # option to support debugging
main()