-
Notifications
You must be signed in to change notification settings - Fork 1
/
example.py
85 lines (63 loc) · 2.91 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from argparse import ArgumentParser
import chainer
import chainer.links as L
from chainer.training import extensions
from exponential_moving_average import ExponentialMovingAverage
from models import VGG
def run(args):
model = L.Classifier(VGG())
train, test = chainer.datasets.get_cifar10()
if len(args.gpu) == 1 and args.gpu[0] < 0:
# CPU mode
pass
elif len(args.gpu) == 1 and args.gpu[0] >= 0:
# single GPU mode
chainer.backends.cuda.get_device_from_id(args.gpu[0]).use()
model.to_gpu()
else:
# multiple GPU mode
chainer.backends.cuda.get_device_from_id(args.gpu[0]).use()
optimizer = chainer.optimizers.MomentumSGD(lr=0.05)
optimizer.setup(model)
optimizer.add_hook(chainer.optimizer_hooks.WeightDecay(5e-4))
train_iter = chainer.iterators.SerialIterator(train, args.batch_size)
test_iter = chainer.iterators.SerialIterator(test, args.batch_size, repeat=False, shuffle=False)
if len(args.gpu) == 1:
updater = chainer.training.updater.StandardUpdater(train_iter, optimizer, device=args.gpu[0])
else:
devices = {str(i): gpu for i, gpu in enumerate(args.gpu[1:])}
devices['main'] = args.gpu[0]
updater = chainer.training.updater.ParallelUpdater(train_iter, optimizer, devices=devices)
trainer = chainer.training.Trainer(updater, (args.epochs, 'epoch'), out=args.out)
if args.ema_rate != 0.0:
print('use ema')
ema = ExponentialMovingAverage(target=model, rate=args.ema_rate, device=args.gpu[0])
optimizer.add_hook(ema)
eval_model = ema.shadow_target
trainer.extend(ema)
else:
print('no ema')
eval_model = model
# here `eval_model` is passed to the evaluator instead of ordinal `model`
trainer.extend(extensions.Evaluator(test_iter, eval_model, device=args.gpu[0]))
# add ordinary extensions
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'elapsed_time', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy']))
trainer.extend(extensions.ProgressBar())
trainer.extend(extensions.snapshot(), trigger=(1, 'epoch'))
if args.resume:
chainer.serializers.load_npz(args.resume, trainer)
trainer.run()
def main():
parser = ArgumentParser(description='Exponential moving decay at chainer')
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--epochs', type=int, default=300)
parser.add_argument('--gpu', '-g', nargs='+', type=int, default=[-1])
parser.add_argument('--out', '-o', default='result')
parser.add_argument('--ema-rate', type=float, default=0.99,
help='Exponential moving decay rate. If 0, ema are not applied')
parser.add_argument('--resume', default='')
run(parser.parse_args())
if __name__ == '__main__':
main()