Skip to content

Commit

Permalink
Simple documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
aagustinconti committed Nov 25, 2022
1 parent 513b194 commit a2ab2bb
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 25 deletions.
2 changes: 1 addition & 1 deletion count_oop.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@


class Count():
def __init__(self, ds_output, roi, names, count_out_classes,counted) -> None:
def __init__(self, ds_output, roi, names, count_out_classes, counted) -> None:

self.count_out_classes = count_out_classes
self.counted = counted
Expand Down
1 change: 0 additions & 1 deletion deepsort_oop.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def __init__(self, coords, classes_detected, deepsort, frame, show_img, ds_color
raise ImportError(
'Error while trying instantiate the tracking object. Please check that.')


def xyxy2xywh(self, x):
"""
WHAT IT DOES:
Expand Down
12 changes: 10 additions & 2 deletions detection_oop.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,12 @@ def __init__(self, frame, model, device, names, show_img, color, img_sz, class_i
raise ImportError(
'Error while trying instantiate the detection object. Please check that.')



def scale_coords_custom(self, img1_shape, coords, img0_shape):
"""
WHAT IT DOES:
- Scale cords obtained from the model to the original frame.
"""

gain = min(img1_shape[0] / img0_shape[0],
img1_shape[1] / img0_shape[1]) # gain = old / new
Expand All @@ -115,6 +118,11 @@ def scale_coords_custom(self, img1_shape, coords, img0_shape):
return coords

def draw_bbox(self, frame, coords, color, names, confidence):
"""
WHAT IT DOES:
- To draw the detection bboxes and text.
"""

# draw bounding box
frame = cv2.rectangle(
Expand Down
20 changes: 8 additions & 12 deletions test_oop.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from yolov7_sort_count_oop import YoloSortCount

#################### TEST ####################
#################### TEST ####################

# INSTANCIATE

test = YoloSortCount()



"""
###### AVAILABLE SOURCES ######
WebCamera: 0 ---> DEFAULT
Expand All @@ -17,7 +16,7 @@
Local video: "img_bank/cows_for_sale.mp4"
Local image: "img_bank/img.jpg" | "img_bank/img.png"
"""
test.video_path = "https://www.youtube.com/watch?v=emI8r2dfk6g"
test.video_path = 0


"""
Expand All @@ -28,9 +27,8 @@
- Invert the image (In case of your WebCamera is mirrored, IE)
"""
test.max_width = 720
test.max_fps = 25 #Max 1000
test.inv_h_frame = False

test.max_fps = 25 # Max 1000
test.inv_h_frame = True


"""
Expand All @@ -57,7 +55,6 @@
test.roi_color = (255, 255, 255)



"""
###### DETECTION MODEL ######
Expand All @@ -71,7 +68,7 @@
"""
test.model_path = 'pretrained_weights/yolov7.pt'
test.graphic_card = 0
test.class_ids = [0,2]
test.class_ids = [0]
test.img_sz = 640
test.color = (0, 255, 0)
test.conf_thres = 0.5
Expand Down Expand Up @@ -113,7 +110,6 @@
test.plot_bgr_color = (0, 0, 0)



"""
###### DEBUG TEXT ######
Expand All @@ -133,9 +129,9 @@
- Select if you want to save the results
- Select a location to save the results
"""
test.save_vid = True
test.save_loc = "results/messi"
test.save_vid = False
test.save_loc = "results/result"


# Run
test.run()
test.run()
64 changes: 55 additions & 9 deletions yolov7_sort_count_oop.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def __init__(self):
self.max_fps = 25
self.max_width = 720


# Pre defined
self.names = None

Expand Down Expand Up @@ -110,6 +109,11 @@ def __init__(self):
self.plot_bgr_color = (0, 0, 0)

def load_device(self, graphic_card):
"""
WHAT IT DOES:
- Load the torch device.
"""

try:
device = torch.device("cuda:"+str(graphic_card))
Expand All @@ -120,6 +124,12 @@ def load_device(self, graphic_card):
'Error while trying to use Graphic Card. Please check that it is available.')

def load_video_capture(self, video_path):
"""
WHAT IT DOES:
- Load the video capture.
- Resize the frames.
"""

try:

Expand Down Expand Up @@ -152,24 +162,29 @@ def load_video_capture(self, video_path):
orig_ratio = orig_h / orig_w

if orig_w > self.max_width:
logging.info('Capture has more width than max. width allowed. Rezising...')
logging.info(
'Capture has more width than max. width allowed. Rezising...')
cap = self.change_res(cap, self.max_width, orig_ratio)

orig_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
orig_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

logging.info(f'Capture has been resized to {(orig_w,orig_h)}')

orig_fps = cap.get(cv2.CAP_PROP_FPS) % 100


return cap, orig_w, orig_h, orig_fps

except Exception as err:
raise ImportError(
'Error while trying read the video. Please check that.')

def load_save_vid(self, save_loc, orig_w, orig_h):
"""
WHAT IT DOES:
- Load the result writer.
"""

try:

Expand All @@ -183,6 +198,11 @@ def load_save_vid(self, save_loc, orig_w, orig_h):
'Error while trying write the results. Please check that.')

def load_detection_model(self, model_path, device):
"""
WHAT IT DOES:
- Load the detection model and extracting the names of the classes and the model.
"""

try:
# Load all characteristics of YOLOv7x model
Expand All @@ -204,6 +224,12 @@ def load_detection_model(self, model_path, device):
'Error while trying to load the detection model. Please check that.')

def load_tracking_model(self, deep_sort_model, max_dist, max_iou_distance, max_age, n_init, nn_budget):
"""
WHAT IT DOES:
- To load the tracking model
"""

try:
deepsort = DeepSort(deep_sort_model,
max_dist=max_dist,
Expand All @@ -217,6 +243,12 @@ def load_tracking_model(self, deep_sort_model, max_dist, max_iou_distance, max_a
'Error while trying to load the tracking model. Please check that.')

def load_roi(self):
"""
WHAT IT DOES:
- To select the ROI, interactive way.
"""

cap_roi, _, _, _ = self.load_video_capture(self.video_path)
ret, select_roi_frame = cap_roi.read()

Expand All @@ -226,7 +258,7 @@ def load_roi(self):

ret, select_roi_frame = cap_roi.read()
frame_count_roi += 1

# To show image correctly (IE: web camera)
if self.inv_h_frame:
select_roi_frame = cv2.flip(select_roi_frame, 1)
Expand All @@ -244,6 +276,11 @@ def load_roi(self):
return roi

def plot_text(self, frame, frame_w, fps, plot_xmin, plot_ymin, padding, counter_text, plot_text_color, plot_bgr_color):
"""
WHAT IT DOES:
- Plot text into the output frame
"""

# Save the first xmin
aux_xmin = plot_xmin
Expand Down Expand Up @@ -288,15 +325,24 @@ def plot_text(self, frame, frame_w, fps, plot_xmin, plot_ymin, padding, counter_

return frame


def change_res(self, cap, max_width, orig_ratio):

cap.set(3,max_width)
"""
WHAT IT DOES:
- Change te resolution of the frame.
"""

cap.set(3, max_width)
cap.set(4, int(max_width * orig_ratio))

return cap

def run(self):
"""
WHAT IT DOES:
- Run the entire process of detection > tracking > count.
"""

# Debug
if self.show_configs:
Expand Down Expand Up @@ -401,7 +447,7 @@ def run(self):
lineType=cv2.LINE_AA
)

# draw fps and detections
# draw fps and detections
self.out_frame = self.plot_text(self.out_frame, self.orig_w, fps, self.plot_xmin, self.plot_ymin,
self.plot_padding, self.count.counter_text, self.plot_text_color, self.plot_bgr_color)

Expand Down

0 comments on commit a2ab2bb

Please sign in to comment.