forked from graphcore/examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
65 lines (50 loc) · 2.29 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Copyright 2019 Graphcore Ltd.
from functools import partial
from typing import Callable, Tuple
import tensorflow as tf
def load_and_preprocess_data(img_path: str, img_width: int, img_height: int,
preprocess_fn: Callable, dtype: tf.DType) -> tf.Tensor:
"""Read and pre-process image.
Args:
img_path: Path to image
img_width: Target width
img_height: Target height
preprocess_fn: Function that scales the input to the correct range.
Returns: tf.Tensor representing pre-processed image in fp16.
"""
image = tf.read_file(img_path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [img_height, img_width])
image = preprocess_fn(image, data_format='channels_last')
return tf.cast(image, dtype)
def get_dataset(image_filenames: Tuple, batch_size: int, preprocess_fn: Callable, img_width: int, img_height: int,
loop: bool, dtype: tf.DType) -> tf.data.Dataset:
"""Creates an `Iterator` for enumerating the elements of this dataset.
Note: The returned iterator will be in an uninitialized state,
and you must run the `iterator.initializer` operation before using it:
```python
dataset = ...
iterator = dataset.make_initializable_iterator()
# ...
sess.run(iterator.initializer)
```
Args:
image_filenames: Tuple of image filenames, with each filename corresponding to the label of the image.
batch_size: Number of images per batch
preprocess_fn: Pre-processing to apply
img_width: Expected width of image
img_height: Expected height of image
loop: Repeatedly loop through images.
dtype: Input data type.
Returns:
Iterator over images and labels.
"""
image_ds = tf.data.Dataset.from_tensor_slices(tf.constant([str(item) for item in image_filenames]))
if loop:
image_ds = image_ds.repeat()
input_preprocess = partial(load_and_preprocess_data, img_width=img_width, img_height=img_height,
preprocess_fn=preprocess_fn, dtype=dtype)
image_ds = image_ds.map(map_func=input_preprocess, num_parallel_calls=100)
image_ds = image_ds.batch(batch_size, drop_remainder=True)
image_ds = image_ds.prefetch(buffer_size=100)
return image_ds