-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
140 lines (130 loc) · 6.77 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import logging
import datetime
import argparse
from processDataset import clean_dataset
import time
from torch.utils.data import DataLoader
from keyphraseExtraction import keyphrase_selection
def greedy_type(value):
valid_options = {'FIRST', 'LONGEST', 'COMBINED', 'NONE'}
if value not in valid_options:
raise argparse.ArgumentTypeError(f"Invalid choice: {value}. Choose from {valid_options}")
return value
def model_version_type(value):
valid_options = {'base', 'small', 'large'}
if value not in valid_options:
raise argparse.ArgumentTypeError(f"Invalid choice: {value}. Choose from {valid_options}")
return value
def get_setting_dict(encoder_header: str, prompt: str, max_len: int, model_version: str,
enable_pos: bool, position_factor: float, length_factor: float):
setting_dict = {}
setting_dict["max_len"] = max_len
setting_dict["temp_en"] = encoder_header
setting_dict["temp_de"] = prompt
setting_dict["model"] = model_version
#setting_dict["enable_filter"] = False #TODO: implement enable_filter
setting_dict["enable_pos"] = enable_pos
setting_dict["position_factor"] = position_factor
setting_dict["length_factor"] = length_factor
return setting_dict
def parse_argument():
parser = argparse.ArgumentParser()
parser.add_argument('--regular_expresion', dest='regular_expresion_value', action='store_true',
help='Set the regular_expresion value to True.')
parser.add_argument('--no-regular_expresion', dest='regular_expresion_value', action='store_false',
help='Set the regular_expresion value to False.')
parser.set_defaults(regular_expresion_value=True)
parser.add_argument('--evaluation', dest='evaluation_value', action='store_true', help='Set to True the Evaluation of the Model in the dataset')
parser.add_argument('--no_evaluation', dest='evaluation_value', action='store_false', help='Set to False the Evaluation of the Model in the dataset')
parser.set_defaults(evaluation_value=False)
parser.add_argument("--greedy",
default="FIRST",
type=greedy_type,
required=False,
help="Method to be used while extracting candidates with regular expresion. LONGEST/FIRST/COMBINED/NONE(we will get all coincidences)")
parser.add_argument("--title_graph_candidates_extraction",
default="Extracción Candidatos",
type=str,
required=False,
help="Title for the grafic")
parser.add_argument("--batch_size",
default=128,
type=int,
required=False,
help="Batch size para evaluar el modelo")
parser.add_argument("--encoder_header",
default="Texto:",
type= str,
required=False,
help= "The text that is going to precede the input at the encoder")
parser.add_argument("--prompt",
default="Este texto habla principalmente de ",
type= str,
help= "The prompt that will precede the candidate")
parser.add_argument("--max_len",
default= 512,
type= int,
help= "Max length that the tokenizer will support for encoding the text")
parser.add_argument("--model_version",
default= "base",
type= model_version_type,
help= "The version of MT5 moder to be used")
parser.add_argument("--length_factor",
default=1.6,
type=float,
required=False,
help="Length factor for being more prone to big or small candidates")
parser.add_argument("--position_factor",
default=1.2e8,
type=float,
required=False,
help="Hyper parameter to regulate position penalty")
parser.add_argument("--enable_pos",
default=False,
type=bool,
required=False,
help="Enable position penalty")
parser.add_argument("--data_path",
default="data/docsutf8",
type=str,
required=False,
help="Path to the data directory")
parser.add_argument("--labels_path",
default="data/keys",
type=str,
required=False,
help="Path to the labels directory")
args = parser.parse_args()
return args
def main():
args = parse_argument()
logger = logging.getLogger(__name__)
setting_dict = get_setting_dict(args .encoder_header, args.prompt, args.max_len, args.model_version,
args.enable_pos, args.position_factor, args.length_factor)
start = time.time()
logging.basicConfig(filename='PromptRankLib.log', encoding='utf-8', filemode='w', level=logging.INFO)
logger.info(f"The main program has started at {datetime.datetime.now()}\n")
# TODO: PASAR EL LOGGER A CLEAN DATASET PARA IR HACIENDO UN RASTREO DE LA EJECUCIÓN
dataset, documents_list, labels, labels_stemed = clean_dataset(args.regular_expresion_value, args.title_graph_candidates_extraction, args.greedy,
args.encoder_header, args.prompt, args.max_len, args.model_version, data_path=args.data_path,
labels_path=args.labels_path, evaluation=args.evaluation_value)
dataloader = DataLoader(dataset, num_workers=4, batch_size=args.batch_size)
if args.evaluation_value is False:
file_name = "results/resultados_modelo.txt"
with open(file_name, "w", encoding="utf-8") as f:
f.write("")
f.write("RESULTADOS KEYPHRASES EXTRAIDAS\n")
f.write(f'FECHA EJECUCION: {datetime.datetime.now()}\n\n')
keyphrase_selection(setting_dict, documents_list, labels_stemed, labels, dataloader, logger, args.model_version, args.evaluation_value)
end = time.time()
log_setting(logger, setting_dict)
logger.info(f'The execution has finished {datetime.datetime.now()}')
logger.info("Processing time: {}".format(end-start))
def log_setting(logger: logging.Logger , setting_dict: dict) -> None:
for i, j in setting_dict.items():
if i == 'length_factor':
logger.info(i + ": {}\n".format(j))
else:
logger.info(i + ": {}".format(j))
if __name__ == "__main__":
main()