-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel_server.py
341 lines (293 loc) · 11.7 KB
/
model_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
from fastapi import FastAPI, Request, HTTPException
from starlette.responses import JSONResponse
import os
import asyncio
from inference.api_models.request_models import (
GenericModelRequest,
SentenceTransformerModelRequest,
OpenCLIPModelRequest,
TimmModelRequest,
DownloadCustomModelRequest,
)
from inference.api_models.response_models import (
GenericMessageResponse,
LoadedModelResponse,
RepositoryModelResponse,
)
from inference.triton_open_clip.clip_model import (
TritonCLIPModelClient,
)
from inference.triton_sentence_transformers.sentence_transformer_model import (
TritonSentenceTransformersModelClient,
)
from inference.triton_timm.timm_model import TritonTimmModelClient
from inference.model_cache import LRUModelCache
from inference.common import get_model_name, delete_model_from_repo
from inference.custom_model_utils import (
download_custom_open_clip_model,
download_custom_sentence_transformers_model,
download_custom_timm_model,
)
import tritonclient.grpc as grpcclient
from threading import Lock
from typing import Union, Literal
TRITON_GRPC_URL = "localhost:8001"
TRITON_CLIENT = grpcclient.InferenceServerClient(url=TRITON_GRPC_URL, verbose=False)
TRITON_MODEL_REPOSITORY_PATH = "model_repository"
CUSTOM_MODEL_DIR = "custom_model_files"
os.makedirs(CUSTOM_MODEL_DIR, exist_ok=True)
app = FastAPI()
# Model cache and lock
MODEL_CACHE = LRUModelCache(capacity=5)
MODEL_CACHE_LOCK = Lock()
def get_model_creation_client(
model_name: str,
pretrained: Union[str, None],
model_library: Literal["open_clip", "sentence_transformers", "timm"],
) -> Union[
TritonCLIPModelClient,
TritonTimmModelClient,
TritonSentenceTransformersModelClient,
None,
]:
nice_model_name = get_model_name(model_name, pretrained)
cache_key = (model_name, pretrained)
if (
TRITON_CLIENT.is_model_ready(nice_model_name)
and model_library == "sentence_transformers"
):
# if the model is ready, create a client for it
# the model name is used directly for sentence transformers
client = TritonSentenceTransformersModelClient(
triton_grpc_url=TRITON_GRPC_URL,
model=model_name,
triton_model_repository_path=TRITON_MODEL_REPOSITORY_PATH,
custom_model_dir=CUSTOM_MODEL_DIR,
)
with MODEL_CACHE_LOCK:
MODEL_CACHE.put(cache_key, client)
return client
if TRITON_CLIENT.is_model_ready(nice_model_name) and model_library == "timm":
# if the model is ready, create a client for it
# the model name is used directly for timm models
client = TritonTimmModelClient(
triton_grpc_url=TRITON_GRPC_URL,
model=model_name,
pretrained=pretrained,
triton_model_repository_path=TRITON_MODEL_REPOSITORY_PATH,
custom_model_dir=CUSTOM_MODEL_DIR,
)
with MODEL_CACHE_LOCK:
MODEL_CACHE.put(cache_key, client)
return client
if (
TRITON_CLIENT.is_model_ready(nice_model_name + "_text_encoder")
and TRITON_CLIENT.is_model_ready(nice_model_name + "_image_encoder")
and model_library == "open_clip"
):
# if the model is ready, create a client for it
# the model name must be split into text and image encoders for CLIP
client = TritonCLIPModelClient(
triton_grpc_url=TRITON_GRPC_URL,
model=model_name,
pretrained=pretrained,
triton_model_repository_path=TRITON_MODEL_REPOSITORY_PATH,
custom_model_dir=CUSTOM_MODEL_DIR,
)
with MODEL_CACHE_LOCK:
MODEL_CACHE.put(cache_key, client)
return client
return None
@app.middleware("http")
async def route_timeout_middleware(request: Request, call_next):
if request.url.path in {
"/load_clip_model",
"/load_sentence_transformer_model",
"/load_timm_model",
}:
try:
return await asyncio.wait_for(call_next(request), timeout=600)
except asyncio.TimeoutError:
return JSONResponse({"error": "Request timed out"}, status_code=504)
else:
return await call_next(request)
@app.get("/health")
async def health() -> GenericMessageResponse:
return {"message": "The model server is running."}
@app.post("/load_clip_model")
async def load_clip_model(request: OpenCLIPModelRequest) -> GenericMessageResponse:
model_name = request.name
pretrained = request.pretrained
cache_key = (model_name, pretrained)
client = get_model_creation_client(
model_name, pretrained, model_library="open_clip"
)
if client is None:
client = TritonCLIPModelClient(
triton_grpc_url=TRITON_GRPC_URL,
model=model_name,
pretrained=pretrained,
triton_model_repository_path=TRITON_MODEL_REPOSITORY_PATH,
custom_model_dir=CUSTOM_MODEL_DIR,
)
with MODEL_CACHE_LOCK:
MODEL_CACHE.put(cache_key, client)
return {
"message": f"Model {model_name} with checkpoint {pretrained} loaded successfully."
}
else:
client.load()
return {
"message": f"Model {model_name} with checkpoint {pretrained} is already loaded."
}
@app.post("/load_sentence_transformer_model")
async def load_sentence_transformer_model(
request: SentenceTransformerModelRequest,
) -> GenericMessageResponse:
model_name = request.name
cache_key = (model_name, None)
client = get_model_creation_client(
model_name, None, model_library="sentence_transformers"
)
if client is None:
client = TritonSentenceTransformersModelClient(
triton_grpc_url=TRITON_GRPC_URL,
model=model_name,
triton_model_repository_path=TRITON_MODEL_REPOSITORY_PATH,
custom_model_dir=CUSTOM_MODEL_DIR,
)
with MODEL_CACHE_LOCK:
MODEL_CACHE.put(cache_key, client)
return {"message": f"Model {model_name} loaded successfully."}
else:
client.load()
return {"message": f"Model {model_name} is already loaded."}
@app.post("/load_timm_model")
async def load_timm_model(request: TimmModelRequest) -> GenericMessageResponse:
model_name = request.name
pretrained = request.pretrained
cache_key = (model_name, pretrained)
client = get_model_creation_client(model_name, pretrained, model_library="timm")
if client is None:
client = TritonTimmModelClient(
triton_grpc_url=TRITON_GRPC_URL,
model=model_name,
pretrained=pretrained,
triton_model_repository_path=TRITON_MODEL_REPOSITORY_PATH,
custom_model_dir=CUSTOM_MODEL_DIR,
)
with MODEL_CACHE_LOCK:
MODEL_CACHE.put(cache_key, client)
return {"message": f"Model {model_name} loaded successfully."}
else:
client.load()
return {"message": f"Model {model_name} is already loaded."}
@app.post("/unload_model")
async def unload_model(request: GenericModelRequest) -> GenericMessageResponse:
model_name = request.name
pretrained = request.pretrained
cache_key = (model_name, pretrained)
original_cache_size = len(MODEL_CACHE)
with MODEL_CACHE_LOCK:
MODEL_CACHE.remove(cache_key)
if len(MODEL_CACHE) < original_cache_size:
return {
"message": f"Model {model_name} with checkpoint {pretrained} unloaded successfully."
}
else:
return {
"message": f"Model {model_name} with checkpoint {pretrained} is not loaded."
}
@app.delete("/delete_model")
async def delete_model(request: GenericModelRequest) -> GenericMessageResponse:
model_name = request.name
pretrained = request.pretrained
cache_key = (model_name, pretrained)
with MODEL_CACHE_LOCK:
client = MODEL_CACHE.get(cache_key)
if client is not None:
client = MODEL_CACHE.remove(cache_key)
delete_model_from_repo(model_name, pretrained, TRITON_MODEL_REPOSITORY_PATH)
return {
"message": f"Model {model_name} with checkpoint {pretrained} deleted successfully."
}
@app.get("/loaded_models")
async def loaded_models() -> LoadedModelResponse:
model_repo_information = TRITON_CLIENT.get_model_repository_index(as_json=True)
loaded_models = []
for model in model_repo_information.get("models", []):
if "state" in model and model["state"] == "UNAVAILABLE":
continue
loaded_models.append(model["name"])
return {"models": loaded_models}
@app.get("/repository_models")
async def repository_models() -> RepositoryModelResponse:
model_repo_information = TRITON_CLIENT.get_model_repository_index(as_json=True)
repository_models = []
for model in model_repo_information.get("models", []):
model_data = {"name": model["name"]}
if "state" in model:
model_data["state"] = model["state"]
repository_models.append(model_data)
return {"models": repository_models}
@app.post("/download_custom_model")
async def download_custom_model(
request: DownloadCustomModelRequest,
) -> GenericMessageResponse:
model_library = request.library
model_name = request.pretrained_name
model_url = request.safetensors_url
if model_library == "open_clip":
try:
download_custom_open_clip_model(
CUSTOM_MODEL_DIR,
model_name,
model_url,
mode=request.mode,
mean=request.mean,
std=request.std,
interpolation=request.interpolation,
resize_mode=request.resize_mode,
)
except ValueError as e:
raise HTTPException(
status_code=400,
detail=f"Error in downloading custom model, original error: {str(e)}",
)
if model_library == "sentence_transformers":
try:
download_custom_sentence_transformers_model(
CUSTOM_MODEL_DIR,
model_name,
model_url,
config_json_url=request.config_json_url,
tokenizer_json_url=request.tokenizer_json_url,
tokenizer_config_json_url=request.tokenizer_config_json_url,
vocab_txt_url=request.vocab_txt_url,
special_tokens_map_json_url=request.special_tokens_map_json_url,
pooling_config_json_url=request.pooling_config_json_url,
sentence_bert_config_json_url=request.sentence_bert_config_json_url,
modules_json_url=request.modules_json_url,
)
except ValueError as e:
raise HTTPException(
status_code=400,
detail=f"Error in downloading custom model, original error: {str(e)}",
)
if model_library == "timm":
try:
download_custom_timm_model(
CUSTOM_MODEL_DIR,
model_name,
model_url,
num_classes=request.num_classes,
)
except ValueError as e:
raise HTTPException(
status_code=400,
detail=f"Error in downloading custom model, original error: {str(e)}",
)
return {"message": f"Custom model {model_name} downloaded successfully."}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8687)