-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdata_process.py
143 lines (132 loc) · 5.49 KB
/
data_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# -*- coding: utf-8 -*-
"""
Created on Sun Sep 23 11:27:58 2018
@author: Yuxi1989
"""
import os,shutil
import numpy as np
import tensorflow as tf
import random
from PIL import Image
from PIL import ImageFile
import re
ImageFile.LOAD_TRUNCATED_IMAGES = True
def move_images():
'''
merge all images in the '其他' into one folder
'''
cwd=os.path.join('..','data','train','其他')
for i in os.listdir(cwd):
local_dir=os.path.join(cwd,i)
if os.path.isdir(local_dir):
for img in os.listdir(local_dir):
img_path=os.path.join(local_dir,img)
shutil.move(img_path,os.path.join(cwd,img))
folders=[folder for folder in os.listdir(cwd) if os.path.isdir(os.path.join(cwd,folder))]
for folder in folders:
os.rmdir(folder)
def get_class_correspond():
'''
the relationship between the label name and the integer number
'''
cwd=os.path.join('..','data','train')
categorys=[]
corresponds={}
for category in os.listdir(cwd):
path=os.path.join(cwd,category)
if os.path.isdir(path):
categorys.append(category)
for idx,category in enumerate(categorys):
corresponds[category]=idx
return corresponds
def make_tfrecords(corresponds,nrows,ncols):
'''
make tfrecord
'''
# train_writer=tf.python_io.TFRecordWriter('..\\data\\train.tfrecords')
# val_writer=tf.python_io.TFRecordWriter('..\\data\\val.tfrecords')
test_writer=tf.python_io.TFRecordWriter('..\\data\\test2.tfrecords')
# train_img_paths=[]
# train_img_labels=[]
# cwd=os.path.join('..','data','train')
# categorys=os.listdir(cwd)
# for category in categorys:
# if os.path.isdir(os.path.join(cwd,category)):
# loop_num=1000//len(os.listdir(os.path.join(cwd,category)))
# for _ in range(loop_num):
# for img in os.listdir(os.path.join(cwd,category)):
# if img.split('.')[-1]=='jpg':
# train_img_paths.append(os.path.join(cwd,category,img))
# train_img_labels.append(corresponds[category])
# randnum=random.randint(0,len(train_img_paths))
# random.seed(randnum)
# random.shuffle(train_img_paths)
# random.seed(randnum)
# random.shuffle(train_img_labels)
# for i in range(len(train_img_labels)):
# img=Image.open(train_img_paths[i]).convert('RGB')
# img=img.resize((ncols,nrows))
# img_raw=img.tobytes()
# label=train_img_labels[i]
# example=tf.train.Example(features=tf.train.Features(feature={
# 'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
# 'label':tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))}))
# train_writer.write(example.SerializeToString())
# train_writer.close()
#
# cwd=os.path.join('..','data','val')
# others=['变形','驳口','打白点','打磨印','返底','划伤','火山口','铝屑',
# '喷涂碰伤','碰凹','气泡','拖烂','纹粗','油印','油渣','杂色','粘接']
# for path in os.listdir(cwd):
# if path.split('.')[-1]=='jpg':
# img=Image.open(os.path.join(cwd,path))
# img=img.resize((ncols,nrows))
# img_raw=img.tobytes()
# pattern=re.compile(r'[^\u4e00-\u9fa5]')
# zh=pattern.split(path)
# if zh[0] in others:
# label=1
# else:
# label=corresponds[zh[0]]
# example=tf.train.Example(features=tf.train.Features(feature={
# 'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
# 'label':tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))}))
# val_writer.write(example.SerializeToString())
# val_writer.close()
cwd=os.path.join('..','data','test2')
for path in sorted(os.listdir(cwd),key=lambda i:int(i.split('.')[0])):
if path.split('.')[-1]=='jpg':
img=Image.open(os.path.join(cwd,path))
img=img.resize((ncols,nrows))
img_raw=img.tobytes()
label=1
example=tf.train.Example(features=tf.train.Features(feature={
'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'label':tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))}))
test_writer.write(example.SerializeToString())
test_writer.close()
def split_train_val():
'''
get validation set from the training set
'''
cwd=os.path.join('..','data')
train_path=os.path.join(cwd,'train')
val_path=os.path.join(cwd,'val')
if not os.path.exists(val_path):
os.mkdir(os.path.join(cwd,'val'))
for category in os.listdir(train_path):
local_path=os.path.join(train_path,category)
img_path=os.listdir(local_path)
cnt=len(img_path)
img_path=np.array(img_path)
val_img_path=np.random.choice(img_path,max(1,int(cnt*0.1)),replace=False)
val_img_path=val_img_path.tolist()
for img in val_img_path:
src_path=os.path.join(local_path,img)
dst_path=os.path.join(val_path,img)
shutil.move(src_path,dst_path)
if __name__=='__main__':
#move_images()
corresponds=get_class_correspond()
#split_train_val()
make_tfrecords(corresponds,240,320)