-
Notifications
You must be signed in to change notification settings - Fork 1
/
sort.py
34 lines (28 loc) · 877 Bytes
/
sort.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import cv2
import os
import matplotlib.pyplot as plt
images = os.listdir('faces')
def smaller(img):
resized = cv2.resize(img, (0,0), fx=1/2, fy=1/2)
return resized
plt.axis('off')
for index, image in enumerate(images):
img = cv2.imread('faces/' + image)
img = smaller(img)
plt.imshow(cv2.cvtColor((img), cv2.COLOR_BGR2RGB))
plt.show(block=False)
category = ''
while category not in ['m', 'f', 'o']:
print(index, '/', len(images))
print(image)
category = input('Category ? ')
if index % 4 == 0:
path = 'data7/test'
else:
path = 'data7/train'
if category == 'm':
cv2.imwrite(path + '/male/' + image, img)
elif category == 'f':
cv2.imwrite(path + '/female/' + image, img)
elif category == 'o':
cv2.imwrite('data7/other/' + image, img)