-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcreate_tfrecords.py
120 lines (97 loc) · 3.91 KB
/
create_tfrecords.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from os.path import join
import glob
import sys
import contextlib2
from random import shuffle
import numpy as np
import cv2
from six.moves import cPickle
from six.moves import urllib
import tensorflow as tf
import dataset_utils
_NUM_TRAIN_FILES = 1
def _create_tfrecord_train(dataset_dir, tfrecord_writer, classes_train):
count = 0
all_image_paths, class_labels = [], []
for cls_train in classes_train:
images_dir = glob.glob(join(dataset_dir, 'train', cls_train, '*.JPG')) +\
glob.glob(join(dataset_dir, 'train', cls_train, '*.jpg'))
all_image_paths.extend(images_dir)
class_labels.extend([classes_train[cls_train] for _ in range(len(images_dir))])
combined = list(zip(all_image_paths, class_labels))
shuffle(combined)
all_image_paths, class_labels = zip(*combined)
for i, im_path in enumerate(all_image_paths):
with tf.gfile.Open(im_path, 'rb') as f:
img = cv2.imread(im_path)
img_shape = img.shape
if img_shape != (256, 256, 3):
print('Resizing!!!!!!!!!!!!!!')
img = cv2.resize(img, (256, 256))
_, encoded_image = cv2.imencode('.jpg', img)
label = class_labels[i]
encoded_image = encoded_image.tobytes()
example = dataset_utils.image_to_tfexample(
encoded_image, b'jpg', 256, 256, label)
output_shard_index = count % _NUM_TRAIN_FILES
tfrecord_writer[output_shard_index].write(example.SerializeToString())
count += 1
print('Processed {} images'.format(count))
def _create_tfrecord_test(dataset_dir, tfrecord_writer, classes_train):
count = 0
image_paths = sorted(glob.glob(join(dataset_dir, 'test_imgs', '*.JPG'))) +\
glob.glob(join(dataset_dir, 'test_imgs', '*.jpg'))
num_images = len(image_paths)
for i, im_path in enumerate(image_paths):
with tf.gfile.Open(im_path, 'rb') as f:
img = cv2.imread(im_path)
img_shape = img.shape
if img_shape != (256, 256, 3):
print('Resizing!!!!!!!!!!!!!!!!')
img = cv2.resize(img, (256, 256))
_, encoded_image = cv2.imencode('.jpg', img)
encoded_image = encoded_image.tobytes()
example = dataset_utils.image_to_tfexample(
encoded_image, b'jpg', 256, 256, 0)
output_shard_index = count % _NUM_TRAIN_FILES
tfrecord_writer[output_shard_index].write(example.SerializeToString())
count += 1
print('Processed {} images'.format(count))
def _get_output_filename(dataset_dir, split_name):
"""Creates the output filename.
Args:
dataset_dir: The dataset directory where the dataset is stored.
split_name: The name of the train/test split.
Returns:
An absolute file path.
"""
return '%s/pc_%s' % (dataset_dir, split_name)
def run(dataset_dir):
"""Runs the download and conversion operation.
Args:
dataset_dir: The dataset directory where the dataset is stored.
"""
if not tf.gfile.Exists(dataset_dir):
tf.gfile.MakeDirs(dataset_dir)
training_filename = _get_output_filename(dataset_dir, 'train')
testing_filename = _get_output_filename(dataset_dir, 'val')
classes_train = sorted(list(filter(lambda x: os.path.isdir(join(dataset_dir, 'train', x)), os.listdir(join(dataset_dir, 'train')))))
classes_map = {}
for idx, cls_train in enumerate(classes_train):
classes_map[cls_train] = idx
with contextlib2.ExitStack() as tf_record_close_stack:
train_writer=dataset_utils.open_sharded_output_tfrecords(
tf_record_close_stack, training_filename, _NUM_TRAIN_FILES)
_create_tfrecord_train(dataset_dir, train_writer, classes_map)
with contextlib2.ExitStack() as tf_record_close_stack:
test_writer=dataset_utils.open_sharded_output_tfrecords(
tf_record_close_stack, testing_filename, _NUM_TRAIN_FILES)
_create_tfrecord_test(dataset_dir, test_writer, classes_map)
labels_to_class_names = dict(zip(range(len(classes_train)), classes_train))
dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
if __name__ == '__main__':
run('./data')