diff --git a/README.md b/README.md index 8d380aa69..ffd8f1336 100644 --- a/README.md +++ b/README.md @@ -63,32 +63,40 @@ epoch main/loss validation/main/loss main/accuracy validation/main/acc ## Supported Classes ``` pytorch_trainer/ -├── iterators -│   ├── multiprocess_iterator.py -│   ├── order_samplers.py -│   └── serial_iterator.py -├── reporter.py -└── training - ├── extensions - │   ├── evaluator.py - │   ├── fail_on_nonnumber.py - │   ├── log_report.py - │   ├── micro_average.py - │   ├── plot_report.py - │   ├── print_report.py - │   ├── progress_bar.py - │   ├── snapshot_writers.py - │   └── value_observation.py - ├── trainer.py - ├── triggers - │   ├── early_stopping_trigger.py - │   ├── interval_trigger.py - │   ├── manual_schedule_trigger.py - │   ├── minmax_value_trigger.py - │   ├── once_trigger.py - │   └── time_trigger.py - └── updaters - └── standard_updater.py +|-- iterators +| |-- multiprocess_iterator.py +| |-- multithread_iterator.py +| |-- order_samplers.py +| `-- serial_iterator.py +|-- reporter.py +`-- training + |-- extensions + | |-- evaluator.py + | |-- exponential_shift.py + | |-- fail_on_nonnumber.py + | |-- inverse_shift.py + | |-- linear_shift.py + | |-- log_report.py + | |-- micro_average.py + | |-- multistep_shift.py + | |-- plot_report.py + | |-- polynomial_shift.py + | |-- print_report.py + | |-- progress_bar.py + | |-- snapshot_writers.py + | |-- step_shift.py + | |-- value_observation.py + | `-- warmup_shift.py + |-- trainer.py + |-- triggers + | |-- early_stopping_trigger.py + | |-- interval_trigger.py + | |-- manual_schedule_trigger.py + | |-- minmax_value_trigger.py + | |-- once_trigger.py + | `-- time_trigger.py + `-- updaters + `-- standard_updater.py ``` ## Test @@ -98,7 +106,7 @@ pytest -s -v tests ## TODO -- [ ] Scheduler +- [x] Scheduler - [ ] DataLoader ## License diff --git a/pytorch_trainer/iterators/__init__.py b/pytorch_trainer/iterators/__init__.py index 6414677f4..98801a188 100644 --- a/pytorch_trainer/iterators/__init__.py +++ b/pytorch_trainer/iterators/__init__.py @@ -1,5 +1,6 @@ # import classes and functions from pytorch_trainer.iterators.multiprocess_iterator import MultiprocessIterator # NOQA +from pytorch_trainer.iterators.multithread_iterator import MultithreadIterator # NOQA from pytorch_trainer.iterators.serial_iterator import SerialIterator # NOQA from pytorch_trainer.iterators.order_samplers import OrderSampler # NOQA diff --git a/pytorch_trainer/iterators/multithread_iterator.py b/pytorch_trainer/iterators/multithread_iterator.py new file mode 100644 index 000000000..d973975da --- /dev/null +++ b/pytorch_trainer/iterators/multithread_iterator.py @@ -0,0 +1,191 @@ +from __future__ import division +from multiprocessing import pool + +import numpy + +from pytorch_trainer.dataset import iterator +from pytorch_trainer.iterators import _statemachine +from pytorch_trainer.iterators.order_samplers import ShuffleOrderSampler + + +class MultithreadIterator(iterator.Iterator): + + """Dataset iterator that loads examples in parallel. + This is an implementation of :class:`~chainer.dataset.Iterator` that loads + examples with worker threads. It uses the standard :mod:`threading` + module to parallelize the loading. + Note that this iterator effectively prefetches the examples for the next + batch asynchronously after the current batch is returned. + This iterator saves ``-1`` instead of ``None`` in snapshots since some + serializers do not support ``None``. + Args: + dataset (~chainer.dataset.Dataset): Dataset to iterate. + batch_size (int): Number of examples within each batch. + repeat (bool): If ``True``, it infinitely loops over the dataset. + Otherwise, it stops iteration at the end of the first epoch. + shuffle (bool): If ``True``, the order of examples is shuffled at the + beginning of each epoch. Otherwise, examples are extracted in the + order of indexes. If ``None`` and no ``order_sampler`` is given, + the behavior is the same as the case with ``shuffle=True``. + n_threads (int): Number of worker threads. + order_sampler (callable): A callable that generates the order + of the indices to sample in the next epoch when a epoch finishes. + This function should take two arguments: the current order + and the current position of the iterator. + This should return the next order. The size of the order + should remain constant. + This option cannot be used when ``shuffle`` is not ``None``. + """ + + def __init__(self, dataset, batch_size, repeat=True, shuffle=None, + n_threads=1, order_sampler=None): + self.dataset = dataset + self.batch_size = batch_size + self._repeat = repeat + self._shuffle = shuffle + + if self._shuffle is not None: + if order_sampler is not None: + raise ValueError('`shuffle` is not `None` and a custom ' + '`order_sampler` is set. Please set ' + '`shuffle` to `None` to use the custom ' + 'order sampler.') + else: + if self._shuffle: + order_sampler = ShuffleOrderSampler() + else: + if order_sampler is None: + order_sampler = ShuffleOrderSampler() + self.order_sampler = order_sampler + + self.n_threads = n_threads + self._pool = None + + self.reset() + + def reset(self): + if self.order_sampler is None: + order = None + else: + order = self.order_sampler(numpy.arange(len(self.dataset)), 0) + self._state = _statemachine.IteratorState(0, 0, False, order) + self._previous_epoch_detail = -1. + + # reset internal state + self._next = None + + def finalize(self): + pool = self._pool + + self._next = None + self._pool = None + if pool is not None: + pool.terminate() + + def __next__(self): + if self._next is None: + # load for the first iteration + self._invoke_prefetch() + + batch = self._get() + self._invoke_prefetch() # prefetch for the next iteration + return batch + + next = __next__ + + @property + def current_position(self): + return self._state.current_position + + @property + def epoch(self): + return self._state.epoch + + @property + def is_new_epoch(self): + return self._state.is_new_epoch + + @property + def epoch_detail(self): + return self.epoch + self.current_position / self._epoch_size + + @property + def previous_epoch_detail(self): + # use -1 instead of None internally. + if self._previous_epoch_detail < 0: + return None + return self._previous_epoch_detail + + def state_dict(self): + state_dict = { + 'current_position': self.current_position, + 'epoch': self.epoch, + 'is_new_epoch': self.is_new_epoch, + } + + order = self._state.order.copy() + state_dict['order'] = order + + state_dict['previous_epoch_detail'] = self._previous_epoch_detail + + return state_dict + + def load_state_dict(self, state_dict): + current_position = state_dict['current_position'] + epoch = state_dict['epoch'] + is_new_epoch = state_dict['is_new_epoch'] + order = state_dict['order'] + self._state = _statemachine.IteratorState( + current_position, epoch, is_new_epoch, order) + self._previous_epoch_detail = state_dict['previous_epoch_detail'] + # Old version serialized ``None``. + if self._previous_epoch_detail is None: + self._previous_epoch_detail = -1. + self._next = None + + @staticmethod + def _read(args): + dataset, index = args + return dataset[index] + + def _invoke_prefetch(self): + assert self._next is None + self._next_state, indices = _statemachine.iterator_statemachine( + self._state, self.batch_size, self.repeat, self.order_sampler, + len(self.dataset)) + + if indices is None: + self._next = None + else: + if self._pool is None: + self._pool = pool.ThreadPool(self.n_threads) + args = [(self.dataset, index) for index in indices] + self._next = self._pool.map_async(MultithreadIterator._read, args) + + def _get(self): + self._previous_epoch_detail = self.epoch_detail + self._state = self._next_state + + next = self._next + if next is None: + raise StopIteration + self._next = None + + while not next.ready(): + next.wait(0.5) # To avoid interruption bug in Python2 + + batch = [data for data in next.get()] + return batch + + @property + def _epoch_size(self): + order = self._state.order + if order is None: + epoch_size = len(self.dataset) + else: + epoch_size = len(order) + return epoch_size + + @property + def repeat(self): + return self._repeat diff --git a/pytorch_trainer/training/extensions/evaluator.py b/pytorch_trainer/training/extensions/evaluator.py index 302af8fa4..f29521d36 100644 --- a/pytorch_trainer/training/extensions/evaluator.py +++ b/pytorch_trainer/training/extensions/evaluator.py @@ -114,7 +114,8 @@ def __init__(self, iterator, target, converter=convert.concat_examples, for key, iter in six.iteritems(iterator): if (isinstance(iter, (iterators.SerialIterator, - iterators.MultiprocessIterator)) and + iterators.MultiprocessIterator, + iterators.MultithreadIterator)) and getattr(iter, 'repeat', False)): msg = 'The `repeat` property of the iterator {} ' 'is set to `True`. Typically, the evaluator sweeps ' diff --git a/setup.py b/setup.py index b07b9b1c8..a6b9114d6 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='pytorch-trainer', - version='1.0.0', + version='1.3.0', packages=find_packages(), url='https://github.com/Hiroshiba/pytorch-trainer', author='Kazuyuki Hiroshiba', diff --git a/tests/iterators_tests/test_multithread_iterator.py b/tests/iterators_tests/test_multithread_iterator.py new file mode 100644 index 000000000..a758f5df6 --- /dev/null +++ b/tests/iterators_tests/test_multithread_iterator.py @@ -0,0 +1,390 @@ +from __future__ import division +import copy +import unittest + +import numpy +import six + +from pytorch_trainer import iterators +from pytorch_trainer import testing + + +@testing.parameterize(*testing.product({ + 'n_threads': [1, 2], + 'order_sampler': [ + None, lambda order, _: numpy.random.permutation(len(order))] +})) +class TestMultithreadIterator(unittest.TestCase): + + def setUp(self): + self.options = {'n_threads': self.n_threads, + 'order_sampler': self.order_sampler} + + def test_iterator_repeat(self): + dataset = [1, 2, 3, 4, 5, 6] + it = iterators.MultithreadIterator(dataset, 2, **self.options) + for i in range(3): + self.assertEqual(it.epoch, i) + self.assertAlmostEqual(it.epoch_detail, i + 0 / 6) + if i == 0: + self.assertIsNone(it.previous_epoch_detail) + else: + self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6) + batch1 = it.next() + self.assertEqual(len(batch1), 2) + self.assertIsInstance(batch1, list) + self.assertFalse(it.is_new_epoch) + self.assertAlmostEqual(it.epoch_detail, i + 2 / 6) + self.assertAlmostEqual(it.previous_epoch_detail, i + 0 / 6) + batch2 = it.next() + self.assertEqual(len(batch2), 2) + self.assertIsInstance(batch2, list) + self.assertFalse(it.is_new_epoch) + self.assertAlmostEqual(it.epoch_detail, i + 4 / 6) + self.assertAlmostEqual(it.previous_epoch_detail, i + 2 / 6) + batch3 = it.next() + self.assertEqual(len(batch3), 2) + self.assertIsInstance(batch3, list) + self.assertTrue(it.is_new_epoch) + self.assertEqual(sorted(batch1 + batch2 + batch3), dataset) + self.assertAlmostEqual(it.epoch_detail, i + 6 / 6) + self.assertAlmostEqual(it.previous_epoch_detail, i + 4 / 6) + + def test_iterator_list_type(self): + dataset = [[i, numpy.zeros((10,)) + i] for i in range(6)] + it = iterators.MultithreadIterator(dataset, 2, **self.options) + for i in range(3): + self.assertEqual(it.epoch, i) + self.assertAlmostEqual(it.epoch_detail, i) + if i == 0: + self.assertIsNone(it.previous_epoch_detail) + else: + self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6) + batches = {} + for j in range(3): + batch = it.next() + self.assertEqual(len(batch), 2) + if j != 2: + self.assertFalse(it.is_new_epoch) + else: + self.assertTrue(it.is_new_epoch) + self.assertAlmostEqual( + it.epoch_detail, (3 * i + j + 1) * 2 / 6) + self.assertAlmostEqual( + it.previous_epoch_detail, (3 * i + j) * 2 / 6) + for x in batch: + self.assertIsInstance(x, list) + self.assertIsInstance(x[1], numpy.ndarray) + batches[x[0]] = x[1] + + self.assertEqual(len(batches), len(dataset)) + for k, v in six.iteritems(batches): + numpy.testing.assert_allclose(dataset[k][1], v) + + def test_iterator_tuple_type(self): + dataset = [(i, numpy.zeros((10,)) + i) for i in range(6)] + it = iterators.MultithreadIterator(dataset, 2, **self.options) + for i in range(3): + self.assertEqual(it.epoch, i) + self.assertAlmostEqual(it.epoch_detail, i) + if i == 0: + self.assertIsNone(it.previous_epoch_detail) + else: + self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6) + batches = {} + for j in range(3): + batch = it.next() + self.assertEqual(len(batch), 2) + if j != 2: + self.assertFalse(it.is_new_epoch) + else: + self.assertTrue(it.is_new_epoch) + self.assertAlmostEqual( + it.epoch_detail, (3 * i + j + 1) * 2 / 6) + self.assertAlmostEqual( + it.previous_epoch_detail, (3 * i + j) * 2 / 6) + for x in batch: + self.assertIsInstance(x, tuple) + self.assertIsInstance(x[1], numpy.ndarray) + batches[x[0]] = x[1] + + self.assertEqual(len(batches), len(dataset)) + for k, v in six.iteritems(batches): + numpy.testing.assert_allclose(dataset[k][1], v) + + def test_iterator_dict_type(self): + dataset = [{i: numpy.zeros((10,)) + i} for i in range(6)] + it = iterators.MultithreadIterator(dataset, 2, **self.options) + for i in range(3): + self.assertEqual(it.epoch, i) + self.assertAlmostEqual(it.epoch_detail, i) + if i == 0: + self.assertIsNone(it.previous_epoch_detail) + else: + self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6) + batches = {} + for j in range(3): + batch = it.next() + self.assertEqual(len(batch), 2) + if j != 2: + self.assertFalse(it.is_new_epoch) + else: + self.assertTrue(it.is_new_epoch) + self.assertAlmostEqual( + it.epoch_detail, (3 * i + j + 1) * 2 / 6) + self.assertAlmostEqual( + it.previous_epoch_detail, (3 * i + j) * 2 / 6) + for x in batch: + self.assertIsInstance(x, dict) + k = tuple(x)[0] + v = x[k] + self.assertIsInstance(v, numpy.ndarray) + batches[k] = v + + self.assertEqual(len(batches), len(dataset)) + for k, v in six.iteritems(batches): + x = dataset[k][tuple(dataset[k])[0]] + numpy.testing.assert_allclose(x, v) + + def test_iterator_repeat_not_even(self): + dataset = [1, 2, 3, 4, 5] + it = iterators.MultithreadIterator(dataset, 2, **self.options) + + batches = sum([it.next() for _ in range(5)], []) + self.assertEqual(sorted(batches), sorted(dataset * 2)) + + def test_iterator_not_repeat(self): + dataset = [1, 2, 3, 4, 5] + it = iterators.MultithreadIterator( + dataset, 2, repeat=False, **self.options) + + batches = sum([it.next() for _ in range(3)], []) + self.assertEqual(sorted(batches), dataset) + for _ in range(2): + self.assertRaises(StopIteration, it.next) + + def test_iterator_not_repeat_not_even(self): + dataset = [1, 2, 3, 4, 5] + it = iterators.MultithreadIterator( + dataset, 2, repeat=False, **self.options) + + self.assertAlmostEqual(it.epoch_detail, 0 / 5) + self.assertIsNone(it.previous_epoch_detail) + batch1 = it.next() + self.assertAlmostEqual(it.epoch_detail, 2 / 5) + self.assertAlmostEqual(it.previous_epoch_detail, 0 / 5) + batch2 = it.next() + self.assertAlmostEqual(it.epoch_detail, 4 / 5) + self.assertAlmostEqual(it.previous_epoch_detail, 2 / 5) + batch3 = it.next() + self.assertAlmostEqual(it.epoch_detail, 5 / 5) + self.assertAlmostEqual(it.previous_epoch_detail, 4 / 5) + self.assertRaises(StopIteration, it.next) + + self.assertEqual(len(batch3), 1) + self.assertEqual(sorted(batch1 + batch2 + batch3), dataset) + + def test_iterator_shuffle_divisible(self): + dataset = list(range(10)) + it = iterators.MultithreadIterator( + dataset, 10, **self.options) + self.assertNotEqual(it.next(), it.next()) + + def test_iterator_shuffle_nondivisible(self): + dataset = list(range(10)) + it = iterators.MultithreadIterator( + dataset, 3, **self.options) + out = sum([it.next() for _ in range(7)], []) + self.assertNotEqual(out[0:10], out[10:20]) + + def test_copy_not_repeat(self): + dataset = [1, 2, 3, 4, 5] + it = iterators.MultithreadIterator( + dataset, 2, repeat=False, **self.options) + copy_it = copy.copy(it) + batches = sum([it.next() for _ in range(3)], []) + self.assertEqual(sorted(batches), dataset) + for _ in range(2): + self.assertRaises(StopIteration, it.next) + it = None + + batches = sum([copy_it.next() for _ in range(3)], []) + self.assertEqual(sorted(batches), dataset) + for _ in range(2): + self.assertRaises(StopIteration, copy_it.next) + + def test_reset(self): + dataset = [1, 2, 3, 4, 5] + it = iterators.MultithreadIterator( + dataset, 2, repeat=False, **self.options) + + for trial in range(4): + batches = sum([it.next() for _ in range(3)], []) + self.assertEqual(sorted(batches), dataset) + for _ in range(2): + self.assertRaises(StopIteration, it.next) + it.reset() + + def test_supported_reset_middle(self): + dataset = [1, 2, 3, 4, 5] + it = iterators.MultithreadIterator( + dataset, 2, repeat=False, **self.options) + it.next() + it.reset() + + def test_supported_reset_repeat(self): + dataset = [1, 2, 3, 4] + it = iterators.MultithreadIterator( + dataset, 2, repeat=True, **self.options) + it.next() + it.next() + it.reset() + + def test_supported_reset_finalized(self): + dataset = [1, 2, 3, 4] + it = iterators.MultithreadIterator( + dataset, 2, repeat=False, **self.options) + it.next() + it.next() + it.finalize() + it.reset() + + +@testing.parameterize(*testing.product({ + 'n_threads': [1, 2], + 'order_sampler': [ + None, lambda order, _: numpy.random.permutation(len(order))] +})) +class TestMultithreadIteratorStateDict(unittest.TestCase): + + def setUp(self): + self.options = {'n_threads': self.n_threads, + 'order_sampler': self.order_sampler} + + def test_iterator_state_dict(self): + dataset = [1, 2, 3, 4, 5, 6] + it = iterators.MultithreadIterator(dataset, 2, **self.options) + + self.assertEqual(it.epoch, 0) + self.assertAlmostEqual(it.epoch_detail, 0 / 6) + self.assertIsNone(it.previous_epoch_detail) + batch1 = it.next() + self.assertEqual(len(batch1), 2) + self.assertIsInstance(batch1, list) + self.assertFalse(it.is_new_epoch) + self.assertAlmostEqual(it.epoch_detail, 2 / 6) + self.assertAlmostEqual(it.previous_epoch_detail, 0 / 6) + batch2 = it.next() + self.assertEqual(len(batch2), 2) + self.assertIsInstance(batch2, list) + self.assertFalse(it.is_new_epoch) + self.assertAlmostEqual(it.epoch_detail, 4 / 6) + self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6) + + state_dict = copy.deepcopy(it.state_dict()) + + it = iterators.MultithreadIterator(dataset, 2, **self.options) + it.load_state_dict(state_dict) + self.assertFalse(it.is_new_epoch) + self.assertAlmostEqual(it.epoch_detail, 4 / 6) + self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6) + + batch3 = it.next() + self.assertEqual(len(batch3), 2) + self.assertIsInstance(batch3, list) + self.assertTrue(it.is_new_epoch) + self.assertEqual(sorted(batch1 + batch2 + batch3), dataset) + self.assertAlmostEqual(it.epoch_detail, 6 / 6) + self.assertAlmostEqual(it.previous_epoch_detail, 4 / 6) + + +class TestMultithreadIteratorOrderSamplerEpochSize(unittest.TestCase): + + def setUp(self): + def order_sampler(order, cur_pos): + return numpy.repeat(numpy.arange(3), 2) + self.options = {'order_sampler': order_sampler} + + def test_iterator_repeat(self): + dataset = [1, 2, 3] + it = iterators.MultithreadIterator(dataset, 2, **self.options) + for i in range(3): + self.assertEqual(it.epoch, i) + self.assertAlmostEqual(it.epoch_detail, i + 0 / 6) + if i == 0: + self.assertIsNone(it.previous_epoch_detail) + else: + self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6) + batch1 = it.next() + self.assertEqual(len(batch1), 2) + self.assertIsInstance(batch1, list) + self.assertFalse(it.is_new_epoch) + self.assertAlmostEqual(it.epoch_detail, i + 2 / 6) + self.assertAlmostEqual(it.previous_epoch_detail, i + 0 / 6) + batch2 = it.next() + self.assertEqual(len(batch2), 2) + self.assertIsInstance(batch2, list) + self.assertFalse(it.is_new_epoch) + self.assertAlmostEqual(it.epoch_detail, i + 4 / 6) + self.assertAlmostEqual(it.previous_epoch_detail, i + 2 / 6) + batch3 = it.next() + self.assertEqual(len(batch3), 2) + self.assertIsInstance(batch3, list) + self.assertTrue(it.is_new_epoch) + self.assertAlmostEqual(it.epoch_detail, i + 6 / 6) + self.assertAlmostEqual(it.previous_epoch_detail, i + 4 / 6) + + self.assertEqual( + sorted(batch1 + batch2 + batch3), [1, 1, 2, 2, 3, 3]) + + +class NoSameIndicesOrderSampler(object): + + def __init__(self, batchsize): + self.n_call = 0 + + def __call__(self, current_order, current_pos): + # all batches contain unique indices + remaining = current_order[current_pos:] + first = numpy.setdiff1d(numpy.arange(len(current_order)), remaining) + second = numpy.setdiff1d(numpy.arange(len(current_order)), first) + return numpy.concatenate((first, second)) + + +class TestMultithreadIteratorNoSameIndicesOrderSampler(unittest.TestCase): + + def test_no_same_indices_order_sampler(self): + dataset = [1, 2, 3, 4, 5, 6] + batchsize = 5 + + it = iterators.MultithreadIterator( + dataset, batchsize, + order_sampler=NoSameIndicesOrderSampler(batchsize)) + for _ in range(5): + batch = it.next() + self.assertEqual(len(numpy.unique(batch)), batchsize) + + +class InvalidOrderSampler(object): + + def __init__(self): + self.n_call = 0 + + def __call__(self, _order, _): + order = numpy.arange(len(_order) - self.n_call) + self.n_call += 1 + return order + + +class TestMultithreadIteratorInvalidOrderSampler(unittest.TestCase): + + def test_invalid_order_sampler(self): + dataset = [1, 2, 3, 4, 5, 6] + + with self.assertRaises(ValueError): + it = iterators.MultithreadIterator( + dataset, 6, order_sampler=InvalidOrderSampler()) + it.next() + + +testing.run_module(__name__, __file__) diff --git a/tests/training_tests/extensions_tests/test_evaluator.py b/tests/training_tests/extensions_tests/test_evaluator.py index 1e403994f..43e9e11bb 100644 --- a/tests/training_tests/extensions_tests/test_evaluator.py +++ b/tests/training_tests/extensions_tests/test_evaluator.py @@ -246,7 +246,8 @@ def test_evaluate(self): @testing.parameterize(*testing.product({ 'repeat': [True, False], 'iterator_class': [iterators.SerialIterator, - iterators.MultiprocessIterator] + iterators.MultiprocessIterator, + iterators.MultithreadIterator] })) class TestEvaluatorRepeat(unittest.TestCase):