diff --git a/count_oop.py b/count_oop.py index 49995eb..05c3237 100644 --- a/count_oop.py +++ b/count_oop.py @@ -1,7 +1,7 @@ class Count(): - def __init__(self, ds_output, roi, names, count_out_classes,counted) -> None: + def __init__(self, ds_output, roi, names, count_out_classes, counted) -> None: self.count_out_classes = count_out_classes self.counted = counted diff --git a/deepsort_oop.py b/deepsort_oop.py index 46b9e11..4795d3e 100644 --- a/deepsort_oop.py +++ b/deepsort_oop.py @@ -57,7 +57,6 @@ def __init__(self, coords, classes_detected, deepsort, frame, show_img, ds_color raise ImportError( 'Error while trying instantiate the tracking object. Please check that.') - def xyxy2xywh(self, x): """ WHAT IT DOES: diff --git a/detection_oop.py b/detection_oop.py index 3402552..0a3551e 100644 --- a/detection_oop.py +++ b/detection_oop.py @@ -97,9 +97,12 @@ def __init__(self, frame, model, device, names, show_img, color, img_sz, class_i raise ImportError( 'Error while trying instantiate the detection object. Please check that.') - - def scale_coords_custom(self, img1_shape, coords, img0_shape): + """ + WHAT IT DOES: + - Scale cords obtained from the model to the original frame. + + """ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new @@ -115,6 +118,11 @@ def scale_coords_custom(self, img1_shape, coords, img0_shape): return coords def draw_bbox(self, frame, coords, color, names, confidence): + """ + WHAT IT DOES: + - To draw the detection bboxes and text. + + """ # draw bounding box frame = cv2.rectangle( diff --git a/test_oop.py b/test_oop.py index 3c5f6e8..5d0507d 100644 --- a/test_oop.py +++ b/test_oop.py @@ -1,13 +1,12 @@ from yolov7_sort_count_oop import YoloSortCount -#################### TEST #################### +#################### TEST #################### # INSTANCIATE test = YoloSortCount() - """ ###### AVAILABLE SOURCES ###### WebCamera: 0 ---> DEFAULT @@ -17,7 +16,7 @@ Local video: "img_bank/cows_for_sale.mp4" Local image: "img_bank/img.jpg" | "img_bank/img.png" """ -test.video_path = "https://www.youtube.com/watch?v=emI8r2dfk6g" +test.video_path = 0 """ @@ -28,9 +27,8 @@ - Invert the image (In case of your WebCamera is mirrored, IE) """ test.max_width = 720 -test.max_fps = 25 #Max 1000 -test.inv_h_frame = False - +test.max_fps = 25 # Max 1000 +test.inv_h_frame = True """ @@ -57,7 +55,6 @@ test.roi_color = (255, 255, 255) - """ ###### DETECTION MODEL ###### @@ -71,7 +68,7 @@ """ test.model_path = 'pretrained_weights/yolov7.pt' test.graphic_card = 0 -test.class_ids = [0,2] +test.class_ids = [0] test.img_sz = 640 test.color = (0, 255, 0) test.conf_thres = 0.5 @@ -113,7 +110,6 @@ test.plot_bgr_color = (0, 0, 0) - """ ###### DEBUG TEXT ###### @@ -133,9 +129,9 @@ - Select if you want to save the results - Select a location to save the results """ -test.save_vid = True -test.save_loc = "results/messi" +test.save_vid = False +test.save_loc = "results/result" # Run -test.run() \ No newline at end of file +test.run() diff --git a/yolov7_sort_count_oop.py b/yolov7_sort_count_oop.py index be9dd46..3f567b1 100644 --- a/yolov7_sort_count_oop.py +++ b/yolov7_sort_count_oop.py @@ -70,7 +70,6 @@ def __init__(self): self.max_fps = 25 self.max_width = 720 - # Pre defined self.names = None @@ -110,6 +109,11 @@ def __init__(self): self.plot_bgr_color = (0, 0, 0) def load_device(self, graphic_card): + """ + WHAT IT DOES: + - Load the torch device. + + """ try: device = torch.device("cuda:"+str(graphic_card)) @@ -120,6 +124,12 @@ def load_device(self, graphic_card): 'Error while trying to use Graphic Card. Please check that it is available.') def load_video_capture(self, video_path): + """ + WHAT IT DOES: + - Load the video capture. + - Resize the frames. + + """ try: @@ -152,9 +162,10 @@ def load_video_capture(self, video_path): orig_ratio = orig_h / orig_w if orig_w > self.max_width: - logging.info('Capture has more width than max. width allowed. Rezising...') + logging.info( + 'Capture has more width than max. width allowed. Rezising...') cap = self.change_res(cap, self.max_width, orig_ratio) - + orig_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) orig_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) @@ -162,7 +173,6 @@ def load_video_capture(self, video_path): orig_fps = cap.get(cv2.CAP_PROP_FPS) % 100 - return cap, orig_w, orig_h, orig_fps except Exception as err: @@ -170,6 +180,11 @@ def load_video_capture(self, video_path): 'Error while trying read the video. Please check that.') def load_save_vid(self, save_loc, orig_w, orig_h): + """ + WHAT IT DOES: + - Load the result writer. + + """ try: @@ -183,6 +198,11 @@ def load_save_vid(self, save_loc, orig_w, orig_h): 'Error while trying write the results. Please check that.') def load_detection_model(self, model_path, device): + """ + WHAT IT DOES: + - Load the detection model and extracting the names of the classes and the model. + + """ try: # Load all characteristics of YOLOv7x model @@ -204,6 +224,12 @@ def load_detection_model(self, model_path, device): 'Error while trying to load the detection model. Please check that.') def load_tracking_model(self, deep_sort_model, max_dist, max_iou_distance, max_age, n_init, nn_budget): + """ + WHAT IT DOES: + - To load the tracking model + + """ + try: deepsort = DeepSort(deep_sort_model, max_dist=max_dist, @@ -217,6 +243,12 @@ def load_tracking_model(self, deep_sort_model, max_dist, max_iou_distance, max_a 'Error while trying to load the tracking model. Please check that.') def load_roi(self): + """ + WHAT IT DOES: + - To select the ROI, interactive way. + + """ + cap_roi, _, _, _ = self.load_video_capture(self.video_path) ret, select_roi_frame = cap_roi.read() @@ -226,7 +258,7 @@ def load_roi(self): ret, select_roi_frame = cap_roi.read() frame_count_roi += 1 - + # To show image correctly (IE: web camera) if self.inv_h_frame: select_roi_frame = cv2.flip(select_roi_frame, 1) @@ -244,6 +276,11 @@ def load_roi(self): return roi def plot_text(self, frame, frame_w, fps, plot_xmin, plot_ymin, padding, counter_text, plot_text_color, plot_bgr_color): + """ + WHAT IT DOES: + - Plot text into the output frame + + """ # Save the first xmin aux_xmin = plot_xmin @@ -288,15 +325,24 @@ def plot_text(self, frame, frame_w, fps, plot_xmin, plot_ymin, padding, counter_ return frame - def change_res(self, cap, max_width, orig_ratio): - - cap.set(3,max_width) + """ + WHAT IT DOES: + - Change te resolution of the frame. + + """ + + cap.set(3, max_width) cap.set(4, int(max_width * orig_ratio)) return cap def run(self): + """ + WHAT IT DOES: + - Run the entire process of detection > tracking > count. + + """ # Debug if self.show_configs: @@ -401,7 +447,7 @@ def run(self): lineType=cv2.LINE_AA ) - # draw fps and detections + # draw fps and detections self.out_frame = self.plot_text(self.out_frame, self.orig_w, fps, self.plot_xmin, self.plot_ymin, self.plot_padding, self.count.counter_text, self.plot_text_color, self.plot_bgr_color)