Skip to content

Commit

Permalink
COPYBARA_INTEGRATE_REVIEW=#234 from ftramer:truth_serum fe44a07
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 447573314
  • Loading branch information
tensorflower-gardener authored and schien1729 committed May 9, 2022
1 parent 137f795 commit 97eec1a
Show file tree
Hide file tree
Showing 11 changed files with 581 additions and 70 deletions.
39 changes: 17 additions & 22 deletions research/mi_lira_2021/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

This directory contains code to reproduce our paper:

**"Membership Inference Attacks From First Principles"**
https://arxiv.org/abs/2112.03570
by Nicholas Carlini, Steve Chien, Milad Nasr, Shuang Song, Andreas Terzis, and Florian Tramer.

**"Membership Inference Attacks From First Principles"** <br>
https://arxiv.org/abs/2112.03570 <br>
by Nicholas Carlini, Steve Chien, Milad Nasr, Shuang Song, Andreas Terzis, and Florian Tramèr.

### INSTALLING

Expand All @@ -18,21 +17,20 @@ with JAX + ObJAX so you will need to follow build instructions for that
https://github.com/google/objax
https://objax.readthedocs.io/en/latest/installation_setup.html


### RUNNING THE CODE

#### 1. Train the models

The first step in our attack is to train shadow models. As a baseline
that should give most of the gains in our attack, you should start by
training 16 shadow models with the command
The first step in our attack is to train shadow models. As a baseline that
should give most of the gains in our attack, you should start by training 16
shadow models with the command

> bash scripts/train_demo.sh
or if you have multiple GPUs on your machine and want to train these models
in parallel, then modify and run
or if you have multiple GPUs on your machine and want to train these models in
parallel, then modify and run

> bash scripts/train_demo_multigpu.sh
> bash scripts/train_demo_multigpu.sh
This will train several CIFAR-10 wide ResNet models to ~91% accuracy each, and
will output a bunch of files under the directory exp/cifar10 with structure:
Expand Down Expand Up @@ -63,14 +61,13 @@ exp/cifar10/
--- 0000000100.npy
```

where this new file has shape (50000, 10) and stores the model's
output features for each example.

where this new file has shape (50000, 10) and stores the model's output features
for each example.

#### 3. Compute membership inference scores

Finally we take the output features and generate our logit-scaled membership inference
scores for each example for each model.
Finally we take the output features and generate our logit-scaled membership
inference scores for each example for each model.

> python3 score.py exp/cifar10/
Expand All @@ -85,7 +82,6 @@ exp/cifar10/

with shape (50000,) storing just our scores.


### PLOTTING THE RESULTS

Finally we can generate pretty pictures, and run the plotting code
Expand All @@ -94,7 +90,6 @@ Finally we can generate pretty pictures, and run the plotting code
which should give (something like) the following output


![Log-log ROC Curve for all attacks](fprtpr.png "Log-log ROC Curve")

```
Expand All @@ -111,9 +106,9 @@ Attack Global threshold
```

where the global threshold attack is the baseline, and our online,
online-with-fixed-variance, offline, and offline-with-fixed-variance
attack variants are the four other curves. Note that because we only
train a few models, the fixed variance variants perform best.
online-with-fixed-variance, offline, and offline-with-fixed-variance attack
variants are the four other curves. Note that because we only train a few
models, the fixed variance variants perform best.

### Citation

Expand All @@ -126,4 +121,4 @@ You can cite this paper with
journal={arXiv preprint arXiv:2112.03570},
year={2021}
}
```
```
52 changes: 20 additions & 32 deletions research/mi_lira_2021/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import os
from typing import Callable
import json
# pylint: skip-file
# pyformat: disable

import json
import os
import re
import jax
import jax.numpy as jn
import numpy as np
import tensorflow as tf # For data augmentation.
import tensorflow_datasets as tfds
from absl import app, flags
from tqdm import tqdm, trange
import pickle
from functools import partial

import numpy as np
import objax
from objax.jaxboard import SummaryWriter, Summary
from objax.util import EasyDict
from objax.zoo import convnet, wide_resnet

from dataset import DataSet
import tensorflow as tf # For data augmentation.
from absl import app
from absl import flags

from train import MemModule, network
from train import MemModule
from train import network

from collections import defaultdict
FLAGS = flags.FLAGS


Expand All @@ -56,7 +46,7 @@ def load(arch):
lr=.1,
batch=0,
epochs=0,
weight_decay=0)
weight_decay=0)

def cache_load(arch):
thing = []
Expand All @@ -68,8 +58,8 @@ def fn():

xs_all = np.load(os.path.join(FLAGS.logdir,"x_train.npy"))[:FLAGS.dataset_size]
ys_all = np.load(os.path.join(FLAGS.logdir,"y_train.npy"))[:FLAGS.dataset_size]


def get_loss(model, xbatch, ybatch, shift, reflect=True, stride=1):

outs = []
Expand All @@ -90,7 +80,7 @@ def get_loss(model, xbatch, ybatch, shift, reflect=True, stride=1):
def features(model, xbatch, ybatch):
return get_loss(model, xbatch, ybatch,
shift=0, reflect=True, stride=1)

for path in sorted(os.listdir(os.path.join(FLAGS.logdir))):
if re.search(FLAGS.regex, path) is None:
print("Skipping from regex")
Expand All @@ -99,9 +89,9 @@ def features(model, xbatch, ybatch):
hparams = json.load(open(os.path.join(FLAGS.logdir, path, "hparams.json")))
arch = hparams['arch']
model = cache_load(arch)()

logdir = os.path.join(FLAGS.logdir, path)

checkpoint = objax.io.Checkpoint(logdir, keep_ckpts=10, makedir=True)
max_epoch, last_ckpt = checkpoint.restore(model.vars())
if max_epoch == 0: continue
Expand All @@ -112,12 +102,12 @@ def features(model, xbatch, ybatch):
first = FLAGS.from_epoch
else:
first = max_epoch-1

for epoch in range(first,max_epoch+1):
if not os.path.exists(os.path.join(FLAGS.logdir, path, "ckpt", "%010d.npz"%epoch)):
# no checkpoint saved here
continue

if os.path.exists(os.path.join(FLAGS.logdir, path, "logits", "%010d.npy"%epoch)):
print("Skipping already generated file", epoch)
continue
Expand All @@ -127,7 +117,7 @@ def features(model, xbatch, ybatch):
except:
print("Fail to load", epoch)
continue

stats = []

for i in range(0,len(xs_all),N):
Expand All @@ -142,9 +132,7 @@ def features(model, xbatch, ybatch):
flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
flags.DEFINE_string('logdir', 'experiments/', 'Directory where to save checkpoints and tensorboard data.')
flags.DEFINE_string('regex', '.*experiment.*', 'keep files when matching')
flags.DEFINE_bool('random_labels', False, 'use random labels.')
flags.DEFINE_integer('dataset_size', 50000, 'size of dataset.')
flags.DEFINE_integer('from_epoch', None, 'which epoch to load from.')
flags.DEFINE_integer('seed_mod', None, 'keep mod seed.')
flags.DEFINE_integer('modulus', 8, 'modulus.')
app.run(main)

14 changes: 9 additions & 5 deletions research/mi_lira_2021/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: skip-file
# pyformat: disable

import os
import scipy.stats

Expand Down Expand Up @@ -113,7 +116,7 @@ def generate_ours_offline(keep, scores, check_keep, check_scores, in_size=100000
dat_out.append(scores[~keep[:, j], j, :])

out_size = min(min(map(len,dat_out)), out_size)

dat_out = np.array([x[:out_size] for x in dat_out])

mean_out = np.median(dat_out, 1)
Expand Down Expand Up @@ -160,7 +163,7 @@ def do_plot(fn, keep, scores, ntest, legend='', metric='auc', sweep_fn=sweep, **
fpr, tpr, auc, acc = sweep_fn(np.array(prediction), np.array(answers, dtype=bool))

low = tpr[np.where(fpr<.001)[0][-1]]

print('Attack %s AUC %.4f, Accuracy %.4f, TPR@0.1%%FPR of %.4f'%(legend, auc,acc, low))

metric_text = ''
Expand Down Expand Up @@ -206,7 +209,7 @@ def fig_fpr_tpr():
"Global threshold\n",
metric='auc'
)

plt.semilogx()
plt.semilogy()
plt.xlim(1e-5,1)
Expand All @@ -220,5 +223,6 @@ def fig_fpr_tpr():
plt.show()


load_data("exp/cifar10/")
fig_fpr_tpr()
if __name__ == '__main__':
load_data("exp/cifar10/")
fig_fpr_tpr()
25 changes: 14 additions & 11 deletions research/mi_lira_2021/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: skip-file
# pyformat: disable

import functools
import os
import shutil
Expand All @@ -24,12 +27,11 @@
import tensorflow as tf # For data augmentation.
import tensorflow_datasets as tfds
from absl import app, flags
from tqdm import tqdm, trange

import objax
from objax.jaxboard import SummaryWriter, Summary
from objax.util import EasyDict
from objax.zoo import convnet, wide_resnet, dnnet
from objax.zoo import convnet, wide_resnet

from dataset import DataSet

Expand Down Expand Up @@ -202,11 +204,11 @@ def get_data(seed):
data = tfds.as_numpy(tfds.load(name=FLAGS.dataset, batch_size=-1, data_dir=DATA_DIR))
inputs = data['train']['image']
labels = data['train']['label']

inputs = (inputs/127.5)-1
np.save(os.path.join(FLAGS.logdir, "x_train.npy"),inputs)
np.save(os.path.join(FLAGS.logdir, "y_train.npy"),labels)

nclass = np.max(labels)+1

np.random.seed(seed)
Expand All @@ -233,7 +235,7 @@ def get_data(seed):
aug = lambda x: augment(x, 0, mirror=False)
else:
raise

train = DataSet.from_arrays(xs, ys,
augment_fn=aug)
test = DataSet.from_tfds(tfds.load(name=FLAGS.dataset, split='test', data_dir=DATA_DIR), xs.shape[1:])
Expand All @@ -252,15 +254,15 @@ def main(argv):
import time
seed = np.random.randint(0, 1000000000)
seed ^= int(time.time())

args = EasyDict(arch=FLAGS.arch,
lr=FLAGS.lr,
batch=FLAGS.batch,
weight_decay=FLAGS.weight_decay,
augment=FLAGS.augment,
seed=seed)


if FLAGS.tunename:
logdir = '_'.join(sorted('%s=%s' % k for k in args.items()))
elif FLAGS.expid is not None:
Expand All @@ -269,7 +271,7 @@ def main(argv):
logdir = "experiment-"+str(seed)
logdir = os.path.join(FLAGS.logdir, logdir)

if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz"%10)):
if os.path.exists(os.path.join(logdir, "ckpt", "%010d.npz"%FLAGS.epochs)):
print(f"run {FLAGS.expid} already completed.")
return
else:
Expand All @@ -282,7 +284,7 @@ def main(argv):
os.makedirs(logdir)

train, test, xs, ys, keep, nclass = get_data(seed)

# Define the network and train_it
tm = MemModule(network(FLAGS.arch), nclass=nclass,
mnist=FLAGS.dataset == 'mnist',
Expand All @@ -303,8 +305,8 @@ def main(argv):

tm.train(FLAGS.epochs, len(xs), train, test, logdir,
save_steps=FLAGS.save_steps, patience=FLAGS.patience)



if __name__ == '__main__':
flags.DEFINE_string('arch', 'cnn32-3-mean', 'Model architecture.')
Expand All @@ -327,3 +329,4 @@ def main(argv):
flags.DEFINE_integer('patience', None, 'Early stopping after this many epochs without progress')
flags.DEFINE_bool('tunename', False, 'Use tune name?')
app.run(main)

Loading

0 comments on commit 97eec1a

Please sign in to comment.