Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add research code to reproduce Truth serum #234

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
6 changes: 3 additions & 3 deletions research/mi_lira_2021/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

This directory contains code to reproduce our paper:

**"Membership Inference Attacks From First Principles"**
https://arxiv.org/abs/2112.03570
by Nicholas Carlini, Steve Chien, Milad Nasr, Shuang Song, Andreas Terzis, and Florian Tramer.
**"Membership Inference Attacks From First Principles"** <br>
https://arxiv.org/abs/2112.03570 <br>
by Nicholas Carlini, Steve Chien, Milad Nasr, Shuang Song, Andreas Terzis, and Florian Tramèr.


### INSTALLING
Expand Down
30 changes: 7 additions & 23 deletions research/mi_lira_2021/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import os
from typing import Callable
import json

import os
import re
import jax
import jax.numpy as jn
import numpy as np
import tensorflow as tf # For data augmentation.
import tensorflow_datasets as tfds
from absl import app, flags
from tqdm import tqdm, trange
import pickle
from functools import partial

import numpy as np
import objax
from objax.jaxboard import SummaryWriter, Summary
from objax.util import EasyDict
from objax.zoo import convnet, wide_resnet

from dataset import DataSet
import tensorflow as tf # For data augmentation.
from absl import app
from absl import flags

from train import MemModule, network
from train import MemModule
from train import network

from collections import defaultdict
FLAGS = flags.FLAGS


Expand Down Expand Up @@ -142,9 +129,6 @@ def features(model, xbatch, ybatch):
flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
flags.DEFINE_string('logdir', 'experiments/', 'Directory where to save checkpoints and tensorboard data.')
flags.DEFINE_string('regex', '.*experiment.*', 'keep files when matching')
flags.DEFINE_bool('random_labels', False, 'use random labels.')
flags.DEFINE_integer('dataset_size', 50000, 'size of dataset.')
flags.DEFINE_integer('from_epoch', None, 'which epoch to load from.')
flags.DEFINE_integer('seed_mod', None, 'keep mod seed.')
flags.DEFINE_integer('modulus', 8, 'modulus.')
app.run(main)
5 changes: 3 additions & 2 deletions research/mi_lira_2021/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,5 +220,6 @@ def fig_fpr_tpr():
plt.show()


load_data("exp/cifar10/")
fig_fpr_tpr()
if __name__ == '__main__':
load_data("exp/cifar10/")
fig_fpr_tpr()
5 changes: 2 additions & 3 deletions research/mi_lira_2021/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@
import tensorflow as tf # For data augmentation.
import tensorflow_datasets as tfds
from absl import app, flags
from tqdm import tqdm, trange

import objax
from objax.jaxboard import SummaryWriter, Summary
from objax.util import EasyDict
from objax.zoo import convnet, wide_resnet, dnnet
from objax.zoo import convnet, wide_resnet

from dataset import DataSet

Expand Down Expand Up @@ -269,7 +268,7 @@ def main(argv):
logdir = "experiment-"+str(seed)
logdir = os.path.join(FLAGS.logdir, logdir)

if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz"%10)):
if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz"%FLAGS.epochs)):
print(f"run {FLAGS.expid} already completed.")
return
else:
Expand Down
118 changes: 118 additions & 0 deletions research/mi_poison_2022/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
## Truth Serum: Poisoning Machine Learning Models to Reveal Their Secrets

This directory contains code to reproduce results from the paper:

**"Truth Serum: Poisoning Machine Learning Models to Reveal Their Secrets"**<br>
https://arxiv.org/abs/2204.00032 <br>
by Florian Tramèr, Reza Shokri, Ayrton San Joaquin, Hoang Le, Matthew Jagielski, Sanghyun Hong and Nicholas Carlini

### INSTALLING

The experiments in this directory are built on top of the [LiRA
membership inference attack](../mi_lira_2021).

After following the [installation instructions](../mi_lira_2021#installing)
for LiRa, make sure the attack code is on your `PYTHONPATH`:

```bash
export PYTHONPATH="${PYTHONPATH}:../mi_lira_2021"
```


### RUNNING THE CODE

#### 1. Train the models

The first step in our attack is to train shadow models, with some data points
targeted by a poisoning attack. You can train 16 shadow models
with the command

> bash scripts/train_demo.sh

or if you have multiple GPUs on your machine and want to train these models
in parallel, then modify and run

> bash scripts/train_demo_multigpu.sh

This will train 16 CIFAR-10 wide ResNet models to ~91% accuracy each, with
250 points targeted for poisoning. For each of these 250 targeted points, the
attacker adds 8 mislabeled poisoned copies of the point into the training set.
The training run will output a bunch of files under the directory exp/cifar10 with structure:

```
exp/cifar10/
- xtrain.npy
- ytain.npy
- poison_pos.npy
- experiment_N_of_16
-- hparams.json
-- keep.npy
-- ckpt/
--- 0000000100.npz
-- tb/
```

The following flags control the poisoning attack:
- `num_poison_targets (default=250)`. The number of targeted points.
- `poison_reps (default=8)`. The number of replicas per poison.
- `poison_pos_seed (default=0)`. The random seed to use to choose the target points.

We recommend that `num_poison_targets * poison_reps < 5000` on CIFAR-10, as
otherwise the poisons introduce too much label noise and the model's
accuracy (and the attack's success rate) will be degraded.

#### 2. Perform inference and compute scores

Exactly as for LiRA, we then evaluate the models on the entire CIFAR-10 dataset,
and generate logit-scaled membership inference scores.
See [here](../mi_lira_2021#2-perform-inference)
and [here](../mi_lira_2021#3-compute-membership-inference-scores)
for details.

```bash
python3 -m inference --logdir=exp/cifar10/
python3 -m score exp/cifar10/
```

### PLOTTING THE RESULTS

Finally we can generate pretty pictures, and run the plotting code

```bash
python3 plot_poison.py
```

which should give (something like) the following output


![Log-log ROC Curve for all attacks](fprtpr.png "Log-log ROC Curve")

```
Attack No poison (LiRA)
AUC 0.7025, Accuracy 0.6258, TPR@0.1%FPR of 0.0544
Attack No poison (Global threshold)
AUC 0.6191, Accuracy 0.6173, TPR@0.1%FPR of 0.0012
Attack With poison (LiRA)
AUC 0.9943, Accuracy 0.9653, TPR@0.1%FPR of 0.4945
Attack With poison (Global threshold)
AUC 0.9922, Accuracy 0.9603, TPR@0.1%FPR of 0.3930
```

where the baselines are LiRA and a simple global threshold on the
membership scores, both without poisoning.
With poisoning, both LiRA and the global threshold attack are boosted
significantly. Note that because we only train a few models, we use
the fixed variance variant of LiRA.

### Citation

You can cite this paper with

```
@article{tramer2022truth,
title={Truth Serum: Poisoning Machine Learning Models to Reveal Their Secrets},
author={Tramer, Florian and Shokri, Reza and San Joaquin, Ayrton and Le, Hoang and Jagielski, Matthew and Hong, Sanghyun and Carlini, Nicholas},
journal={arXiv preprint arXiv:2204.00032},
year={2022}
}
```
Binary file added research/mi_poison_2022/fprtpr.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
113 changes: 113 additions & 0 deletions research/mi_poison_2022/plot_poison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import numpy as np
import matplotlib.pyplot as plt
import functools

import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

# from mi_lira_2021
from plot import sweep, load_data, generate_ours, generate_global


def do_plot_all(fn, keep, scores, legend='', metric='auc', sweep_fn=sweep, **plot_kwargs):
"""
Generate the ROC curves by using one model as test model and the rest to train,
with a full leave-one-out cross-validation.
"""

all_predictions = []
all_answers = []
for i in range(0, len(keep)):
mask = np.zeros(len(keep), dtype=bool)
mask[i:i+1] = True
prediction, answers = fn(keep[~mask],
scores[~mask],
keep[mask],
scores[mask])
all_predictions.extend(prediction)
all_answers.extend(answers)

fpr, tpr, auc, acc = sweep_fn(np.array(all_predictions),
np.array(all_answers, dtype=bool))

low = tpr[np.where(fpr < .001)[0][-1]]
print('Attack %s AUC %.4f, Accuracy %.4f, TPR@0.1%%FPR of %.4f'%(legend, auc, acc, low))

metric_text = ''
if metric == 'auc':
metric_text = 'auc=%.3f' % auc
elif metric == 'acc':
metric_text = 'acc=%.3f' % acc

plt.plot(fpr, tpr, label=legend+metric_text, **plot_kwargs)
return acc, auc


def fig_fpr_tpr(poison_mask, scores, keep):

plt.figure(figsize=(4, 3))

# evaluate LiRA on the points that were not targeted by poisoning
do_plot_all(functools.partial(generate_ours, fix_variance=True),
keep[:, ~poison_mask], scores[:, ~poison_mask],
"No poison (LiRA)\n",
metric='auc',
)

# evaluate the global-threshold attack on the points that were not targeted by poisoning
do_plot_all(generate_global,
keep[:, ~poison_mask], scores[:, ~poison_mask],
"No poison (Global threshold)\n",
metric='auc', ls="--", c=plt.gca().lines[-1].get_color()
)

# evaluate LiRA on the points that were targeted by poisoning
do_plot_all(functools.partial(generate_ours, fix_variance=True),
keep[:, poison_mask], scores[:, poison_mask],
"With poison (LiRA)\n",
metric='auc',
)

# evaluate the global-threshold attack on the points that were targeted by poisoning
do_plot_all(generate_global,
keep[:, poison_mask], scores[:, poison_mask],
"With poison (Global threshold)\n",
metric='auc', ls="--", c=plt.gca().lines[-1].get_color()
)

plt.semilogx()
plt.semilogy()
plt.xlim(1e-3, 1)
plt.ylim(1e-3, 1)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.plot([0, 1], [0, 1], ls='--', color='gray')
plt.subplots_adjust(bottom=.18, left=.18, top=.96, right=.96)
plt.legend(fontsize=8)
plt.savefig("/tmp/fprtpr.png")
plt.show()


if __name__ == '__main__':
logdir = "exp/cifar10/"
scores, keep = load_data(logdir)
poison_pos = np.load(os.path.join(logdir, "poison_pos.npy"))
poison_mask = np.zeros(scores.shape[1], dtype=bool)
poison_mask[poison_pos] = True
fig_fpr_tpr(poison_mask, scores, keep)
30 changes: 30 additions & 0 deletions research/mi_poison_2022/scripts/train_demo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15
32 changes: 32 additions & 0 deletions research/mi_poison_2022/scripts/train_demo_multigpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0 &
CUDA_VISIBLE_DEVICES='1' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1 &
CUDA_VISIBLE_DEVICES='2' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2 &
CUDA_VISIBLE_DEVICES='3' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3 &
CUDA_VISIBLE_DEVICES='4' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4 &
CUDA_VISIBLE_DEVICES='5' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5 &
CUDA_VISIBLE_DEVICES='6' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6 &
CUDA_VISIBLE_DEVICES='7' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7 &
wait;
CUDA_VISIBLE_DEVICES='0' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8 &
CUDA_VISIBLE_DEVICES='1' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9 &
CUDA_VISIBLE_DEVICES='2' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10 &
CUDA_VISIBLE_DEVICES='3' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11 &
CUDA_VISIBLE_DEVICES='4' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12 &
CUDA_VISIBLE_DEVICES='5' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13 &
CUDA_VISIBLE_DEVICES='6' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14 &
CUDA_VISIBLE_DEVICES='7' python3 -u train_poison.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15 &
wait;
Loading