Skip to content

Commit ea518f7

Browse files
authored
Merge pull request #15 from openai/documentation
Documentation
2 parents 1473af1 + 89cd479 commit ea518f7

File tree

4 files changed

+144
-19
lines changed

4 files changed

+144
-19
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ benchmark machine learning systems' vulnerability to
77
[adversarial examples](http://karpathy.github.io/2015/03/30/breaking-convnets/)
88
.
99

10+
Note: this library is still in active development.
11+
1012
## Setting up `cleverhans`
1113

1214
### Dependencies
@@ -50,7 +52,6 @@ Bug fixes can be initiated through Github pull requests.
5052
The following authors contributed to this library (by alphabetical order):
5153
* Ian Goodfellow (OpenAI)
5254
* Nicolas Papernot (Pennsylvania State University)
53-
* Ryan Sheatsley (Pennsylvania State University)
5455

5556
## Copyright
5657

cleverhans/attacks.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313

1414
def fgsm(x, predictions, eps, back='tf'):
1515
"""
16-
16+
A wrapper for the Fast Gradient Sign Method.
17+
It calls the right function, depending on the
18+
user's backend.
1719
:param sess:
1820
:param x:
1921
:param y:
@@ -31,6 +33,14 @@ def fgsm(x, predictions, eps, back='tf'):
3133
raise NotImplementedError("Theano FGSM not implemented.")
3234

3335
def fgsm_tf(x, predictions, eps):
36+
"""
37+
TensorFlow implementation of the Fast Gradient
38+
Sign method.
39+
:param x: the input placeholder
40+
:param predictions: the model's output tensor
41+
:param eps: the epsilon (input variation parameter)
42+
:return: a tensor for the adversarial example
43+
"""
3444
# Define loss
3545

3646
y = tf.to_float(tf.equal(predictions, tf.reduce_max(predictions, 1, keep_dims=True)))

cleverhans/utils_tf.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,16 @@ def tf_model_train(sess, x, y, predictions, X_train, Y_train, save=False,
3333
predictions_adv=None):
3434
"""
3535
Train a TF graph
36-
:param sess:
37-
:param x:
38-
:param y:
39-
:param model:
40-
:param X_train:
41-
:param Y_train:
42-
:param save:
43-
:return:
36+
:param sess: TF session to use when training the graph
37+
:param x: input placeholder
38+
:param y: output placeholder (for labels)
39+
:param predictions: model output predictions
40+
:param X_train: numpy array with training inputs
41+
:param Y_train: numpy array with training outputs
42+
:param save: Boolean controling the save operation
43+
:param predictions_adv: if set with the adversarial example tensor,
44+
will run adversarial training
45+
:return: True if model trained
4446
"""
4547
print "Starting model training using TensorFlow."
4648

@@ -93,14 +95,14 @@ def tf_model_train(sess, x, y, predictions, X_train, Y_train, save=False,
9395

9496
def tf_model_eval(sess, x, y, model, X_test, Y_test):
9597
"""
96-
97-
:param sess:
98-
:param x:
99-
:param y:
100-
:param model:
101-
:param X_test:
102-
:param Y_test:
103-
:return:
98+
Compute the accuracy of a TF model on some data
99+
:param sess: TF session to use when training the graph
100+
:param x: input placeholder
101+
:param y: output placeholder (for labels)
102+
:param model: model output predictions
103+
:param X_test: numpy array with training inputs
104+
:param Y_test: numpy array with training outputs
105+
:return: a float with the accuracy value
104106
"""
105107
# Define sympbolic for accuracy
106108
acc_value = keras.metrics.categorical_accuracy(y, model)
@@ -145,6 +147,9 @@ def tf_model_load(sess):
145147
return True
146148

147149
def batch_eval(sess, tf_inputs, tf_outputs, numpy_inputs):
150+
"""
151+
A helper function that computes a tensor on numpy inputs by batches.
152+
"""
148153
n = len(numpy_inputs)
149154
assert n > 0
150155
assert n == len(tf_inputs)

tutorials/mnist_tutorial.md

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,116 @@ it is made up of multiple convolutional and ReLU layers.
3232
You can find the model definition in the
3333
[`utils_mnist` cleverhans module](https://github.com/openai/cleverhans/blob/master/cleverhans/utils_mnist.py).
3434

35-
TODO(insert code snippet here)
35+
```
36+
# Define input TF placeholder
37+
x = tf.placeholder(tf.float32, shape=(None, 1, 28, 28))
38+
y = tf.placeholder(tf.float32, shape=(None, FLAGS.nb_classes))
39+
40+
# Define TF model graph
41+
model = model_mnist()
42+
predictions = model(x)
43+
print "Defined TensorFlow model graph."
44+
```
3645

3746
## Training the model with TensorFlow
3847

48+
The library includes a helper function that runs a
49+
TensorFlow optimizer to train models and another
50+
helper function to load the MNIST dataset.
51+
To train our MNIST model, we run the following:
52+
53+
```
54+
# Get MNIST test data
55+
X_train, Y_train, X_test, Y_test = data_mnist()
56+
57+
# Train an MNIST model
58+
tf_model_train(sess, x, y, predictions, X_train, Y_train)
59+
```
60+
61+
We can then evaluate the performance of this model
62+
using `tf_model_eval` included in `cleverhans.utils_tf`:
63+
64+
```
65+
# Evaluate the accuracy of the MNIST model on legitimate test examples
66+
accuracy = tf_model_eval(sess, x, y, predictions, X_test, Y_test)
67+
assert X_test.shape[0] == 10000, X_test.shape
68+
print 'Test accuracy on legitimate test examples: ' + str(accuracy)
69+
```
70+
71+
The accuracy returned should be above `97%`.
72+
73+
## Crafting adversarial examples
74+
75+
This tutorial applies the Fast Gradient Sign method
76+
introduced by [Goodfellow et al.](https://arxiv.org/abs/1412.6572).
77+
We first need to create the necessary graph elements by
78+
calling `cleverhans.attacks.fgsm` before using the helper
79+
function `cleverhans.utils_tf.batch_eval` to apply it to
80+
our test set. This gives the following:
81+
82+
```
83+
# Craft adversarial examples using Fast Gradient Sign Method (FGSM)
84+
adv_x = fgsm(x, predictions, eps=0.3)
85+
X_test_adv, = batch_eval(sess, [x], [adv_x], [X_test])
86+
assert X_test_adv.shape[0] == 10000, X_test_adv.shape
87+
88+
# Evaluate the accuracy of the MNIST model on adversarial examples
89+
accuracy = tf_model_eval(sess, x, y, predictions, X_test_adv, Y_test)
90+
print'Test accuracy on adversarial examples: ' + str(accuracy)
91+
```
92+
93+
The second part evaluates the accuracy of the model on
94+
adversarial examples in a similar way than described
95+
previously for legitimate examples. It should be lower
96+
than the previous accuracy you obtained.
97+
98+
99+
## Improving robustness using adversarial training
100+
101+
One defense strategy to mitigate adversarial examples is to use
102+
adversarial training, i.e. train the model with both the
103+
original data and adversarially modified data (with correct
104+
labels). You can use the training function `utils_tf.tf_model_train`
105+
with the optional argument `predictions_adv` set to the result
106+
of `cleverhans.attacks.fgsm` in order to perform adversarial
107+
training.
108+
109+
In the following snippet, we first declare a new model (in a
110+
way similar to the one described previously) and then we train
111+
it with both legitimate and adversarial training points.
112+
113+
```
114+
# Redefine TF model graph
115+
model_2 = model_mnist()
116+
predictions_2 = model_2(x)
117+
adv_x_2 = fgsm(x, predictions_2, eps=0.3)
118+
predictions_2_adv = model_2(adv_x_2)
119+
120+
# Perform adversarial training
121+
tf_model_train(sess, x, y, predictions_2, X_train, Y_train, predictions_adv=predictions_2_adv)
122+
```
123+
124+
We can then verify that (1) its accuracy on legitimate data is
125+
still comparable to the first model, (2) its accuracy on newly
126+
generated adversarial examples is higher.
127+
128+
```
129+
# Evaluate the accuracy of the adversarialy trained MNIST model on
130+
# legitimate test examples
131+
accuracy = tf_model_eval(sess, x, y, predictions_2, X_test, Y_test)
132+
print 'Test accuracy on legitimate test examples: ' + str(accuracy)
133+
134+
# Craft adversarial examples using Fast Gradient Sign Method (FGSM) on
135+
# the new model, which was trained using adversarial training
136+
X_test_adv_2, = batch_eval(sess, [x], [adv_x_2], [X_test])
137+
assert X_test_adv_2.shape[0] == 10000, X_test_adv_2.shape
138+
139+
# Evaluate the accuracy of the adversarially trained MNIST model on
140+
# adversarial examples
141+
accuracy_adv = tf_model_eval(sess, x, y, predictions_2, X_test_adv_2, Y_test)
142+
print'Test accuracy on adversarial examples: ' + str(accuracy_adv)
143+
```
144+
145+
## Code
146+
147+
The complete code for this tutorial is available [here](https://github.com/openai/cleverhans/blob/master/tutorials/mnist_tutorial.py).

0 commit comments

Comments
 (0)