Skip to content

Commit

Permalink
目标跟踪
Browse files Browse the repository at this point in the history
  • Loading branch information
ahao-laptop committed Sep 19, 2022
1 parent d8f94c5 commit 8d7f619
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 62 deletions.
157 changes: 103 additions & 54 deletions controller/AiController.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def __init__(self, view,start:bool) -> None:
super().__init__(view)
self.start_rec(start)

self.view.track.stateChanged.connect(self.toggleTrack)
self.view.tracktime.valueChanged.connect(self.changeTrackTime)

self.view.xRate.valueChanged.connect(self.changeRate)
self.view.yRate.valueChanged.connect(self.changeRate)
self.view.airegion.valueChanged.connect(self.changeAiRegion)
Expand Down Expand Up @@ -175,6 +178,12 @@ def changeThreshold(self):
self.recoginizer.ai.NMS_THRESHOLD=nms
self.recoginizer.ai.CONFIDENCE_THRESHOLD=confidence

def toggleTrack(self):
self.recoginizer.track = self.view.track.isChecked()

def changeTrackTime(self):
self.recoginizer.tracktime = self.view.tracktime.value()

class Recognizer(QObject):
status = None
start = pyqtSignal(bool)
Expand Down Expand Up @@ -217,6 +226,7 @@ def rec(self, run):
self.run = run
while self.run:
start_time = time.time()
self.strat_time = start_time
self.recognize()
sleep(self.sleeptime)
end_time = time.time()
Expand All @@ -227,35 +237,43 @@ class AIRecoginzer(Recognizer):
xrate = 0.2
yrate = 0
box = None
track = True
tracktime = 0.5

def __init__(self,qt_comunicate=None, sleeptime=0.01,accuracy=0,provider='CPUExecutionProvider',):
super().__init__(qt_comunicate, sleeptime,accuracy)
try:
self.ai = ORDML(provider)
tracker_types = ['BOOSTING', 'MIL','KCF', 'TLD', 'MEDIANFLOW', 'GOTURN', 'MOSSE', 'CSRT']
tracker_type = tracker_types[4]
if tracker_type == 'GOTURN':
tracker = cv.legacy.TrackerBoosting_create()
elif tracker_type == 'MIL':
tracker = cv.legacy.TrackerMIL_create()
elif tracker_type == 'KCF':
tracker = cv.legacy.TrackerKCF_create()
elif tracker_type == 'TLD':
tracker = cv.legacy.TrackerTLD_create()
elif tracker_type == 'MEDIANFLOW':
tracker = cv.legacy.TrackerMedianFlow_create()
elif tracker_type == 'GOTURN':
tracker = cv.legacy.TrackerGOTURN_create()
elif tracker_type == 'MOSSE':
tracker = cv.legacy.TrackerMOSSE_create()
elif tracker_type == "CSRT":
tracker = cv.TrackerCSRT_create()
self.tracker = tracker
except:
traceback.print_exc()
self.ai = None
# self.ai = OVINO()
# self.ai = PTORCH()

try:
self.create_tracker()
except:
traceback.print_exc()
self.track = False

def create_tracker(self):
tracker_types = ['BOOSTING', 'MIL','KCF', 'TLD', 'MEDIANFLOW', 'GOTURN', 'MOSSE', 'CSRT']
tracker_type = tracker_types[4]
if tracker_type == 'GOTURN':
tracker = cv.legacy.TrackerBoosting_create()
elif tracker_type == 'MIL':
tracker = cv.legacy.TrackerMIL_create()
elif tracker_type == 'KCF':
tracker = cv.legacy.TrackerKCF_create()
elif tracker_type == 'TLD':
tracker = cv.legacy.TrackerTLD_create()
elif tracker_type == 'MEDIANFLOW':
tracker = cv.legacy.TrackerMedianFlow_create()
elif tracker_type == 'GOTURN':
tracker = cv.legacy.TrackerGOTURN_create()
elif tracker_type == 'MOSSE':
tracker = cv.legacy.TrackerMOSSE_create()
elif tracker_type == "CSRT":
tracker = cv.TrackerCSRT_create()
self.tracker = tracker

def recognize(self):
if not self.ai:
Expand All @@ -270,33 +288,52 @@ def recognize(self):

img = screenshot(box)
cvimg = screenshot_to_cv(img)
# cvimg = resize_img(cvimg,self.resize_rate)
# try track
if self.box and time.time()-self.detecttime<1:

# 目标跟踪
if self.track and self.box and time.time()-self.detecttime<self.tracktime:
trackimg = cv.cvtColor(cvimg,cv.COLOR_BGRA2BGR)
trackimg,_ = self.ai.circle_mask(trackimg)
trackimg,simg = self.ai.circle_mask(trackimg)
ok,bbox = self.tracker.update(trackimg)
if ok:
p1 = (int(bbox[0]),int(bbox[1]))
p2 = (int(bbox[0] + bbox[2]),int(bbox[1] + bbox[3]))
cv.rectangle(trackimg,p1,p2,(0,255,0),2,1)
il,it,iw,ih = bbox
ip1 = (int(il),int(it))
ip2 = (int(il+iw),int(it+ih))
op1 = (
int(il-(self.ai.fixregion/2)*iw),
int(it-(self.ai.fixregion/3)*ih)
)
op2 = (
int(il+iw+(self.ai.fixregion/2)*iw),
int(it+ih+(self.ai.fixregion/3)*ih)
)

if ok and self.pointer_in_range(center,bbox):
cv.rectangle(simg,ip1,ip2,(0,255,0),3,cv.LINE_AA)
cv.putText(simg, "TARGET", ip1, cv.FONT_HERSHEY_SIMPLEX, 1, (0,0,255), 3,cv.LINE_AA)
cv.rectangle(simg,op1,op2,(255,178,50),1,cv.LINE_AA)

boxcenter = (
round(bbox[0]+(bbox[2]/2)),
round(bbox[1]+(bbox[3]/5)),
)
cv.line(trackimg,center,boxcenter,(255,255,255))
cv.line(simg,center,boxcenter,(255,255,255),3,cv.LINE_AA)
cv.circle(simg,center,5,(0,255,0),3,cv.LINE_AA)

movex = int((boxcenter[0]-center[0])*self.xrate)
movey = int((boxcenter[1]-center[1])*self.yrate)
self.qt_comunicate.update.emit({"move":(movex,movey)})
qimg = cv_img_to_qimg(trackimg)

text = "{}fps".format(round(1/(time.time()-self.strat_time)))
cv.putText(simg,text,(5,50),cv.FONT_HERSHEY_SIMPLEX,1,(0,255,0),2)
qimg = cv_img_to_qimg(simg)
self.qt_comunicate.update.emit({"img":qimg}) if self.qt_comunicate else None
return
else:
self.qt_comunicate.update.emit({"move":(0,0)})
self.box = None

# ai目标检测
img,box = self.ai.detect(cvimg)

if box:
boxcenter,bbox = self.findclose(box,center,width,height)
if boxcenter:
Expand All @@ -306,37 +343,49 @@ def recognize(self):
self.box = bbox
self.detecttime = time.time()

# self.tracker = cv.legacy.TrackerMOSSE_create()
# self.tracker = cv.TrackerGOTURN_create()
self.tracker = cv.legacy.TrackerMedianFlow_create()
# tracker = cv.TrackerCSRT_create()
self.tracker.init(trackimg,bbox)

# cv.line(img,center,boxcenter,(255,255,255))

# movex = int((boxcenter[0]-center[0])*self.xrate)
# movey = int((boxcenter[1]-center[1])*self.yrate)
# self.qt_comunicate.update.emit({"move":(movex,movey)})

# filename = "E:/Video/ai/"+str(time.time())+".jpg"
# cv.imwrite(filename,img)
if self.track:
self.create_tracker()
self.tracker.init(trackimg,bbox)
else:
if self.pointer_in_range(center,bbox):
cv.line(img,center,boxcenter,(255,255,255),3,cv.LINE_AA)
movex = int((boxcenter[0]-center[0])*self.xrate)
movey = int((boxcenter[1]-center[1])*self.yrate)
self.qt_comunicate.update.emit({"move":(movex,movey)})
else:
self.qt_comunicate.update.emit({"move":(0,0)})
# filename = "E:/Video/ai/"+str(time.time())+".jpg"
# cv.imwrite(filename,img)
else:
self.box = None
self.qt_comunicate.update.emit({"move":(0,0)})

text = "{}fps".format(round(1/(time.time()-self.strat_time)))
cv.putText(img,text,(5,50),cv.FONT_HERSHEY_SIMPLEX,1,(0,255,0),2)
if box and bbox:
cv.rectangle(img,(bbox[0],bbox[1]),(bbox[0]+bbox[2],bbox[1]+bbox[3]),(0,255,0),3,cv.LINE_AA)
cv.circle(img,center,5,(0,255,0),3,cv.LINE_AA)
qimg = cv_img_to_qimg(img)
self.qt_comunicate.update.emit({"img":qimg}) if self.qt_comunicate else None


def pointer_in_range(self,point,xywh):
il,it,iw,ih = xywh
op1 = (
int(il-(self.ai.fixregion/2)*iw),
int(it-(self.ai.fixregion/3)*ih)
)
op2 = (
int(il+iw+(self.ai.fixregion/2)*iw),
int(it+ih+(self.ai.fixregion/3)*ih)
)
return point[0] > op1[0] and point[0]<op2[0] and point[1]>op1[1] and point[1]<op2[1]

def findclose(self,boxs,center,w,h):
count = 0
boxcenter = None
# 找到离中心最近的目标
closebox = None
for b in boxs:
# if ((b[2]*b[3])/(w*h)<0.01):
# print(b[2]*b[3],w*h,(b[2]*b[3])/(w*h))
# continue
# print(b[2]*b[3],w*h,(b[2]*b[3])/(w*h))
boxcenter_t = (
round(b[0]+(b[2]/2)),
round(b[1]+(b[3]/5)),
Expand All @@ -345,15 +394,15 @@ def findclose(self,boxs,center,w,h):
y = int(boxcenter_t[1]-center[1])
sumxy = x*x+y*y
if count == 0 or sumxy < count:
count == sumxy
count = sumxy
boxcenter = boxcenter_t
closebox = b
else:
continue
# 识别率不高的情况下,不锁准心外的目标
if center[0] > closebox[0]-(self.ai.fixregion)*closebox[2] and center[0] < closebox[0]+(self.ai.fixregion+1)*closebox[2] and center[1] > closebox[1]-(self.ai.fixregion)*closebox[3] and center[1] < closebox[1]+(self.ai.fixregion+1)*closebox[3]:
return boxcenter,closebox
return None,None
# if center[0] > closebox[0]-(self.ai.fixregion)*closebox[2] and center[0] < closebox[0]+(self.ai.fixregion+1)*closebox[2] and center[1] > closebox[1]-(self.ai.fixregion)*closebox[3] and center[1] < closebox[1]+(self.ai.fixregion+1)*closebox[3]:
# return boxcenter,closebox
return boxcenter,closebox

def changeEngine(self,engine):
if engine in self.ai.providers.keys():
Expand Down
17 changes: 11 additions & 6 deletions myutils/AIwithoutTorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,21 @@ def drawPred(self, frame, classId, conf, left, top, right, bottom):
cv2.rectangle(frame, (left, top), (right, bottom), self.RED, self.THICKNESS)
cv2.putText(frame, label, (left, top), self.FONT_FACE, self.FONT_SCALE, self.YELLOW, self.THICKNESS,cv2.LINE_AA)

xywh = [
il,it,iw,ih = (
left,
top,
right-left,
bottom-top,
]

p1 = (int(xywh[0]-xywh[2]*(self.fixregion)), int(xywh[1]-xywh[3]*(self.fixregion)))
p2 = (int(xywh[0]+xywh[2]*(self.fixregion+1)), int(xywh[1]+xywh[3]*(self.fixregion+1)))
cv2.rectangle(frame, p1, p2, self.BLUE, self.THICKNESS)
)
op1 = (
int(il-self.fixregion/2*iw),
int(it-self.fixregion/3*ih),
)
op2 = (
int(il+iw+self.fixregion/2*iw),
int(it+ih+self.fixregion/3*ih),
)
cv2.rectangle(frame, op1, op2, self.BLUE, self.THICKNESS)

return frame

Expand Down
18 changes: 16 additions & 2 deletions view/widgets/AIWidget.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

from superqt import QLabeledSlider
from model.Settings import Settings
from myutils.QtUtils import set_label_img,qimg_to_qpix
from controller.AiController import AIRecognizeController, AntiShakeController,BloodRecognizeController
Expand Down Expand Up @@ -74,9 +75,9 @@ def setUI(self):
grid =QGridLayout()

providerBox = QComboBox()
providerBox.addItem("CPUExecutionProvider","CPUExecutionProvider")
# providerBox.addItem("CPUExecutionProvider","CPUExecutionProvider")
providerBox.addItem("DmlExecutionProvider","DmlExecutionProvider")
providerBox.addItem("CUDAExecutionProvider","CUDAExecutionProvider")
# providerBox.addItem("CUDAExecutionProvider","CUDAExecutionProvider")
# providerBox.addItem("TensorrtExecutionProvider","TensorrtExecutionProvider")
self.providerBox = providerBox
# ordmlBTN = QRadioButton("onnxruntime DirectML (AMD-GPU)") #onnxruntime dml (GPU)
Expand All @@ -90,7 +91,20 @@ def setUI(self):
# self.ptorchBTN = ptorchBTN
# self.tensortBTN = tensortBTN

track = QCheckBox("目标跟踪算法补偿帧率 追踪时间(秒):")
track.setChecked(True)
self.track = track

tracktime = QLabeledDoubleSlider(Qt.Orientation.Horizontal)
tracktime.setMaximum(3)
tracktime.setMinimum(0.1)
tracktime.setValue(0.5)
tracktime.setSingleStep(0.01)
self.tracktime = tracktime

grid.addWidget(providerBox)
grid.addWidget(track)
grid.addWidget(tracktime,1,1)
# grid.addWidget(ordmlBTN, 0,0)
# grid.addWidget(ovinoBTN, 0,1)
# grid.addWidget(ptorchBTN, 1,0)
Expand Down

0 comments on commit 8d7f619

Please sign in to comment.