-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
156 lines (123 loc) · 5.49 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import requests
import os
import evaluators
import concurrent.futures
from tqdm import tqdm
import time
import json
import argparse
import scorers
import tasks
import predictors
import optimizers
from dotenv import load_dotenv
load_dotenv()
def get_task_class(task_name):
if task_name == 'climate_fever':
return tasks.ClimateBinaryTask
elif task_name == 'politifacts':
return tasks.PolitifactBinaryTask
else:
raise Exception(f'Unsupported task: {task_name}')
def get_evaluator(evaluator):
if evaluator == 'bf':
return evaluators.BruteForceEvaluator
elif evaluator in {'ucb', 'ucb-e'}:
return evaluators.UCBBanditEvaluator
elif evaluator in {'sr', 's-sr'}:
return evaluators.SuccessiveRejectsEvaluator
elif evaluator == 'sh':
return evaluators.SuccessiveHalvingEvaluator
else:
raise Exception(f'Unsupported evaluator: {evaluator}')
def get_scorer(scorer):
if scorer == '01':
return scorers.Cached01Scorer
elif scorer == 'll':
return scorers.CachedLogLikelihoodScorer
else:
raise Exception(f'Unsupported scorer: {scorer}')
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', default='climate_fever')
parser.add_argument('--prompts', default='prompts/climate_fever.md')
# parser.add_argument('--config', default='default.json')
parser.add_argument('--out', default='test_out.txt')
parser.add_argument('--paraphraser', default=None)
parser.add_argument('--max_threads', default=1, type=int) #do not change this, as vllm works only in single thread for now
parser.add_argument('--max_tokens', default=4, type=int)
parser.add_argument('--model', default="mistralai/Mistral-7B-Instruct-v0.2", type=str)
parser.add_argument('--temperature', default=0.8, type=float)
parser.add_argument('--rounds', default=6, type=int)
parser.add_argument('--beam_size', default=4, type=int)
parser.add_argument('--n_test_exs', default=200, type=int)
parser.add_argument('--minibatch_size', default=64, type=int)
parser.add_argument('--n_gradients', default=4, type=int)
parser.add_argument('--errors_per_gradient', default=4, type=int)
parser.add_argument('--gradients_per_error', default=1, type=int)
parser.add_argument('--steps_per_gradient', default=1, type=int)
parser.add_argument('--mc_samples_per_step', default=2, type=int)
parser.add_argument('--max_expansion_factor', default=8, type=int)
parser.add_argument('--engine', default="chatgpt", type=str)
parser.add_argument('--evaluator', default="bf", type=str)
parser.add_argument('--scorer', default="01", type=str)
parser.add_argument('--eval_rounds', default=8, type=int)
parser.add_argument('--eval_prompts_per_round', default=8, type=int)
# calculated by s-sr and sr
parser.add_argument('--samples_per_eval', default=32, type=int)
parser.add_argument('--c', default=1.0, type=float, help='exploration param for UCB. higher = more exploration')
parser.add_argument('--knn_k', default=2, type=int)
parser.add_argument('--knn_t', default=0.993, type=float)
parser.add_argument('--reject_on_errors', action='store_true')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = get_args()
config = vars(args)
config['eval_budget'] = config['samples_per_eval'] * config['eval_rounds'] * config['eval_prompts_per_round']
task = get_task_class(args.task)(args.max_threads)
scorer = get_scorer(args.scorer)()
evaluator = get_evaluator(args.evaluator)(config)
bf_eval = get_evaluator('bf')(config)
predictor = predictors.VLLMPredictor(config)
optimizer = optimizers.ProTeGi(
config, evaluator, scorer, args.max_threads, bf_eval, predictor)
train_exs = task.get_train_examples()
test_exs = task.get_test_examples()
if os.path.exists(args.out):
os.remove(args.out)
print(config)
with open(args.out, 'a') as outf:
outf.write(json.dumps(config) + '\n')
candidates = [open(fp.strip()).read() for fp in args.prompts.split(',')]
for round in tqdm(range(config['rounds'] + 1)):
print("STARTING ROUND ", round)
start = time.time()
# expand candidates
if round > 0:
candidates = optimizer.expand_candidates(candidates, task, predictor, train_exs)
# score candidates
scores = optimizer.score_candidates(candidates, task, predictor, train_exs)
[scores, candidates] = list(zip(*sorted(list(zip(scores, candidates)), reverse=True)))
# select candidates
candidates = candidates[:config['beam_size']]
scores = scores[:config['beam_size']]
# record candidates, estimated scores, and true scores
with open(args.out, 'a') as outf:
outf.write(f"======== ROUND {round}\n")
outf.write(f'{time.time() - start}\n')
outf.write(f'{candidates}\n')
outf.write(f'{scores}\n')
metrics = []
accuracies = []
cfmats = []
for candidate, score in zip(candidates, scores):
f1, accuracy, cfmat, texts, labels, preds = task.evaluate(predictor, candidate, test_exs, n=args.n_test_exs)
metrics.append(f1)
accuracies.append(accuracy)
cfmats.append(cfmat)
with open(args.out, 'a') as outf:
outf.write(f'f1: {metrics}\n')
outf.write(f'accuracy: {accuracies}\n')
outf.write(f'cf_mat: {cfmats}\n')
print("DONE!")