Skip to content

Commit

Permalink
feat: main.py에서 argparse 분리
Browse files Browse the repository at this point in the history
  • Loading branch information
myeolinmalchi committed Jan 29, 2025
1 parent 293ca67 commit 532ce98
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 18 deletions.
80 changes: 80 additions & 0 deletions cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import argparse
from typing import Literal, TypedDict


class Args(TypedDict):
device: Literal["cpu", "cuda"]
backend: Literal["onnx", "llama_cpp"]
model_path: str

batch_size: int
sessions: int
workers: int


def init_server_args() -> Args:
parser = argparse.ArgumentParser()
parser.add_argument(
"-d",
"--device",
dest="device",
action="store",
default="cpu",
choices=["cpu", "cuda"],
help="Device to use for inference. (cpu/cuda, default: cpu)"
)
parser.add_argument(
"-b",
"--batch-size",
dest="batch_size",
action="store",
default="1",
type=int,
help="Batch size for inference. (default: 1)"
)
parser.add_argument(
"-f",
"--backend",
dest="backend",
action="store",
default="onnx",
choices=["onnx", "llama_cpp"],
help="Inference backend to use. (onnx/llama_cpp, default: onnx)"
)
parser.add_argument(
"-s",
"--sessions",
dest="sessions",
action="store",
default="1",
type=int,
help="Number of session instances for parallel processing. (default: 1)"
)
parser.add_argument(
"-w",
"--workers",
dest="workers",
action="store",
default="1",
type=int,
help="Number of uvicorn worker process. (default: 1)"
)
parser.add_argument(
"-m",
"--model-path",
dest="model_path",
action="store",
type=str,
help="Path to the model file. Required for model inference."
)

args = parser.parse_args()

return Args(
device=args.device,
batch_size=args.batch_size,
backend=args.backend,
sessions=args.sessions,
workers=args.workers,
model_path=args.model_path
)
43 changes: 25 additions & 18 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,14 @@
from contextlib import asynccontextmanager
from fastapi import Depends, FastAPI, HTTPException
from schemas.embed import EmbedRequest, EmbedResponse
from utils.text import preprocess, split_chunks
from utils.embed import init_runtime
import cli

from tqdm import tqdm

runtime = None


@asynccontextmanager
async def lifespan(app: FastAPI):
args = {
"device": "cpu",
"backend": "llama_cpp",
"batch_size": 4,
"max_workers": 2,
"model_path": "models/bge-m3-f16.gguf",
}

runtime = init_runtime(**args)
yield
runtime.release()


app = FastAPI(lifespan=lifespan)
app = FastAPI()


@app.get("/")
Expand Down Expand Up @@ -66,3 +50,26 @@ async def embed(
raise HTTPException(
status_code=400, detail=f"Error has been occurred {e}"
)


if __name__ == "__main__":
import uvicorn

args = cli.init_server_args()

init_runtime(
model_path=args['model_path'],
batch_size=args['batch_size'],
max_workers=args['sessions'],
backend=args['backend'],
device=args['device']
)

uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
reload=False,
timeout_keep_alive=600,
workers=args['workers']
)

0 comments on commit 532ce98

Please sign in to comment.