Skip to content

Commit

Permalink
Merge pull request #4 from oval-group/camera-ready
Browse files Browse the repository at this point in the history
release
  • Loading branch information
lberrada authored Apr 24, 2019
2 parents 1528111 + 6d019a9 commit 1b31f03
Show file tree
Hide file tree
Showing 35 changed files with 470 additions and 389 deletions.
4 changes: 2 additions & 2 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[submodule "src/InferSent"]
path = src/InferSent
[submodule "experiments/InferSent"]
path = experiments/InferSent
url = https://github.com/lberrada/InferSent
122 changes: 68 additions & 54 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,72 @@
This repository contains the implementation of the paper [Deep Frank-Wolfe For Neural Network Optimization](https://arxiv.org/abs/1811.07591) in pytorch. If you use this work for your research, please cite the paper:

```
@Article{berrada2018deep,
@Article{berrada2019deep,
author = {Berrada, Leonard and Zisserman, Andrew and Kumar, M Pawan},
title = {Deep Frank-Wolfe For Neural Network Optimization},
journal = {Under review},
year = {2018},
journal = {International Conference on Learning Representations},
year = {2019},
}
```

The DFW algorithm is a first-order optimization algorithm for deep neural networks. To use it for your learning task, consider the two following requirements:
* the loss function has to be convex piecewise linear function (e.g. multi-class SVM [as implemented here](src/losses/hinge.py#L5), or l1 loss)
* the optimizer needs access to the value of the loss function of the current mini-batch [as shown here](src/epoch.py#L31)
## Requirements

Beside these requirements, the optimizer can be used as plug-and-play, and its independent code is available in [src/optim/dfw.py](src/optim/dfw.py)
This code should work for pytorch >= 1.0 in python3. Detailed requirements are available in `requirements.txt`.

## Requirements
## Installation

* Clone this repository: `git clone --recursive https://github.com/oval-group/dfw` (note that the option `recursive` is necessary to have clone the submodules, these are needed to reproduce the experiments but not for the DFW implementation itself).
* Go to directory and install the requirements: `cd dfw && pip install -r requirements.txt`
* Install the DFW package `python setup.py install`

## Example of Usage

* Simple usage example:
```python
from dfw import DFW
from dfw.losses import MultiClassHingeLoss


# boilerplate code:
# `model` is a nn.Module
# `x` is an input sample, `y` is a label

# create loss function
svm = MultiClassHingeLoss()

# create DFW optimizer with learning rate of 0.1
optimizer = DFW(model.parameters(), eta=0.1)

# DFW can be used with standard pytorch syntax
optimizer.zero_grad()
loss = svm(model(x), y)
loss.backward()
# NB: DFW needs to have access to the current loss value,
# (this syntax is compatible with standard pytorch optimizers too)
optimizer.step(lambda: float(loss))
```

This code has been tested for pytorch 0.4.1 in python3. Detailed requirements are available in `requirements.txt`.
* Technical requirement: the DFW uses a custom step-size at each step. For this update to make sense, the loss function must be piecewise linear convex.
For instance, one can use a multi-class SVM loss or an l1 regression.

* Smoothing: sometimes the multi-class SVM loss does not fare well with a large number of classes.
This issue can be alleviated by using dual smoothing, which is easy to plug in the code:
```python
from dfw.losses import set_smoothing_enabled
...
with set_smoothing_enabled(True):
loss = svm(model(x), y)
```

## Reproducing the Results

* To reproduce the CIFAR experiments: `VISION_DATA=[path/to/your/cifar/data] python scripts/reproduce_cifar.py`
* To reproduce the SNLI experiments: follow the [preparation instructions](https://github.com/lberrada/InferSent/tree/c4ded441cf701c256126c5283e4381abb8271792) and run `python scripts/reproduce_snli.py`
![alt text](plot_cifar.png)

* To reproduce the CIFAR experiments: `VISION_DATA=[path/to/your/cifar/data] python reproduce/cifar.py`
* To reproduce the SNLI experiments: follow the [preparation instructions](https://github.com/lberrada/InferSent/tree/dfw#download-datasets) and run `python reproduce/snli.py`

Note that SGD benefits from a hand-designed learning rate schedule. In contrast, all the other optimizers (including DFW) automatically adapt their steps and rely on the tuning of the initial learning rate only.
On average, you should obtain similar results to the ones reported in the paper (there might be some variance on some instances of CIFAR experiments):
DFW largely outperforms all baselines that do not use a manual schedule for the learning rate.
The tables below show the performance on the CIFAR data sets when using data augmentation (AMSGrad, a variant of Adam, is the strongest baseline in our experiments), and on the SNLI data set.

### CIFAR-10:

Expand All @@ -37,23 +78,17 @@ On average, you should obtain similar results to the ones reported in the paper

| Optimizer | Test Accuracy (%) |
| --------- | :--------------: |
| Adagrad | 86.07 |
| Adam | 84.86 |
| AMSGrad | 86.08 |
| BPGrad | 88.62 |
| **DFW** | **90.18** |
| SGD | 90.08 |
| AMSGrad | 90.1 |
| **DFW** | **94.7** |
| SGD (with schedule) | 95.4 |

</td><td>

| Optimizer | Test Accuracy (%) |
| --------- | :--------------: |
| Adagrad | 87.32 |
| Adam | 88.44 |
| AMSGrad | 90.53 |
| **BPGrad**| **90.85** |
| DFW | 90.22 |
| **SGD** | **92.02** |
| AMSGrad | 91.8 |
| **DFW** | **94.9** |
| SGD (with schedule) | 95.3 |

</td></tr> </table>

Expand All @@ -65,43 +100,23 @@ On average, you should obtain similar results to the ones reported in the paper

| Optimizer | Test Accuracy (%) |
| --------- | :--------------: |
| Adagrad | 57.64 |
| Adam | 58.46 |
| AMSGrad | 60.73 |
| BPGrad | 60.31 |
| **DFW** | **67.83** |
| SGD | 66.78 |
| AMSGrad | 67.8 |
| **DFW** | **74.7** |
| SGD (with schedule) | 77.8 |

</td><td>

| Optimizer | Test Accuracy (%) |
| --------- | :--------------: |
| Adagrad | 56.47 |
| Adam | 64.61 |
| AMSGrad | 68.32 |
| BPGrad | 59.36 |
| **DFW** | **69.55** |
| **SGD** | **70.33** |
| AMSGrad | 69.6 |
| **DFW** | **73.2** |
| SGD (with schedule) | 76.3 |

</td></tr> </table>

### SNLI:

<table>
<tr><th>CE Loss</th><th>SVM Loss</th></tr>
<tr><td>

| Optimizer | Test Accuracy (%) |
| --------- | :--------------: |
| Adagrad | 83.8 |
| Adam | 84.5 |
| AMSGrad | 84.2 |
| BPGrad | 83.6 |
| DFW | - |
| SGD | 84.7 |
| SGD* | 84.5 |

</td><td>

| Optimizer | Test Accuracy (%) |
| --------- | :--------------: |
Expand All @@ -110,10 +125,9 @@ On average, you should obtain similar results to the ones reported in the paper
| AMSGrad | 85.1 |
| BPGrad | 84.2 |
| **DFW** | **85.2** |
| **SGD** | **85.2** |
| SGD* | - |
| SGD (with schedule) | 85.2 |

</td></tr> </table>
</table>

## Acknowledgments

Expand Down
1 change: 1 addition & 0 deletions dfw/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .dfw import DFW
1 change: 1 addition & 0 deletions dfw/baselines/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .bpgrad import BPGrad
11 changes: 6 additions & 5 deletions src/optim/bpgrad.py → dfw/baselines/bpgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,11 @@ def step(self, closure):
for group in self.param_groups:
L = group['L']
mu = group['momentum']
for p in group['params']:
v = self.state[p]['v']
v *= mu
v -= step_size / L * p.grad.data
p.data += v
if mu:
for p in group['params']:
v = self.state[p]['v']
v *= mu
v -= step_size / L * p.grad.data
p.data += v

self.gamma = step_size
File renamed without changes.
1 change: 1 addition & 0 deletions dfw/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .hinge import MultiClassHingeLoss, set_smoothing_enabled
File renamed without changes.
1 change: 1 addition & 0 deletions experiments/InferSent
Submodule InferSent added at a4c6b4
6 changes: 3 additions & 3 deletions src/cli.py → experiments/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def _add_dataset_parser(parser):
help="val data size")
d_parser.add_argument('--test-size', type=int, default=None,
help="test data size")
d_parser.add_argument('--no-data-augmentation', dest='data_aug',
action='store_false', help='no data augmentation')
d_parser.set_defaults(data_aug=True)
d_parser.add_argument('--data-augmentation', dest='augment',
action='store_true', help='use data augmentation')
d_parser.set_defaults(augment=False)


def _add_model_parser(parser):
Expand Down
2 changes: 1 addition & 1 deletion src/cuda.py → experiments/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
try:
import waitGPU
ngpu = int(os.environ['NGPU']) if 'NGPU' in os.environ else 1
waitGPU.wait(nproc=0, interval=10, ngpu=ngpu)
waitGPU.wait(nproc=0, interval=10, ngpu=ngpu, gpu_ids=[2,3])
except ImportError:
print('Failed to import waitGPU --> no automatic scheduling on GPU')
pass
Expand Down
File renamed without changes.
9 changes: 6 additions & 3 deletions src/data/loaders.py → experiments/data/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def create_loaders(dataset_train, dataset_val, dataset_test,

def loaders_mnist(dataset, batch_size=64, cuda=0,
train_size=50000, val_size=10000, test_size=10000,
test_batch_size=1000, augment=False, **kwargs):
test_batch_size=1000, **kwargs):

assert dataset == 'mnist'
root = '{}/{}'.format(os.environ['VISION_DATA'], dataset)
Expand All @@ -86,11 +86,10 @@ def loaders_mnist(dataset, batch_size=64, cuda=0,


def loaders_cifar(dataset, batch_size, cuda,
train_size=45000, augment=False, val_size=5000, test_size=10000,
train_size=45000, augment=True, val_size=5000, test_size=10000,
test_batch_size=128, **kwargs):

assert dataset in ('cifar10', 'cifar100')
# assert topk is None or topk == 1, "Top-k not wanted for CIFAR for now"

root = '{}/{}'.format(os.environ['VISION_DATA'], dataset)

Expand All @@ -105,12 +104,14 @@ def loaders_cifar(dataset, batch_size, cuda,
normalize])

if augment:
print('Using data augmentation on CIFAR data set.')
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize])
else:
print('Not using data augmentation on CIFAR data set.')
transform_train = transform_test

# define two datasets in order to have different transforms
Expand Down Expand Up @@ -148,12 +149,14 @@ def loaders_svhn(dataset, batch_size, cuda,
normalize])

if augment:
print('Using data augmentation on SVHN data set.')
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize])
else:
print('Not using data augmentation on SVHN data set.')
transform_train = transform_test

# define two datasets in order to have different transforms
Expand Down
File renamed without changes.
95 changes: 95 additions & 0 deletions experiments/epoch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch

from tqdm import tqdm
from dfw.losses import set_smoothing_enabled
from utils import accuracy, regularization


def train(model, loss, optimizer, loader, args, xp):

model.train()

for metric in xp.train.metrics():
metric.reset()

for x, y in tqdm(loader, disable=not args.tqdm, desc='Train Epoch',
leave=False, total=len(loader)):
(x, y) = (x.cuda(), y.cuda()) if args.cuda else (x, y)

# forward pass
scores = model(x)

# compute the loss function, possibly using smoothing
with set_smoothing_enabled(args.smooth_svm):
loss_value = loss(scores, y)

# backward pass
optimizer.zero_grad()
loss_value.backward()

# optimization step
optimizer.step(lambda: float(loss_value))

# monitoring
batch_size = x.size(0)
xp.train.acc.update(accuracy(scores, y), weighting=batch_size)
xp.train.loss.update(loss(scores, y), weighting=batch_size)
xp.train.gamma.update(optimizer.gamma, weighting=batch_size)

xp.train.eta.update(optimizer.eta)
w_norm = torch.sqrt(sum(p.norm() ** 2 for p in model.parameters()))
xp.train.reg.update(0.5 * args.l2 * xp.train.weight_norm.value ** 2)
xp.train.obj.update(xp.train.reg.value + xp.train.loss.value)
xp.train.timer.update()

print('\nEpoch: [{0}] (Train) \t'
'({timer:.2f}s) \t'
'Obj {obj:.3f}\t'
'Loss {loss:.3f}\t'
'Acc {acc:.2f}%\t'
.format(int(xp.epoch.value),
timer=xp.train.timer.value,
acc=xp.train.acc.value,
obj=xp.train.obj.value,
loss=xp.train.loss.value))

for metric in xp.train.metrics():
metric.log(time=xp.epoch.value)


@torch.autograd.no_grad()
def test(model, loader, args, xp):
model.eval()

if loader.tag == 'val':
xp_group = xp.val
else:
xp_group = xp.test

for metric in xp_group.metrics():
metric.reset()

for x, y in tqdm(loader, disable=not args.tqdm,
desc='{} Epoch'.format(loader.tag.title()),
leave=False, total=len(loader)):
(x, y) = (x.cuda(), y.cuda()) if args.cuda else (x, y)
scores = model(x)
xp_group.acc.update(accuracy(scores, y), weighting=x.size(0))

xp_group.timer.update()

print('Epoch: [{0}] ({tag})\t'
'({timer:.2f}s) \t'
'Obj ----\t'
'Loss ----\t'
'Acc {acc:.2f}% \t'
.format(int(xp.epoch.value),
tag=loader.tag.title(),
timer=xp_group.timer.value,
acc=xp_group.acc.value))

if loader.tag == 'val':
xp.max_val.update(xp.val.acc.value).log(time=xp.epoch.value)

for metric in xp_group.metrics():
metric.log(time=xp.epoch.value)
2 changes: 1 addition & 1 deletion src/losses/__init__.py → experiments/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch.nn as nn
from losses.hinge import MultiClassHingeLoss, set_smoothing_enabled
from dfw.losses import MultiClassHingeLoss, set_smoothing_enabled


def get_loss(args):
Expand Down
Loading

0 comments on commit 1b31f03

Please sign in to comment.