1
+ import argparse
2
+ import subprocess
3
+ import time
4
+ import os
5
+ import requests
6
+ import sys
7
+ import json
8
+
9
+ def parse_args ():
10
+ parser = argparse .ArgumentParser ()
11
+ parser .add_argument ("--tp" , type = int , required = True , help = "Number of GPUs to use." )
12
+ parser .add_argument ("--model_dir" , type = str , required = True , help = "Directory of the model." )
13
+ return parser .parse_args ()
14
+
15
+ def start_server (tp , model_dir ):
16
+ cmd = [
17
+ "python" ,
18
+ "-m" , "lightllm.server.api_server" ,
19
+ "--tp" , str (tp ),
20
+ "--model_dir" , model_dir ,
21
+ "--data_type" , "fp16" ,
22
+ "--mode" , "triton_gqa_flashdecoding" ,
23
+ "--trust_remote_code" ,
24
+ "--tokenizer_mode" , "fast" ,
25
+ "--host" , "0.0.0.0" ,
26
+ "--port" , "8080"
27
+ ]
28
+ process = subprocess .Popen (cmd , stdout = sys .stdout , stderr = sys .stderr )
29
+ return process
30
+
31
+ def check_health ():
32
+ health_url = "http://localhost:8080/health"
33
+ try :
34
+ r = requests .get (health_url , timeout = 2 )
35
+ return r .status_code == 200
36
+ except Exception :
37
+ return False
38
+
39
+ def send_prompts (prompts , output_file ):
40
+ for prompt in prompts :
41
+ while not check_health ():
42
+ time .sleep (1 )
43
+
44
+ request_data = {
45
+ "inputs" : prompt ,
46
+ "parameters" : {
47
+ "max_new_tokens" : 1024 ,
48
+ "frequency_penalty" : 1 ,
49
+ "do_sample" : False
50
+ },
51
+ "multimodal_params" : {}
52
+ }
53
+
54
+ try :
55
+ r = requests .post (
56
+ "http://localhost:8080/generate" ,
57
+ json = request_data ,
58
+ timeout = 10
59
+ )
60
+ response_json = json .loads (r .text )
61
+ generated_text = response_json ["generated_text" ][0 ] if "generated_text" in response_json else "No generated_text."
62
+ except Exception as e :
63
+ generated_text = f"ERROR: { str (e )} "
64
+
65
+ with open (output_file , "a" , encoding = "utf-8" ) as f :
66
+ f .write (f"===== prompt: { prompt } =====\n " )
67
+ f .write (f"{ generated_text } \n \n " )
68
+
69
+ print (f"===================Ouput saved in { output_file } ===========================" )
70
+
71
+ def main ():
72
+ # args
73
+ args = parse_args ()
74
+ tp = args .tp
75
+ model_dir = args .model_dir
76
+
77
+ #output_file
78
+ output_file = "test_results.txt"
79
+
80
+ if os .path .exists (output_file ):
81
+ os .remove (output_file )
82
+
83
+ # start server
84
+ process = start_server (tp , model_dir )
85
+
86
+ # prompts
87
+ prompts = [
88
+ "What is the machine learning?" ,
89
+ "1+1等于几" ,
90
+ "What role does attention play in transformer architectures?" ,
91
+ "西红柿炒鸡蛋怎么做?" ,
92
+ "Describe the concept of overfitting and underfitting." ,
93
+ "CPU和GPU的区别是什么?" ,
94
+ "What is the role of a loss function in machine learning?" ,
95
+ ]
96
+
97
+ send_prompts (prompts , output_file )
98
+
99
+ # shutdown server
100
+ process .terminate ()
101
+ process .wait ()
102
+
103
+ if __name__ == "__main__" :
104
+ main ()
105
+
106
+ # python test_accuracy.py --tp 2 --model_dir /xx/xx
0 commit comments