Skip to content

Commit e88b9a4

Browse files
author
baishihao
committed
add test_accuracy.py
1 parent ad29534 commit e88b9a4

File tree

1 file changed

+106
-0
lines changed

1 file changed

+106
-0
lines changed

test/test_accuracy.py

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import argparse
2+
import subprocess
3+
import time
4+
import os
5+
import requests
6+
import sys
7+
import json
8+
9+
def parse_args():
10+
parser = argparse.ArgumentParser()
11+
parser.add_argument("--tp", type=int, required=True, help="Number of GPUs to use.")
12+
parser.add_argument("--model_dir", type=str, required=True, help="Directory of the model.")
13+
return parser.parse_args()
14+
15+
def start_server(tp, model_dir):
16+
cmd = [
17+
"python",
18+
"-m", "lightllm.server.api_server",
19+
"--tp", str(tp),
20+
"--model_dir", model_dir,
21+
"--data_type", "fp16",
22+
"--mode", "triton_gqa_flashdecoding",
23+
"--trust_remote_code",
24+
"--tokenizer_mode", "fast",
25+
"--host", "0.0.0.0",
26+
"--port", "8080"
27+
]
28+
process = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stderr)
29+
return process
30+
31+
def check_health():
32+
health_url = "http://localhost:8080/health"
33+
try:
34+
r = requests.get(health_url, timeout=2)
35+
return r.status_code == 200
36+
except Exception:
37+
return False
38+
39+
def send_prompts(prompts, output_file):
40+
for prompt in prompts:
41+
while not check_health():
42+
time.sleep(1)
43+
44+
request_data = {
45+
"inputs": prompt,
46+
"parameters": {
47+
"max_new_tokens": 1024,
48+
"frequency_penalty": 1,
49+
"do_sample": False
50+
},
51+
"multimodal_params": {}
52+
}
53+
54+
try:
55+
r = requests.post(
56+
"http://localhost:8080/generate",
57+
json=request_data,
58+
timeout=10
59+
)
60+
response_json = json.loads(r.text)
61+
generated_text = response_json["generated_text"][0] if "generated_text" in response_json else "No generated_text."
62+
except Exception as e:
63+
generated_text = f"ERROR: {str(e)}"
64+
65+
with open(output_file, "a", encoding="utf-8") as f:
66+
f.write(f"===== prompt: {prompt} =====\n")
67+
f.write(f"{generated_text}\n\n")
68+
69+
print(f"===================Ouput saved in {output_file}===========================")
70+
71+
def main():
72+
# args
73+
args = parse_args()
74+
tp = args.tp
75+
model_dir = args.model_dir
76+
77+
#output_file
78+
output_file = "test_results.txt"
79+
80+
if os.path.exists(output_file):
81+
os.remove(output_file)
82+
83+
# start server
84+
process = start_server(tp, model_dir)
85+
86+
# prompts
87+
prompts = [
88+
"What is the machine learning?",
89+
"1+1等于几",
90+
"What role does attention play in transformer architectures?",
91+
"西红柿炒鸡蛋怎么做?",
92+
"Describe the concept of overfitting and underfitting.",
93+
"CPU和GPU的区别是什么?",
94+
"What is the role of a loss function in machine learning?",
95+
]
96+
97+
send_prompts(prompts, output_file)
98+
99+
# shutdown server
100+
process.terminate()
101+
process.wait()
102+
103+
if __name__ == "__main__":
104+
main()
105+
106+
# python test_accuracy.py --tp 2 --model_dir /xx/xx

0 commit comments

Comments
 (0)