Tamaki Kojima(tamakoji@gmail.com)
Pytorch 1.0 support
This is alternative implementation of "Synchronized Multi-GPU Batch Normalization" which computes global stats across gpus instead of locally computed. SyncBN are getting important for those input image is large, and must use multi-gpu to increase the minibatch-size for the training.
The code was inspired by Pytorch-Encoding and Inplace-ABN
- Unlike Pytorch-Encoding, you don't need custom nn.DataParallel
- Unlike Inplace-ABN, you can just replace your nn.BatchNorm2dto this module implementation, since it will not mark for inplace operation
- You can plug into arbitrary module written in PyTorch to enable Synchronized BatchNorm
- Backward computation is rewritten and tested against behavior of nn.BatchNorm2d
For PyTorch, please refer to https://pytorch.org/
NOTE : The code is tested only with PyTorch v1.0.0, CUDA10/CuDNN7.4.2 on ubuntu18.04
It utilize Pytorch JIT mechanism to compile seamlessly, using ninja. Please install ninja-build before use.
sudo apt-get install ninja-build
Also install all dependencies for python. For pip, run:
pip install -U -r requirements.txt
There is no need to build. just run and JIT will take care. JIT and cpp extensions are supported after PyTorch0.4, however it is highly recommended to use PyTorch > 1.0 due to huge design changes.
Please refer to test.py for testing the difference between nn.BatchNorm2d and modules.nn.BatchNorm2d
import torch
from modules import nn as NN
num_gpu = torch.cuda.device_count()
model = nn.Sequential(
    nn.Conv2d(3, 3, 1, 1, bias=False),
    NN.BatchNorm2d(3),
    nn.ReLU(inplace=True),
    nn.Conv2d(3, 3, 1, 1, bias=False),
    NN.BatchNorm2d(3),
).cuda()
model = nn.DataParallel(model, device_ids=range(num_gpu))
x = torch.rand(num_gpu, 3, 2, 2).cuda()
z = model(x)
- 
gather all from workers to master and compute where and and then above global stats to be shared to all gpus, update running_mean and running_var by moving average using global stats. 
- 
forward batchnorm using global stats by and then 
- 
Compute below sums on each gpu and then gather them at master node to sum up global, and normalize with N where N is total number of elements for each channels. Global sums are then shared among all gpus. 
- 
compute gradients using global stats where and and finally, Note that in the implementation, normalization with N is performed at step (2) and above equation and implementation is not exactly the same, but mathematically is same. You can go deeper on above explanation at Kevin Zakka's Blog