-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoptimizers.py
204 lines (171 loc) · 8.52 KB
/
optimizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import numpy as np
from tqdm import tqdm
import random
from abc import ABC, abstractmethod
import utils
class PromptOptimizer(ABC):
def __init__(self, args, evaluator_fn, scorer, max_threads=1, bf_eval=None, predictor=None):
self.opt = args
self.evaluator_fn = evaluator_fn
self.scorer = scorer
self.max_threads = max_threads
self.bf_eval = bf_eval
self.predictor = predictor
@abstractmethod
def expand_candidates(self, prompts, task, gpt4, train_exs):
pass
class ProTeGi(PromptOptimizer):
""" ProTeGi: Prompt Optimization with Textual Gradients
"""
def _sample_error_str(self, texts, labels, preds, task, n=4):
""" Sample n error strings from the given texts, labels, and preds"""
error_idxs = []
for i, (l, p) in enumerate(zip(labels, preds)):
if l != p:
error_idxs.append(i)
sample_idxs = random.sample(error_idxs, min(len(error_idxs), n))
sample_texts = [texts[i] for i in sample_idxs]
sample_labels = [labels[i] for i in sample_idxs]
sample_preds = [preds[i] for i in sample_idxs]
error_string = ''
num_errors = 0
error_idx = 0
for i, (t, l, p) in enumerate(zip(sample_texts, sample_labels, sample_preds)):
error_string += f'## Example {error_idx+1}\n'
error_string += f'Text: \"{t.strip()}\"\nLabel: {task.stringify_prediction(l)}\nPrediction: {task.stringify_prediction(p)}\n\n'
error_idx += 1
return error_string.strip()
def parse_tagged_text(self, text, start_tag, end_tag):
""" Parse text that is tagged with start and end tags."""
texts = []
while True:
start_index = text.find(start_tag)
if start_index == -1:
break
end_index = text.find(end_tag, start_index)
if end_index == -1:
break
start_index += len(start_tag)
texts.append(text[start_index:end_index].strip())
text = text[end_index+len(end_tag):]
return texts
def _get_gradients(self, prompt, error_string, num_feedbacks=5, n=1):
""" Get "gradients" for a prompt based on the error string."""
gradient_prompt = f"""
I'm trying to write a zero-shot classifier prompt.
My current prompt is:
"{prompt}"
But this prompt gets the following examples wrong:
{error_string}
give {num_feedbacks} reasons why the prompt could have gotten these examples wrong.
Wrap each reason with <START> and <END>
"""
gradient_prompt = '\n'.join([line.lstrip() for line in gradient_prompt.split('\n')])
res = self.predictor.eval_multiple(gradient_prompt, n=n)
feedbacks = []
new_prompts = []
for r in res:
feedbacks += self.parse_tagged_text(r, "<START>", "<END>")
return feedbacks
def apply_gradient(self, prompt, error_str, feedback_str, steps_per_gradient, n=1):
""" Incorporate feedback gradient into a prompt."""
transformation_prompt = f"""
I'm trying to write a zero-shot classifier.
My current prompt is:
"{prompt}"
But it gets the following examples wrong:
{error_str}
Based on these examples the problem with this prompt is that {feedback_str}
Based on the above information, I wrote {steps_per_gradient} different improved prompts.
Each prompt is wrapped with <START> and <END>.
The {steps_per_gradient} new prompts are:
"""
transformation_prompt = '\n'.join([line.lstrip() for line in transformation_prompt.split('\n')])
res = self.predictor.eval_multiple(transformation_prompt, n=n)
new_prompts = []
for r in res:
new_prompts += self.parse_tagged_text(r, "<START>", "<END>")
return new_prompts
def generate_synonyms(self, prompt_section, n=3):
""" Generate synonyms for a prompt section."""
rewriter_prompt = f"Generate a variation of the following instruction while keeping the semantic meaning.\n\nInput: {prompt_section}\n\nOutput:"
new_instructions = self.predictor.eval_multiple(rewriter_prompt, n=n)
new_instructions = [x for x in new_instructions if x]
return new_instructions
def get_gradients(self, prompt, task_section, task, gpt4, texts, labels, preds):
""" Get "gradients" for a prompt based on sampled error strings."""
prompt_feedbacks = []
for _ in tqdm(range(self.opt['n_gradients']), total=self.opt['n_gradients'], desc='gradients..'):
error_string = self._sample_error_str(
texts, labels, preds, task, n=self.opt['errors_per_gradient'])
gradients = self._get_gradients(
task_section, error_string, self.opt['gradients_per_error'], n=1)
prompt_feedbacks += [(t, error_string) for t in gradients]
return prompt_feedbacks
def expand_candidates(self, prompts, task, gpt4, train_exs):
""" Expand a list of prompts by generating gradient-based successors and
synonyms for each section.
"""
minibatch = random.sample(train_exs, k=self.opt['minibatch_size'])
new_prompts = []
for prompt in tqdm(prompts, desc=f'expanding {len(prompts)} prompts'):
sections = utils.parse_sectioned_prompt(prompt)
task_section = sections['task'].strip()
# evaluate prompt on minibatch
_, _, _, texts, labels, preds = task.evaluate(gpt4, prompt, minibatch)
# get gradients
new_task_sections = []
if self.opt['n_gradients'] > 0:
gradients = self.get_gradients(prompt, task_section, task, gpt4, texts, labels, preds)
new_task_sections = []
for feedback, error_string in tqdm(gradients, desc='applying gradients'):
tmp = self.apply_gradient(
task_section, error_string, feedback, self.opt['steps_per_gradient'])
new_task_sections += tmp
# generate synonyms
mc_sampled_task_sections = []
if self.opt['mc_samples_per_step'] > 0:
for sect in tqdm(new_task_sections + [task_section], desc='mc samples'):
mc_sects = self.generate_synonyms(
sect, n=self.opt['mc_samples_per_step'])
mc_sampled_task_sections += mc_sects
# combine
new_sections = new_task_sections + mc_sampled_task_sections
new_sections = list(set(new_sections)) # dedup
new_sections = [sec.replace("{", "").replace("}", "") for sec in new_sections]
tmp_new_prompts = [
prompt.replace(task_section, tmp)
for tmp in new_sections
]
# filter a little
if len(new_sections) > self.opt['max_expansion_factor']:
if self.opt['reject_on_errors']:
error_exs = []
for i, (t, l, p) in enumerate(zip(texts, labels, preds)):
if l != p:
error_exs.append({'text': t, 'label': l})
error_exs = random.sample(error_exs, min(len(error_exs), 16))
# speed up a little
tmp_new_prompts = random.sample(tmp_new_prompts, min(len(tmp_new_prompts), self.opt['max_expansion_factor'] * 2))
error_scores = self.bf_eval(tmp_new_prompts, error_exs, task, gpt4, self.scorer, max_threads=self.max_threads)
tmp_new_prompts = [tmp_new_prompts[i] for i in np.argsort(error_scores)[-self.opt['max_expansion_factor']:]]
else:
tmp_new_prompts = random.sample(tmp_new_prompts,
k=self.opt['max_expansion_factor'])
new_prompts += tmp_new_prompts
new_prompts += prompts # add originals
new_prompts = list(set(new_prompts)) # dedup
return new_prompts
def score_candidates(self, prompts, task, gpt4, train_exs):
""" Score a list of prompts."""
if len(prompts) == 1:
return [1.0]
evals = self.evaluator_fn(
prompts, train_exs, task, gpt4,
scorer=self.scorer,
rounds=self.opt['eval_rounds'],
num_prompts_per_round=self.opt['eval_prompts_per_round'],
samples_per_eval=self.opt['samples_per_eval'],
max_threads=self.max_threads
)
return evals