-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathlaunch_triton_server.py
149 lines (137 loc) · 4.17 KB
/
launch_triton_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import argparse
import os
import subprocess
import sys
from pathlib import Path
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--world_size",
type=int,
default=1,
help="world size, only support tensor parallelism now",
)
parser.add_argument(
"--tritonserver",
type=str,
help="path to the tritonserver exe",
default="/opt/tritonserver/bin/tritonserver",
)
parser.add_argument(
"--grpc_port",
type=str,
help="tritonserver grpc port",
default="8001",
)
parser.add_argument(
"--http_port",
type=str,
help="tritonserver http port",
default="8000",
)
parser.add_argument(
"--metrics_port",
type=str,
help="tritonserver metrics port",
default="8002",
)
parser.add_argument(
"--force",
"-f",
action="store_true",
help="launch tritonserver regardless of other instances running",
)
parser.add_argument(
"--log", action="store_true", help="log triton server stats into log_file"
)
parser.add_argument(
"--log-file",
type=str,
help="path to triton log file",
default="triton_log.txt",
)
path = str(Path(__file__).parent.absolute()) + "/../all_models/gpt"
parser.add_argument("--model_repo", type=str, default=path)
parser.add_argument(
"--tensorrt_llm_model_name",
type=str,
help="Name(s) of the tensorrt_llm Triton model in the repo. Use comma to separate if multiple model names",
default="tensorrt_llm",
)
parser.add_argument(
"--multi-model",
action="store_true",
help="Enable support for multiple TRT-LLM models in the Triton model repository",
)
return parser.parse_args()
def get_cmd(
world_size,
tritonserver,
grpc_port,
http_port,
metrics_port,
model_repo,
log,
log_file,
tensorrt_llm_model_name,
):
cmd = ["mpirun", "--allow-run-as-root"]
for i in range(world_size):
rank_cmd = [
tritonserver,
f"--model-repository={model_repo}",
"--disable-auto-complete-config",
f"--backend-config=python,shm-region-prefix-name=prefix{i}_",
]
if log and (i == 0):
rank_cmd += ["--log-verbose=3", f"--log-file={log_file}"]
# If rank is not 0, skip loading of models other than `tensorrt_llm_model_name`
if i != 0:
rank_cmd += ["--model-control-mode=explicit"]
model_names = tensorrt_llm_model_name.split(",")
for name in model_names:
rank_cmd += [f"--load-model={name}"]
if i == 0:
rank_cmd += [
f"--grpc-port={grpc_port}",
f"--http-port={http_port}",
f"--metrics-port={metrics_port}",
]
else:
rank_cmd += [
"--allow-grpc=false",
"--allow-http=false",
"--allow-metrics=false",
]
cmd += ["-n", "1"] + rank_cmd + [":"]
return cmd
if __name__ == "__main__":
args = parse_arguments()
res = subprocess.run(
["pgrep", "-r", "R", "tritonserver"], capture_output=True, encoding="utf-8"
)
if res.stdout:
pids = res.stdout.replace("\n", " ").rstrip()
msg = f"tritonserver process(es) already found with PID(s): {pids}.\n\tUse `kill {pids}` to stop them."
if args.force:
print(msg, file=sys.stderr)
else:
raise RuntimeError(msg + " Or use --force.")
cmd = get_cmd(
int(args.world_size),
args.tritonserver,
args.grpc_port,
args.http_port,
args.metrics_port,
args.model_repo,
args.log,
args.log_file,
args.tensorrt_llm_model_name,
)
env = os.environ.copy()
if args.multi_model:
assert (
args.world_size == 1
), "World size must be 1 when using multi-model. Processes will be spawned automatically to run the multi-GPU models"
env["TRTLLM_ORCHESTRATOR"] = "1"
subprocess.Popen(cmd, env=env)