-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgetdata.py
118 lines (87 loc) · 3.33 KB
/
getdata.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# feature = {'image/width': tf.FixedLenFeature([], dtype=tf.int64),
# 'image/object/class/label': tf.VarLenFeature(dtype=tf.int64),
# 'image/height': tf.FixedLenFeature([], dtype=tf.int64),
# 'image/object/class/text': tf.VarLenFeature(dtype=tf.string),
# 'image/source_id': tf.VarLenFeature(dtype=tf.string),
# 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
# 'image/encoded': tf.VarLenFeature(dtype=tf.string)}
import tensorflow as tf
from PIL import Image
import numpy as np
import os
from io import BytesIO
def train_input_fn():
filenames = ["./training.record"]
dataset = tf.data.TFRecordDataset(filenames)
def parser(record):
keys_to_features = {'image/width': tf.FixedLenFeature((), dtype=tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
'image/object/class/label': tf.FixedLenFeature((), dtype=tf.int64),
'image/height': tf.FixedLenFeature([], dtype=tf.int64),
'image/object/class/text': tf.VarLenFeature(dtype=tf.string),
'image/source_id': tf.VarLenFeature(dtype=tf.string),
'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
'image/encoded': tf.FixedLenFeature((), dtype=tf.string, default_value="")}
parsed = tf.parse_single_example(record, keys_to_features)
# image = tf.decode_jpg(parsed["image/encoded"])
image = parsed["image/encoded"]
label = parsed["image/object/class/label"]
width = parsed["image/width"]
height = parsed["image/height"]
source_id = parsed['image/source_id']
text = parsed['image/object/class/text']
ymin = parsed['image/object/bbox/ymin']
# label = tf.cast(parsed["image/object/class/label"], tf.int32)
return image, label, source_id, width, height, text, ymin
dataset = dataset.map(parser)
# dataset = dataset.shuffle(buffer_size=10000)
# dataset = dataset.batch(32)
# dataset = dataset.repeat(10)
iterator = dataset.make_one_shot_iterator()
# image, label, source_id, width, height, text, ymin = iterator.get_next()
# return image, label, source_id, width, height, text, ymin
next_element = iterator.get_next()
return next_element
sess = tf.InteractiveSession()
# image, label, source_id, width, height, text, ymin = train_input_fn()
next_element = train_input_fn()
image, label, source_id, width, height, text, ymin = sess.run(next_element)
# image = sess.run(image)
# label = sess.run(label)
# source_id = sess.run(source_id)
# width = sess.run(width)
# height = sess.run(height)
# text = sess.run(text)
# ymin = sess.run(ymin)
print(type(image))
# print(image)
img = Image.open(BytesIO(image))
w, h = img.size
print('w = ' + str(w) + ', h = ' + str(h))
print('size = ' + str(np.array(img).shape))
img.save(source_id.values[0].decode("utf-8"))
print('label:')
print(type(label))
print(label)
print('source_id:')
print(type(source_id))
print(source_id.values[0].decode("utf-8"))
print('width:')
print(type(width))
print(width.size)
print(width)
print('height:')
print(type(height))
print(height.size)
print(height)
print('text:')
print(type(text))
print(text)
print('ymin:')
print(type(ymin))
print(ymin)
for i in range(100):
image, label, source_id, width, height, text, ymin = sess.run(next_element)
print('label ' + str(i + 1))
print(type(label))
print(label)
exit()