diff --git a/cli.py b/cli.py new file mode 100644 index 0000000..66a55e4 --- /dev/null +++ b/cli.py @@ -0,0 +1,80 @@ +import argparse +from typing import Literal, TypedDict + + +class Args(TypedDict): + device: Literal["cpu", "cuda"] + backend: Literal["onnx", "llama_cpp"] + model_path: str + + batch_size: int + sessions: int + workers: int + + +def init_server_args() -> Args: + parser = argparse.ArgumentParser() + parser.add_argument( + "-d", + "--device", + dest="device", + action="store", + default="cpu", + choices=["cpu", "cuda"], + help="Device to use for inference. (cpu/cuda, default: cpu)" + ) + parser.add_argument( + "-b", + "--batch-size", + dest="batch_size", + action="store", + default="1", + type=int, + help="Batch size for inference. (default: 1)" + ) + parser.add_argument( + "-f", + "--backend", + dest="backend", + action="store", + default="onnx", + choices=["onnx", "llama_cpp"], + help="Inference backend to use. (onnx/llama_cpp, default: onnx)" + ) + parser.add_argument( + "-s", + "--sessions", + dest="sessions", + action="store", + default="1", + type=int, + help="Number of session instances for parallel processing. (default: 1)" + ) + parser.add_argument( + "-w", + "--workers", + dest="workers", + action="store", + default="1", + type=int, + help="Number of uvicorn worker process. (default: 1)" + ) + parser.add_argument( + "-m", + "--model-path", + dest="model_path", + action="store", + type=str, + help="Path to the model file. Required for model inference." + ) + + args = parser.parse_args() + + return Args( + device=args.device, + batch_size=args.batch_size, + backend=args.backend, + sessions=args.sessions, + workers=args.workers, + model_path=args.model_path + ) diff --git a/main.py b/main.py index 7bf4ab5..766836c 100644 --- a/main.py +++ b/main.py @@ -1,30 +1,14 @@ -from contextlib import asynccontextmanager from fastapi import Depends, FastAPI, HTTPException from schemas.embed import EmbedRequest, EmbedResponse from utils.text import preprocess, split_chunks from utils.embed import init_runtime +import cli from tqdm import tqdm runtime = None - -@asynccontextmanager -async def lifespan(app: FastAPI): - args = { - "device": "cpu", - "backend": "llama_cpp", - "batch_size": 4, - "max_workers": 2, - "model_path": "models/bge-m3-f16.gguf", - } - - runtime = init_runtime(**args) - yield - runtime.release() - - -app = FastAPI(lifespan=lifespan) +app = FastAPI() @app.get("/") @@ -66,3 +50,26 @@ async def embed( raise HTTPException( status_code=400, detail=f"Error has been occurred {e}" ) + + +if __name__ == "__main__": + import uvicorn + + args = cli.init_server_args() + + init_runtime( + model_path=args['model_path'], + batch_size=args['batch_size'], + max_workers=args['sessions'], + backend=args['backend'], + device=args['device'] + ) + + uvicorn.run( + "main:app", + host="0.0.0.0", + port=8000, + reload=False, + timeout_keep_alive=600, + workers=args['workers'] + )