Official PyTorch implementation of DiWA | paper, openreview
Alexandre Ramé, Matthieu Kirchmeyer, Thibaud Rahier, Alain Rakotomamonjy, Patrick Gallinari, Matthieu Cord
To improve out-of-distribution generalization, we average diverse weights obtained from different training runs; this strategy is motivated by an extension of the bias-variance theory to weight averaging and is state-of-the-art on DomainBed.
Standard neural networks struggle to generalize under distribution shifts. For out-of-distribution generalization in computer vision, the best current approach averages the weights along a training run. In this paper, we propose Diverse Weight Averaging (DiWA) that makes a simple change to this strategy: DiWA averages the weights obtained from several independent training runs rather than from a single run. Perhaps surprisingly, averaging these weights performs well under soft constraints despite the network's nonlinearities. The main motivation behind DiWA is to increase the functional diversity across averaged models. Indeed, models obtained from different runs are more diverse than those collected along a single run thanks to differences in hyperparameters and training procedures. We motivate the need for diversity by a new bias-variance-covariance-locality decomposition of the expected error, exploiting similarities between DiWA and standard functional ensembling. Moreover, this decomposition highlights that DiWA succeeds when the variance term dominates, which we show happens when the marginal distribution changes at test time. Experimentally, DiWA consistently improves the state of the art on the competitive DomainBed benchmark without inference overhead.
Our code is adapted from the open-source DomainBed github, which is a PyTorch benchmark including datasets and algorithms for Out-of-Distribution generalization. It was introduced in In Search of Lost Domain Generalization, ICLR 2021.
In addition to the newly-added domainbed/scripts/diwa.py
and domainbed/algorithms_inference.py
files, we made only few modifications to this codebase, all preceded by ## DiWA ##
.
- in
domainbed/hparams_registry.py
, to define our mild hyperparameter ranges. - in
domainbed/train.py
, to handle the shared initialization and save the weights of the epoch with the highest validation accuracy. - in
domainbed/algorithms.py
, to handle the shared initialization, the linear probing approach and implement the MA baseline. - in
domainbed/datasets.py
, to define the checkpoint frequency. - in
domainbed/scripts/sweep.py
, to be able to force the test env. - in
domainbed/lib/misc.py
, to include some tools.
Then you should be able to reproduce our main experiment (Table 1) on the DomainBed benchmark.
- python == 3.7.10
- torch == 1.8.1
- torchvision == 0.9.1
- numpy == 1.20.2
We ran DiWA on the following datasets:
- VLCS (Fang et al., 2013)
- PACS (Li et al., 2017)
- OfficeHome (Venkateswara et al., 2017)
- A TerraIncognita (Beery et al., 2018) subset
- DomainNet (Peng et al., 2019)
- Colored MNIST (Arjovsky et al., 2019)
You can download the datasets with following command:
python3 -m domainbed.scripts.download --data_dir=/my/data/dir
Our training procedure is in three stages.
First, we need to fix the initialization.
python3 -m domainbed.scripts.train\
--data_dir=/my/data/dir/\
--algorithm ERM\
--dataset OfficeHome\
--test_env ${test_env}\
--init_step\
--path_for_init ${path_for_init}\
--steps ${steps}\
In the paper, we proposed
- random initialization, set
steps
to-1
: there will be no training. - Linear Probing, ICLR2022, set
steps
to0
: only the classifier will be trained.
The initialization is then saved at ${path_for_init}
, to be used in the subsequent sweep.
Second, we launch several ERM runs following the hyperparameter distributions from here, as defined in Table 5 from Appendix F.1.
To do so, we leverage the native sweep
script from DomainBed.
python -m domainbed.scripts.sweep launch\
--data_dir=/my/data/dir/\
--output_dir=/my/sweep/output/path\
--command_launcher multi_gpu\
--datasets OfficeHome\
--test_env ${test_env}\
--path_for_init ${path_for_init}\
--algorithms ERM\
--n_hparams 20\
--n_trials 3
Finally, we average the weights obtained from this grid search.
python -m domainbed.scripts.diwa\
--data_dir=/my/data/dir/\
--output_dir=/my/sweep/output/path\
--dataset OfficeHome\
--test_env ${test_env}\
--weight_selection ${weight_selection}
--trial_seed ${trial_seed}
In the paper, we proposed
- DiWA-restricted, set
weight_selection
torestricted
andtrial_seed
to an integer between0
and2
. - DiWA-uniform, set
weight_selection
touniform
andtrial_seed
to an integer between0
and2
. - DiWA$^\dagger$-uniform, set
weight_selection
touniform
andtrial_seed
to-1
.
You can reproduce the Moving Average (MA) baseline by replacing ERM by MA as the algorithm argument.
python -m domainbed.scripts.sweep launch\
--data_dir=/my/data/dir/\
--output_dir=/my/sweep/output/path\
--command_launcher multi_gpu\
--datasets OfficeHome\
--test_env ${test_env}\
--algorithms MA\
--n_hparams 20\
--n_trials 3
Then to view the results of your sweep:
python -m domainbed.scripts.collect_results --input_dir=/my/sweep/output/path
DiWA sets a new state of the art on DomainBed.
Algorithm | Weight selection | Init | PACS | VLCS | OfficeHome | TerraInc | DomainNet | Avg |
---|---|---|---|---|---|---|---|---|
ERM | N/A | Random | 85.5 | 77.5 | 66.5 | 46.1 | 40.9 | 63.3 |
Coral | N/A | Random | 86.2 | 78.8 | 68.7 | 47.6 | 41.5 | 64.6 |
SWAD | Overfit-aware | Random | 88.1 | 79.1 | 70.6 | 50.0 | 46.5 | 66.9 |
MA | Uniform | Random | 87.5 | 78.2 | 70.6 | 50.3 | 46.0 | 66.5 |
--- | --- | --- | --- | --- | --- | --- | --- | --- |
ERM | N/A | Random | 85.5 | 77.6 | 67.4 | 48.3 | 44.1 | 64.6 |
DiWA | Restricted | Random | 87.9 | 79.2 | 70.5 | 50.5 | 46.7 | 67.0 |
DiWA | Uniform | Random | 88.8 | 79.1 | 71.0 | 48.9 | 46.1 | 66.8 |
DiWA$^{\dagger}$ | Uniform | Random | 89.0 | 79.4 | 71.6 | 49.0 | 46.3 | 67.1 |
--- | --- | --- | --- | --- | --- | --- | --- | --- |
ERM | N/A | LP | 85.9 | 78.1 | 69.4 | 50.4 | 44.3 | 65.6 |
DiWA | Restricted | LP | 88.0 | 78.5 | 71.5 | 51.6 | 47.7 | 67.5 |
DiWA | Uniform | LP | 88.7 | 78.4 | 72.1 | 51.4 | 47.4 | 67.6 |
DiWA$^{\dagger}$ | Uniform | LP | 89.0 | 78.6 | 72.8 | 51.9 | 47.7 | 68.0 |
If you find this code useful for your research, please consider citing our work:
@inproceedings{rame2022diwa,
title = {Diverse Weight Averaging for Out-of-Distribution Generalization},
author = {Rame, Alexandre and Kirchmeyer, Matthieu and Rahier, Thibaud and Rakotomamonjy, Alain and Gallinari, Patrick and Cord, Matthieu},
year = {2022},
booktitle = {NeurIPS}
}
Correspondence to alexandre.rame at sorbonne-universite dot fr