-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from aagustinconti/oop
Oop
- Loading branch information
Showing
7 changed files
with
870 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
|
||
|
||
class Count(): | ||
def __init__(self, ds_output, roi, names, count_out_classes,counted) -> None: | ||
|
||
self.count_out_classes = count_out_classes | ||
self.counted = counted | ||
|
||
for detection in ds_output: | ||
|
||
# Get variables | ||
ds_cpoint = detection[0] | ||
ds_id = detection[1] | ||
ds_class = detection[2] | ||
|
||
# To check if the ds_cpoint is into the roi | ||
is_into_roi = (roi[0] < ds_cpoint[0] < roi[2]) and ( | ||
roi[1] < ds_cpoint[1] < roi[3]) | ||
|
||
# If is into the roi | ||
if is_into_roi: | ||
|
||
# fill the empty vector | ||
if len(self.counted) == 0: | ||
self.counted.append([ds_id, ds_class]) | ||
|
||
# get the classes detected | ||
self.count_out_classes = dict.fromkeys( | ||
[elem[1] for elem in self.counted], 0) | ||
|
||
# count per class | ||
for elem in self.counted: | ||
self.count_out_classes[elem[1]] += 1 | ||
|
||
else: | ||
# if the id is not in the list | ||
if (ds_id not in [elem[0] for elem in self.counted]): | ||
# count object | ||
self.counted.append([ds_id, ds_class]) | ||
|
||
# get the classes detected | ||
self.count_out_classes = dict.fromkeys( | ||
[elem[1] for elem in self.counted], 0) | ||
|
||
# count per class | ||
for elem in self.counted: | ||
self.count_out_classes[elem[1]] += 1 | ||
|
||
self.counter_text = [[key, names[key], self.count_out_classes[key]] | ||
for key in self.count_out_classes.keys()] | ||
|
||
def __str__(self) -> str: | ||
output_text_counting = f""" | ||
COUNTING:\n | ||
Classes Detected: {self.count_out_classes} | ||
Counter output: {self.counter_text} | ||
""" | ||
return output_text_counting |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# https://github.com/dongdv95/yolov5/blob/master/Yolov5_DeepSort_Pytorch/track.py | ||
|
||
# Basics | ||
import cv2 | ||
import numpy as np | ||
import time | ||
|
||
|
||
class DeepSortTrack(): | ||
|
||
def __init__(self, coords, classes_detected, deepsort, frame, show_img, ds_color, names): | ||
|
||
self.ds_out_frame = frame | ||
self.ds_delta_time = time.time() | ||
self.ds_out_tracking = [] | ||
|
||
try: | ||
|
||
xywhs = self.xyxy2xywh(np.array(coords)) | ||
confs = np.array([[elem[2]] for elem in classes_detected]) | ||
clss = np.array([[elem[1]] for elem in classes_detected]) | ||
|
||
if coords != []: | ||
|
||
# pass detections to deepsort | ||
start_time = time.time() | ||
outputs = list(deepsort.update( | ||
xywhs, confs, clss, self.ds_out_frame)) | ||
end_time = time.time() | ||
|
||
self.ds_delta_time = end_time - start_time | ||
|
||
# draw boxes for visualization | ||
if len(outputs) > 0: | ||
for j, (output, conf) in enumerate(zip(outputs, confs)): | ||
ds_cpoint = tuple(self.xyxy2cxcy(output[0:4])) | ||
id = output[4] | ||
cls = output[5] | ||
|
||
self.ds_out_tracking.append([ds_cpoint, id, cls]) | ||
|
||
if show_img: | ||
cv2.circle( | ||
self.ds_out_frame, (ds_cpoint[0], ds_cpoint[1]), radius=0, color=ds_color, thickness=3) | ||
cv2.putText(self.ds_out_frame, f"{names[cls]}: {id}", (ds_cpoint[0]-10, ds_cpoint[1]-7), cv2.FONT_HERSHEY_SIMPLEX, | ||
0.5, ds_color, 1) | ||
|
||
else: | ||
start_time = time.time() | ||
deepsort.increment_ages() | ||
ds_output = [] | ||
end_time = time.time() | ||
|
||
self.ds_delta_time = end_time - start_time | ||
|
||
except Exception as err: | ||
raise ImportError( | ||
'Error while trying instantiate the tracking object. Please check that.') | ||
|
||
|
||
def xyxy2xywh(self, x): | ||
""" | ||
WHAT IT DOES: | ||
- Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right | ||
- xywhs is making negative h because some ymin and ymax are inverted or they have the same dimention, | ||
so the resize of the frame has an error when the method _resize tries to resize the bboxes. | ||
This function solve this problem. | ||
INPUTS: | ||
x = [xmin,ymin,xmax,ymax] -> List of coordinates of a bounding box. | ||
OUTPUTS: | ||
y = [xleft,ytop,width,height] -> List of | ||
""" | ||
|
||
y = np.copy(x) | ||
|
||
for i in range(len(x)): | ||
if x[i][3] <= x[i][1]: | ||
x[i][3] = y[i][1] + 1 | ||
x[i][1] = y[i][3] | ||
|
||
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center | ||
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center | ||
y[:, 2] = x[:, 2] - x[:, 0] # width | ||
y[:, 3] = x[:, 3] - x[:, 1] # height | ||
|
||
return y | ||
|
||
def xyxy2cxcy(self, x): | ||
"""" | ||
WHAT IT DOES: | ||
Convert nx4 boxes from [x1, y1, x2, y2] to [cx,xy] where xy1=top-left, xy2=bottom-right | ||
INPUTS: | ||
x = [x1, y1, x2, y2] -> xy1=top-left, xy2=bottom-right | ||
OUTPUTS: | ||
y = [cx,xy] -> Centroid of bounding box. | ||
""" | ||
y = np.copy(x[:2]) | ||
y[0] = ((x[2] - x[0])) / 2 + x[0] # x center | ||
y[1] = ((x[3] - x[1])) / 2 + x[1] # y center | ||
|
||
return y | ||
|
||
def __str__(self): | ||
|
||
output_text_tracking = f""" | ||
TRACKING:\n | ||
Classes Detected: {self.ds_out_tracking}\n | ||
Exec. time DeepSort model: {self.ds_delta_time} [s]\n\n | ||
""" | ||
|
||
return output_text_tracking |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# https://github.com/WongKinYiu/yolov7 | ||
|
||
# Pytorch | ||
import torch | ||
import torchvision | ||
from torchvision import transforms | ||
|
||
# Basics | ||
import cv2 | ||
import numpy as np | ||
import time | ||
|
||
# Utilities | ||
from utils.general import non_max_suppression | ||
from utils.datasets import letterbox | ||
from utils.plots import output_to_keypoint | ||
|
||
|
||
class YoloDetect(): | ||
|
||
def __init__(self, frame, model, device, names, show_img, color, img_sz, class_ids, conf_thres, iou_thres): | ||
|
||
try: | ||
# Frame | ||
|
||
self.det_out_frame = frame | ||
|
||
# initialize vectors | ||
self.det_out_coords = [] | ||
self.det_out_classes = [] | ||
|
||
img = cv2.cvtColor(self.det_out_frame, cv2.COLOR_BGR2RGB) | ||
|
||
# reshape the frames to the adecuate w and h | ||
img = letterbox(img, img_sz, stride=64, auto=True)[0] | ||
|
||
# get image data to use for rescaling | ||
img0 = img.copy() | ||
|
||
# transform the image to tensor and send the tensor of the image to the device | ||
img = transforms.ToTensor()(img) | ||
img = torch.tensor(np.array([img.numpy()])) | ||
img = img.to(device) | ||
img = img.half() | ||
|
||
# time to count fps | ||
start_time = time.time() | ||
|
||
# get the output of the model | ||
with torch.no_grad(): | ||
pred, _ = model(img) | ||
|
||
# calculate fps | ||
end_time = time.time() | ||
|
||
self.det_delta_time = end_time - start_time | ||
|
||
# remove the noise of the output (NMS: a technique to filter the predictions of object detectors.) | ||
pred = non_max_suppression(pred, conf_thres, iou_thres) | ||
|
||
# process the information of the filtered output and return the main characteristics [batch_id, class_id, x, y, w, h, conf] | ||
self.det_output = output_to_keypoint(pred) | ||
|
||
# for detection in frame | ||
for idx in range(self.det_output.shape[0]): | ||
|
||
# Separate by class id | ||
if (int(self.det_output[idx][1]) in class_ids) or (class_ids == []): | ||
|
||
# Rescale boxes (Rescale coords (xyxy) from img0 to frame) | ||
self.det_output[idx][2:6] = self.scale_coords_custom( | ||
img0.shape[0:2], self.det_output[idx][2:6], self.det_out_frame.shape).round() | ||
|
||
# generate coord to bounding boxes | ||
xmin, ymin = (self.det_output[idx, 2]-self.det_output[idx, 4] / | ||
2), (self.det_output[idx, 3]-self.det_output[idx, 5]/2) | ||
xmax, ymax = (self.det_output[idx, 2]+self.det_output[idx, 4] / | ||
2), (self.det_output[idx, 3]+self.det_output[idx, 5]/2) | ||
|
||
# xyxy | ||
coord_bb = [xmin, ymin, xmax, ymax] | ||
|
||
# [class id, class name, confidence] | ||
class_detected = [names[int(self.det_output[idx][1])], int( | ||
self.det_output[idx][1]), round(self.det_output[idx][6], 2)] | ||
|
||
# fill the output list | ||
self.det_out_coords.append(coord_bb) | ||
self.det_out_classes.append(class_detected) | ||
|
||
# draw bounding boxes, classnames and confidence | ||
if show_img: | ||
self.draw_bbox(self.det_out_frame, coord_bb, color, | ||
class_detected[0], class_detected[2]) | ||
|
||
except Exception as err: | ||
raise ImportError( | ||
'Error while trying instantiate the detection object. Please check that.') | ||
|
||
|
||
|
||
def scale_coords_custom(self, img1_shape, coords, img0_shape): | ||
|
||
gain = min(img1_shape[0] / img0_shape[0], | ||
img1_shape[1] / img0_shape[1]) # gain = old / new | ||
pad = (img1_shape[1] - img0_shape[1] * gain) / \ | ||
2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding | ||
|
||
coords[0] -= pad[0] # x padding | ||
coords[2] -= pad[0] # x padding | ||
coords[1] -= pad[1] # y padding | ||
coords[3] -= pad[1] # y padding | ||
coords[:] /= gain | ||
|
||
return coords | ||
|
||
def draw_bbox(self, frame, coords, color, names, confidence): | ||
|
||
# draw bounding box | ||
frame = cv2.rectangle( | ||
frame, | ||
(int(coords[0]), int(coords[1])), | ||
(int(coords[2]), int(coords[3])), | ||
color=color, | ||
thickness=1, | ||
lineType=cv2.LINE_AA | ||
) | ||
|
||
# write confidence and class names | ||
cv2.putText(frame, f"{names}: {confidence}", (int(coords[0]), int(coords[1])-5), cv2.FONT_HERSHEY_SIMPLEX, | ||
0.5, color, 1) | ||
|
||
return True | ||
|
||
def __str__(self): | ||
|
||
output_text_detection = f""" | ||
DETECTION:\n | ||
BBxes coords: {self.det_out_coords}\n | ||
Classes Detected: {self.det_out_classes}\n | ||
Exec. time YOLOv7x model: {self.det_delta_time} [s]\n\n | ||
""" | ||
|
||
return output_text_detection |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from yolov7_sort_count_oop import YoloSortCount | ||
import time | ||
import cv2 | ||
|
||
# Test | ||
test = YoloSortCount() | ||
|
||
# Source | ||
test.video_path = 0#"https://www.youtube.com/watch?v=qP1y7Tdab7Y" | "http://IP/hls/stream_src.m3u8" | 0 | "img_bank/cows_for_sale.mp4" | ||
|
||
test.max_fps = 1000 #Max 1000 | ||
test.max_width = 720 | ||
|
||
# Show results | ||
test.show_img = True | ||
test.hold_img = False | ||
|
||
test.auto_load_roi = True | ||
|
||
test.ends_in_sec = 10 | ||
|
||
# Debug | ||
test.show_configs = False | ||
test.show_detection = False | ||
test.show_tracking = False | ||
test.show_count = False | ||
|
||
# Detection model | ||
test.class_ids = [0] | ||
test.conf_thres = 0.5 | ||
|
||
# Frame | ||
test.inv_h_frame = False | ||
|
||
# Save | ||
test.save_loc = "results/test_test" | ||
test.save_vid = True | ||
|
||
# Run | ||
test.run() |
Oops, something went wrong.