Skip to content

Commit

Permalink
add MultithreadIterator
Browse files Browse the repository at this point in the history
  • Loading branch information
Hiroshiba committed Nov 7, 2020
1 parent 75a7e90 commit 7ccc860
Show file tree
Hide file tree
Showing 7 changed files with 622 additions and 30 deletions.
62 changes: 35 additions & 27 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,32 +63,40 @@ epoch main/loss validation/main/loss main/accuracy validation/main/acc
## Supported Classes
```
pytorch_trainer/
├── iterators
│   ├── multiprocess_iterator.py
│   ├── order_samplers.py
│   └── serial_iterator.py
├── reporter.py
└── training
├── extensions
│   ├── evaluator.py
│   ├── fail_on_nonnumber.py
│   ├── log_report.py
│   ├── micro_average.py
│   ├── plot_report.py
│   ├── print_report.py
│   ├── progress_bar.py
│   ├── snapshot_writers.py
│   └── value_observation.py
├── trainer.py
├── triggers
│   ├── early_stopping_trigger.py
│   ├── interval_trigger.py
│   ├── manual_schedule_trigger.py
│   ├── minmax_value_trigger.py
│   ├── once_trigger.py
│   └── time_trigger.py
└── updaters
└── standard_updater.py
|-- iterators
| |-- multiprocess_iterator.py
| |-- multithread_iterator.py
| |-- order_samplers.py
| `-- serial_iterator.py
|-- reporter.py
`-- training
|-- extensions
| |-- evaluator.py
| |-- exponential_shift.py
| |-- fail_on_nonnumber.py
| |-- inverse_shift.py
| |-- linear_shift.py
| |-- log_report.py
| |-- micro_average.py
| |-- multistep_shift.py
| |-- plot_report.py
| |-- polynomial_shift.py
| |-- print_report.py
| |-- progress_bar.py
| |-- snapshot_writers.py
| |-- step_shift.py
| |-- value_observation.py
| `-- warmup_shift.py
|-- trainer.py
|-- triggers
| |-- early_stopping_trigger.py
| |-- interval_trigger.py
| |-- manual_schedule_trigger.py
| |-- minmax_value_trigger.py
| |-- once_trigger.py
| `-- time_trigger.py
`-- updaters
`-- standard_updater.py
```

## Test
Expand All @@ -98,7 +106,7 @@ pytest -s -v tests

## TODO

- [ ] Scheduler
- [x] Scheduler
- [ ] DataLoader

## License
Expand Down
1 change: 1 addition & 0 deletions pytorch_trainer/iterators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# import classes and functions
from pytorch_trainer.iterators.multiprocess_iterator import MultiprocessIterator # NOQA
from pytorch_trainer.iterators.multithread_iterator import MultithreadIterator # NOQA
from pytorch_trainer.iterators.serial_iterator import SerialIterator # NOQA

from pytorch_trainer.iterators.order_samplers import OrderSampler # NOQA
Expand Down
191 changes: 191 additions & 0 deletions pytorch_trainer/iterators/multithread_iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
from __future__ import division
from multiprocessing import pool

import numpy

from pytorch_trainer.dataset import iterator
from pytorch_trainer.iterators import _statemachine
from pytorch_trainer.iterators.order_samplers import ShuffleOrderSampler


class MultithreadIterator(iterator.Iterator):

"""Dataset iterator that loads examples in parallel.
This is an implementation of :class:`~chainer.dataset.Iterator` that loads
examples with worker threads. It uses the standard :mod:`threading`
module to parallelize the loading.
Note that this iterator effectively prefetches the examples for the next
batch asynchronously after the current batch is returned.
This iterator saves ``-1`` instead of ``None`` in snapshots since some
serializers do not support ``None``.
Args:
dataset (~chainer.dataset.Dataset): Dataset to iterate.
batch_size (int): Number of examples within each batch.
repeat (bool): If ``True``, it infinitely loops over the dataset.
Otherwise, it stops iteration at the end of the first epoch.
shuffle (bool): If ``True``, the order of examples is shuffled at the
beginning of each epoch. Otherwise, examples are extracted in the
order of indexes. If ``None`` and no ``order_sampler`` is given,
the behavior is the same as the case with ``shuffle=True``.
n_threads (int): Number of worker threads.
order_sampler (callable): A callable that generates the order
of the indices to sample in the next epoch when a epoch finishes.
This function should take two arguments: the current order
and the current position of the iterator.
This should return the next order. The size of the order
should remain constant.
This option cannot be used when ``shuffle`` is not ``None``.
"""

def __init__(self, dataset, batch_size, repeat=True, shuffle=None,
n_threads=1, order_sampler=None):
self.dataset = dataset
self.batch_size = batch_size
self._repeat = repeat
self._shuffle = shuffle

if self._shuffle is not None:
if order_sampler is not None:
raise ValueError('`shuffle` is not `None` and a custom '
'`order_sampler` is set. Please set '
'`shuffle` to `None` to use the custom '
'order sampler.')
else:
if self._shuffle:
order_sampler = ShuffleOrderSampler()
else:
if order_sampler is None:
order_sampler = ShuffleOrderSampler()
self.order_sampler = order_sampler

self.n_threads = n_threads
self._pool = None

self.reset()

def reset(self):
if self.order_sampler is None:
order = None
else:
order = self.order_sampler(numpy.arange(len(self.dataset)), 0)
self._state = _statemachine.IteratorState(0, 0, False, order)
self._previous_epoch_detail = -1.

# reset internal state
self._next = None

def finalize(self):
pool = self._pool

self._next = None
self._pool = None
if pool is not None:
pool.terminate()

def __next__(self):
if self._next is None:
# load for the first iteration
self._invoke_prefetch()

batch = self._get()
self._invoke_prefetch() # prefetch for the next iteration
return batch

next = __next__

@property
def current_position(self):
return self._state.current_position

@property
def epoch(self):
return self._state.epoch

@property
def is_new_epoch(self):
return self._state.is_new_epoch

@property
def epoch_detail(self):
return self.epoch + self.current_position / self._epoch_size

@property
def previous_epoch_detail(self):
# use -1 instead of None internally.
if self._previous_epoch_detail < 0:
return None
return self._previous_epoch_detail

def state_dict(self):
state_dict = {
'current_position': self.current_position,
'epoch': self.epoch,
'is_new_epoch': self.is_new_epoch,
}

order = self._state.order.copy()
state_dict['order'] = order

state_dict['previous_epoch_detail'] = self._previous_epoch_detail

return state_dict

def load_state_dict(self, state_dict):
current_position = state_dict['current_position']
epoch = state_dict['epoch']
is_new_epoch = state_dict['is_new_epoch']
order = state_dict['order']
self._state = _statemachine.IteratorState(
current_position, epoch, is_new_epoch, order)
self._previous_epoch_detail = state_dict['previous_epoch_detail']
# Old version serialized ``None``.
if self._previous_epoch_detail is None:
self._previous_epoch_detail = -1.
self._next = None

@staticmethod
def _read(args):
dataset, index = args
return dataset[index]

def _invoke_prefetch(self):
assert self._next is None
self._next_state, indices = _statemachine.iterator_statemachine(
self._state, self.batch_size, self.repeat, self.order_sampler,
len(self.dataset))

if indices is None:
self._next = None
else:
if self._pool is None:
self._pool = pool.ThreadPool(self.n_threads)
args = [(self.dataset, index) for index in indices]
self._next = self._pool.map_async(MultithreadIterator._read, args)

def _get(self):
self._previous_epoch_detail = self.epoch_detail
self._state = self._next_state

next = self._next
if next is None:
raise StopIteration
self._next = None

while not next.ready():
next.wait(0.5) # To avoid interruption bug in Python2

batch = [data for data in next.get()]
return batch

@property
def _epoch_size(self):
order = self._state.order
if order is None:
epoch_size = len(self.dataset)
else:
epoch_size = len(order)
return epoch_size

@property
def repeat(self):
return self._repeat
3 changes: 2 additions & 1 deletion pytorch_trainer/training/extensions/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def __init__(self, iterator, target, converter=convert.concat_examples,

for key, iter in six.iteritems(iterator):
if (isinstance(iter, (iterators.SerialIterator,
iterators.MultiprocessIterator)) and
iterators.MultiprocessIterator,
iterators.MultithreadIterator)) and
getattr(iter, 'repeat', False)):
msg = 'The `repeat` property of the iterator {} '
'is set to `True`. Typically, the evaluator sweeps '
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='pytorch-trainer',
version='1.0.0',
version='1.3.0',
packages=find_packages(),
url='https://github.com/Hiroshiba/pytorch-trainer',
author='Kazuyuki Hiroshiba',
Expand Down
Loading

0 comments on commit 7ccc860

Please sign in to comment.