diff --git a/sahi_inference.py b/sahi_inference.py index f3542d7e..207eba72 100644 --- a/sahi_inference.py +++ b/sahi_inference.py @@ -220,7 +220,9 @@ def main(args): model_type='torchvision', model=model, confidence_threshold=args['threshold'], - device=args['device'] + device=args['device'], + category_mapping={str(i): CLASSES[i] for i in range(1, len(CLASSES))}, + # category_remapping={CLASSES[i]: i for i in range(1, len(CLASSES))} ) COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))