-
Notifications
You must be signed in to change notification settings - Fork 35
/
codegen_stream.py
132 lines (121 loc) · 4.26 KB
/
codegen_stream.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import requests
import time
from aiohttp import web
import json
from jaxformer.hf.sample import load_model, sampling
from gpt_j import gpt_load_model, gpt_generate_stream
from ChatGLM_6b import getAnswerFromChatGLM6b, getAnswerFromChatGLM6b_v2
from Vicuna_7b import getAnswerFromVicuna7b, getAnswerFromVicuna7b_v2
from LlaMA2_7b import getAnswerFromLLaMA_v2
from Qwen_7b import getAnswerFromQwen7b_v2
from Agent_6b import getAnswerFromAgent6b_v2
filter_string = None
def sampling_gptj(context, maxlength):
gpt_load_model()
return gpt_generate_stream(context, maxlength)
def filter_context(context):
global filter_string
if filter_string is None:
print("loading filter")
try:
with open('filter.txt', mode='r', encoding='utf-8') as f:
text = f.read().rstrip()
filter_string = text.split('\n')
except FileNotFoundError as err:
filter_string = []
for line in filter_string:
if line in context:
return True
return False
async def codegen_stream(request):
params = await request.json()
context = params["context"]
maxlength = params["maxlength"]
modelname = params["modelname"]
# filter
if filter_context(context):
return web.Response(
content_type="application/json",
text=json.dumps(
{"result_en": "请更换问题重新输入", "result_ch": "请更换问题重新输入",
"time": 0, "stop": True}
),
)
start = time.perf_counter()
print(time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime()), "context : " + context)
context = context.strip()
f = lambda x='ddd': sum(
[1 if u'\u4e00' <= i <= u'\u9fff' else 0 for i in x]) > 0
flag_chs = f(context)
stop = False
if flag_chs:
if modelname == 'vicuna-7b':
result_en = getAnswerFromVicuna7b(context)
else:
result_en = getAnswerFromChatGLM6b(context)
stop = result_en.endswith("[stop]")
result_ch = result_en.replace("[stop]", "")
if result_ch == "":
result_ch = "思考中"
result_en = result_ch
else:
result_en, stop = sampling(context, maxlength)
result_ch = result_en
end = time.perf_counter()
print(time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime()), "result : " + result_ch)
return web.Response(
content_type="application/json",
text=json.dumps(
{"result_en": result_en, "result_ch": result_ch,
"time": end-start, "stop": stop}
),
)
async def codegen_stream_v2(request):
params = await request.json()
context = params["context"]
modelname = params["modelname"]
prompt = context["prompt"]
# filter
if filter_context(prompt):
return web.Response(
content_type="application/json",
text=json.dumps(
{"response": "请更换问题重新输入",
"history": [],
"status": 403,
"time": 0,
"stop": True}
),
)
start = time.perf_counter()
print(time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime()), "request : " + prompt)
stop = False
if modelname == 'vicuna-7b':
result = getAnswerFromVicuna7b_v2(context)
elif modelname == 'Llama-7b':
result = getAnswerFromLLaMA_v2(context)
elif modelname == 'Qwen-7b':
result = getAnswerFromQwen7b_v2(context)
elif modelname == 'Agent-6b':
result = getAnswerFromAgent6b_v2(context)
else:
result = getAnswerFromChatGLM6b_v2(context)
stop = result["response"] .endswith("[stop]")
if result["response"] == "":
result["response"] = "思考中"
if stop:
result["response"] = result["response"].replace("[stop]", "")
if "以如下题目写一篇文章" in prompt :
result["response"] = result["response"] + "[gitclone.com为您服务]"
end = time.perf_counter()
result["time"] = end-start
result["stop"] = stop
print(time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime()), "result : " + result["response"])
return web.Response(
content_type="application/json",
text=json.dumps(result),
)