-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_prompt.py
361 lines (332 loc) · 18 KB
/
generate_prompt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
from utils import *
from transformers import GPT2Tokenizer
gpt2tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
def get_prompt_instance_finqa(instance,split='train',retrieved_text_idx=[]):
'''
Generate the prompt of an instance of FinQA.
Params:
instance (dict) : a FinQA instance.
split (str) : "train" or "test".
retrieved_text_idx (int) : the list of index of text paragraphs from the context to include in the prompt
'''
#Format table
table = json_to_pandas(instance)
#Linearize the table description as string
table_description = get_table_description(table)
try:
code = DSL_to_code_finqa(instance,table)
except:
code = ''
prompt='Read the following text and table, and then write code to answer the question:\n'
if split=='train':
#Train instance : include only the relevant text paragraphs.
for k,v in instance['qa']['gold_inds'].items():
if 'text' in k:
prompt += v + '\n'
else:
#Test instance : include the relevant text paragraphs according to the retriever
context_text = instance['pre_text'] + instance['post_text']
for i in retrieved_text_idx:
prompt+= context_text[int(i)] + '\n'
#Add the table description
prompt += table_description
#Add the question
prompt+= 'Question: ' + instance['qa']['question'] + '?\n'
if split=='train':
#Train instance : provide the answer
prompt+= 'Answer:\n#Python\n' + code
else:
prompt+= 'Answer:\n#Python\n'
return prompt
def get_prompt_instance_tatqa(instance,context,split='train',retrieved_text_idx=[]):
'''
Generate the prompt of an instance of TAT-QA.
Params:
instance (dict) : a FinQA instance.
split (str) : "train" or "test".
retrieved_text_idx (int) : the list of index of text paragraphs from the context to include in the prompt
'''
table = tatqa_table_to_pandas(context)
table_description = get_table_description(table)
try:
code = DSL_to_code_tatqa(instance,table)
except:
code = ''
prompt='Read the following text and table, and then write code to answer the question:\n'
if split=='train':
for v in instance['rel_paragraphs']:
prompt += context['paragraphs'][int(v)-1]['text'] + '\n'
else:
context_text = context['paragraphs']
for i in retrieved_text_idx:
prompt+= context_text[int(i)]['text'] + '\n'
#Add the table description
prompt += table_description
#Add the question
prompt+= 'Question: ' + instance['question'] + '?\n'
if split=='train':
prompt+= 'Answer:\n#Python\n' + code
else:
prompt+= 'Answer:\n#Python\n'
return prompt
def get_prompt_modality_finqa(instance,split='train',retrieved_text_idx=[],modality='table'):
'''
Generate the prompt of an instance of FinQA for the constraint module "Modality Prediction".
Params:
instance (dict) : a FinQA instance.
split (str) : "train" or "test".
retrieved_text_idx (int) : the list of index of text paragraphs from the context to include in the prompt
modality (str) : the ground truth modality, to be included for train instances.
'''
prompt = 'You are a financial analyst. Read the following document and question.\n'
if split=='train':
for k,v in instance['qa']['gold_inds'].items():
if 'text' in k:
prompt += v + '\n'
else:
context_text = instance['pre_text'] + instance['post_text']
for i in retrieved_text_idx:
prompt+= context_text[int(i)] + '\n'
table = json_to_pandas(instance)
table_description = get_table_description(table)
prompt += table_description
prompt+= 'Question: ' + instance['qa']['question'] + '?\n'
prompt+= ('Do you need data from the table, the text paragraphs, or both (hybrid) to answer this question? Answer by one of the following : table, text, hybrid.')
if split=='train':
#add the answer
prompt+= '\n' +modality + '\n'
return prompt
def get_prompt_modality_tatqa(instance,context,split='train',retrieved_text_idx=[],modality='table'):
'''
Generate the prompt of an instance of TAT-QA for the constraint module "Modality Prediction".
Params:
instance (dict) : a FinQA instance.
split (str) : "train" or "test".
retrieved_text_idx (int) : the list of index of text paragraphs from the context to include in the prompt
modality (str) : the ground truth modality, to be included for train instances.
'''
prompt = 'You are a financial analyst. Read the following document and question.\n'
if split=='train':
for v in instance['rel_paragraphs']:
prompt += context['paragraphs'][int(v)-1]['text'] + '\n'
else:
context_text = context['paragraphs']
for i in retrieved_text_idx:
prompt+= context_text[int(i)]['text'] + '\n'
table = tatqa_table_to_pandas(context)
table_description = get_table_description(table)
prompt += table_description
prompt+= 'Question: ' + instance['question'] + '?\n'
prompt+= ('Do you need data from the table, the text paragraphs, or both (hybrid) to answer this question? Answer by one of the following : table, text, hybrid.')
if split=='train':
prompt+= '\n' +modality + '\n'
return prompt
def get_prompt_answer_type_tatqa(instance,context,split='train',retrieved_text_idx=[],answer_type='span'):
'''
Generate the prompt of an instance of TAT-QA for the constraint module "Answer Type Prediction".
Params:
instance (dict) : a TAT-QA instance.
split (str) : "train" or "test".
retrieved_text_idx (int) : the list of index of text paragraphs from the context to include in the prompt
modality (str) : the ground truth modality, to be included for train instances.
'''
prompt = 'You are a financial analyst. Read the following document and question.\n'
if split=='train':
for v in instance['rel_paragraphs']:
prompt += context['paragraphs'][int(v)-1]['text'] + '\n'
else:
context_text = context['paragraphs']
for i in retrieved_text_idx:
prompt+= context_text[int(i)]['text'] + '\n'
table = tatqa_table_to_pandas(context)
table_description = get_table_description(table)
prompt += table_description
prompt+= 'Question: ' + instance['question'] + '?\n'
prompt+= 'Does this question require to extract spans from the document, to count, or to perform an arithmetic reasoning? Answer by one of the following : span, multi-span, count, arithmetic.'
if split=='train':
prompt+= '\n' +answer_type + '\n'
return prompt
def get_test_prompt_finqa(instance,train,few_shot_idx=[],retrieved_text_idx=[]):
'''
Generate the complete prompt for a test instance, including few-shot exemplars.
Params:
instance (dict) : a FinQA instance.
train (list) : the list object containing all train instances.
few_shot_idx (list) : the train indexes of the few-shot exemplars
retrieved_text_idx (int) : the list of index of text paragraphs from the context to include in the prompt
'''
prompt = ''
#Add few shot instances
for i in few_shot_idx[::-1]:
#Present the few_shot_idx in inverted order from least similar to most similar
prompt+= get_prompt_instance_finqa(train[i]) + '\n'
#Add test instance
prompt += get_prompt_instance_finqa(instance,'test',retrieved_text_idx)
return prompt
def get_test_prompt_modality_finqa(instance,train,train_dataframe,few_shot_idx=[],retrieved_text_idx=[]):
'''
Generate the complete prompt for the modality prediction of a test instance, including few-shot exemplars.
Params:
instance (dict) : a FinQA instance.
train (list) : the list object containing all train instances.
train_dataframe (pandas.DataFrame) : the dataframe containing metadata information about the train instances.
few_shot_idx (list) : the train indexes of the few-shot exemplars
retrieved_text_idx (int) : the list of index of text paragraphs from the context to include in the prompt
'''
prompt = ''
#Add few shot instances
for i in few_shot_idx:
mapping = {0:'table',1:'text',2:'hybrid'}
modality = mapping[train_dataframe.loc[i,'modality']]
prompt+= get_prompt_modality_finqa(train[i],'train',modality=modality) + '\n'
#Add test instance
prompt += get_prompt_modality_finqa(instance,'test',retrieved_text_idx)
return prompt
def get_test_prompt_tatqa(instance,context,train,train_dataframe,few_shot_idx=[],retrieved_text_idx=[]):
'''
Generate the complete prompt for a test instance, including few-shot exemplars.
Params:
instance (dict) : a TAT-QA instance.
context (dict) : the context of the TAT-QA instance.
train (list) : the list object containing all train instances.
few_shot_idx (list) : the train indexes of the few-shot exemplars
retrieved_text_idx (int) : the list of index of text paragraphs from the context to include in the prompt
'''
prompt = ''
#Add few shot instances
for i in few_shot_idx[::-1]:
#Present the few_shot_idx in inverted order from least similar to most similar
train_context = train[train_dataframe.loc[i,'context_index']]
train_instance = train_context['questions'][train_dataframe.loc[i,'instance_index']]
prompt+= get_prompt_instance_tatqa(train_instance,train_context) + '\n'
#Add test instance
prompt += get_prompt_instance_tatqa(instance,context,'test',retrieved_text_idx)
return prompt
def get_test_prompt_modality_tatqa(instance,context,train,train_dataframe,few_shot_idx=[],retrieved_text_idx=[]):
'''
Generate the complete prompt for the modality prediction of a test instance, including few-shot exemplars.
Params:
instance (dict) : a TAT-QA instance.
context (dict) : the context of the TAT-QA instance.
train (list) : the list object containing all train instances.
train_dataframe (pandas.DataFrame) : the dataframe containing metadata information about the train instances.
few_shot_idx (list) : the train indexes of the few-shot exemplars
retrieved_text_idx (int) : the list of index of text paragraphs from the context to include in the prompt
'''
prompt = ''
#Add few shot instances
for i in few_shot_idx[::-1]:
train_context = train[train_dataframe.loc[i,'context_index']]
train_instance = train_context['questions'][train_dataframe.loc[i,'instance_index']]
mapping = {0:'table',1:'text',2:'hybrid'}
modality = mapping[train_dataframe.loc[i,'modality']]
prompt+= get_prompt_modality_tatqa(train_instance,train_context,'train',modality=modality) + '\n'
#Add test instance
prompt += get_prompt_modality_tatqa(instance,context,'test',retrieved_text_idx)
return prompt
def get_test_prompt_answer_type_tatqa(instance,context,train,train_dataframe,few_shot_idx=[],retrieved_text_idx=[]):
'''
Generate the complete prompt for the answer type prediction of a test instance, including few-shot exemplars.
Params:
instance (dict) : a TAT-QA instance.
context (dict) : the context of the TAT-QA instance.
train (list) : the list object containing all train instances.
train_dataframe (pandas.DataFrame) : the dataframe containing metadata information about the train instances.
few_shot_idx (list) : the train indexes of the few-shot exemplars
retrieved_text_idx (int) : the list of index of text paragraphs from the context to include in the prompt
'''
prompt = ''
#Add few shot instances
for i in few_shot_idx[::-1]:
train_context = train[train_dataframe.loc[i,'context_index']]
train_instance = train_context['questions'][train_dataframe.loc[i,'instance_index']]
answer_type = train_dataframe.loc[i,'answer_type']
prompt+= get_prompt_answer_type_tatqa(train_instance,train_context,'train',answer_type=answer_type) + '\n'
#Add test instance
prompt += get_prompt_answer_type_tatqa(instance,context,'test',retrieved_text_idx)
return prompt
def get_test_messages_finqa(instance,train,few_shot_idx=[],retrieved_text_idx=[],query=False):
'''
Generate a test prompt using the OpenAI Chat API syntax.
Generates a message with the following structure : messages = [{'role':'user','content':prompt}]
'''
messages = []
for i in few_shot_idx[::-1]:
prompt = get_prompt_instance_finqa(train[i],query=query)
messages.append({'role':'user','content':prompt.split('#Python\n')[0]+'#Python\n'})
messages.append({'role':'assistant','content':prompt.split('#Python\n')[1]})
#Add test instance
prompt = get_prompt_instance_finqa(instance,'test',retrieved_text_idx,query=query)
messages.append({'role':'user','content':prompt.split('#Python\n')[0]+'#Python\n'})
return messages
def get_test_messages_tatqa(instance,context,train,train_dataframe,few_shot_idx=[],retrieved_text_idx=[],query=False):
'''
Generate a test prompt using the OpenAI Chat API syntax.
Generates a message with the following structure : messages = [{'role':'user','content':prompt}]
'''
messages = []
for i in few_shot_idx[::-1]:
train_context = train[train_dataframe.loc[i,'context_index']]
train_instance = train_context['questions'][train_dataframe.loc[i,'instance_index']]
prompt= get_prompt_instance_tatqa(train_instance,train_context,query=False) + '\n'
messages.append({'role':'user','content':prompt.split('#Python\n')[0]+'#Python\n'})
messages.append({'role':'assistant','content':prompt.split('#Python\n')[1]})
#Add test instance
prompt = get_prompt_instance_tatqa(instance,context,'test',retrieved_text_idx,query=False)
messages.append({'role':'user','content':prompt.split('#Python\n')[0]+'#Python\n'})
return messages
def get_max_prompt_length_finqa(i,raw_dataset,text_filter_df,max_code_length,max_model_length=4096):
#4096 tokens for Codex
tf = [int(t) for t in text_filter_df.iloc[:,i].dropna().to_list()]
prompt = get_test_prompt_finqa(raw_dataset[i],'',[],tf)
inputs = gpt2tokenizer(prompt)
prompt_length = len(inputs['input_ids'])
return max_model_length - max_code_length - prompt_length
def get_max_prompt_length_tatqa(i,test,test_dataframe,train,train_dataframe,text_filter_df,max_code_length,max_model_length=4096):
#4096 tokens for Codex
context = test[test_dataframe.loc[i,'context_index']]
instance = context['questions'][test_dataframe.loc[i,'instance_index']]
tf = [int(t) for t in text_filter_df.iloc[:,i].dropna().to_list()]
prompt = get_test_prompt_tatqa(instance,context,train,train_dataframe,[],tf)
inputs = gpt2tokenizer(prompt)
prompt_length = len(inputs['input_ids'])
return max_model_length - max_code_length - prompt_length
def remove_invalid_scripts_finqa(dataset):
'''
Remove train instances that have not been correctly processed and do not constitute correct examples for FinQA dataset
'''
#Filter out train examples that have not been correctly converted to DSL
wrong_program_templates = ['add(X_0,X_1),add(X_2,#0),add(#1,constant),divide(#2,constant)',
'add(X_0,X_1),add(#0,X_2),add(#1,constant),divide(#2,constant)',
'divide(X_0,X_1),divide(#0,X_1)',
'add(X_0,X_1),add(#0,constant),divide(#1,constant)'
]
prompt = [get_prompt_instance_finqa(dataset[i]) for i in range(len(dataset))]
wrong_programs = [i for i in range(len(dataset)) if get_program_template(dataset[i]['qa']['program']) in wrong_program_templates ]
table_golds = [i for i in range(len(dataset)) if 'table_' in [g[:-1] for g in dataset[i]['qa']['gold_inds'].keys()]]
no_text_golds = [i for i in range(len(dataset)) if not 'text_' in [g[:-1] for g in dataset[i]['qa']['gold_inds'].keys()]] #Examples that do not use text
#Problem needs table but no query provided
filtered_instances = [p for p in table_golds if not 'table_query_0' in prompt[p]]
#The problem is table only but variables are instantiated
filtered_instances += [p for p in range(len(prompt)) if 'text_variable_' in prompt[p] and p in no_text_golds and p not in filtered_instances]
#An error occured in the prompt generation
filtered_instances += [p for p in range(len(prompt)) if not 'ans' in prompt[p] and not p in filtered_instances]
#Remove filtered insances and wrong programs
instances_to_keep = [p for p in range(len(prompt)) if p not in filtered_instances+wrong_programs]
return instances_to_keep
def remove_invalid_scripts_tatqa(dataset,dataframe):
'''
Remove train instances that have not been correctly processed and do not constitute correct examples
'''
#Filter out train examples that have not been correctly converted to DSL
prompt = [get_prompt_instance_tatqa(dataset[dataframe.loc[i,'context_index']]['questions'][dataframe.loc[i,'instance_index']],dataset[dataframe.loc[i,'context_index']]) for i in range(len(dataframe))]
table_golds = [i for i in range(len(dataframe)) if dataframe.loc[i,'modality'] in ['table','table_text']]
no_text_golds = [i for i in range(len(dataframe)) if not dataframe.loc[i,'modality'] in ['text','table_text']] #Examples that do not use text
#Problem needs table but no query provided
filtered_instances = [p for p in table_golds if not 'table_query_0' in prompt[p]]
#The problem is table only but variables are instantiated
filtered_instances += [p for p in range(len(prompt)) if 'text_variable_' in prompt[p] and p in no_text_golds and p not in filtered_instances]
#An error occured in the prompt generation, 'ans' has not been generated
filtered_instances += [p for p in range(len(prompt)) if not 'ans' in prompt[p] and not p in filtered_instances]
instances_to_keep = [p for p in range(len(prompt)) if p not in filtered_instances]
return instances_to_keep