-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathabstract_generate.py
More file actions
142 lines (125 loc) · 6.26 KB
/
abstract_generate.py
File metadata and controls
142 lines (125 loc) · 6.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import json
import itertools
from tqdm import tqdm
import random
import argparse
import logging
from collections import Counter
from transformers import set_seed, AutoTokenizer, AutoModelForCausalLM
import torch
# huggingface-cli download meta-llama/Meta-Llama-3-8B-Instruct --exclude "original/*" --local-dir Llama3 --token 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
PROMPT = '''I will give you an INTENTION. You need to give several phrases containing 1-3 words for the ABSTRACT INTENTION of this INTENTION.
You must return your answer in the following format: phrases1,phrases2,phrases3,...., which means you can't return anything other than answers.
These abstract intention words should fulfill the following requirements.
1. The ABSTRACT INTENTION phrases can well represent the INTENTION.
2. The ABSTRACT INTENTION phrases don't have a lot of less relevant word meanings. For example, "spring" is not a good abstract intention word because it can represent both a coiled metal device and the season of the year.
3. The ABSTRACT INTENTION phrases of the same INTENTION cannot be semantically similar with each other. For example, health and wellness are two close synonyms, so they can't be together.
INTENTION: Moisturize dry skin while enjoying a special effect bath.
Your answer: hydration,skincare
INTENTION: Create a festive atmosphere for a Christmas party.
Your answer: party planning,celebration,decorations,holiday spirit
INTENTION: [INTENTION].
Your answer:
'''
LLaMA3_CHAT_TEMPLATE = {
"system_start": """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n You are a helpful AI assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>""",
"prompt_start": """<|start_header_id|>user<|end_header_id|>""",
"prompt_end": "<|eot_id|>",
"model_start": "<|start_header_id|>assistant<|end_header_id|>"
}
def build_batch_data(sessions, batch_size):
batched_sessions = []
for i in range(0, len(sessions), batch_size):
batched_sessions.append(sessions[i:i+batch_size])
return batched_sessions
def batched_inference(tokenizer, model, inputs):
input_ids = tokenizer(
inputs,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024).to(model.device)
with torch.no_grad():
output_ids = model.generate(
**input_ids,
do_sample=True,
temperature=0.6,
top_p=0.9)
results = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
answers = []
for i in range(len(results)):
answer = results[i].split('\n')[-1]
answers.append([x.strip() for x in answer.split(",")])
return answers
def generate(args):
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler(args.logging_file)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logging.getLogger().addHandler(file_handler)
# read data
with open(args.input_file, "r") as file:
lines = file.readlines()
sessions = []
for line in lines:
dict_obj = json.loads(line)
sessions.append(dict_obj)
if args.sample_num:
sessions = random.sample(sessions, args.sample_num)
batched_sessions = build_batch_data(sessions, args.batch_size)
# define model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
args.model,
device_map="auto",
torch_dtype=torch.bfloat16
)
set_seed(42)
# start processing
with open(args.output_file, "w") as file:
for sessions in tqdm(batched_sessions):
merged_intentions = [intention for session in sessions for intention in session['Intentions']]
inputs = []
for intention in merged_intentions:
prompt = PROMPT.replace("[INTENTION]", intention)
constructed_input = LLaMA3_CHAT_TEMPLATE["system_start"] + \
LLaMA3_CHAT_TEMPLATE["prompt_start"] + \
prompt + \
LLaMA3_CHAT_TEMPLATE["prompt_end"] + \
LLaMA3_CHAT_TEMPLATE["model_start"]
inputs.append(constructed_input)
try:
answers = batched_inference(tokenizer, model, inputs)
except Exception as e:
logging.error(f"Error occurred on element intention: {intention}. Error message: {e}")
continue
results = [(merged_intentions[i], answers[i]) for i in range(len(merged_intentions))]
index = 0
for session in sessions:
tmp = []
for intention, answer in results[index:index+len(session['Intentions'])]:
tmp.append({
"INTENTION": intention,
"ABSTRACT INTENTION": answer,
"model_type": args.model
})
session['abstract_generation_result'] = tmp
json.dump(session, file)
file.write("\n")
index += len(session['Intentions'])
file.close()
return
def main():
random.seed(8)
parser = argparse.ArgumentParser()
parser.add_argument('--input_file', type=str, default="./sample.json", help="Path to the input file.")
parser.add_argument('--output_file', type=str, default="./output.json", help="Path to the output file.")
parser.add_argument('--logging_file', type=str, default="./generate_log_for_test.log", help="Path to the logging file.")
parser.add_argument('--sample_num', type=int, default=None, help="Sample number of sessions.")
parser.add_argument('--model', type=str, default="./Llama3", help="Model name in huggingface or local path.")
parser.add_argument('--batch_size', type=int, default=5, help="Number of sessions processed at the same time.")
args = parser.parse_args()
generate(args)
if __name__=="__main__":
main()