-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathapi_utils.py
93 lines (87 loc) · 3.94 KB
/
api_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from header import *
from models import *
def parse_msg(request):
data = request.data.decode()
xml = ET.fromstring(data)
toUser = xml.find('ToUserName').text
fromUser = xml.find('FromUserName').text
msgType = xml.find('MsgType').text
content = xml.find('Content').text
return toUser, fromUser, msgType, content
def flask_load_agent(model, gpu, logger):
'''init the agent'''
args = {
'model': model,
'multi_gpu': gpu
}
logger.info(f'[!] begin to init the {args["model"]} agent on GPU {args["multi_gpu"]}')
if args['model'] == 'bertretrieval':
agent = BERTRetrievalAgent(args['multi_gpu'], run_mode='test', kb=False)
agent.load_model(f'ckpt/zh50w/bertretrieval/best.pt')
elif args['model'] == 'bertmc':
# [model_type]: bertmc -> mc; bertmcf -> mcf
agent = BERTMCAgent(args['multi_gpu'], kb=False, model_type='mc')
agent.load_model(f'ckpt/zh50w/bertmc/best.pt')
elif args['model'] in ['bertirbi', 'bertirbicomp']:
model = 'no-compare' if args['model'] == 'bertirbi' else 'compare'
agent = BERTBiEncoderAgent(args['multi_gpu'], None, run_mode='test', model=model)
agent.load_model(f'ckpt/zh50w/{args["model"]}/best.pt')
elif args['model'] == 'bertretrieval_multiview':
agent = BERTMULTIVIEWAgent(args['multi_gpu'], kb=False)
agent.load_model(f'ckpt/zh50w/bertretrieval_multiview/best.pt')
elif args['model'] == 'gpt2':
# available run_mode: test, rerank, rerank_ir
agent = GPT2Agent(1000, args['multi_gpu'], run_mode='rerank_ir')
agent.load_model(f'ckpt/train_generative/gpt2/best.pt')
elif args['model'] == 'lccc':
agent = LCCCAgent(args['multi_gpu'], run_mode='test') # run_mode: test/rerank
elif args['model'] == 'when2talk':
agent = When2TalkAgent(1000, args['multi_gpu'], run_mode='test')
agent.load_model(f'ckpt/when2talk/when2talk/best.pt')
elif args['model'] == 'test':
agent = TestAgent()
elif args['model'] == 'multiview':
agent = MultiViewTestAgent()
else:
raise Exception(f'[!] obtain the unknown model name {args["model"]}')
print(f'[!] init {args["model"]} agent on GPU {args["multi_gpu"]} over ...')
return agent
def chat(agent, content, args=None, logger=None):
if args['chat_mode'] == 0:
return normal_chat_single_turn(agent, content, args=args)
elif args['chat_mode'] == 1:
return normal_chat_multi_turn(agent, content, args=args)
elif args['chat_mode'] == 2:
return kg_driven_chat_multi_turn(agent, content, args=args)
else:
print(f'[!] Unknow chat mode {args["chat_mode"]}')
return None
def normal_chat_single_turn(agent, content, topic=None, args=None):
data = {
'topic': topic,
'msgs': [{'msg': content}]
}
args['content'] = content
return agent.get_res(data)
def normal_chat_multi_turn(agent, content, topic=None, args=None):
query = {"$or": [{"fromUser": args["fromUser"]}, {"toUser": args["fromUser"]}]}
previous_utterances = [i['utterance'] for i in args['table'].find(query)][-args['multi_turn_size']:]
content_list = [{'msg': i} for i in previous_utterances]
data = {
'topic': topic,
'msgs': content_list,
}
args['content'] = ' [SEP] '.join(previous_utterances)
return agent.get_res(data)
def kg_driven_chat_multi_turn(agent, content, topic=None, args=None):
query = {"$or": [{"fromUser": args["fromUser"]}, {"toUser": args["fromUser"]}]}
previous_utterances = [(i['fromUser'], i['utterance']) for i in args['table'].find(query)][-args['multi_turn_size']:]
content_list = [{'msg': i[1], 'fromUser': i[0]} for i in previous_utterances]
data = {
'topic': topic,
'msgs': content_list,
'path': args['session'].get('kg_path'),
'current_node': args['session'].get('node'),
}
args['content'] = ' [SEP] '.join(previous_utterances)
return agent.get_res(data)