Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Several speed & code updates #19

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,6 @@ ENV/
.mypy_cache/

*.pth

# vim
*.swp
34 changes: 23 additions & 11 deletions bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,39 @@ def bboxloginv(dx,dy,dw,dh,axc,ayc,aww,ahh):
x1,x2,y1,y2 = xc-ww/2,xc+ww/2,yc-hh/2,yc+hh/2
return x1,y1,x2,y2

def nms(dets, thresh):
if 0==len(dets): return []
x1,y1,x2,y2,scores = dets[:, 0],dets[:, 1],dets[:, 2],dets[:, 3],dets[:, 4]
def nms(bboxlist:torch.Tensor, thresh:float) -> list:
"""Given an Nx5 tensor of bounding boxes, and a threshold,
return a list of the indexes of bounding boxes to keep.
"""
if len(bboxlist) == 0:
return []
x1 = bboxlist[:,0]
y1 = bboxlist[:,1]
x2 = bboxlist[:,2]
y2 = bboxlist[:,3]
scores = bboxlist[:,4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]

# Go through the boxes in order of decreasing score...
scores = np.asarray(scores)
order = scores.argsort()
order = np.asarray(list(order[::-1]))
keep = []
while order.size > 0:
while len(order) > 0:
i = order[0]
keep.append(i)
keep.append(i) # Keep this one.

# For all the remaining (lower score) bounding boxes, figure out something about the overlap I think
xx1,yy1 = np.maximum(x1[i], x1[order[1:]]),np.maximum(y1[i], y1[order[1:]])
xx2,yy2 = np.minimum(x2[i], x2[order[1:]]),np.minimum(y2[i], y2[order[1:]])

w,h = np.maximum(0.0, xx2 - xx1 + 1),np.maximum(0.0, yy2 - yy1 + 1)
ovr = w*h / (areas[i] + areas[order[1:]] - w*h)

ovr = w*h / (areas[i] + areas[order[1:]] - w*h) # looks like the overlap
inds = np.where(ovr <= thresh)[0]
order = order[inds + 1]
order = order[inds + 1] # eliminate the ones that don't meet the threshhold

return keep


def encode(matched, priors, variances):
"""Encode the variances from the priorbox layers into the ground truth boxes
we have matched (based on jaccard overlap) with the prior boxes.
Expand Down Expand Up @@ -92,4 +104,4 @@ def decode(loc, priors, variances):
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
boxes[:, :2] -= boxes[:, 2:] / 2
boxes[:, 2:] += boxes[:, :2]
return boxes
return boxes
Binary file modified data/test01_output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
73 changes: 73 additions & 0 deletions detect_faces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from typing import List, Tuple
torch.backends.cudnn.benchmark = True

import os,sys,cv2,random,datetime,time,math
import argparse
import numpy as np

from bbox import decode, nms

def detect_faces(net:nn.Module, img:np.ndarray, minscale:int=3, ovr_threshhold:float=0.3,
score_threshhold:float=0.5) -> List[Tuple]:
"""returns an list of tuples describing bounding boxes: [x1,y1,x2,y2,score].
Setting minscale to 0 finds the smallest faces, but takes the longest.
"""
bboxlist = detect(net, img, minscale)
keep_idx = nms(bboxlist, ovr_threshhold)
bboxlist = bboxlist[keep_idx,:]
out = []
for b in bboxlist:
x1,y1,x2,y2,s = b
if s<0.5:
continue
out.append((int(x1),int(y1),int(x2),int(y2),s))
return out


def detect(net:nn.Module, img:np.ndarray, minscale:int=3) -> torch.Tensor:
"""returns an Nx5 tensor describing bounding boxes: [x1,y1,x2,y2,score].
This will have LOTS of similar/overlapping regions. Need to call bbox.nms to reconcile them.
Setting minscale to 0 finds the smallest faces, but takes the longest.
"""
img = img - np.array([104,117,123])
img = img.transpose(2, 0, 1)
img = img.reshape((1,)+img.shape)

img = Variable(torch.from_numpy(img).float()).cuda()
BB,CC,HH,WW = img.size()
olist = net(img)

bboxlist = []
for i in range(minscale, len(olist)//2):
ocls = F.softmax(olist[i*2], dim=1).data
oreg = olist[i*2+1].data
FB,FC,FH,FW = ocls.size() # feature map size
stride = 2**(i+2) # 4,8,16,32,64,128
anchor = stride*4
# this workload is small enough that it's faster on CPU than GPU (~55ms vs ~65ms)
# but most of that time (40ms) is spend moving the data from GPU to CPU.
all_scores = ocls[0,1,:,:].cpu()
oreg = oreg.cpu()
# instead of running a sliding window, first find the places where score is big enough to bother
bigenough = torch.nonzero(all_scores > 0.05)
for hindex, windex in bigenough:
score = all_scores[hindex,windex]
loc = oreg[0,:,hindex,windex].contiguous().view(1,4)
axc,ayc = stride/2+windex*stride,stride/2+hindex*stride
priors = torch.Tensor([[axc/1.0,ayc/1.0,stride*4/1.0,stride*4/1.0]])
variances = [0.1,0.2]
box = decode(loc,priors,variances)
x1,y1,x2,y2 = box[0]*1.0
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
bboxlist.append([x1,y1,x2,y2,score])
if len(bboxlist) == 0:
bboxlist=torch.zeros((1, 5))
bboxlist = torch.Tensor(bboxlist)
return bboxlist

53 changes: 53 additions & 0 deletions livecam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
torch.backends.cudnn.benchmark = True

import os,sys,cv2,random,datetime,time,math
import argparse
import numpy as np

import net_s3fd
from detect_faces import detect_faces

parser = argparse.ArgumentParser(description='PyTorch face detect')
parser.add_argument('--net','-n', default='s3fd', type=str)
parser.add_argument('--model', required=True, type=str)
parser.add_argument('--path', default='CAMERA', type=str)

args = parser.parse_args()
use_cuda = torch.cuda.is_available()


net = getattr(net_s3fd,args.net)()
net.load_state_dict(torch.load(args.model))
net.cuda()
net.eval()


if args.path=='CAMERA':
cap = cv2.VideoCapture(0)
with torch.no_grad():
while(True):
if args.path=='CAMERA':
ret, img = cap.read()
else:
img = cv2.imread(args.path)

imgshow = np.copy(img)
start_time = time.time()
bboxlist = detect_faces(net, img, 3)
print(f"Running detect_faces took {1000*(time.time() - start_time):.1f}ms. Found {len(bboxlist)} faces.")
for b in bboxlist:
x1,y1,x2,y2,s = b
cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,255,0),1)
cv2.imshow('test',imgshow)

if args.path=='CAMERA':
if cv2.waitKey(1) & 0xFF == ord('q'): break
else:
cv2.imwrite(args.path[:-4]+'_output.png',imgshow)
if cv2.waitKey(0) or True: break
4 changes: 3 additions & 1 deletion net_s3fd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def forward(self, x):
x = x / norm * self.weight.view(1,-1,1,1)
return x

class s3fd(nn.Module):
class S3fd_Model(nn.Module):
def __init__(self):
super(s3fd, self).__init__()
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
Expand Down Expand Up @@ -120,3 +120,5 @@ def forward(self, x):
cls1 = torch.cat([bmax,chunk[3]],dim=1)

return [cls1,reg1,cls2,reg2,cls3,reg3,cls4,reg4,cls5,reg5,cls6,reg6]

s3fd = S3fd_Model
85 changes: 0 additions & 85 deletions test.py

This file was deleted.

22 changes: 22 additions & 0 deletions test_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import cv2
import numpy as np
import pytest
import torch

from detect_faces import detect_faces
from net_s3fd import S3fd_Model

def test_ellen_selfie():
model = S3fd_Model()
try:
state_dict = torch.load("s3fd_convert.pth")
model.load_state_dict(state_dict)
except:
print("Failed to load pre-trained model for test")
raise
model.cuda()
model.eval()
with torch.no_grad():
img = cv2.imread('data/test01.jpg')
faces = detect_faces(model, img)
assert len(faces) == 11
35 changes: 2 additions & 33 deletions wider_eval_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,7 @@
from net_s3fd import s3fd
from bbox import *


def detect(net,img):
img = img - np.array([104,117,123])
img = img.transpose(2, 0, 1)
img = img.reshape((1,)+img.shape)

img = Variable(torch.from_numpy(img).float(),volatile=True).cuda()
BB,CC,HH,WW = img.size()
olist = net(img)

bboxlist = []
for i in range(len(olist)/2): olist[i*2] = F.softmax(olist[i*2])
for i in range(len(olist)/2):
ocls,oreg = olist[i*2].data.cpu(),olist[i*2+1].data.cpu()
FB,FC,FH,FW = ocls.size() # feature map size
stride = 2**(i+2) # 4,8,16,32,64,128
anchor = stride*4
for Findex in range(FH*FW):
windex,hindex = Findex%FW,Findex//FW
axc,ayc = stride/2+windex*stride,stride/2+hindex*stride
score = ocls[0,1,hindex,windex]
loc = oreg[0,:,hindex,windex].contiguous().view(1,4)
if score<0.05: continue
priors = torch.Tensor([[axc/1.0,ayc/1.0,stride*4/1.0,stride*4/1.0]])
variances = [0.1,0.2]
box = decode(loc,priors,variances)
x1,y1,x2,y2 = box[0]*1.0
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
bboxlist.append([x1,y1,x2,y2,score])
bboxlist = np.array(bboxlist)
if 0==len(bboxlist): bboxlist=np.zeros((1, 5))
return bboxlist
from detect_faces import detect

def flip_detect(net,img):
img = cv2.flip(img, 1)
Expand Down Expand Up @@ -134,4 +103,4 @@ def scale_detect(net,img,scale=2.0,facesize=None):
x1,y1,x2,y2,s = b
f.write('{:.1f} {:.1f} {:.1f} {:.1f} {:.3f}\n'.format(x1,y1,(x2-x1+1),(y2-y1+1),s))
f.close()
print('event:%d num:%d' % (index + 1, num + 1))
print('event:%d num:%d' % (index + 1, num + 1))