Skip to content

Commit ee1b712

Browse files
authored
feat: custom export dir
1 parent 1a2c2ac commit ee1b712

File tree

1 file changed

+24
-9
lines changed

1 file changed

+24
-9
lines changed

main.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,19 @@ class Frame:
6464

6565

6666
class DetectionModel:
67-
def __init__(self, model: str, device: str):
67+
def __init__(self, model: str, model_dir: str, export_dir: str, device: str):
6868
logging.info(f"Cuda is available: {torch.cuda.is_available()}")
6969
self.track_history = defaultdict(lambda: [])
7070
try:
71-
self.model = YOLO(model)
72-
if device != "cpu":
71+
model_path = f"{model_dir}/{model}"
72+
self.model = YOLO(model_path)
73+
exported_model_path = f"${export_dir}/${model}.engine"
74+
should_export = device != "cpu" and not Path(exported_model_path).exists()
75+
if should_export:
7376
self.model.export(format="engine", device=device, half=True)
74-
model = model.replace(".pt", ".engine")
75-
self.model = YOLO(model)
76-
logging.info(f"Loaded {model} model")
77+
self.model = YOLO(exported_model_path)
78+
model_path = exported_model_path
79+
logging.info(f"Loaded {model_path} model")
7780
except Exception as e:
7881
logging.error(f"Failed to load YOLO model: {e}")
7982
raise
@@ -462,7 +465,7 @@ async def terminate(self):
462465

463466
class DetectionApp:
464467
def __init__(self, args):
465-
self.model = DetectionModel(args.model, args.device)
468+
self.model = DetectionModel(args.model, args.model_dir, args.export_dir, args.device)
466469
self.source = FrameSource(args.source, args.frame_interval)
467470
self.args = args
468471
self.stop_processing: bool = False
@@ -523,11 +526,23 @@ def parse_args():
523526
default="0",
524527
help="Source for detection (default: '0'). Use '0' for the default webcam, an index (e.g., '1') for additional webcams, or specify a path to a video/image file or URL.",
525528
)
529+
parser.add_argument(
530+
"--model-dir",
531+
type=str,
532+
default="./weights",
533+
help="Path to the directory containing the model weights file (default: './weights').",
534+
)
535+
parser.add_argument(
536+
"--export-dir",
537+
type=str,
538+
default="./weights-optimized",
539+
help="Path to export the optimized model engine file (default: './weights-optimized'). Used for GPU acceleration.",
540+
)
526541
parser.add_argument(
527542
"--model",
528543
type=str,
529-
default="weights/yolo11n.pt",
530-
help="Path to model weights (.pt) file (default: 'weights/yolo11n.pt'). The model will be automatically downloaded if not found.",
544+
default="yolo11n.pt",
545+
help="Path to the model weights file (default: 'yolo11n.pt'). Model will be downloaded if not found.",
531546
)
532547
parser.add_argument(
533548
"--frame-interval",

0 commit comments

Comments
 (0)