Skip to content

Commit c2e96ca

Browse files
authored
feat: configurable model precision (#41)
* feat: configurable model precision * fix
1 parent 6ad0f08 commit c2e96ca

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

main.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -82,24 +82,29 @@ class Frame:
8282

8383

8484
class DetectionModel:
85-
def __init__(self, model: str, model_dir: str, export_dir: str, device: str):
85+
def __init__(self, model: str, model_precision: str, model_dir: str, model_export_dir: str, device: str):
8686
logging.info(f"Cuda is available: {torch.cuda.is_available()}")
8787
self.track_history = defaultdict(lambda: [])
8888

8989
try:
9090
is_gpu = device != "cpu"
9191
model_path = Path(model_dir) / model
92-
exported_model_path = Path(export_dir) / model_path.with_suffix('.engine').name
92+
exported_model_path = Path(model_export_dir) / model_path.with_suffix(f'.{model_precision}.engine').name
9393

9494
if is_gpu and not Path(exported_model_path).exists():
9595
logging.info(f"Exporting model for GPU usage: {exported_model_path}")
9696
temp_export_path = YOLO(model_path).export(
97-
format="engine", device=device, half=True
97+
format="engine",
98+
device=device,
99+
half=(model_precision == "fp16"),
100+
int8=(model_precision == "int8"),
98101
)
99102
os.makedirs(os.path.dirname(exported_model_path), exist_ok=True)
100103
shutil.move(temp_export_path, exported_model_path)
101104

102-
self.model = YOLO(exported_model_path if is_gpu else model_path)
105+
self.model = YOLO(
106+
model=exported_model_path if is_gpu else model_path,
107+
)
103108
logging.info(f"Successfully loaded model from {exported_model_path if is_gpu else model_path}")
104109
except Exception as e:
105110
logging.error(f"Failed to load YOLO model: {e}")
@@ -499,7 +504,7 @@ async def terminate(self):
499504

500505
class DetectionApp:
501506
def __init__(self, args):
502-
self.model = DetectionModel(args.model, args.model_dir, args.export_dir, args.device)
507+
self.model = DetectionModel(args.model, args.model_precision, args.model_dir, args.model_export_dir, args.device)
503508
self.source = FrameSource(args.source, args.frame_interval)
504509
self.args = args
505510
self.stop_processing: bool = False
@@ -567,7 +572,7 @@ def parse_args():
567572
help="Path to the directory containing the model weights file (default: './weights').",
568573
)
569574
parser.add_argument(
570-
"--export-dir",
575+
"--model-export-dir",
571576
type=str,
572577
default="./weights-optimized",
573578
help="Path to export the optimized model engine file (default: './weights-optimized'). Used for GPU acceleration.",
@@ -578,6 +583,12 @@ def parse_args():
578583
default="yolo11n.pt",
579584
help="Path to the model weights file (default: 'yolo11n.pt'). Model will be downloaded if not found.",
580585
)
586+
parser.add_argument(
587+
"--model-precision",
588+
type=str,
589+
default="fp16",
590+
help="Model precision for inference (default: 'fp16'). Options: 'fp32', 'fp16', 'int8'.",
591+
)
581592
parser.add_argument(
582593
"--frame-interval",
583594
type=int,

0 commit comments

Comments
 (0)