diff --git a/object_detection/object_detection/Detectors/RetinaNet.py b/object_detection/object_detection/Detectors/RetinaNet.py index 1005034..c2bf844 100755 --- a/object_detection/object_detection/Detectors/RetinaNet.py +++ b/object_detection/object_detection/Detectors/RetinaNet.py @@ -23,7 +23,7 @@ class RetinaNet(DetectorBase): - def __init(self): + def __init__(self): super.__init__() def build_model(self, model_dir_path, weight_file_name): diff --git a/object_detection/object_detection/Detectors/YOLOv8.py b/object_detection/object_detection/Detectors/YOLOv8.py index 9bb1918..41ebb9f 100755 --- a/object_detection/object_detection/Detectors/YOLOv8.py +++ b/object_detection/object_detection/Detectors/YOLOv8.py @@ -21,9 +21,9 @@ class YOLOv8(DetectorBase): - def __init__(self, conf_threshold=0.7): + def __init__(self): + super().__init__() - self.conf_threshold = conf_threshold def build_model(self, model_dir_path, weight_file_name): try: @@ -53,7 +53,7 @@ def get_predictions(self, cv_image): boxes = [] # Perform object detection on image - result = self.model.predict(self.frame, conf=self.conf_threshold, verbose=False) + result = self.model.predict(self.frame, verbose=False) # Perform object detection on image row = result[0].boxes.cpu() for box in row: diff --git a/object_detection/object_detection/ObjectDetection.py b/object_detection/object_detection/ObjectDetection.py index b2ff1e0..549954b 100644 --- a/object_detection/object_detection/ObjectDetection.py +++ b/object_detection/object_detection/ObjectDetection.py @@ -107,19 +107,23 @@ def detection_cb(self, img_msg): print("Image input from topic: {} is empty".format(self.input_img_topic)) else: for prediction in predictions: - x1, y1, x2, y2 = map(int, prediction['box']) + confidence = prediction['confidence'] - # Draw the bounding box - cv_image = cv2.rectangle(cv_image, (x1, y1), (x2, y2), (0, 255, 0), 1) + # Check if the confidence is above the threshold + if confidence >= self.confidence_threshold: + x1, y1, x2, y2 = map(int, prediction['box']) - # Show names of classes on the output image - class_id = int(prediction['class_id']) - class_name = self.detector.class_list[class_id] - label = f"{class_name}: {prediction['confidence']:.2f}" + # Draw the bounding box + cv_image = cv2.rectangle(cv_image, (x1, y1), (x2, y2), (0, 255, 0), 1) - cv_image = cv2.putText(cv_image, label, (x1, y1 - 5), - cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) + # Show names of classes on the output image + class_id = int(prediction['class_id']) + class_name = self.detector.class_list[class_id] + label = f"{class_name} : {confidence:.2f}" + cv_image = cv2.putText(cv_image, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) + + # Publish the modified image output = self.bridge.cv2_to_imgmsg(cv_image, "bgr8") self.img_pub.publish(output)