-
Notifications
You must be signed in to change notification settings - Fork 83
/
tensorpack_extension.py
30 lines (23 loc) · 1.1 KB
/
tensorpack_extension.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# -*- coding: utf-8 -*-
#!/usr/bin/env python
import tensorflow as tf
from tensorpack.utils import logger
from tensorpack.input_source.input_source import QueueInput, EnqueueThread
class FlexibleQueueInput(QueueInput):
""" Extend QueueInput to set queue capacity.
"""
def __init__(self, ds, capacity=1000):
super(FlexibleQueueInput, self).__init__(ds)
self.capacity = capacity
def _setup(self, inputs):
self._input_placehdrs = [v.build_placeholder_reuse() for v in inputs]
assert len(self._input_placehdrs) > 0, \
"QueueInput has to be used with some inputs!"
with self.cached_name_scope():
if self.queue is None:
self.queue = tf.FIFOQueue(
self.capacity, [x.dtype for x in self._input_placehdrs],
name='input_queue')
logger.info("Setting up the queue '{}' for CPU prefetching ...".format(self.queue.name))
self.thread = EnqueueThread(self.queue, self._inf_ds, self._input_placehdrs)
self._dequeue_op = self.queue.dequeue(name='dequeue_for_reset')